# 2: ``PyTorch`` Model Training

* Time to run the cells: ~ 

## Define training settings

In [None]:
import os
import numpy as np
import torch
import rholearn.io
from rholearn.features import lambda_soap_vector

root_dir = "/Users/joe.abbott/Documents/phd/code/qml/rho_learn/docs/example/azoswitch"
data_dir = os.path.join(root_dir, "data/")
run_dir = os.path.join(root_dir, "simulations/")
rholearn.io.check_or_create_dir(run_dir)

In [None]:
settings = {
    "io": {
        "data_dir": os.path.join(data_dir, "partitions/"),
        "run_dir": os.path.join(run_dir, "01_linear"),
        "coulomb": os.path.join(data_dir, "coulomb_matrices.npz"),
    },
    "data_partitions": {
        "n_exercises": 2,
        "n_subsets": 3,
    },
    "torch": {
        "requires_grad": True,
        "dtype": torch.float64,
        "device": torch.device("cpu"),
    },
    "model": {
        "type": "linear",
        "args": {
            # "hidden_layer_widths": [8, 8, 8],
            # "activation_fn": "SiLU"
        },
    },
    "optimizer": {
        "algorithm": torch.optim.LBFGS,
        "args": {
            "lr": 0.25,
        },
    },
    "loss": {
        "fn": "CoulombLoss",
        "args": {
            # "reduction": "sum",
        },
    },
    "training": {
        "n_epochs": 30,
        "save_interval": 15,
        "restart_epoch": None,
    },
}

Before going any further, it is important to set the torch default dtype. This
needs to be consistent throughout all operations otherwise things will break.
Typically, we want the precision of 64-bit floats, but the torch out-of-box
default is 32-bit, so we need to explicitly set it to 64-bit here.

In [None]:
# IMPORTANT!
torch.set_default_dtype(settings["torch"]["dtype"])

## Linear Model

In [None]:
import os
import rholearn.io
from rholearn.models import EquiModelGlobal

in_train = rholearn.io.load_tensormap_to_torch(
    os.path.join(settings["io"]["data_dir"], "in_train.npz"), **settings["torch"]
)
out_train = rholearn.io.load_tensormap_to_torch(
    os.path.join(settings["io"]["data_dir"], "out_train.npz"), **settings["torch"]
)

EquiModelGlobal?

In [None]:
linear_model = EquiModelGlobal(
    "linear",
    keys=in_train.keys,
    in_feature_labels={key: block.properties for key, block in in_train},
    out_feature_labels={key: block.properties for key, block in out_train},
)
list(linear_model.models.items())[:2]

## Nonlinear Model

![nonlinear_model_forward](../figures/nonlinear_architecture.png)

In addition to a linear model, a nonlinear model is also implemented in
``rholearn``. As in the global linear model, the global nonlinear model is a
collection of individual local models applied to each block. Each local model
makes a prediction on a given equivariant (i.e. either invariant or covariant)
block of the input TensorMap, indexed by a key.

The nonlinear local model architecture is shown in the figure below. Predictions
are made on an equivariant (blue) block, using its associated invariant to act
as a nonlinear multiplier. For instance, the equivariant block for Carbon,
$\lambda = 3$ is passed to the forward method along with the invariant ($\lambda
= 0$) block for Carbon. The equivariant is passed through a linear model and the
invariant through a neural network of arbitrary architecture. Then, element-wise
multiplication of the two blocks is performed before passing the result through
a final linear output layer to get the electron density prediction.

Applying nonlinear transformations to only the invariant ensure that
equivariance isn't broken. Upon element-wise multiplication, the component
vectors of the equivariant block are multiplied by a vector of constant size
(thanks to the h-stacking of the invariant, botoom-right of the figure), thus
retaining equivariance.

Performing such operations within a single custom PyTorch ``forward()`` method
allows the operations to be tracked, and therefore the gradientsto be
calculated. This means that model training involves optimization of the weights
of all weights and biases seen below - in the linear input layer applied to the
equivariant, all layers of the neural network applied to the invariant, and the
linear output layer applied to the mixed block.

Let's build a global model by hand and look at just 2 of the individual local
models. As the keys of the TensorMap are ``('spherical_harmonics_l',
'species_center')``, there exists one invariant block for each chemical species
(i.e. ``species_center``: H (1), C (6), N (7), O (8), S (16)). As explained
above, these invariants are used as nonlinear multiplier to the equivariant
blocks, so the size of their features need to be passed to the model in the
``in_invariant_features`` during initialization. In the figure above, these
values correspond to $q_{\text{in}}^{\text{inv}}$ (bottom left of the figure).

The model architecture can also be controlled. The ``activation_fn`` to use in
alternating layers between linear layers can be specified, choosing from "Tanh",
"GELU", or "SiLU". The length of the list arg ``hidden_layer_widths`` controls
the number of pairs of (nonlinear, linear) layers after the first input layer
(i.e. all the hidden layers), whilst the values in the list control the width of
them.

