[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/klarh/geometric_algebra_attention/blob/master/examples/Molecular%20force%20regression%20using%20multivectors%20in%20pytorch.ipynb)

In [None]:
%%sh
# Colab-specific setup that will be ignored elsewhere
if [ ! -z "$COLAB_GPU" ]; then
    pip install flowws-keras-geometry keras-gtar
    pip install --force-reinstall git+https://github.com/klarh/flowws-keras-experimental
    pip install git+https://github.com/klarh/geometric_algebra_attention
fi

import torch as pt
print(pt.cuda.memory_summary())

In [None]:
import flowws
from flowws_keras_geometry.data import RMD17

In [None]:
import torch as pt
import geometric_algebra_attention.pytorch as gala

class MomentumNorm(pt.nn.Module):
    def __init__(self, n_dim, momentum=.99):
        super().__init__()
        self.n_dim = n_dim
        self.register_buffer('momentum', pt.as_tensor(momentum))
        self.register_buffer('mu', pt.zeros(n_dim))
        self.register_buffer('sigma', pt.ones(n_dim))

    def forward(self, x):
        axes = tuple(range(x.ndim - 1))
        mu_calc = pt.mean(x, axes, keepdim=False)
        sigma_calc = pt.std(x, axes, keepdim=False)

        new_mu = self.momentum*self.mu + (1 - self.momentum)*mu_calc
        new_sigma = self.momentum*self.sigma + (1 - self.momentum)*sigma_calc

        if self.training:
            self.mu[:] = new_mu.detach()
            self.sigma[:] = new_sigma.detach()

        sigma = pt.maximum(self.sigma, pt.as_tensor(1e-7))

        return (x - self.mu.detach())/sigma.detach()

class LayerNorm(pt.nn.Module):
    def forward(self, x):
        mu = pt.mean(x, -1, keepdim=True)
        sigmasq = pt.var(x, -1, keepdim=True)
        sigma = pt.sqrt(pt.maximum(sigmasq, pt.as_tensor(1e-7)))

        return (x - mu.detach())/sigma.detach()

