# Machine Learning Interatomic Potentials with Metatensor

This notebook contains a short example of how one can train a very simple machine learning model for energy and forces, 
built on top of metatensor and metatomic. 

> **WARNING:**
>
> To ensure a reasonable run time for this tutorial, the dataset size, model complexity and training procedure have been reduced to their absolute minimum.
> If you want to train a model for your own research, you should make sure that the dataset size and model architecture are sufficient to get a stable and precise model.

In general, when training a machine learning model we need the following ingredients:

- a dataset, containing structures and their energy (and forces/virials or stresses). Here we
  will use a dataset of conformers of 2-Propen-1-ol, with the energy and forces
  computed using [DFTB](https://dftbplus.org/).
- a model architecture, which defines how our machine learning model makes its predictions. Here
  we'll use a perceptron based neural network trained on rotation invariants SOAP power
  spectrum. This should be very close to the first generation of
  Behler-Parrinello NNs.
- an optimizer and loss function, used inside the training loop to optimize the NN
  weights and ensure the model predictions match the DFTB calculations, at least at the training points.

In [None]:
from typing import Dict, List, Optional

import numpy as np
import matplotlib.pyplot as plt
import torch

torch.manual_seed(123456)

import ase.io  # read the dataset


In [None]:
# Read the data and extract energies and forces from ASE
frames = ase.io.read("../propenol_conformers_dftb.xyz", ":500")

energies = np.array([[f.info["dftb_energy_eV"]] for f in frames])
forces = np.vstack([f.arrays["dftb_forces_eV_per_Ang"] for f in frames])


## 2 - The machine learning model

In [None]:
from featomic.torch import SoapPowerSpectrum

from metatensor.torch import Labels, TensorBlock, TensorMap

import metatomic.torch as mta
import metatensor.torch as mts

from metatomic.torch import System

As is customary when using PyTorch, our model will be a class inheriting from
`torch.nn.Module`. In the `forward` function, we'll take
`metatomic.torch.System`, compute the SOAP power spectrum for all atoms in the
system and then send this representation through a neural network. This will
give us per-atom energies, that we will then sum together to get the overall
prediction.

$$
E = \sum_i NN(\langle \alpha_1 \alpha_2 n_1 n_2 l |
\rho_i^2 \rangle)
$$

The same NN will be used regardless of the central atom species (this will be the
first possible improvement of this model later!).

Featomic outputs the SOAP Power Spectrum in a maximally sparse format, where
each central species, species of the first neighbor $\alpha_1$, and species of
the second neighbor $\alpha_2$ are stored separately, minimizing the memory
usage and enabling varied treatments of different blocks. Here, we will just
treat the neighboring species as [one-hot encodings](https://en.wikipedia.org/wiki/One-hot#Machine_learning_and_statistics); and the central species as
samples with the same behavior.

In [None]:
class SOAPModel(torch.nn.Module):
    def __init__(self, soap_parameters, atomic_types, energy_offset):
        super().__init__()

        self.energy_offset = torch.tensor(energy_offset)
        self.atomic_types = atomic_types

        self.soap_calculator = SoapPowerSpectrum(**soap_parameters)
        self.neighbor_atom_types = Labels(
            ["neighbor_1_type", "neighbor_2_type"], 
            torch.tensor([(i, j) for i in atomic_types for j in atomic_types if i <= j])
        )

        # Number of features produced by the SOAP calculator,
        # i.e. size of the input of the NN
        n_soap = (
            (soap_parameters["basis"]["max_angular"] + 1)
            * (soap_parameters["basis"]["radial"]["max_radial"] + 1) ** 2
            * len(self.neighbor_atom_types)
        )

    
        # we are using utilities from `metatensor-learn` to define the NN in a metatensor-compatible way
        # https://docs.metatensor.org/latest/learn/reference/nn/index.html#metatensor.learn.nn.ModuleMap
        self.soap_nn = mts.learn.nn.ModuleMap(
            in_keys = Labels("_", torch.tensor([[0]])),
            modules = [torch.nn.Sequential(
                # Definition of our NN: one hidden layer,
                # SiLU activation, 128-sized latent space
                torch.nn.Linear(
                    in_features=n_soap, out_features=128, bias=False, dtype=torch.float64
                ),
                torch.nn.SiLU(),
                torch.nn.Linear(
                    in_features=128, out_features=128, bias=False, dtype=torch.float64
                ),
                torch.nn.SiLU(),
                torch.nn.Linear(
                    in_features=128, out_features=1, bias=True, dtype=torch.float64
                ),
            )]
        )

    def forward(
        self,
        systems: List[System],
        selected_atoms: Optional[Labels] = None,
    ) -> torch.Tensor:        
        soap = self.soap_calculator(systems, selected_samples=selected_atoms)
        soap = soap.keys_to_properties(self.neighbor_atom_types)
        soap = soap.keys_to_samples("center_type")

        energies_per_atom = self.soap_nn(soap)
        energy = mts.sum_over_samples(energies_per_atom, ["atom", "center_type"])
        energy = energy.block().values

        return energy + self.energy_offset


Let's create our model!

To simplify the task of the NN, we will enforce a constant energy offset
corresponding to some arbitrary energy baseline (here, the mean energy of the
training set). 

| ![TASK](../img/clipboard.png) | Go back to the class definition above, and add the energy offset to the prediction |
|-------------------------------|------------------------------------------------------------------------------------|

In [None]:
SOAP_PARAMETERS = {
    "cutoff": {
        "radius": 3.5,
        "smoothing": {
            "type": "ShiftedCosine",
            "width": 0.2
        }
    },
    "density": {
        "type": "Gaussian",
        "width": 0.3
    },
    "basis": {
        "type": "TensorProduct",
        "max_angular": 5,
        "radial": {
            "type": "Gto",
            "max_radial": 5
        }
    }
}

energy_offset = energies.mean()
model = SOAPModel(
    SOAP_PARAMETERS,
    atomic_types=[1, 6, 8],
    energy_offset=energy_offset,
)


first_energy = model(mta.systems_to_torch(frames[:1], dtype=torch.float64))
if torch.abs(first_energy + 290) > 10:
    raise Exception(
        f"energy of the first structure is {first_energy.item()}, should be around -290. "
        "Please modify the forward function above! Hint: you can use self.energy_offset"
    )


Now we can create the tools to train the model:

In [None]:
# Let's start with the inputs (systems) and expected outputs (reference_energies) of our model
systems = mta.systems_to_torch(frames, dtype=torch.float64)

reference_energies = torch.tensor(energies)

# We'll need a loss function to compare the predictions to the actual outputs of the model
# let's use the mean square error loss
mse_loss = torch.nn.MSELoss()

# the optimizer updates the weights of the model according to the gradients
# a learning rate of 0.003 allows to learn fast enough while preventing the model
# from jumping around in parameter space
optimizer = torch.optim.AdamW(model.parameters(), lr=0.003)
epoch = -1


We can now run the training phase! We might have to run the loop multiple times to
ensure we reach a high enough accuracy. As a starting point, we'll stop at a
loss around 0.03, but feel free to come back and try to get the loss even
lower!


| ![TASK](../img/clipboard.png) | Run the training loop until the loss is below 0.03 |
|----------------------------|----------------------------------------------------|

In [None]:
start = epoch + 1

for epoch in range(start, start + 100):
    optimizer.zero_grad()  # set all parameters gradients to zero

    predicted_energies = model(systems)  # run the model once

    loss = mse_loss(predicted_energies, reference_energies)  # compute a loss
    print(f"loss at epoch {epoch} is", loss.item())

    loss.backward()  # backward propagate from the loss, updating all parameters gradients
    optimizer.step()  # run one optimizer step, updating the parameters based on gradients


In [None]:
if loss.item() > 0.03:
    raise Exception(
        f"loss is still too high, please continue running the training loop"
    )


We can now check the energy prediction we are making against the reference
values.

In an actual research setting, you would also want to check the predictions your
model is making on a validation/hold-out set of structures, to prevent your model
from over-fitting to your training set.

In [None]:
predicted_energy = model(systems)

plt.scatter(energies, predicted_energy.detach().numpy())

x = [np.min(energies), np.max(energies)]
plt.plot(x, x, c="grey")

plt.title("energies")
plt.xlabel("reference / eV")
plt.ylabel("predicted / eV")
plt.show()


All the code in featomic, metatomic and metatensor is fully integrated with the torch
automatic differentiation framework, which allows us to compute the gradients of
any output with respect to any input. In particular, we can use this to also
predict the forces acting on the system:

In [None]:
# Convert to `metatomic.torch.System`, but now tracking
# gradients with respect to positions
systems_positions_grad = mta.systems_to_torch(frames, positions_requires_grad=True, dtype=torch.float64)

# make a new prediction
predicted_energy = model(systems_positions_grad)

# extract the gradient of the prediction with backward propagation
# using `torch.autograd.grad`
predicted_forces = torch.autograd.grad(
    outputs=predicted_energy,
    inputs=[s.positions for s in systems_positions_grad],
    grad_outputs=-torch.ones_like(predicted_energy),
    create_graph=False,
    retain_graph=False,
)
predicted_forces = torch.vstack(predicted_forces)

plt.scatter(forces.flatten(), predicted_forces.detach().numpy().flatten())

x = [np.min(forces.flatten()), np.max(forces.flatten())]
plt.plot(x, x, c="grey")

plt.title("forces")
plt.xlabel("reference / eV/Å")
plt.ylabel("predicted / eV/Å")
plt.show()


> **NOTE:**
> 
> The forces this model is producing are very bad, for many reasons: the dataset is very small, the model architecture is minimal, the model has not been trained on forces, only on energy. We will address some of these shortcomings in future tutorials.

## Exporting the model

Now that we have a reasonable model, let's export it! We'll need to define some
metadata about our model as well, so the MD engine knows which unit
conversions to make and what the model can do.

In [None]:
from metatomic.torch import (
    AtomisticModel, 
    System, 
    ModelOutput, 
    ModelMetadata, 
    ModelCapabilities, 
    ModelEvaluationOptions,
)


We'll need a class conforming to the `AtomisticModel` API (https://docs.metatensor.org/metatomic/latest/torch/reference/models/export.html#metatomic.torch.ModelInterface). In this API, the model receives as its input a single structure and a set of options,
including which outputs the engine needs. The model should then return these outputs in a dictionary of `TensorMap`.

In [None]:
class ExportWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        # model we are wrapping
        self.model = model

    def forward(
        self, 
        systems: List[System], 
        outputs: Dict[str, ModelOutput], 
        selected_atoms: Optional[Labels]
    ) -> Dict[str, TensorMap]:
        # check if the energy was even required
        if "energy" not in outputs:
            return {}
        
        # Run the model
        energy = self.model(systems, selected_atoms)

        # Return our prediction in a Dict[str, TensorMap]. Here there isn't much
        # metadata to attach to the output, but this will change if we are returning
        # per-atom energy, or more complex outputs (dipole moments, electronic density,
        # etc.)
        block = TensorBlock(
            values=energy.reshape(-1, 1),
            samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)),
            components=[],
            properties=Labels("energy", torch.tensor([[0]])),
        )
        
        return {
            "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block])
        }


