<a href="https://colab.research.google.com/github/mdi-group/mace-field-tutorial/blob/main/MACE_Field_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MACE Field Tutorial




In this tutorial, we dive into the alterations we have made to the MACE model to incorporate an external perturbing electric field into the MACE architecture, and how to use it to derive derivative properties such as the macroscopic polarisation, Born Effective Charges (BECS) and polarisability.

To learn how the base MACE code works, we highly recommend you look at the [MACE theory tutorial](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU) developed by Will Baldwin and Ilyes Batatia.

## Installs

In [1]:
!rm -rf /usr/local/lib/python3.11/site-packages/mace
!git clone https://github.com/mdi-group/mace-field-tutorial.git
!git clone https://github.com/mdi-group/mace-field.git
%cd mace-field
!git switch field
!pip install .
!pip install --force-reinstall numpy==2.0
%cd ../mace-field-tutorial

Cloning into 'mace-field-tutorial'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (50/50), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 50 (delta 21), reused 8 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (50/50), 2.27 MiB | 5.88 MiB/s, done.
Resolving deltas: 100% (21/21), done.
Cloning into 'mace-field'...
remote: Enumerating objects: 6184, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 6184 (delta 38), reused 29 (delta 27), pack-reused 6126 (from 3)[K
Receiving objects: 100% (6184/6184), 123.65 MiB | 13.74 MiB/s, done.
Resolving deltas: 100% (4643/4643), done.
Updating files: 100% (98/98), done.
/content/mace-field
Branch 'field' set up to track remote branch 'field' from 'origin'.
Switched to a new branch 'field'
Processing /content/mace-field
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build whee

In [2]:
import numpy as np
import torch

# Add this line before importing e3nn
torch.serialization.add_safe_globals([slice])  # Allow the 'slice'

import torch.nn.functional
from e3nn import o3
from matplotlib import pyplot as plt

from ase.io import read, write

from mace import data, modules, tools
from mace.tools import torch_geometric
torch.set_default_dtype(torch.float64)

import warnings
warnings.filterwarnings("ignore")

from mace.cli.run_train import main as mace_run_train_main
import sys
import logging

In [3]:
from typing import Any, Dict, Optional
from mace.tools.scatter import scatter_sum
from mace.modules.blocks import (
    LinearReadoutBlock,
    ScaleShiftBlock,
)
from mace.modules.utils import (
    get_edge_vectors_and_lengths,
    get_symmetric_displacement,
)
from mace.modules.models import (
    MACE,
    ScaleShiftFieldMACE
)

## Write Config for the Model

In [11]:
training_atoms = read("data/ferroelectric/ferroelectric_train_2040.xyz", ":")[:1]
validation_atoms = read("data/ferroelectric/ferroelectric_valid.xyz", ":")[:1]
test_atoms = read("data/ferroelectric/ferroelectric_test.xyz", ":")[:1]
write("data/training_dataset.xyz", training_atoms)
write("data/validation_dataset.xyz", validation_atoms)
write("data/test_dataset.xyz", test_atoms)

In [8]:
%%writefile train_mace_field.yml

name: "mace_field"
train_file: "data/training_dataset.xyz"
test_file: "data/validation_dataset.xyz"
valid_file: "data/test_dataset.xyz"
E0s: "average"
loss: "universal_field"
energy_weight: 0.0
forces_weight: 0.0
stress_weight: 0.0
bec_weight: 0.0
polarisability_weight: 0.0
polarisation_weight: 1e7
compute_field: True
eval_interval: 1
error_table: "PerAtomRMSEstressvirialsfield"
model: "ScaleShiftFieldMACE"
interaction_first: "RealAgnosticResidualInteractionBlock"
interaction: "RealAgnosticResidualInteractionBlock"
num_interactions: 2
correlation: 3
r_max: 6.0
max_L: 1
max_ell: 3
num_channels: 128
num_radial_basis: 10
MLP_irreps: "16x0e"
num_workers: 1
lr: 0.05
weight_decay: 1e-8
batch_size: 1
valid_batch_size: 1
max_num_epochs: 50
device: cuda
seed: 1

Overwriting train_polarisation.yml


## Training a MACE Field model

In [9]:
def train_mace_field(config_file_path):
    logging.getLogger().handlers.clear()
    sys.argv = ["program", "--config", config_file_path]
    mace_run_train_main()

In [10]:
# train_mace_field("train_mace_field.yml")

2025-04-16 08:39:04.167 INFO: MACE version: 0.3.10
2025-04-16 08:39:04.168 INFO: CUDA version: 12.4, CUDA device: 0
2025-04-16 08:39:04.174 INFO: Using heads: ['default']
2025-04-16 08:39:04.181 INFO: Training set [1 configs, 1 energy, 252 forces] loaded from 'data/ferroelectric/ferroelectric_train_1.xyz'
2025-04-16 08:39:04.185 INFO: Validation set [1 configs, 1 energy, 132 forces] loaded from 'data/ferroelectric/ferroelectric_valid_1.xyz'
2025-04-16 08:39:04.189 INFO: Test set (1 configs) loaded from 'data/ferroelectric/ferroelectric_test_1.xyz':
2025-04-16 08:39:04.190 INFO: Default_default: 1 configs, 1 energy, 120 forces
2025-04-16 08:39:04.192 INFO: Total number of configurations: train=1, valid=1, tests=[Default_default: 1],
2025-04-16 08:39:04.193 INFO: Atomic Numbers used: [np.int64(3), np.int64(6), np.int64(8), np.int64(17), np.int64(27), np.int64(30), np.int64(37)]
2025-04-16 08:39:04.195 INFO: Isolated Atomic Energies (E0s) not in training file, using command line argument


ValueError: np.int64(11) is not in list

In [12]:
train_configs = [data.config_from_atoms(atoms) for atoms in training_atoms]
valid_configs = [data.config_from_atoms(atoms) for atoms in validation_atoms]
atomic_numbers = []
for config in train_configs + valid_configs:
    atomic_numbers.extend(config.atomic_numbers)
atomic_numbers = list(set(atomic_numbers))
z_table = tools.AtomicNumberTable(atomic_numbers)
print(training_atoms)
print(validation_atoms)
print(z_table.zs)

[Atoms(symbols='Cl48Rb24Zn12', pbc=True, cell=[28.258334, 13.013901, 7.437371], REF_forces=...)]
[Atoms(symbols='Li8Co4C8O24', pbc=True, cell=[[7.346871, 0.0, -2.533357], [0.0, 9.693216, 0.0], [0.022785, 0.0, 7.725388]], REF_forces=...)]
[np.int64(3), np.int64(37), np.int64(6), np.int64(8), np.int64(17), np.int64(27), np.int64(30)]


In [1]:
# setup some default prameters
atomic_energy = {1: -3.4811272190365488, 3: -3.405789842096201, 4: -7.067932456053734, 5: -7.944493721785815, 6: -8.746951185675684, 7: -8.018275502146556, 8: -7.397118831026134, 9: -5.181586043491356, 11: -2.660663237507169, 12: -4.07183344068331, 13: -7.148953698251626, 14: -7.384133054595276, 15: -6.1087097551269816, 16: -4.679135128814314, 17: -2.9455352531524985, 19: -2.2568741341295846, 20: -4.576199362313644, 21: -9.582149061762781, 22: -12.103761788494054, 23: -9.01724147475884, 24: -9.254443628597073, 25: -8.442821202900108, 26: -6.159394604543429, 27: -5.41537230497698, 28: -3.1758801181240184, 29: -3.715754869753084, 30: -2.056319785008477, 31: -3.7276941073804197, 32: -5.120537129598759, 33: -4.154546322256404, 34: -3.811932781194901, 35: -2.2423125363058842, 37: -1.7320761004586025, 38: -4.281880504148956, 39: -11.452259451769779, 40: -13.166111399736801, 41: -12.060371264882782, 42: -9.973108695355801, 43: -10.188550035773947, 44: -7.74238222587903, 45: -6.274156197571513, 46: -4.851553253453126, 47: -2.04226071160437, 48: -1.1108172299063257, 49: -3.401692732338145, 50: -4.056216830037828, 51: -3.76555136817488, 52: -3.0582319023512547, 53: -1.9890885706191899, 54: 11.273107214700449, 55: -2.5803250049696214, 56: -4.690975145187065, 57: -10.024900097916152, 58: -11.657174020320685, 59: -9.630573241828774, 60: -9.69824155747603, 62: -9.688781605275459, 64: -18.976866365297035, 65: -9.46140901088781, 66: -9.42149220145286, 67: -9.423979979661791, 68: -9.300407919532761, 69: -9.311797683028653, 70: -4.400815224095631, 71: -10.180974069695909, 72: -14.835461901890257, 73: -15.627692535310869, 74: -9.290101166555921, 75: -12.722728657224483, 76: -9.251678424540348, 77: -8.837881571893774, 78: -5.635244693115368, 79: -1.64027184618186, 80: 0.6989666481252783, 81: -2.432059660659048, 82: -3.649430152103191, 83: -3.5907381591957708, 90: -12.198230223339193, 92: -19.079775230792716}
atomic_energies = np.array([atomic_energy.get(z) for z in atomic_numbers], dtype=float)

model_config = dict(
        num_elements=len(z_table),  # number of chemical elements
        atomic_energies=atomic_energies,  # atomic energies used for normalisation
        atomic_numbers=z_table.zs,
        avg_num_neighbors=8,  # avg number of neighbours of the atoms, used for internal normalisation of messages
        r_max=6.0,  # cutoff
        num_bessel=10,  # number of radial features
        num_polynomial_cutoff=6,  # smoothness of the radial cutoff
        max_ell=3,  # expansion order of spherical harmonic adge attributes
        num_interactions=2,  # number of layers, typically 2
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],  # interation block of first layer
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],  # interaction block of subsequent layers
        hidden_irreps=o3.Irreps("16x0e+16x1o"),  # 8: number of embedding channels, 0e, 1o is specifying which equivariant messages to use. Here up to L_max=1
        correlation=3,  # correlation order of the messages (body order - 1)
        MLP_irreps=o3.Irreps("16x0e"),  # number of hidden dimensions of last layer readout MLP
        gate=torch.nn.functional.silu,  # nonlinearity used in last layer readout MLP
    )