class GalaPotential(pt.nn.Module):
    """Calculate a potential using geometric algebra attention

    Stacks permutation-covariant attention blocks, then adds a permutation-invariant reduction layer.

    """
    def __init__(self, D_in, D=32, depth=2, dilation=2., residual=True,
                 nonlinearities=True, merge_fun='mean', join_fun='mean',
                 invariant_mode='single', rank=2,
                 invar_value_normalization=None, value_normalization=None,
                 score_normalization=None, block_normalization=None):
        super().__init__()

        self.D_in = D_in
        self.D = D
        self.depth = depth
        self.dilation = dilation
        self.residual = residual
        self.nonlinearities = nonlinearities
        self.rank = rank
        self.invariant_mode = invariant_mode
        self.GAANet_kwargs = dict(merge_fun=merge_fun, join_fun=join_fun, invariant_mode=invariant_mode)

        self.invar_value_normalization = invar_value_normalization
        self.value_normalization = value_normalization
        self.score_normalization = score_normalization
        self.block_normalization = block_normalization

        self.vec2mv = gala.Vector2Multivector()
        self.up_project = pt.nn.Linear(2*D_in, self.D)
        self.final_mlp = self.make_value_net(self.D)
        self.energy_projection = pt.nn.Linear(self.D, 1, bias=False)

        self.make_attention_nets()

        self.nonlin_mlps = []
        if self.nonlinearities:
            self.nonlin_mlps = pt.nn.ModuleList([
                self.make_value_net(self.D, within_network=False) for _ in range(self.depth + 1)])

        self.block_norm_layers = pt.nn.ModuleList([])
        for _ in range(self.depth + 1):
            self.block_norm_layers.extend(self._get_normalization_layers(self.block_normalization, self.D))

    def make_attention_nets(self):
        D_in = lambda i: 1 if (i == self.depth and self.rank == 1) else 2
        self.score_nets = pt.nn.ModuleList([])
        self.value_nets = pt.nn.ModuleList([])
        self.scale_nets = pt.nn.ModuleList([])
        self.eqvar_att_nets = pt.nn.ModuleList([])
        self.invar_att_nets = pt.nn.ModuleList([])

        for i in range(self.depth + 1):
            reduce = i == self.depth
            rank = max(2, self.rank) if not reduce else self.rank

            # rotation-equivariant (multivector-producing) networks
            self.score_nets.append(self.make_score_net())
            self.value_nets.append(self.make_value_net(gala.Multivector2MultivectorAttention.get_invariant_dims(
                self.rank, self.invariant_mode)))
            self.scale_nets.append(self.make_score_net())
            self.eqvar_att_nets.append(gala.Multivector2MultivectorAttention(
                self.D, self.score_nets[-1], self.value_nets[-1], self.scale_nets[-1],
                reduce=False, rank=rank, **self.GAANet_kwargs))

            # rotation-invariant (node value-producing) networks
            self.score_nets.append(self.make_score_net())
            self.value_nets.append(self.make_value_net(gala.MultivectorAttention.get_invariant_dims(
                self.rank, self.invariant_mode)))
            self.invar_att_nets.append(gala.MultivectorAttention(
                self.D, self.score_nets[-1], self.value_nets[-1],
                reduce=reduce, rank=rank, **self.GAANet_kwargs))

    def _get_normalization_layers(self, norm, n_dim):
        if not norm:
            return []
        elif norm == 'momentum':
            return [MomentumNorm(n_dim)]
        elif norm == 'layer':
            return [LayerNorm()]
        else:
            raise NotImplementedError(norm)

    def make_score_net(self):
        big_D = int(self.D*self.dilation)
        layers = [
            pt.nn.Linear(self.D, big_D),
        ]

        layers.extend(self._get_normalization_layers(self.score_normalization, big_D))

        layers.extend([
            pt.nn.SiLU(),
            pt.nn.Linear(big_D, 1),
        ])
        return pt.nn.Sequential(*layers)

    def make_value_net(self, D_in, D_out=None, within_network=True):
        D_out = D_out or self.D
        big_D = int(self.D*self.dilation)
        layers = []

        if within_network:
            layers.extend(self._get_normalization_layers(self.invar_value_normalization, D_in))

        layers.append(pt.nn.Linear(D_in, big_D))

        layers.extend(self._get_normalization_layers(self.value_normalization, big_D))

        layers.extend([
            pt.nn.SiLU(),
            pt.nn.Linear(big_D, D_out),
        ])
        return pt.nn.Sequential(*layers)

    def forward(self, x):
        (r, v) = x
        r = pt.as_tensor(r)
        v = pt.as_tensor(v)

        neighbor_rij = r[..., None, :, :] - r[..., :, None, :]
        neighbor_rij = self.vec2mv(neighbor_rij)
        vplus = v[..., None, :, :] + v[..., :, None, :]
        vminus = v[..., None, :, :] - v[..., :, None, :]
        neighbor_vij = pt.cat([vplus, vminus], axis=-1)

        last_r = neighbor_rij
        last = self.up_project(neighbor_vij)

        for i in range(self.depth + 1):
            residual = last
            residual_r = last_r

            last_r = self.eqvar_att_nets[i]((last_r, last))
            last = self.invar_att_nets[i]((last_r, last))
            if self.nonlinearities:
                last = self.nonlin_mlps[i](last)

            if self.residual and i < self.depth:
                last = last + residual

            if self.block_norm_layers:
                last = self.block_norm_layers[i](last)

            if self.residual:
                last_r = last_r + residual_r
            last_r = last_r + neighbor_rij

        last = self.final_mlp(last)
        last = pt.sum(last, -2)
        last = self.energy_projection(last)

        return last