In [None]:
# Wrap the model in our export wrapper
wrapper = ExportWrapper(model)
wrapper = wrapper.eval()


The final step before exporting our model is to define its capabilities: what
this model can compute; what are the expected inputs, etc.

In [None]:
# our model has a single output: the energy
energy_output = ModelOutput(
    quantity="energy",
    # energy is returned in eV
    unit="eV",
    # energy is returned globally, not per-atom
    per_atom=False,
)

# overall capabilities of the model
capabilities = ModelCapabilities(
    # expected unit for the positions and cell vectors
    length_unit="angstrom",
    # how far atoms are interacting together in this model
    interaction_range=SOAP_PARAMETERS["cutoff"]["radius"],
    # which atomic types can this model work with
    atomic_types=[1, 6, 8],
    # which torch devices this model can handle
    supported_devices=["cpu"],
    # which torch dtype is used for inputs and outputs
    dtype = "float64",
    # the outputs this model supports
    outputs={
        "energy": energy_output,
    },
)

metadata = ModelMetadata(
    name="A simple SOAP NN model",
    description="...",
    authors=["John Doe"],
    references={
        "implementation": [],
        "architecture": [],
        "model": [],
    }
)


| ![TASK](../img/clipboard.png) | Define the atomic types this model can handle in the capabilities above, and add your name in the authors list |
|----------------------------|-------------------------------------------------------------------------|

In [None]:
if len(capabilities.atomic_types) == 0:
    raise Exception("missing atomic types in the capabilities")

if len(metadata.authors) == 0 or (len(metadata.authors) == 1 and metadata.authors[0] == "you"):
    raise Exception("please add your name to the authors list")

Finally, we can export our model and it's capabilities as a new
`MetatensorAtomisticModule`, which will run a couple of checks on the model and
handle all the units conversions.

In [None]:
metatensor_model = AtomisticModel(wrapper, metadata, capabilities)
metatensor_model.save("propenol-model.pt", collect_extensions="extensions")


Let's now go to the next notebook, and run some Molecular Dynamics with our model!