model = ScaleShiftFieldMACE(**model_config, atomic_inter_scale=[1.0], atomic_inter_shift=[0.0])

NameError: name 'np' is not defined

In [None]:
%matplotlib inline
import pylab as pl
from IPython import display
import copy
import os

# Optimiser
param_options = dict(
        params=[
            {
                "name": "embedding",
                "params": model.node_embedding.parameters(),
                "weight_decay": 0,
            },
            {
                "name": "products",
                "params": model.products.parameters(),
                "weight_decay": 1e-7,
            },
            {
                "name": "readouts",
                "params": model.readouts.parameters(),
                "weight_decay": 0,
            },
        ],
        lr=0.005,
        amsgrad=True,
        betas=(0.9, 0.999),
    )
optimizer = torch.optim.AdamW(**param_options)

energy_train_trace = []
force_train_trace = []
polarisation_train_trace = []
bec_train_trace = []
polarisability_train_trace = []
loss_train_trace = []

energy_valid_trace = []
force_valid_trace = []
polarisation_valid_trace = []
bec_valid_trace = []
polarisability_valid_trace = []
loss_valid_trace = []

mse = torch.nn.MSELoss()

epoch = 0
best_epoch = 0
best_alpha = 0
best_loss = np.inf
best_model = copy.deepcopy(model)
max_num_epochs = 1000

