# Pytorch geometric algebra attention networks

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/klarh/geometric_algebra_attention/blob/master/examples/Basic%20structure%20tasks%20with%20pytorch.ipynb)

This notebook formulates some deep learning layers using geometric algebra attention mechanisms.


In [None]:
%%sh
# Colab-specific setup that will be ignored elsewhere
if [ ! -z "$COLAB_GPU" ]; then
    pip install plato-draw
    pip install git+https://github.com/klarh/geometric_algebra_attention
fi

## Example pytorch networks using these layers

In [None]:
import torch as pt
import geometric_algebra_attention.pytorch as gala

In [None]:
class GAANetClassifier(pt.nn.Module):
    """Create a classifier for neighborhoods.

    Stacks permutation-covariant layers, then adds a permutation-invariant layer.

    """
    def __init__(self, D_in, num_classes, D=32, depth=2, dilation=2., residual=True,
                 nonlinearities=True, merge_fun='mean', join_fun='mean', rank=2, dropout=0):
        super().__init__()

        self.D_in = D_in
        self.num_classes = num_classes
        self.D = D
        self.depth = depth
        self.dilation = dilation
        self.residual = residual
        self.nonlinearities = nonlinearities
        self.rank = rank
        self.dropout = dropout
        self.GAANet_kwargs = dict(merge_fun=merge_fun, join_fun=join_fun)

        self.up_project = pt.nn.Linear(D_in, self.D)
        self.class_project = pt.nn.Linear(self.D, self.num_classes)

        self.make_attention_nets()

        self.nonlin_mlps = []
        if self.nonlinearities:
            self.nonlin_mlps = pt.nn.ModuleList([self.make_value_net(self.D) for _ in range(self.depth + 1)])

    def make_attention_nets(self):
        D_in = lambda i: 1 if (i == self.depth and self.rank == 1) else 2
        self.score_nets = pt.nn.ModuleList([self.make_score_net() for _ in range(self.depth + 1)])
        self.value_nets = pt.nn.ModuleList([self.make_value_net(D_in(i)) for i in range(self.depth + 1)])

        att_nets = []
        for (scnet, vnet) in zip(self.score_nets, self.value_nets):
            reduce = scnet is self.score_nets[-1]
            rank = max(2, self.rank) if not reduce else self.rank
            att_nets.append(gala.VectorAttention(self.D, scnet, vnet, reduce=reduce, rank=rank, **self.GAANet_kwargs))
        self.att_nets = pt.nn.ModuleList(att_nets)

    def make_score_net(self):
        big_D = int(self.D*self.dilation)
        layers = [
            pt.nn.Linear(self.D, big_D),
        ]

        if self.dropout:
            layers.append(pt.nn.Dropout(self.dropout))

        layers.extend([
            pt.nn.ReLU(),
            pt.nn.Linear(big_D, 1),
        ])
        return pt.nn.Sequential(*layers)

    def make_value_net(self, D_in, D_out=None):
        D_out = D_out or self.D
        big_D = int(self.D*self.dilation)
        layers = [
            pt.nn.Linear(D_in, big_D),
            pt.nn.LayerNorm(big_D),
        ]
        if self.dropout:
            layers.append(pt.nn.Dropout(self.dropout))

        layers.extend([
            pt.nn.ReLU(),
            pt.nn.Linear(big_D, D_out),
        ])
        return pt.nn.Sequential(*layers)

    def forward(self, x):
        (r, v) = x

        last = self.up_project(v)

        for i, attnet in enumerate(self.att_nets):
            residual = last
            last = attnet((r, last))
            if self.nonlinearities:
                last = self.nonlin_mlps[i](last)
            if self.residual and attnet is not self.att_nets[-1]:
                last = last + residual

        last = self.class_project(last)
        last = pt.nn.Softmax(-1)(last)

        return last

class GAANetVectorRegressor(GAANetClassifier):
    """Learn a model to regress a (geometric) vector from inputs.

    Stacks permutation-invariant layers that manipulate the values stored
    at each vertex, then adds a permutation-invariant, rotation-covariant layer on top.

    """
    def make_attention_nets(self):
        D_in = lambda i: 1 if (i == self.depth and self.rank == 1) else 2
        self.score_nets = pt.nn.ModuleList([self.make_score_net() for _ in range(self.depth + 1)])
        self.value_nets = pt.nn.ModuleList([self.make_value_net(D_in(i)) for i in range(self.depth + 1)])

        self.final_scale_net = self.make_value_net(self.D, 1)

        att_nets = []
        for (scnet, vnet) in zip(self.score_nets, self.value_nets[:-1]):
            rank = max(self.rank, 2)
            att_nets.append(gala.VectorAttention(self.D, scnet, vnet, reduce=False, rank=rank, **self.GAANet_kwargs))
        att_nets.append(gala.Vector2VectorAttention(
            self.D, self.score_nets[-1], self.value_nets[-1], self.final_scale_net, reduce=True, rank=self.rank, **self.GAANet_kwargs))
        self.att_nets = pt.nn.ModuleList(att_nets)

    @property
    def class_project(self):
        return None

    @class_project.setter
    def class_project(self, value):
        pass

    def forward(self, x):
        (r, v) = x

        last = self.up_project(v)

        for i, attnet in enumerate(self.att_nets):
            residual = last
            last = attnet((r, last))
            if self.nonlinearities and attnet is not self.att_nets[-1]:
                last = self.nonlin_mlps[i](last)
            if self.residual and attnet is not self.att_nets[-1]:
                last = last + residual

        return last