class GalaForce(pt.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.potential = GalaPotential(*args, **kwargs)

    def forward(self, x):
        (r, v) = x
        r = pt.as_tensor(r)
        v = pt.as_tensor(v)
        r.requires_grad = True

        potential = self.potential((r, v)).sum()
        result = pt.autograd.grad(potential, r, create_graph=True)
        assert len(result) == 1
        return result[0]


In [None]:
import flowws
from flowws import Argument as Arg

@flowws.add_stage_arguments
class MoleculeForceRegression(flowws.Stage):
    """Build a geometric attention network for the molecular force regression task.

    This module specifies the architecture of a network to calculate
    atomic forces given the coordinates and types of atoms in a
    molecule. Conservative forces are computed by calculating the
    gradient of a scalar.

    """

    ARGS = [
        Arg('rank', None, int, 2,
            help='Degree of correlations (n-vectors) to consider'),
        Arg('n_dim', '-n', int, 32,
            help='Working dimensionality of point representations'),
        Arg('dilation', None, float, 2,
            help='Working dimension dilation factor for MLP components'),
        Arg('merge_fun', '-m', str, 'concat',
            help='Method to merge point representations'),
        Arg('join_fun', '-j', str, 'concat',
            help='Method to join invariant and point representations'),
        Arg('n_blocks', '-b', int, 2,
            help='Number of deep blocks to use'),
        Arg('block_nonlinearity', None, bool, True,
            help='If True, add a nonlinearity to the end of each block'),
        Arg('residual', '-r', bool, True,
            help='If True, use residual connections within blocks'),
        Arg('invariant_mode', None, str, 'single',
           help='Attention invariant_mode to use'),
        Arg('invar_value_normalization', None, str, 'momentum',
           help='Normalization applied to rotation-invariant attributes'),
        Arg('value_normalization', None, str, 'momentum',
           help='Normalization applied to value function hidden layer'),
        Arg('score_normalization', None, str, 'momentum',
           help='Normalization applied to score function hidden layer'),
        Arg('block_normalization', None, str, 'momentum',
           help='Normalization applied to post-block values'),
    ]

    def run(self, scope, storage):
        D_in = scope['num_types']

        model = GalaForce(
            D_in, self.arguments['n_dim'], self.arguments['n_blocks'],
            self.arguments['dilation'], self.arguments['residual'],
            self.arguments['block_nonlinearity'], self.arguments['merge_fun'],
            self.arguments['join_fun'], self.arguments['invariant_mode'],
            self.arguments['rank'], self.arguments['invar_value_normalization'],
            self.arguments['value_normalization'], self.arguments['score_normalization'],
            self.arguments['block_normalization'],
        )

        scope['model'] = model


In [None]:
import flowws
from flowws import Argument as Arg
import numpy as np

def tensify(arrs):
    if isinstance(arrs, (list, tuple)):
        return tuple(tensify(arr) for arr in arrs)
    return pt.as_tensor(arrs.astype(np.float32))

OPTIMIZER_NAME_DICT = {
    name.lower(): obj for (name, obj) in vars(pt.optim).items()
    if type(obj) == type and issubclass(obj, pt.optim.Optimizer) and name != 'Optimizer'
}

@flowws.add_stage_arguments
class Train(flowws.Stage):
    """Train a model using pytorch.
    """

    ARGS = [
        Arg('batch_size', '-b', int, 32,
           help='Batch size to use for training'),
        Arg('optimizer', '-o', str, 'adam',
           help='Optimizer to use'),
        Arg('epochs', '-e', int, 16,
           help='Number of epochs to run'),
        Arg('seed', '-s', int, 13,
           help='Random seed to use for training'),
        Arg('accumulate_gradients', '-g', int, 1,
           help='Accumulate gradients over the given number of batches'),
        Arg('verbose', '-v', bool, False,
           help='Print more info during training'),
        Arg('gpu', None, bool, True,
           help='Try to use the GPU'),
        Arg('optimizer_kwargs', None, [(str, eval)],
           help='Keyword arguments to be passed to optimizer initialization')
    ]

    def run(self, scope, storage):
        train_data = list(scope['x_train']) + [scope['y_train']]
        val_data = list(scope['validation_data'][0]) + [scope['validation_data'][1]]

        batch_size = self.arguments['batch_size']
        train_data = pt.utils.data.DataLoader(
            pt.utils.data.TensorDataset(*tensify(train_data)), batch_size=batch_size, pin_memory=True)
        val_data = pt.utils.data.DataLoader(
            pt.utils.data.TensorDataset(*tensify(val_data)), batch_size=batch_size, pin_memory=True)

        if self.arguments['gpu'] and pt.cuda.is_available():
            device = pt.device('cuda')
        else:
            device = pt.device('cpu')

        model = scope['model'].to(device)

        Optimizer = OPTIMIZER_NAME_DICT[self.arguments['optimizer']]
        kwargs = dict(self.arguments.get('optimizer_kwargs', []))
        opt = Optimizer(model.parameters(), **kwargs)

        def loop(dataset, train=True):
            if train:
                model.train()
                opt.zero_grad()
            else:
                model.eval()

            stats = []
            for i, d in enumerate(dataset):
                (r, v, y) = [a.to(device) for a in d]
                x = (r, v)
                pred = model(x)
                loss = pt.nn.MSELoss()(pred, y)

                if train:
                    (loss/self.arguments['accumulate_gradients']).backward()
#                     loss.backward()

                if train and (i + 1)%self.arguments['accumulate_gradients'] == 0:
                    opt.step()
                    opt.zero_grad()

                scaled_mae = pt.mean(pt.abs(pred - y))*scope['y_scale']
                stats.append((loss.detach().item(), scaled_mae.detach().item()))

                if self.arguments['verbose']:
                    print('batch', i, stats[-1])
            return stats

        for epoch in range(self.arguments['epochs']):
            stats = loop(train_data)
            print('epoch', epoch, 'train stats', *np.mean(stats, axis=0))
            stats = loop(val_data, False)
            print('epoch', epoch, 'val stats', *np.mean(stats, axis=0))

In [None]:
w = flowws.Workflow(
    [
        RMD17(
            seed=13,
            cache_dir="/tmp",
            n_train=1000,
            n_val=1000,
            y_scale_reduction=4,
            x_scale_reduction=-1,
            molecules=[
                "malonaldehyde",
            ],
            no_keras=True,
        ),
        MoleculeForceRegression(
            n_dim=32,
            n_blocks=3,
            invariant_mode='single',
            merge_fun='mean',
            join_fun='mean',
            invar_value_normalization='momentum',
            value_normalization='layer',
            score_normalization='layer',
            block_normalization='layer',
            residual=True,
        ),
        Train(
            epochs=40,
            batch_size=4,
            accumulate_gradients=32,
        ),
    ],
    storage=flowws.DirectoryStorage("/tmp"),
)

scope = w.run()