# Load model
# model.load_state_dict(torch.load('water_bulk_all_64_epoch_41.model', weights_only=True), strict=False)

# Training algorithm
while epoch <= max_num_epochs:

    # Initialise training and validation arrays
    loss_train = []
    energies_train = []
    forces_train = []
    polarisations_train = []
    becs_train = []
    polarisabilities_train = []
    loss_valid = []
    energies_valid = []
    forces_valid = []
    polarisations_valid = []
    becs_valid = []
    polarisabilities_valid = []

    i=0

    # Validation
    for batch in validation_batches:
        for param in model.parameters():
            param.requires_grad = False
        output = model(batch.to_dict())
        loss = loss_fn(output, batch)
        loss_valid.append(loss.detach())
        energies_valid.append(torch.sqrt(mse(batch["energy"], output["energy"])))
        forces_valid.append(torch.sqrt(mse(batch["forces"], output["forces"])))
        polarisations_valid.append(torch.sqrt(mse(batch["polarisation"], output["polarisation"])))
        becs_valid.append(torch.sqrt(mse(batch["bec"], output["bec"])))
        polarisabilities_valid.append(torch.sqrt(mse(batch["polarisability"][0], output["polarisability"])))
        print(i, "| valid:", loss.detach().item())
        i+=1

    i=0

    # Training
    for batch in training_batches:
        for param in model.parameters():
            param.requires_grad = True
        optimizer.zero_grad(set_to_none=True)
        output = model(batch.to_dict())
        loss = loss_fn(output, batch)
        loss.backward()
        optimizer.step()
        loss_train.append(loss.detach())
        energies_train.append(torch.sqrt(mse(batch["energy"], output["energy"])))
        forces_train.append(torch.sqrt(mse(batch["forces"], output["forces"])))
        polarisations_train.append(torch.sqrt(mse(batch["polarisation"], output["polarisation"])))
        becs_train.append(torch.sqrt(mse(batch["bec"], output["bec"])))
        polarisabilities_train.append(torch.sqrt(mse(batch["polarisability"][0], output["polarisability"])))
        print(i, "| train:", loss.detach().item())
        i+=1

    # Log traces
    loss_train_trace.append(torch.mean(torch.tensor(loss_train)))
    energy_train_trace.append(torch.mean(torch.tensor(energies_train)))
    force_train_trace.append(torch.mean(torch.tensor(forces_train)))
    polarisation_train_trace.append(torch.mean(torch.tensor(polarisations_train)))
    bec_train_trace.append(torch.mean(torch.tensor(becs_train)))
    polarisability_train_trace.append(torch.mean(torch.tensor(polarisabilities_train)))

    loss_valid_trace.append(torch.mean(torch.tensor(loss_valid)))
    energy_valid_trace.append(torch.mean(torch.tensor(energies_valid)))
    force_valid_trace.append(torch.mean(torch.tensor(forces_valid)))
    polarisation_valid_trace.append(torch.mean(torch.tensor(polarisations_valid)))
    bec_valid_trace.append(torch.mean(torch.tensor(becs_valid)))
    polarisability_valid_trace.append(torch.mean(torch.tensor(polarisabilities_valid)))


    # Plot training and validation loss ratio traces and property RMSE traces

    pl.plot(torch.tensor(polarisability_train_trace), color='blue', linestyle='-', label="alp train")
    pl.plot(torch.tensor(polarisability_valid_trace), color='blue', linestyle='--', label = 'alp valid')
    pl.plot(torch.tensor(bec_train_trace), color='green', linestyle='-', label="bec train")
    pl.plot(torch.tensor(bec_valid_trace), color='green', linestyle='--', label = 'bec valid')
    pl.plot(torch.tensor(polarisation_train_trace), color='orange', linestyle='-', label="pol train")
    pl.plot(torch.tensor(polarisation_valid_trace), color='orange', linestyle='--', label = 'pol valid')
    pl.plot(torch.tensor(energy_train_trace), color='red', linestyle='-', label='E train')
    pl.plot(torch.tensor(energy_valid_trace), color='red', linestyle='dotted', label='E valid')
    pl.plot(torch.tensor(force_train_trace), color='purple', linestyle='-', label='F train')
    pl.plot(torch.tensor(force_valid_trace), color='purple', linestyle='dotted', label='F valid')
    pl.plot(torch.tensor(loss_train_trace), color='black', linestyle='-', label='train')
    pl.plot(torch.tensor(loss_valid_trace), color='black', linestyle='--', label='valid')

    pl.autoscale()
    pl.ylim(bottom=-1, top=None)

    if epoch == 0:
        pl.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        pl.xlabel('Epoch')
        pl.ylabel('RMSE Error & Loss Ratio')
        pl.axhline(y=1.0, color='gray', linestyle='dotted')
        pl.axhline(y=0.0, color='gray', linestyle='--')
        pl.yscale('symlog')

    display.clear_output(wait=True)
    display.display(pl.gcf())

    # Checkpoint lowest valid loss
    if loss_valid_trace[-1] == torch.min(torch.tensor(loss_valid_trace)):
        best_model = copy.deepcopy(model)
        # try:
        #     os.remove(f"../test_models/water_bulk_epoch_{best_epoch}.model")
        # except OSError:
        #     pass
        best_epoch = epoch
        best_loss = loss_valid_trace[-1].item()
        best_alpha = polarisability_valid_trace[-1].item()
        # torch.save(obj=best_model.state_dict(), f=f"../test_models/water_bulk_epoch_{best_epoch}.model")
        print("Checkpoint for epoch", epoch)

    print("Epoch:", epoch, "| valid:", loss_valid_trace[-1].item(), "| train:", loss_train_trace[-1].item(), "| best epoch:", best_epoch, "| best loss:", best_loss)

    epoch += 1

    # Patience
    # if epoch - best_epoch == 50:
    #     break

