# Training the eMLP

In this tutorial, we will train a simple eMLP model on the first 10 molecules of the [eQM7 dataset](https://archive.materialscloud.org/record/2021.154). 

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
from glob import glob
import numpy as np
import tensorflow as tf

## Preprocessing data

[Extended xyz](https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#extxyz) files are the easiest file format to store and process data for use within the eMLP. In general, a single snapshot should look like this:

    10
    Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 energy=-1101.1872787241705 efield="-0.004991 0.000988 0.002404"
    C	0.998176	-0.002580	-0.004569	6	-0.013318	-0.000720	-0.000771
    H	2.142527	-0.002458	0.004496	1	-1.596604	0.002982	-0.005142
    H	0.616300	1.076179	0.004496	1	0.540086	-1.563694	-0.004782
    H	0.609230	-0.552938	0.920356	1	0.552016	0.805681	-1.339545
    H	0.624429	-0.531433	-0.948150	1	0.517872	0.755750	1.350292
    Es	0.998192	-0.002583	-0.004577	99	0.000000	0.000000	0.000000
    Es	0.753050	0.702150	-0.000837	99	0.000000	0.000000	0.000000
    Es	1.758252	-0.003409	-0.000813	99	0.000000	0.000000	0.000000
    Es	0.757369	-0.350852	-0.626559	99	0.000000	0.000000	0.000000
    Es	0.748588	-0.363237	0.597844	99	0.000000	0.000000	0.000000

This examples describes methane CH<sub>4</sub>, a molecule with 10 electrons or 5 electron pairs, which are stored as *Einsteinium* atoms with atomic number 99 and are just appended to the list of all the other atoms. Note that in this example, the core electron pairs are still included. They should be filtered out later. Electric field information is included via the comment line with the keyword `efield`.  In this tutorial, the data files have already been stored in a respective train and validation set in `data/train.xyz` and `data/validation.xyz`.

The first step to train the eMLP, is converting all the data into a Tensorflow Record file (tfr-file). This can be easily done with the `TFRWriter` class of the eMLP. Before we can do so, we need to tell the eMLP what kind of properties to look for in extended xyz files. The is done via the `list_of_properties` variable:

In [None]:
list_of_properties = ['positions', 'numbers', 'centers', 'energy', 'forces', 'efield'] 

The `'centers'` denote the electron pair positions. If the property `'center_forces'` is not given explicitly, they are assumed to be zero (as is the case in this example). Afterwards, we can convert both the training and validation set.

In [None]:
from emlp.datasets import TFRWriter

writer = TFRWriter('train.tfr', list_of_properties = list_of_properties, reference = 'pbe0_aug-cc-pvtz', filter_centers = True)
writer.write_from_xyz('data/train.xyz')
writer.close()

writer = TFRWriter('validation.tfr', list_of_properties = list_of_properties, reference = 'pbe0_aug-cc-pvtz', filter_centers = True)
writer.write_from_xyz('data/validation.xyz')
writer.close()

Note that eMLP will print out the amount of configurations being stored in the respective tfr files, together with the total number of atoms and some statistics of the energy. The total number of configurations should be remembered as this information is not included into the tfr-files itself (see below when constructing the data sets).

The TFRWriter class has several arguments:
- `filter_centers`: (boolean) If true, filter the core electron pairs
- `reference`: If a number, substract that constant value from all the energies. If a string, use the molecular fragments (H<sub>2</sub>, CH<sub>4</sub>, NH<sub>3</sub> and H<sub>2</sub>O) as a reference at the given level of theory, as described in the [eMLP paper](https://doi.org/10.1021/acs.jctc.1c00978), see eq. 32. So far, only `pbe0_aug-cc-pvtz` is supported.
- `per_atom_reference`: a number indicating the reference energy per atomic core.

## Configuring the training procedure

First, one has the choose a Tensorflow distributed training strategy (https://www.tensorflow.org/guide/distributed_training#types_of_strategies). This enables multi-GPU training. The `MirroredStrategy` is a good default, which is also valid for single GPU training.

In [None]:
strategy = tf.distribute.MirroredStrategy()

All the following code should be initialized within the scope of the same strategy.

In [None]:
from emlp.longrange import LongrangeCoulomb, LongrangeEwald
with strategy.scope():
    longrange_compute = LongrangeCoulomb(cutoff = 16.5, sigma = 1.2727922061357857)

The longrange_compute argument specifies the algorithm to compute the electrostatic interactions within the eMLP. There are two possible choices. For nonperiodic systems, one should use the `LongrangeCoulomb` class with arguments:
- `cutoff`: the maximum distance for which all the pairwise electrostatic interactions are being computed.
- `sigma`: the width of all the Gaussian charges of every particle. The default value of 1.272792 A is being used in the eMLP paper.

For periodic systems, the `LongrangeEwald` class can be used with the arguments:
- `real_space_cancellation`: (boolean) if True, the width of the screening charges take the same value as the width of the Gaussian charges of the particles. In this case, the direct space contribution of the Ewald summation vanishes which increases the computational efficiency. Therefore, we recommend to use this option.
- `sigma`: the width of all the Gaussian charges of every particle. The default value of 1.272792 A is being used in the eMLP paper.
- `gcut`: the cutoff in reciprocal space. Defaults to 0.25.
- `cutoff`: the cutoff in the direct space contribution. Defaults to 12 A. Is ignored when using the `real_space_cancellation` argument.
- `alpha`: the alpha parameter in the Ewald summation. Defaults to 4/15. Is ignored when using the `real_space_cancellation` argument.

In [None]:
from emlp.datasets import DataSet, DataAugmentation
with strategy.scope():
    train_data = DataSet(['train.tfr'], num_configs = 4000, cutoff = 4.0, longrange_compute = longrange_compute, batch_size = 64, float_type = 32, num_parallel_calls = 8, 
                          strategy = strategy, list_of_properties = list_of_properties, augment_data = DataAugmentation())
    validation_data = DataSet(['validation.tfr'], num_configs = 1000, cutoff = 4.0, longrange_compute = longrange_compute, batch_size = 64, float_type = 32, num_parallel_calls = 8, 
                              strategy = strategy, list_of_properties = list_of_properties, test = True)


Next, the training and validation data sets are initialized via the `DataSet` class starting from a list of tfr files. Here, it is important to specify the correct amount of configurations being stored in the data sets via the `num_configs` argument (see above when generating the tfr-files). Otherwise, the eMLP will not correctly count the number of epochs. The validation set should have the extra argument `test=True`. Other arguments may include:
- `cutoff`: The cutoff distance of the shortranged MLP part of the eMLP. Defaults to 4A. This value should not necessarily be equal to the cutoff radius of the longranged part.
- `longrange_compute`: the algorithm being used for the longrange electrostatic computation. If `None`, no longrange interactions will be computed.
- `batch_size`: the batch size being used while training.
- `test`: (boolean) Set True for the validation set, False for the training set.
- `num_parallel_calls`: The number of configurations to be preprocessed in parallel. A good default is the number of CPU cores available.
- `strategy`: The distributed strategy defined above.
- `list_of_properties`: The list of properties defined above.
- `augment_data`: The data augmentation instance. If None, no data augmentation is being applied.

If data augmentation is necessary, it can be included by creating a `DataAugmentation` instance, which has the following arguments:
- `delta_lower`: The minimum displacement of an electron pair. Default is 0.06 A. 
- `delta_upper`: The minimum displacement of an electron pair. Default is 0.12 A. 
- `k`: The *k* value in the eMLP paper (eq. 30), tuning the augmentation strength. Defaults to 2.
- `percentage`: The percentage of augmentated configurations in every batch. Defaults to 0.1.
- `cutoff`: The maximum distance over which a augmentated electron pair influences the forces of the other atoms. Defaults to 4A.
- `periodic`: (boolean) Set True, when working with periodic structures. Then, it applies a simple form of the minimum image convention. Defaults to False.

In [None]:
from emlp.schnet import SchNet
from emlp.reference import ConsistentFragmentsReference
with strategy.scope():
    model = SchNet(cutoff = 4., n_max = 32, num_layers = 4, start = 0.0, end = 4.0, num_filters = 128, num_features = 512, shared_W_interactions = False, float_type = 32, 
                   cutoff_transition_width = 0.5, reference = ConsistentFragmentsReference('pbe0_aug-cc-pvtz'), longrange_compute = longrange_compute)
    #model = SchNet.from_restore_file('model_dir/model_name_2.00', reference = ConsistentFragmentsReference('pbe0_aug-cc-pvtz'), longrange_compute = longrange_compute)

Next, the architecture of the MLP should specified. Here, we choose for the SchNet architecture but, in principle, any MLP is valid. One can tune the following arguments:
- `cutoff`: The cutoff of the shortranged MLP. Should take the same value as the cutoff specified in the `DataSet` class.
- `n_max`: The number of radial features.
- `num_features`: The number of features for every particle.
- `num_filters`: The number of filters. 
- `num_layers`: The number of layers or interaction blocks
- `start`: The distance of the first radial feature.
- `end`: The distance of the last radial feature.
- `cutoff_transition_width`: The transition width of the cutoff function
- `shared_W_interactions`: Whether or not the share the same filter functions for every interaction block. Defaults to False.
- `longrange_compute`: the algorithm being used for the longrange electrostatic computation. If None, no longrange interactions will be computed.
- `reference`: If a number, use a constant reference while training. Otherwise, a reference instance can be given here to compute the reference energy (see below).
- `xla`: (boolean). Whether or not to use [XLA](https://www.tensorflow.org/xla) when training the model. Defaults to False.

The reference instance can be one of the following:
- `ConstantReference(value = 0., per_atom = False)`: a constant reference energy (per atom when specified).
- `ConsistentFragmentsReference(label = 'pbe0_aug-cc-pvtz')`: molecular fragments are being used for the reference energy. See eq. 32 in the eMLP paper. Here, they are computed on the fly in every batch.
- `ConstantFragmentsReference(label = 'pbe0_aug-cc-pvtz')`: molecular fragments are being used for the reference energy. See eq. 32 in the eMLP paper. Here, they have a constant value.

When restarting from a pretrained model, or simply when resuming the training procedure, the MLP can be loaded as follows: 

    model = SchNet.from_restore_file('model_dir/model_name_2.00', reference = ConsistentFragmentsReference('pbe0_aug-cc-pvtz'), longrange_compute = longrange_compute) 
where `'model_dir/model_name_2.00'` is the location where the MLP is stored.

In [None]:
from emlp.learning_rate_manager import ExponentialDecayLearningRate
with strategy.scope():
    optimizer = tf.optimizers.Adam(3e-04)
    learning_rate_manager = ExponentialDecayLearningRate(initial_learning_rate = 3e-04, decay_rate = 0.5, decay_epochs = 300)

Here, the optimizer and learning rate schedular are being loaded. Any tensorflow optimizer can be used here (https://www.tensorflow.org/api_docs/python/tf/keras/optimizers ). Two popular learning rate schedulars can be used:
- `ExponentialDecayLearningRate(initial_learning_rate = 3e-04, decay_rate = 0.5, decay_epochs = 300)`: An exponentially decaying learning rate. The learning rate starts at the initial value of `initial_learning_rate` and decays every `decay_epochs` by a factor `decay_rate`.
- `ConstantDecayLearningRate(initial_learning_rate = 1e-04, decay_factor = 0.5, min_learning_rate = 1e-07, decay_patience = 25)`: Starting from an initial learning rate of `initial_learning_rate`, the learning rate decays with a factor of `decay_factor` when the validation losses have not decreased anymore after `decay_patience` epochs. The training stops when the learning rate has dropped below `min_learning_rate`.

In [None]:
from emlp.losses import MSE, MAE
with strategy.scope():
    losses = [MSE('energy', scale_factor = 1., per_atom = True), MSE('forces', scale_factor = 1.), MSE('center_forces', scale_factor = 1.)]
    validation_losses = [MAE('energy', per_atom = True), MAE('forces', scale_factor = 1.), MAE('center_forces', scale_factor = 1.)]

To define the training and validation losses, one can use mean squared errors (MSE) or mean absolute errors (MAE). They should be given as a list, where the `scale_factor` argument is the weight tuning the relative weights.

Finally, the `SaveHook` class specifies the save location and how frequently the model is stored.

In [None]:
from emlp.hooks import SaveHook
with strategy.scope():
    savehook = SaveHook(model, ckpt_name = 'model_dir/model_name', max_to_keep = 5, save_period = 1.0, history_period = 8.0,
                        npz_file = 'model_dir/model_name.npz')

The following arguments can be specified:
- `ckpt_name`: the location of where to save to model (checkpoint files).
- `max_to_keep`: How many saves are at most being stored. Older saves will always be deleted.
- `save_period`: The amount of epochs after which the validation set losses are being calculated and **if the current validation losses are the lowest**, the model is saved.
- `history_period`: The amount of eopchs after which the model is saved irrespective of the current validation losses. Can be switched off by setting the value to `None`.
- `npz_file`: All the losses, time per epoch and other information will be stored to produce graphs in an npz file at this location.

Finally, at the end, one can train the model.

In [None]:
from emlp.training import Trainer
with strategy.scope():
    trainer = Trainer(model, losses, train_data, validation_data, strategy = strategy, optimizer = optimizer, savehook = savehook, 
                      learning_rate_manager = learning_rate_manager, validation_losses = validation_losses)
    trainer.train(verbose = True, validate_first = False)