[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/klarh/geometric_algebra_attention/blob/master/examples/Molecular%20force%20regression%20using%20tied%20weights%20and%20multivectors%20in%20keras.ipynb)

In [None]:
%%sh
# Colab-specific setup that will be ignored elsewhere
if [ ! -z "$COLAB_GPU" ]; then
    pip install flowws-keras-geometry keras-gtar
    pip install --force-reinstall git+https://github.com/klarh/flowws-keras-experimental
    pip install git+https://github.com/klarh/geometric_algebra_attention
fi

In [None]:
import flowws
from flowws_keras_geometry.data import RMD17
from flowws_keras_experimental import InitializeTF, Train, Save
import geometric_algebra_attention.keras as gala

In [None]:
from flowws_keras_geometry.models.internal import GradientLayer, \
    NeighborhoodReduction, \
    PairwiseValueNormalization, PairwiseVectorDifference, \
    PairwiseVectorDifferenceSum

from geometric_algebra_attention import keras as gala
from geometric_algebra_attention.tensorflow.geometric_algebra import custom_norm

import flowws
from flowws import Argument as Arg
import numpy as np
import tensorflow as tf
from tensorflow import keras

LAMBDA_ACTIVATIONS = {
    'log1pswish': lambda x: tf.math.log1p(tf.nn.swish(x)),
    'sin': tf.sin,
    'leakyswish': lambda x: tf.nn.swish(x) - 1e-2*tf.nn.swish(-x)
}

NORMALIZATION_LAYERS = {
    'batch': lambda _: keras.layers.BatchNormalization(),
    'layer': lambda _: keras.layers.LayerNormalization(),
    'momentum': lambda _: gala.MomentumNormalization(.9),
    'momentum_layer': lambda _: gala.MomentumLayerNormalization(.9),
}

NORMALIZATION_LAYER_DOC = ' (any of {})'.format(','.join(NORMALIZATION_LAYERS))

class NoisifyMultivector(keras.layers.Layer):
    def __init__(self, scale=1e-7, **kwargs):
        self.scale = scale
        super().__init__(**kwargs)
        
    def call(self, x):
        sh = tf.shape(x)
        return x + tf.random.normal(sh[1:], stddev=self.scale)

    def get_config(self):
        result = super().get_config()
        result['scale'] = self.scale
        return result