## Understanding the MACE Field Architecture

### Theory

Our approach inherits most of the original MACE architecture. The primary alteration is in the readout blocks where we include an additional energy term, $-\Omega\ \mathbf{P} \cdot \mathcal{E}$, where $\Omega$ is the unit-cell volume, $\mathbf{P}$ is the macroscopic polarisation, and $\mathcal{E}$ is an external electric field. This is all in analogy to the electric enthalpy functional from Density Functional Perturbation Theory (DFPT). Please see the [VASP wiki](https://www.vasp.at/wiki/index.php/Berry_phases_and_finite_electric_fields) for a good introduction to Berry Phases and finite electric fields in DFPT.

The molecular dipole moment is the sum of the ionic contributions, with bare ionic charge $-e Z_\alpha$ and position $\mathbf{R}_\alpha$, and an electronic contribution from the first moment of the electronic charge density $\rho(\mathbf{r})$:


\begin{equation}
\begin{aligned}
    \mathbf{p} &= \mathbf{p}_{\text{ion}} + \mathbf{p}_{\text{el}}\\
    \mathbf{p} &= -e \sum_{\alpha} Z_\alpha \mathbf{R}_\alpha + \int d\mathbf{r}\ \mathbf{r}\ \rho(\mathbf{r})
\end{aligned}
\end{equation}


where the polarisation density is then just this dipole moment divided by the total volume, $\mathbf{P} = \mathbf{p} / V$.


We see that the molecular dipole $\mathbf{p}_{\text{ion}, \alpha}$ has a contribution per ion. Just as we decompose the total energy of the system into contributions per ion or per "node", suppose we decompose the electronic dipole into contributions per node, $\mathbf{p}_{\text{el}, \alpha}$.

Due to the Modern Theory of Polarisation, quantum polarisation is multivalued / "ill-defined" for infinite periodic systems. The electronic dipole cannot in principle be decomposed this way, but just as is the case for the total energy, we will do it anyway. To compensate, the loss term for the polarisation will only be defined modulo the polarisation quantum to account for this multivalueness.

### Alterations to MACE

Each layer $1 \leq t \leq T$ of MACE contributes to the final energy readout $E_\alpha$ for node $\alpha$:

\begin{equation}
    E_\alpha = E_\alpha^{(0)} + E_\alpha^{(1)} + \dots + E_\alpha^{(T)}
\end{equation}

In a $T$-layer MACE, the readout is altered to include an additional perturbing term of a $K$-dimensional "total atomic dipole" feature $\mathbf{p}_{\alpha, k}$ for each node $\alpha$, dot-producted with the external electric field $\mathcal{E}$:

\begin{equation}
  E_\alpha(t) = \mathcal{R}_t\left(\mathbf{h}_i^{(t)}\right) =
    \begin{cases}
    \begin{aligned}
        &\sum_{\tilde{k}} W_{\text{readout}, \tilde{k}}^{(t)} \left[ h_{\alpha,\tilde{k} 0 0}^{e, (t)} - \mathbf{p}_{\alpha, \tilde{k}}^{(t)} \cdot \mathbf{\mathcal{E}} \right] \qquad\ \text{if}\ t<T, \\
        &\text{MLP}_{\text{readout}}^{(t)}\left( \left\{ h_{\alpha,k 0 0}^{e, (t)} - \mathbf{p}_{\alpha, k}^{(t)} \cdot \mathbf{\mathcal{E}} \right\}_{k} \right) \quad \text{if}\ t = T.
    \end{aligned}
    \end{cases}
\end{equation}

Where in the final layer the electric field enters the nonlinear MLP.

After the higher body-order `node_feats` are produced from the standard `product` blocks in the MACE model, we linearly map them to two scalar features and one vector feature which we may relate to the local node energy "$e$", charge "$q$" and electronic dipole moment:

\begin{equation}
    \left[ h_{\alpha, k 0 0}^{e, (t)},\ h_{\alpha, k 0 0}^{q, (t)},\ h_{\alpha, k 1 m}^{(t)} \right] = \sum_{l \tilde{m}} W^{l \tilde{m}}_{0 0, 1 m} h^{(t)}_{\alpha,kl\tilde{m}}.
\end{equation}

Note that these are not complete readouts yet as we have not yet mixed the k channels.

To extract these features, we need to initialise in `__init__()` a new readout:

```
field_irreps = o3.Irreps("0e") + o3.Irreps.spherical_harmonics(1)
field_irreps_out = o3.Irreps(f"{num_channels * field_irreps}").sort()[0].simplify()

self.field_readout = LinearReadoutBlock(hidden_irreps, field_irreps_out)
```

which gives us two $l=0$ scalars and a $l=1$ vector *without* mixing the channels.

After we obtain `node_feats` from the `product` block, we obtain these additional features from:

```
node_out = self.field_readout(node_feats, node_heads).reshape(node_feats.shape[0], -1, 5)
node_energy_feats, node_charge_feats, node_electronic_dipole_feats = node_out[:,:,0],  node_out[:,:,1], -node_out[:,:,2:]
```

We then define a total dipole moment feature as:

\begin{equation}
    \mathbf{p}^{(t)}_{\alpha,k} = h^{q,(t)}_{\alpha,k 0 0} \mathbf{R}_{\alpha} - \mathbf{h}^{(t)}_{\alpha,k 1},
\end{equation}

where the first term represents the ionic dipole contribution.

In the code we have:

```
node_atomic_dipole_feats = torch.einsum('ij,ik->ijk', node_charge_feats, data["positions"])
node_dipole_feats = node_atomic_dipole_feats + node_electronic_dipole_feats           
```

This "total dipole" acts as an atomic decomposition of the total macroscopic polarisation which, dot product with the external electric field, contributes to the total energy:

\begin{equation}
h_{\alpha,\tilde{k} 0 0}^{e, (t)} - \mathbf{p}_{\alpha, \tilde{k}}^{(t)} \cdot \mathbf{\mathcal{E}},
\end{equation}

which we then feed into the final energy `readout` block which mixes the $k$ channels:

```
node_energies = node_energy_feats - torch.einsum('ijk,k->ij', node_dipole_feats, data["electric_field"])
node_energies = readout(node_energies, node_heads)[num_atoms_arange, node_heads]        
```

Since we need the node features, `node_feats`, to always have a $l=1$ piece, we need to alter the final `interaction` and `product` blocks of MACE which originally only preserve the $l=0$ piece in the last layer $T$.

Therefore, in `__init__()` in our new model `ScaleShiftFieldMACE` we also include:
```
self.interactions[-1].skip_tp = o3.FullyConnectedTensorProduct(
    hidden_irreps,
    self.interactions[-1].node_attrs_irreps,
    hidden_irreps,
    self.interactions[-1].cueq_config,
)
```
and
```
self.products[-1] = self.products[-2]
```

Finally, as our new `field_readout` is an intermediate step between the `product` block and the energy `readout`, we need to alter the `irreps_in` of the energy `readout` blocks:
```
for i in range(len(self.readouts)-1):
    self.readouts[i].linear = o3.Linear(f"{num_channels}x0e", f"{len(self.heads)}x0e")

self.readouts[-1].linear_1 = o3.Linear(f"{num_channels}x0e", f"{len(self.heads) * kwargs['MLP_irreps']}")   
```

### Deriving Polarisation, BECs and Polarisability as derivates of energy

The predicted energy $E$ depends upon the atomic positions $\mathbf{R}_{\alpha}$ and the external electric field $\mathbf{\mathcal{E}}$ to arbitrary order,

\begin{equation}
    E(\{\mathbf{R}_\alpha\}, \mathbf{\mathcal{E}}) = E^{(0)}(\mathbf{\mathcal{E}}) + \sum_{\alpha=1}^N E_{\alpha}^{(0)}\left( \mathbf{R}_{\alpha};\  \mathbf{\mathcal{E}} \right) + \sum_{1 \leq \alpha \leq \beta \leq N}^N E_{\alpha, \beta}^{(1)}\left( \mathbf{R}_{\alpha}, \mathbf{R}_{\beta};\ \mathbf{\mathcal{E}} \right) + \dots + \mathcal{F}\left[m^{(T)}\left(\mathbf{R}_{\alpha_1}, \dots, \mathbf{R}_{\alpha_N};\  \mathbf{\mathcal{E}}\right)\right],
\end{equation}  

where $\mathcal{F}$ is a general, learnable non-linear term (here evaluated using a MLP) that accounts for excluded higher-order terms.

Therefore, we account for non-linear susceptibilities,

\begin{equation}
    \mathbf{P}(\{\mathbf{R}_\alpha\}, \mathbf{\mathcal{E}}) = -\frac{1}{\Omega} \mathbf{\nabla}_{\mathbf{\mathcal{E}}} E(\{\mathbf{R}_\alpha\}, \mathbf{\mathcal{E}}) = \mathbf{P}_0(\{\mathbf{R}_\alpha\}) + \chi^{(1)}(\{\mathbf{R}_\alpha\}) \cdot  \mathbf{\mathcal{E}} + \mathbf{\mathcal{E}} \cdot \chi^{(2)}(\{\mathbf{R}_\alpha\}) \cdot  \mathbf{\mathcal{E}} + \dots,
\end{equation}

where we see that the polarisability is the linear susceptibility term, $\chi^{(1)} \equiv \alpha$.

The polarisation ($P_i$) is then derived from the derivative of the total energy with respect to the external electric field:

\begin{equation}
  P_i = -\frac{1}{\Omega} \frac{\partial E}{\partial \mathcal{E}_i},
\end{equation}

which we compute using the forward `autograd.grad` function:

```
def compute_polarisation(
    energy: torch.Tensor,
    electric_field: torch.Tensor,
    cell: torch.Tensor,
    training: bool = True,
) -> torch.Tensor:
    
    volume = torch.linalg.det(cell.view(-1, 3, 3)).abs()
    
    polarisation = torch.autograd.grad(
        outputs=energy,  # [n_graphs, ]            
        inputs=electric_field,  # [3, ]
        grad_outputs=torch.ones_like(energy),
        retain_graph=training,  # Make sure the graph is not destroyed during training  
        create_graph=training,  # Create graph for higher derivatives
        allow_unused=True,
    )[0]

    return -polarisation / volume
```


The Born effective charges ($Z^*_{\alpha,ij}$) and the polarisability tensor ($\alpha_{ij}$) may then be derived from derivatives of the polarisation with respect to atom position and the electric field, respectively,

\begin{equation}
    Z^*_{\alpha, ij} = -\frac{1}{e}\frac{\partial^2 E}{\partial \mathcal{E}_i\ \partial R_{\alpha,j}}\bigg\rvert_{\mathbf{\mathcal{E}}=\mathbf{0}}
    = \frac{1}{e} \frac{\partial F_i}{\partial \mathcal{E}_j}\bigg\rvert_{\mathbf{\mathcal{E}}=\mathbf{0}} = \frac{\Omega}{e} \frac{\partial P_i}{\partial R_{\alpha, j}}\bigg\rvert_{\mathbf{\mathcal{E}}=\mathbf{0}},
\end{equation}

\begin{equation}
    \alpha_{ij} = - \frac{1}{\Omega} \frac{\partial^2 E}{\partial \mathcal{E}_i\ \partial \mathcal{E}_j} = \frac{\partial P_i}{\partial \mathcal{E}_j},
\end{equation}

which are computed by looping over the $3$ dimensions of the polarisation vector:

```
def compute_bec(
    polarisation: torch.Tensor,
    positions: torch.Tensor,
    cell: torch.Tensor,
    training: bool = True,
) -> torch.Tensor:
    
    volume = torch.linalg.det(cell.view(-1, 3, 3)).abs()
    
    bec_polar_list = []
    for d in range(3): # Loop over dimensions
        polar_component = polarisation[d]
        gradient = torch.autograd.grad(
            outputs=polar_component, # [n_graphs, 1]
            inputs=positions, # [n_nodes, 3]
            grad_outputs=torch.ones_like(polar_component),
            retain_graph=training,  # Make sure the graph is not destroyed during training
            create_graph=training,  # Create graph for higher derivatives
            allow_unused=True,
        )[0]
        bec_polar_list.append(gradient) # [n_nodes, 3]
        
    bec = torch.stack(bec_polar_list, dim=1) # [n_nodes, 3, 3]

    return bec * volume
```

```
def compute_polarisability(
    polarisation: torch.Tensor,
    electric_field: torch.Tensor,
    training: bool = True,
) -> torch.Tensor:
    
    # Second derivatives (BEC and polarisability) computed for each polarisation component.   
    polarisability_list = []
    for d in range(3):
        polar_component = polarisation[d]
        grad_field = torch.autograd.grad(
            outputs=polar_component, # [n_graphs, 1]
            inputs=electric_field, # [3, ]
            grad_outputs=torch.ones_like(polar_component),
            retain_graph=training,  # Make sure the graph is not destroyed during training
            create_graph=training,  # Create graph for higher derivatives
            allow_unused=True,
        )[0]
        polarisability_list.append(grad_field) # [n_graphs, 3]
        
    polarisability = torch.stack(polarisability_list, dim=1)  # [n_graphs, 3, 3]

    return polarisability
```

### Additional Loss Term

In addition to the original loss function for the energy, forces and stress, we include three additional terms to account for the polarisation, becs and polarisability:

\begin{equation}
\begin{aligned}
    \Delta \mathcal{L} &= \frac{\lambda_P}{3 B} \sum_{b=1}^{B} \sum_{i=1}^3 \left(\left[-\frac{1}{\Omega}\frac{\partial E_b^{(\text{pred})}}{\partial \mathcal{E}_i} - P^{(\text{ref})}_{b,i}\right] \text{mod}\ \Delta P_{b,i} \right)^2 + \frac{\lambda_Z}{9 B N} \sum_{\alpha=1}^{B \cdot N} \sum_{i=1}^3 \sum_{j=1}^3 \left(-\frac{1}{e}\frac{\partial^2 E^{(\text{pred})}}{\partial \mathcal{E}_i\ \partial R_{\alpha,j}} - Z^{* (\text{ref})}_{\alpha, ij} \right)^2 \\
    &+ \frac{\lambda_\alpha}{9 B} \sum_{b=1}^{B} \sum_{i=1}^3 \sum_{j=1}^3 \left(-\frac{1}{\Omega}\frac{\partial^2 E_b^{(\text{pred})}}{\partial \mathcal{E}_i\ \partial \mathcal{E}_j} - \alpha^{(\text{ref})}_{b,ij} \right)^2 .
\end{aligned}
\end{equation}

Here $B$ is the size of the batch and $b$ the batch index. $N$ is the number of atoms in the graph. The $\lambda$s are the weights which will be automatically set to zero if the relevant training data is absent. This means a mixed dataset can be used for training.

Here we write the MSE loss, but other losses can be used. In the main code we actually use the `torch.nn.HuberLoss` Huber loss with $\delta = 0.1$.

The modulo $\text{mod}\ \Delta P_{b,i}$ term tells the model that the polarisation may be multivalued due to the Berry Phase, where the $ \Delta \mathbf{P} $ term is the polarisation quantum:

\begin{equation}
   \Delta \mathbf{P} = \frac{e \mathbf{R}}{\Omega},
\end{equation}

where $\mathbf{R}$ is the lattice vector.

To implement this, we `torch.repeat` the reference and predicted polarisation vectors into a $3\times3$ matrices and compare their difference modulo the $3\times3$ `Cell` object from the `ASE` `Atom` type , weighted by the volume:

```
# Calculate the polarisation quantum
cell = ref["cell"].view(-1,3,3)
polarisation_quantum = cell / torch.linalg.det(cell).abs()

# modulo ignore zero components to leave pol unfolded and avoid divide by 0
polarisation_quantum[polarisation_quantum == 0] = max(torch.cat((ref["polarisation"], pred["polarisation"]))) + 1.0

# Expand polarisation to lattice (3x3 matrix) that is modulo the polarisation quantum. Any nan (due to divide by 0), set to zero.
ref_polarisation = ref["polarisation"].repeat(3,1).view(-1,3,3)
pred_polarisation = pred["polarisation"].repeat(3,1).view(-1,3,3)

polarisation_loss = self.huber_loss((ref_polarisation - pred_polarisation).fmod(polarisation_quantum).nan_to_num(nan=0) / num_atoms, torch.zeros_like(ref_polarisation))
        
```

The [`Cell` object](https://wiki.fysik.dtu.dk/ase/ase/cell.html) is simply the three lattice vectors forming a parallelepiped.

The weird manipulation of the polarisation quantum here is just to avoid dividing by zero in the modulo and to ignore any null directions where the lattice vector has zero entries. The difference modulo the `Cell` object is the compared to a $3\times3$ null matrix.