In the cell below, we are initializing the neural network to have 3 pairs of
layer, of widths 8, 8, and 16. Run the cell and look for the ``(invariant_nn)``
``Sequential`` layer, containing alternating Linear and SiLU functions. Note
also how for the ``EquiLocalModel`` for key 
``('spherical_harmonics_l', 'species_center')`` == ``(0, 1)`` a bias is used on
the input and output linear layers, but for local model ``(1, 1)`` a bias isn't
used. This is because for covariant blocks ($\lambda > 0$) covariance is broken
by applying a bias, but for invariants it isn't. A bias is applied **in the neural
network layers** for local models, however, as only the invariant blocks
supporting the forward method are passed through the NN.

In [None]:
in_invariant_features_by_species = {
    specie: len(in_train.block(spherical_harmonics_l=0, species_center=specie).properties)
    for specie in np.unique(in_train.keys["species_center"])
}
in_invariant_features = {
    key: in_invariant_features_by_species[key[in_train.keys.names.index("species_center")]]
    for key in in_train.keys
}
nonlinear_model = EquiModelGlobal(
    "nonlinear",
    keys=in_train.keys,
    in_feature_labels={key: block.properties for key, block in in_train},
    out_feature_labels={key: block.properties for key, block in out_train},
    in_invariant_features=in_invariant_features,
    hidden_layer_widths=[8, 8, 16],
    activation_fn="SiLU",
)

list(nonlinear_model.models.items())[:2]

## Checking the Equivariance Condition

Before going any further, it is paramount that we check that our structural
representations and machine learning models are equivariant.

In order for a structural representation to be equivariant, the irreducible
spherical components that comprise it must transform like spherical harmonics.

