<a href="https://colab.research.google.com/github/mdi-group/mace-field-tutorial/blob/main/MACE_Field_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MACE Field Tutorial




In this tutorial, we dive into the alterations we have made to the MACE model to incorporate an external perturbing electric field into the MACE architecture, and how to use it to derive derivative properties such as the macroscopic polarisation, Born Effective Charges (BECS) and polarisability.

To learn how the base MACE code works, we highly recommend you look at the [MACE theory tutorial](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU) developed by Will Baldwin and Ilyes Batatia.

## Installs, Includes & Imports

In [None]:
!git clone https://github.com/mdi-group/mace-field-tutorial.git
!git clone https://github.com/mdi-group/mace-field.git
!pip install torch==2.0.0 torchvision torchaudio
!pip install ./mace-field
%cd mace-field
!git switch field
!pip uninstall -y numpy
!pip install numpy==1.26.4

Cloning into 'mace-field-tutorial'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 29 (delta 7), reused 8 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (29/29), 2.25 MiB | 4.11 MiB/s, done.
Resolving deltas: 100% (7/7), done.
Cloning into 'mace-field'...
remote: Enumerating objects: 6183, done.[K
remote: Counting objects: 100% (59/59), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 6183 (delta 39), reused 29 (delta 28), pack-reused 6124 (from 3)[K
Receiving objects: 100% (6183/6183), 123.65 MiB | 8.85 MiB/s, done.


In [None]:
import numpy as np
import torch
import torch.nn.functional
from e3nn import o3
from matplotlib import pyplot as plt
import ase.io

from mace import data, modules, tools
from mace.tools import torch_geometric
torch.set_default_dtype(torch.float64)

import warnings
warnings.filterwarnings("ignore")

In [None]:
from typing import Any, Dict, Optional
from mace.tools.scatter import scatter_sum
from mace.modules.blocks import (
    LinearReadoutBlock,
    ScaleShiftBlock,
)
from mace.modules.utils import (
    get_edge_vectors_and_lengths,
    get_symmetric_displacement,
)
from mace.modules.models import (
    MACE,
    ScaleShiftFieldMACE,
)

## Write Config for the Model

In [None]:
%%writefile train_polarisation.yml

name: "ferroelectrics"
train_file: "data/ferroelectric/ferroelectric_train_2040.xyz"
test_file: "data/ferroelectric/ferroelectric_test.xyz"
valid_file: "data/ferroelectric/ferroelectric_valid.xyz"
E0s: "average"
loss: "universal_field"
energy_weight: 0.0
forces_weight: 0.0
stress_weight: 0.0
bec_weight: 1e2
polarisability_weight: 0.0
polarisation_weight: 0.0
compute_field: True
eval_interval: 1
error_table: "PerAtomRMSEstressvirialsfield"
model: "ScaleShiftFieldMACE"
interaction_first: "RealAgnosticResidualInteractionBlock"
interaction: "RealAgnosticResidualInteractionBlock"
num_interactions: 2
correlation: 3
r_max: 6.0
max_L: 1
max_ell: 3
num_channels: 128
num_radial_basis: 10
MLP_irreps: "16x0e"
num_workers: 1
lr: 0.05
weight_decay: 1e-8
batch_size: 1
valid_batch_size: 1
max_num_epochs: 200
distributed: True
device: cuda
seed: 1

## Understanding the MACE Field Architecture

Our approach inherits most of the original MACE architecture. The primary alteration is in the readout blocks where we include an additional energy term, $-\Omega\ \mathbf{P} \cdot \mathcal{E}$, where $\Omega$ is the unit-cell volume, $\mathbf{P}$ is the macroscopic polarisation, and $\mathcal{E}$ is an external electric field. This is all in analogy to the electric enthalpy functional from Density Functional Perturbation Theory (DFPT). Please see the [VASP wiki](https://www.vasp.at/wiki/index.php/Berry_phases_and_finite_electric_fields) for a good introduction to Berry Phases and finite electric fields in DFPT.

The molecular dipole moment is the sum of the ionic contributions, with bare ionic charge $-e Z_\alpha$ and position $\mathbf{R}_\alpha$, and an electronic contribution from the first moment of the electronic charge density $\rho(\mathbf{r})$:


\begin{equation}
\begin{aligned}
    \mathbf{p} &= \mathbf{p}_{\text{ion}} + \mathbf{p}_{\text{el}}\\
    \mathbf{p} &= -e \sum_{\alpha} Z_\alpha \mathbf{R}_\alpha + \int d\mathbf{r}\ \mathbf{r}\ \rho(\mathbf{r})
\end{aligned}
\end{equation}


where the polarisation density is then just this dipole moment divided by the total volume, $\mathbf{P} = \mathbf{p} / V$.


We see that the molecular dipole $\mathbf{p}_{\text{ion}, \alpha}$ has a contribution per ion. Just as we decompose the total energy of the system into contributions per ion or per "node", suppose we decompose the electronic dipole into contributions per node, $\mathbf{p}_{\text{el}, \alpha}$.

Due to the Modern Theory of Polarisation, quantum polarisation is multivalued / "ill-defined" for infinite periodic systems. The electronic dipole cannot in principle be decomposed this way, but just as is the case for the total energy, we will do it anyway.

Each layer $1 \leq t \leq T$ of MACE contributes to the final energy readout $E_\alpha$ for node $\alpha$:

\begin{equation}
    E_\alpha = E_\alpha^{(0)} + E_\alpha^{(1)} + \dots + E_\alpha^{(T)}
\end{equation}

In a $T$-layer MACE, the readout is altered to include an additional perturbing term of a $K$-dimensional "total atomic dipole" feature $\mathbf{p}_{\alpha, k}$ for each node $\alpha$, dot-producted with the external electric field $\mathcal{E}$:

\begin{equation}
  E_\alpha(t) = \mathcal{R}_t\left(\mathbf{h}_i^{(t)}\right) =
    \begin{cases}
    \begin{aligned}
        &\sum_{\tilde{k}} W_{\text{readout}, \tilde{k}}^{(t)} \left[ h_{\alpha,\tilde{k} 0 0}^{e, (t)} - \mathbf{p}_{\alpha, \tilde{k}}^{(t)} \cdot \mathbf{\mathcal{E}} \right] \qquad\ \text{if}\ t<T, \\
        &\text{MLP}_{\text{readout}}^{(t)}\left( \left\{ h_{\alpha,k 0 0}^{e, (t)} - \mathbf{p}_{\alpha, k}^{(t)} \cdot \mathbf{\mathcal{E}} \right\}_{k} \right) \quad \text{if}\ t = T.
    \end{aligned}
    \end{cases}
\end{equation}

Where in the final layer the electric field enters the nonlinear MLP.

After the higher body-order `node_feats` are produced from the standard `product` blocks in the MACE model, we linearly map them to two scalar features and one vector feature which we may relate to the local node energy "$e$", charge "$q$" and electronic dipole moment:

\begin{equation}
    \left[ h_{\alpha, k 0 0}^{e, (t)},\ h_{\alpha, k 0 0}^{q, (t)},\ h_{\alpha, k 1 m}^{(t)} \right] = \sum_{l \tilde{m}} W^{l \tilde{m}}_{0 0, 1 m} h^{(t)}_{\alpha,kl\tilde{m}}.
\end{equation}

Note that these are not complete readouts yet as we have not yet mixed the k channels.

To extract these features, we need to initialise in `__init__()` a new readout:

```
field_irreps = o3.Irreps("0e") + o3.Irreps.spherical_harmonics(1)
field_irreps_out = o3.Irreps(f"{num_channels * field_irreps}").sort()[0].simplify()

self.field_readout = LinearReadoutBlock(hidden_irreps, field_irreps_out)
```

which gives us two $l=0$ scalars and a $l=1$ vector *without* mixing the channels.

We then define a total dipole moment feature as:

\begin{equation}
    \mathbf{p}^{(t)}_{\alpha,k} = h^{q,(t)}_{\alpha,k 0 0} \mathbf{R}_{\alpha} - \mathbf{h}^{(t)}_{\alpha,k 1},
\end{equation}

where the first term represents the ionic dipole contribution. This "total dipole" acts as an atomic decomposition of the total macroscopic polarisation and, dot product with the external electric field, contributes to the total energy.

Since we need the node features, `node_feats`, to always have a $l=1$ piece, we need to alter the final `interaction` and `product` blocks of MACE which originally only preserve the $l=0$ piece in the last layer $T$.

Therefore, in `__init__()` in our new model `ScaleShiftFieldMACE` we also include:
```
self.interactions[-1].skip_tp = o3.FullyConnectedTensorProduct(
    hidden_irreps,
    self.interactions[-1].node_attrs_irreps,
    hidden_irreps,
    self.interactions[-1].cueq_config,
)
```
and
```
self.products[-1] = self.products[-2]
```

Finally, as our new `field_readout` is an intermediate step between the `product` block and the energy `readout`, we need to alter the `irreps_in` of the energy `readout` blocks:
```
for i in range(len(self.readouts)-1):
    self.readouts[i].linear = o3.Linear(f"{num_channels}x0e", f"{len(self.heads)}x0e")

self.readouts[-1].linear_1 = o3.Linear(f"{num_channels}x0e", f"{len(self.heads) * kwargs['MLP_irreps']}")   
```