# Example applications

In [None]:
import numpy as np
import itertools

# nearest neighbors for NaCl structure type (simple cubic grid with alternating particle types)
pos = np.array(list(itertools.product(*(3*[[-1, 0, 1]]))))
# remove center (its type would be 0)
pos = pos[np.any(pos != 0, axis=-1)]

# particle types
types = (np.sum(pos, axis=-1)%2).astype(np.int32)
max_types = 2

In [None]:
# This cell just performs a very simple visualization, don't worry if you don't have plato installed
import plato, plato.imp as imp

imp.spheres(positions=pos, colors=plato.cmap.cubehelix(.3 + .4*types.astype(np.float32)))
imp.show(backend='zdog', zoom=8, features=dict(ambient_light=.5))

## Task 1: Identify warped neighborhoods (rotation invariant)

This task randomly squishes neighborhoods in 3 directions and learns to distinguish between neighborhoods that just have random noise added to them.

In [None]:
def encode_types(ti, tj, max_types):
    ti_onehot = np.eye(max_types)[ti]
    tj_onehot = np.eye(max_types)[tj]
    return np.concatenate([ti_onehot - tj_onehot, ti_onehot + tj_onehot], axis=-1)

xs, ts, ys = [], [], []
for _ in range(512):
    x = pos.astype(np.float32)
    x += np.random.normal(scale=5e-2)
    ti = 0
    tj = types.copy()
    y = 0 # unsquished

    if np.random.rand() < .5:
        ti = 1 - ti
        tj = 1 - tj

    if np.random.rand() < .5:
        squish = np.random.normal(1, .25, 3)
        x*= squish
        y = 1

    xs.append(x)
    ts.append(encode_types([ti], tj, max_types))
    ys.append(y)

xs = np.array(xs)
ts = np.array(ts)
ys = np.array(ys)

xs = pt.tensor(xs, dtype=pt.get_default_dtype())
ts = pt.tensor(ts, dtype=pt.get_default_dtype())
ys = pt.tensor(ys)

In [None]:
batch_size = 32

model = GAANetClassifier(ts.shape[-1], 2, 16, depth=2, rank=1,
                         merge_fun='mean', join_fun='mean', dropout=0.5)
optim = pt.optim.Adam(model.parameters())
batches = np.arange(0, len(xs), batch_size)
train_batches, val_batches = batches[len(batches)//4:], batches[:len(batches)//4]

def loop(batches, train=True):
    np.random.shuffle(batches)
    if train:
        model.train()
    else:
        model.eval()

    stats = []
    for batch_start in batches:
        batch = slice(batch_start, batch_start + batch_size)
        pred = model((xs[batch], ts[batch]))
        loss = pt.nn.CrossEntropyLoss()(pred, ys[batch])

        if train:
            optim.zero_grad()
            loss.backward()
            optim.step()

        acc = pred.argmax(-1).eq(ys[batch]).double().mean(dim=0).item()
        stats.append((loss.detach().numpy(), acc))
    return stats

for epoch in range(32):
    stats = loop(train_batches)
    print('train loss, accuracy', *np.mean(stats, axis=0))
    stats = loop(val_batches, False)
    print('val loss, accuracy', *np.mean(stats, axis=0))


## Task 2: Fill in the blanks (rotation covariant)

This network learns to produce a vector which has been randomly deleted from the input set of neighboring points.

In [None]:
xs, ts, ys = [], [], []
for _ in range(512):
    x = pos.astype(np.float32)
    x += np.random.normal(scale=5e-2)
    ti = 0
    tj = types.copy()
    index_to_take = np.random.randint(0, pos.shape[0] - 1)

    if np.random.rand() < .5:
        ti = 1 - ti
        tj = 1 - tj

    y = x[index_to_take].copy()
    x = np.concatenate([x[:index_to_take], x[index_to_take + 1:]], axis=0)
    t = encode_types([ti], tj, max_types)
    t = np.concatenate([t[:index_to_take], t[index_to_take + 1:]], axis=0)

    xs.append(x)
    ts.append(t)
    ys.append(y)

xs = np.array(xs)
ts = np.array(ts)
ys = np.array(ys)

xs = pt.tensor(xs, dtype=pt.get_default_dtype())
ts = pt.tensor(ts, dtype=pt.get_default_dtype())
ys = pt.tensor(ys, dtype=pt.get_default_dtype())

In [None]:
batch_size = 32

model = GAANetVectorRegressor(
    ts.shape[-1], 2, 32, depth=2, rank=1,
    merge_fun='mean', join_fun='mean')
optim = pt.optim.Adam(model.parameters())
batches = np.arange(0, len(xs), batch_size)
train_batches, val_batches = batches[len(batches)//4:], batches[:len(batches)//4]

def loop(batches, train=True):
    np.random.shuffle(batches)
    if train:
        model.train()
    else:
        model.eval()

    stats = []
    for batch_start in batches:
        batch = slice(batch_start, batch_start + batch_size)
        pred = model((xs[batch], ts[batch]))
        loss = pt.nn.MSELoss()(pred, ys[batch])

        if train:
            optim.zero_grad()
            loss.backward()
            optim.step()

        stats.append((loss.detach().numpy()))
    return stats

for epoch in range(16):
    stats = loop(train_batches)
    print('train', np.mean(stats, axis=0))
    stats = loop(val_batches, False)
    print('val', np.mean(stats, axis=0))