Spherical harmonics have known behviour under rotations, such that any spherical
component $\mu$ of order $\lambda$ transforms into new component $\mu'$ according to the
action of the Wigner D-Matrix of order $\lambda$, $D^{\lambda}_{\mu\mu'}$. Which
is constructed for a given arbitrary rotation matrix in Cartesian space.

In order to check that our $\lambda$-SOAP feature vector is equivariant, we can
run the following test:

1. Take an ``.xyz`` file of a given structure in the training set
2. Build an ASE frame of this structure
3. Generate a random rotation matrix
4. Rotate the structure using the random rotation matrix, storing the new
   structure in a new ASE frame
5. Generate a $\lambda$-SOAP representation for the unrotated and rotated
   structures
6. For each $\lambda$ channel of the representation of the unrotated structure,
   extract a selection of $(2\lambda + 1)$-sized irreducible spherical component (ISC)
   vectors.
7. Rotate each of these ISC vectors using the Wigner D-Matrix constucted, using
   the same random euler angles of rotation
8. Check for exact equivalence between the rotated ISC vectors of the unrotated
   structure and the corresponding ISC vectors of the rotated structure.

First, let's load a random structure from the training set, construct rotated
and unrotated ASE frames, and visualize them using chemiscope.

In [None]:
import ase.io
import numpy as np
import chemiscope
from rholearn import spherical

# Pick a random index between 0 and 10 as the test structure
structure_idx = np.random.randint(0, 10)

# Load the xyz file corresponding to this structure
with open(os.path.join(data_dir, "molecule_list.dat"), "r") as molecule_list:
    structure_xyz = molecule_list.read().splitlines()[structure_idx]

# Read xyz file into an ASE frame
unrotated = ase.io.read(os.path.join(data_dir, "xyz", structure_xyz))

# Generated a randomly rotated copy of the ASE frame
rotated, (alpha, beta, gamma) = spherical.rotate_ase_frame(unrotated)

# Visualize the frames using chemiscope
cs = chemiscope.show([unrotated, rotated], mode="structure")
display(cs)

Now generate $\lambda$-SOAP representations of the rotated and unrotated structures.

In [None]:
# Rascaline hypers
rascal_hypers = {
    "cutoff": 5.0,  # Angstrom
    "max_radial": 6,  # Exclusive
    "max_angular": 5,  # Inclusive
    "atomic_gaussian_width": 0.2,
    "radial_basis": {"Gto": {}},
    "cutoff_function": {"ShiftedCosine": {"width": 0.5}},
    "center_atom_weight": 1.0,
}

# Generate lambda-SOAP descriptors. We want to do this individually for the
# rotated and unrotated structures to keep the structure indices consistent
lsoap_unrotated = lambda_soap_vector(
    [unrotated], rascal_hypers, neighbor_species=[1, 6, 7, 8, 16]
)
lsoap_rotated = lambda_soap_vector(
    [rotated], rascal_hypers, neighbor_species=[1, 6, 7, 8, 16]
)

# Convert tensors to torch
lsoap_unrotated = utils.tensor_to_torch(lsoap_unrotated, **settings["torch"])
lsoap_rotated = utils.tensor_to_torch(lsoap_rotated, **settings["torch"])

# Perform the equivariance check - this returns a bool
is_equi = spherical.check_equivariance(
    lsoap_unrotated,
    lsoap_rotated,
    lmax=rascal_hypers["max_angular"],
    alpha=alpha,
    beta=beta,
    gamma=gamma,
    n_checks_per_block=5000,
)
if is_equi:
    print("Our lambda-SOAP is equivariant!")
else:
    print("Oops, our lambda-SOAP is not equivariant...")

We also need to check that our model is equivariant. When passing input tensors
through a model, certain tensor operations contained within the model
architecture (no matter how simple or complex) can break equivariance. It would
be a shame to go through the effort of generating an equivariant representation
if something as simple as applying a bias in our linear model breaks
equivariance!

In order for our model to be equivariant, it must satisfy the equivariant
condition: 

$\hat{R} y(A) = y(\hat{R} A)$

where $\hat{R}$ is an arbitrary rotation matrix of the SO(3) group, $A$ is a
trial structural representation, and $y(A)$ is the output property (i.e.
electron density) of the model, predicting on structure $A$.

In plainer words, the condition states (under the assumption that our structural
representation is equivariant) that our model is equivariant if the property
(electron density) we predict on an unrotated structure, subsequently rotated is
**exactly equivalent** to the property we get if predict on the rotated
structure, if the rotation matrix used in both is equal.

We can therefore construct a test in the following way:

1. 

Or, written in another way:

$\tilde{y}(A) = \hat{R}^{-1} \tilde{y}(\hat{R} A)$

In plainer words, the equivariance condition states that the output property of
the model must transform equivariantly with the input structure. 

If the equivariance condition does not hold, it means that an operation
somewhere in the workflow breaks equivariance. This could be either in the
generation of the structural representation or in the model prediction. Since we
just want to test our ML model for equivariance, we will take $A$ as the
$\lambda$-SOAP representation of our validation structure, but to check
equivariance of the whole workflow one could take $A$ as the starting ``.xyz``
file of the validation structure.


We can check this by doing the following:

1. Take the xyz file of the validation structure and rotate it with a
random rotation matrix, $\hat{R}$
2. Generate a $\lambda$-SOAP representation, $\hat{R} A$, of the rotated structure
3. Use the model to make a prediction on the $\lambda$-SOAP of the rotated
   structure, getting a predicted electron density, $\tilde{y}(\hat{R} A)$
4. Perform the inverse rotation on this electron density to get 
   $\hat{R}^{-1} \tilde{y}(\hat{R} A)$
5. Check that this quantity is exactly equivalent to the electron density
   prediction of the unrotated strutcure, $\tilde{y}(A)$




First, we need to do some cleaning (similar to the first notebook)and padding of the
TensorMaps to make the dimensions consistent - don't worry too much about this.

In [None]:
from azoswitch_utils import clean_azoswitch_lambda_soap

# Pad with empty blocks
lsoap_rotated = utils.pad_with_empty_blocks(
    clean_azoswitch_lambda_soap(lsoap_rotated), in_train
)
lsoap_unrotated = utils.pad_with_empty_blocks(
    clean_azoswitch_lambda_soap(lsoap_unrotated), in_train
)

# Check equivariance of linear model
out_pred_linear_unrot = linear_model(lsoap_unrotated)
out_pred_linear_rot = linear_model(lsoap_rotated)

out_pred_nonlin_unrot = nonlinear_model(lsoap_unrotated)
out_pred_nonlin_rot = nonlinear_model(lsoap_rotated)

Perform the equivariance check on both the linear and nonlinear model.

In [None]:
# Perform the equivariance check on the linear and nonlinear models
for i, (unrot, rot) in enumerate(
    [
        (out_pred_linear_unrot, out_pred_linear_rot),
        (out_pred_nonlin_unrot, out_pred_nonlin_rot),
    ]
):
    is_equi = spherical.check_equivariance(
        unrot,
        rot,
        lmax=rascal_hypers["max_angular"],
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        n_checks_per_block=None,  # None checks on all ISC vectors
    )
    if is_equi:
        print(f"Our {['linear', 'nonlinear'][i]} model is equivariant!")
    else:
        print(
            f"Oops, something in our {['linear', 'nonlinear'][i]} model is breaking equivariance..."
        )

Good stuff! Now that we've performed those checks, let's do some model training.

## Construct ``torch`` objects used in training

**ML Model**

PyTorch model training is based on torch tensor operations. In order to
interface with equistore and allow tracking of all the metadata useful in
atomistic ML, custom model classes have been built in ``rholearn`` to allow
predictions to be made on TensorMaps as a whole. The class ``EquiModelGlobal``
stores individual models for each input/output block in the data.

**Loss Function**

A function that calculates a difference metric between a predicted (or 'input')
and reference (or 'target') tensor. At the torch-equistore interface this is a
custom torch module that calculates this difference on the TensorMap level.

Currently implemented are the ``MSELoss`` (otherwise called L2 loss) and the
``CoulombLoss`` metrics. 

As detailed in the paper [__"Impact of quantum-chemical metrics on the machine
learning prediction of electron
density"__](https://aip.scitation.org/doi/10.1063/5.0055393), use of a
physically-inspired loss function such as the Coulomb repulsion metric can lead
to better model performance when predicted properties derived from the electron
density.

**Optimizer**

An algorithm, such as stochastic gradient descent (SGD) or LBFGS that performs
gradient descent on the loss landscape with respect to the model parameters.

In [None]:
from rholearn.pretraining import construct_torch_objects

construct_torch_objects(settings)

Inspect the directory structure in the ``simulations/`` folder - it mirrors the
nested directory structure of the ``data/`` folder, but contains only torhc
objects corresponding to the torch model ``model.pt``, the coulomb loss function
object ``loss_fn.pt`` and that of the test loss function ``loss_fn_test.pt``.

## Training a Linear Model

Though we have shuffled and partitioned our data ready for 3 learning exercises,
each with 4 training subsets, for computational brevity we will perform a single
learning exercise. As every training subsets (whether belonging to the same or different
learning exercise) are independent, model training can be performed separately
and in principle in parallel. Here we will perform subset training sequentially

In [None]:
# Define the range of exercises and subsets to learn on
exercises = range(settings["data_partitions"]["n_exercises"])
subsets = range(settings["data_partitions"]["n_subsets"])

As the test data is not dependent on the training subset, this can be loaded
first. The torch settings (i.e. requires_grad, device, dtype) from the settings
dict are used to load the TensorMaps to torch.

In [None]:
from rholearn.io import load_tensormap_to_torch

# Load the test data, which is independent of the training subdirectory
in_test = load_tensormap_to_torch(
    os.path.join(settings["io"]["data_dir"], "in_test.npz"), **settings["torch"]
)
out_test = load_tensormap_to_torch(
    os.path.join(settings["io"]["data_dir"], "out_test.npz"), **settings["torch"]
)

Now we iterate over the exercises and subsets and train models for each subset
sequentially. To train on the 4 subsets for 10 epochs each should take roughly

In [None]:
# Runtime for 30 epochs, 2 exercises, 3 subsets:
# linear ~ 35 min
# nonlinear ~
import time
import numpy as np

from rholearn.pretraining import load_training_objects
from rholearn.training import train

for exercise in exercises:
    for subset in subsets:

        # Start timer
        t0 = time.time()

        # Define the training subdirectory
        train_dir = os.path.join(
            settings["io"]["run_dir"], f"exercise_{exercise}", f"subset_{subset}"
        )

        # Load training data and torch objects
        in_train, out_train, model, loss_fn, optimizer = load_training_objects(
            settings, exercise, subset, settings["training"]["restart_epoch"]
        )

        print(f"\nTraining in subdirectory {train_dir}")

        # Execute model training
        train(
            in_train=in_train,
            out_train=out_train,
            in_test=in_test,
            out_test=out_test,
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            n_epochs=settings["training"]["n_epochs"],
            save_interval=settings["training"]["save_interval"],
            save_dir=train_dir,
            restart=settings["training"]["restart_epoch"],
        )

        # Report on timings
        dt = time.time() - t0
        num_epochs_run = (
            settings["training"]["n_epochs"]
            if settings["training"]["restart_epoch"] is None
            else settings["training"]["n_epochs"]
            - settings["training"]["restart_epoch"]
        )
        # TODO: write timing to log.txt
        print(
            f"\nTraining finished in {np.round(dt, 2)} s = {np.round(dt / num_epochs_run, 2)} s per epoch"
            + f"\n(Timed over {num_epochs_run} epochs, perhaps since restart)"
        )

## Training a Nonlinear Model

Now we can use the same notebook to train a nonlinear model, by changing only a
few lines of code. Make the following changes:

In the ``settings`` dict:

1. In the nested dict under key ``"io"``, change the run directory to "02_nonlinear":

    ``"run_dir": os.path.join(run_dir, "02_nonlinear"),``


2. Under key ``"model"``, change the type to "nonlinear" and uncomment the neural
  network args:
  
  
        ```
        "type": "linear",
            "args": {
                "hidden_layer_widths": [32, 32, 32],
                "activation_fn": "SiLU"
            },
        ```

Then run all the notebook cells again, in order.

After doing so, we can observe the model architecture for one of the blocks: