[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/klarh/geometric_algebra_attention/blob/master/examples/Structure%20identification%20using%20keras.ipynb)

In [None]:
%%sh
# Colab-specific setup that will be ignored elsewhere
if [ ! -z "$COLAB_GPU" ]; then
    pip install flowws-keras-geometry flowws-keras-experimental pyriodic-aflow freud-analysis keras-gtar
    pip install git+https://github.com/klarh/geometric_algebra_attention
fi

In [None]:
# More colab-specific setup
import pyriodic
for (n,) in pyriodic.db.query('select count(name) from unit_cells'):
    if n == 0:
        msg = """The colab import machinery sometimes makes pyriodic not get
        initialized correctly. Restarting the runtime should make it work properly."""
        raise RuntimeError(msg)

In [None]:
import flowws
from flowws import Argument as Arg
import numpy as np
from tensorflow import keras

from geometric_algebra_attention.keras import MomentumNormalization, VectorAttention

@flowws.add_stage_arguments
class CrystalStructureClassification(flowws.Stage):
    """Build a geometric attention network for the structure identification task.

    This module specifies the architecture of a network to classify
    local environments of crystal structures in a rotation-invariant
    manner.

    """

    ARGS = [
        Arg('rank', None, int, 2,
            help='Degree of correlations (n-vectors) to consider'),
        Arg('n_dim', '-n', int, 32,
            help='Working dimensionality of point representations'),
        Arg('dilation', None, float, 2,
            help='Working dimension dilation factor for MLP components'),
        Arg('merge_fun', '-m', str, 'mean',
            help='Method to merge point representations'),
        Arg('join_fun', '-j', str, 'mean',
            help='Method to join invariant and point representations'),
        Arg('dropout', '-d', float, 0,
            help='Dropout rate to use, if any'),
        Arg('n_blocks', '-b', int, 2,
            help='Number of deep blocks to use'),
        Arg('block_nonlinearity', None, bool, True,
            help='If True, add a nonlinearity to the end of each block'),
        Arg('residual', '-r', bool, True,
            help='If True, use residual connections within blocks'),
        Arg('activation', '-a', str, 'relu',
            help='Activation function to use inside the network'),
        Arg('final_activation', None, str, 'relu',
            help='Final activation function to use within the network'),
    ]

    def run(self, scope, storage):
        n_dim = self.arguments['n_dim']
        dilation_dim = int(np.round(n_dim*self.arguments['dilation']))

        def make_scorefun():
            layers = [
                keras.layers.Dense(dilation_dim),
                keras.layers.Activation(self.arguments['activation'])
            ]

            if self.arguments.get('dropout', 0):
                layers.append(keras.layers.Dropout(self.arguments['dropout']))

            layers.append(keras.layers.Dense(1))
            return keras.models.Sequential(layers)

        def make_valuefun():
            layers = [
                MomentumNormalization(),
                keras.layers.Dense(dilation_dim),
                keras.layers.Activation(self.arguments['activation']),
            ]

            if self.arguments.get('dropout', 0):
                layers.append(keras.layers.Dropout(self.arguments['dropout']))

            layers.append(keras.layers.Dense(n_dim))
            return keras.models.Sequential(layers)

        def make_block(last):
            residual_in = last
            last = VectorAttention(
                make_scorefun(), make_valuefun(), False, rank=self.arguments['rank'],
                join_fun=self.arguments['join_fun'],
                merge_fun=self.arguments['merge_fun'])([x_in, last])

            if self.arguments['block_nonlinearity']:
                last = make_valuefun()(last)

            if self.arguments['residual']:
                last = last + residual_in

            return last

        (xs, ts) = scope['x_train']
        x_in = keras.layers.Input(xs[0].shape)
        v_in = keras.layers.Input(ts[0].shape)

        last = keras.layers.Dense(n_dim)(v_in)
        for _ in range(self.arguments['n_blocks']):
            last = make_block(last)

        (last, ivs, att) = VectorAttention(
            make_scorefun(), make_valuefun(), True, name='final_attention',
            rank=self.arguments['rank'], join_fun=self.arguments['join_fun'],
            merge_fun=self.arguments['merge_fun'])(
            [x_in, last], return_invariants=True, return_attention=True)
        last = keras.layers.Dense(dilation_dim, name='final_mlp')(last)
        if self.arguments.get('dropout', 0):
            last = keras.layers.Dropout(self.arguments['dropout'])(last)
        last = keras.layers.Activation(self.arguments['final_activation'])(last)
        last = keras.layers.Dense(scope['num_classes'], activation='softmax')(last)

        scope['input_symbol'] = [x_in, v_in]
        scope['output'] = last
        scope['loss'] = 'sparse_categorical_crossentropy'
        scope['attention_model'] = keras.models.Model([x_in, v_in], att)
        scope['invariant_model'] = keras.models.Model([x_in, v_in], ivs)
        scope.setdefault('metrics', []).append('accuracy')


In [None]:
import flowws
from flowws_keras_geometry.data import PyriodicDataset
from flowws_keras_experimental import InitializeTF, Train

w = flowws.Workflow(
    [
        InitializeTF(),
        PyriodicDataset(
            noise=[1e-3, 5e-2, 0.1],
            structures=[
                "hP2-Mg",
                "cI2-W",
                "cF4-Cu",
                "cF8-C",
                "cF8-SZn",
                "cP46-Si",
                "cF136-Si",
                "cP2-ClCs",
            ],
            size=2048,
            num_neighbors=12,
            test_fraction=0.2,
            seed=13,
        ),
        CrystalStructureClassification(),
        Train(epochs=32, validation_split=0.25, reduce_lr=8, early_stopping=20, disable_tqdm=True),
    ]
)

scope = w.run()

In [None]:
import matplotlib.pyplot as pp

hist = scope['log_quantities'][0][1]

pp.plot(hist['loss'], label='Train')
pp.plot(hist['val_loss'], label='Val')
pp.xlabel('Epoch'); pp.ylabel('Loss')
pp.legend()

pp.figure()
pp.plot(hist['accuracy'], label='Train')
pp.plot(hist['val_accuracy'], label='Val')
pp.xlabel('Epoch'); pp.ylabel('Accuracy')
pp.legend();