@flowws.add_stage_arguments
class MoleculeForceRegression(flowws.Stage):
    """Build a geometric attention network for the molecular force regression task.

    This module specifies the architecture of a network to calculate
    atomic forces given the coordinates and types of atoms in a
    molecule. Conservative forces are computed by calculating the
    gradient of a scalar.

    """

    ARGS = [
        Arg('rank', None, int, 2,
            help='Degree of correlations (n-vectors) to consider'),
        Arg('n_dim', '-n', int, 32,
            help='Working dimensionality of point representations'),
        Arg('dilation', None, float, 2,
            help='Working dimension dilation factor for MLP components'),
        Arg('merge_fun', '-m', str, 'concat',
            help='Method to merge point representations'),
        Arg('join_fun', '-j', str, 'concat',
            help='Method to join invariant and point representations'),
        Arg('dropout', '-d', float, 0,
            help='Dropout rate to use, if any'),
        Arg('mlp_layers', None, int, 1,
            help='Number of hidden layers for score/value MLPs'),
        Arg('n_blocks', '-b', int, 2,
            help='Number of deep blocks to use'),
        Arg('block_nonlinearity', None, bool, True,
            help='If True, add a nonlinearity to the end of each block'),
        Arg('residual', '-r', bool, True,
            help='If True, use residual connections within blocks'),
        Arg('activation', '-a', str, 'swish',
            help='Activation function to use inside the network'),
        Arg('final_activation', None, str, 'swish',
            help='Final activation function to use within the network'),
        Arg('score_normalization', None, [str], [],
            help=('Normalizations to apply to score (attention) function' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('value_normalization', None, [str], [],
            help=('Normalizations to apply to value function' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('block_normalization', None, [str], [],
            help=('Normalizations to apply to the output of each attention block' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('invariant_value_normalization', None, [str], [],
            help=('Normalizations to apply to value function, before MLP layers' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('equivariant_value_normalization', None, [str], [],
            help=('Normalizations to apply to equivariant results' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('invariant_mode', None, str, 'single',
           help='Attention invariant_mode to use'),
        Arg('covariant_mode', None, str, 'single',
           help='Multivector2MultivectorAttention covariant_mode to use'),
        Arg('include_normalized_products', None, bool, False,
           help='Also include normalized geometric product terms'),
        Arg('normalize_equivariant_values', None, bool, False,
           help='If True, divide equivariant values by magnitude of inputs after each attention step'),
    ]

    def run(self, scope, storage):
        rank = self.arguments['rank']

        if self.arguments['activation'] in LAMBDA_ACTIVATIONS:
            activation_layer = lambda: keras.layers.Lambda(
                LAMBDA_ACTIVATIONS[self.arguments['activation']])
        else:
            activation_layer = lambda: keras.layers.Activation(
                self.arguments['activation'])

        if self.arguments['final_activation'] in LAMBDA_ACTIVATIONS:
            final_activation_layer = lambda: keras.layers.Lambda(
                LAMBDA_ACTIVATIONS[self.arguments['final_activation']])
        else:
            final_activation_layer = lambda: keras.layers.Activation(
                self.arguments['final_activation'])

        n_dim = self.arguments['n_dim']
        dilation_dim = int(np.round(n_dim*self.arguments['dilation']))
        
        def make_layer_inputs(x, v):
            nonnorm = (x, v)
            if self.arguments['normalize_equivariant_values']:
                xnorm = keras.layers.LayerNormalization()(x)
                norm = (xnorm, v)
                return [nonnorm] + (rank - 1) * [norm]
            else:
                return rank * [nonnorm]

        def make_scorefun():
            layers = []

            for _ in range(self.arguments['mlp_layers']):
                layers.append(keras.layers.Dense(dilation_dim))

                for name in self.arguments['score_normalization']:
                    layers.append(NORMALIZATION_LAYERS[name](rank))

                layers.append(activation_layer())

                if self.arguments.get('dropout', 0):
                    layers.append(keras.layers.Dropout(self.arguments['dropout']))

            layers.append(keras.layers.Dense(1))
            return keras.models.Sequential(layers)

        def make_valuefun(n_dim=n_dim, in_network=True, activation=None):
            layers = []

            if in_network:
                for name in self.arguments['invariant_value_normalization']:
                    layers.append(NORMALIZATION_LAYERS[name](rank))

            for _ in range(self.arguments['mlp_layers']):
                layers.append(keras.layers.Dense(dilation_dim))

                for name in self.arguments['value_normalization']:
                    layers.append(NORMALIZATION_LAYERS[name](rank))

                layers.append(activation_layer())

                if self.arguments.get('dropout', 0):
                    layers.append(keras.layers.Dropout(self.arguments['dropout']))

            if activation in LAMBDA_ACTIVATIONS:
                layers.append(keras.layers.Dense(n_dim))
                layers.append(keras.layers.Lambda(LAMBDA_ACTIVATIONS[activation]))
            else:
                layers.append(keras.layers.Dense(n_dim, activation=activation))
            return keras.models.Sequential(layers)

        def make_block(x_last, last):
            residual_in = last
            residual_in_x = x_last

            (x_last, last) = gala.TiedMultivectorAttention(
                make_scorefun(), make_valuefun(), make_valuefun(1), False, rank=rank,
                join_fun=self.arguments['join_fun'],
                merge_fun=self.arguments['merge_fun'],
                invariant_mode=self.arguments['invariant_mode'],
                covariant_mode=self.arguments['covariant_mode'],
                include_normalized_products=self.arguments['include_normalized_products'],
            )(make_layer_inputs(x_last, last))

            if self.arguments['block_nonlinearity']:
                last = make_valuefun(in_network=False)(last)

            if self.arguments['residual']:
                last = last + residual_in
                x_last = x_last + residual_in_x

            for name in self.arguments['equivariant_value_normalization']:
                x_last = NORMALIZATION_LAYERS[name](rank)(x_last)

            for name in self.arguments.get('block_normalization', []):
                last = NORMALIZATION_LAYERS[name](rank)(last)

            return x_last, last

        x_in = keras.layers.Input((scope['neighborhood_size'], 3))
        v_in = keras.layers.Input((scope['neighborhood_size'], scope['num_types']))

        delta_x = PairwiseVectorDifference()(x_in)
        delta_v = PairwiseVectorDifferenceSum()(v_in)

        delta_x = NoisifyMultivector(1e-7)(delta_x)

        x_last = gala.Vector2Multivector()(delta_x)

        last = keras.layers.Dense(n_dim)(delta_v)
        for _ in range(self.arguments['n_blocks']):
            x_last, last = make_block(x_last, last)

        (last, ivs, att) = gala.MultivectorAttention(
            make_scorefun(), make_valuefun(), True, name='final_attention',
            rank=rank,
            join_fun=self.arguments['join_fun'],
            merge_fun=self.arguments['merge_fun'],
            invariant_mode=self.arguments['invariant_mode'],
            include_normalized_products=self.arguments['include_normalized_products'],
        )(
            make_layer_inputs(x_last, last), return_invariants=True, return_attention=True)

        last = keras.layers.Dense(dilation_dim, name='final_mlp')(last)
        last = final_activation_layer()(last)
        last = keras.layers.Dense(1, name='energy_projection', use_bias=False)(last)
        last = NeighborhoodReduction()(last)
        last = GradientLayer()((last, x_in))

        scope['input_symbol'] = [x_in, v_in]
        scope['output'] = last
        scope['loss'] = 'mse'
        scope['attention_model'] = keras.models.Model([x_in, v_in], att)
        scope['invariant_model'] = keras.models.Model([x_in, v_in], ivs)

In [None]:
w = flowws.Workflow(
    [
        InitializeTF(),
        RMD17(
            seed=13,
            cache_dir="/tmp",
            n_train=1000,
            n_val=1000,
            y_scale_reduction=16,
            x_scale_reduction=16,
            molecules=[
                "benzene",
            ],
        ),
        MoleculeForceRegression(
            n_dim=32,
            n_blocks=3,
            invariant_mode='single',
            covariant_mode='single',
            activation='swish',
            merge_fun='mean',
            join_fun='mean',
            block_normalization=['layer'],
            score_normalization=['layer'],
            value_normalization=['layer'],
            invariant_value_normalization=['momentum'],
            equivariant_value_normalization=['momentum_layer'],
            residual=True,
            include_normalized_products=True,
            normalize_equivariant_values=True,
            mlp_layers=2,
            dropout=.1,
        ),
        Train(
            epochs=40,
            reduce_lr=25,
            early_stopping=70,
            batch_size=4,
            validation_split=0,
            reduce_lr_factor=0.8,
            early_stopping_best=1,
            accumulate_gradients=32,
            catch_keyboard_interrupt=True,
        ),
    ],
    storage=flowws.DirectoryStorage("/tmp"),
)

scope = w.run()

In [None]:
import matplotlib.pyplot as pp

y = np.concatenate([series[1]['mean_absolute_error'] for series in scope['log_quantities']])
pp.plot(y, label='train set')
y = np.concatenate([series[1]['val_mean_absolute_error'] for series in scope['log_quantities']])
pp.plot(y, label='val set')
pp.gca().set_yscale('log')
pp.xlabel('Epoch'); pp.ylabel('MAE')
pp.legend();