In [67]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Fine-tuning foundation models

This notebook reviews methods to fine-tune two large "foundation" models for neural network potentials: CHGNet and MACE-MP-0. The models are fine-tuned using data from a set of Crystals in the toy database. This approach relies on freezing a subset of the layers of the model and training the remaining layers on the new data. There are other fine-tuning strategies, such as training the entire model with a smaller learning rate, but this notebook focuses on the former.

In [68]:
import os
import shutil

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from datetime import datetime
from torch.utils.data import DataLoader
from mace.modules.models import ScaleShiftMACE
from sklearn.metrics import mean_absolute_error
from nff.train.transfer import freeze_parameters, unfreeze_readout
from nff.data import Dataset, split_train_validation_test, collate_dicts
from nff.train import Trainer, loss, hooks, metrics, evaluate
from nff.nn.models.mace import NFFMACEWrapper, DirectNffScaleMACEWrapper
from nff.nn.models.chgnet import CHGNetNFF

DEVICE = "cpu"
OUTDIR_MACE = "./sandbox/mace"
OUTDIR_CHGNET = "./sandbox/chgnet"
BATCH_SIZE = 1
UNITS = {
    "energy_grad": "kcal (mol $\mathdefault{\AA}$)$^{\mathdefault{-1}}$",
    "energy": "kcal mol$^{\mathdefault{-1}}$",
}

if os.path.exists(OUTDIR_MACE):
    shutil.rmtree(OUTDIR_MACE)
    os.mkdir(OUTDIR_MACE)

# The tutorial writer does not like the default font
# so we change it to Arial
mpl.font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
mpl.font_manager.findfont("Arial")
plt.rcParams["figure.dpi"] = 100
plt.rcParams["font.family"] = "Arial"

## MACE-MP-0

### Accessing the model

Let's start by loading the `MACE-MP-0` model and look at its architecture. The manuscript describing the architecture is available [here](https://proceedings.neurips.cc/paper_files/paper/2022/file/4a36c3c51af11ed9f34615b81edb5bbc-Paper-Conference.pdf). If you would like to dig into the details of the model, the appendices are available [here](https://proceedings.neurips.cc/paper_files/paper/2022/file/4a36c3c51af11ed9f34615b81edb5bbc-Supplemental-Conference.pdf).

The code in the cell below assumes that you have installed `mace` in your home directory and you have downloaded the `MACE-MP-0` model (you can find instructions on installing MACE and downloading its pre-trained models using the package's built-in methods in `htvs/chemconfigs/ase/README_MACE.md`).

In [69]:
user = os.environ.get("USER")
mace_location = (
    f"/home/{user}/mace/mace/calculators/foundations_models/2023-12-03-mace-mp.model"
)

# Load the MACE model with the NFF wrapper
mace_model = DirectNffScaleMACEWrapper.from_file(mace_location, map_location="cpu")

No dtype selected, switching to float64 to match model dtype.


Let's look at the architecture of the `MACE-MP-0` model:

In [70]:
mace_model.modules

<bound method Module.modules of DirectNffScaleMACEWrapper(
  (node_embedding): LinearNodeEmbeddingBlock(
    (linear): Linear(89x0e -> 128x0e | 11392 weights)
  )
  (radial_embedding): RadialEmbeddingBlock(
    (bessel_fn): BesselBasis(r_max=6.0, num_basis=10, trainable=False)
    (cutoff_fn): PolynomialCutoff(p=5.0, r_max=6.0)
  )
  (spherical_harmonics): SphericalHarmonics()
  (atomic_energies_fn): AtomicEnergiesBlock(energies=[-3.6672, -1.3321, -3.4821, -4.7367, -7.7249, -8.4056, -7.3601, -7.2846, -4.8965, 0.0000, -2.7594, -2.8140, -4.8469, -7.6948, -6.9633, -4.6726, -2.8117, -0.0626, -2.6176, -5.3905, -7.8858, -10.2684, -8.6651, -9.2331, -8.3050, -7.0490, -5.5774, -5.1727, -3.2521, -1.2902, -3.5271, -4.7085, -3.9765, -3.8862, -2.5185, 6.7669, -2.5635, -4.9380, -10.1498, -11.8469, -12.1389, -8.7917, -8.7869, -7.7809, -6.8500, -4.8910, -2.0634, -0.6396, -2.7887, -3.8186, -3.5871, -2.8804, -1.6356, 9.8467, -2.7653, -4.9910, -8.9337, -8.7356, -8.0190, -8.2515, -7.5917, -8.1697, -13.592

The model has several groups of modules:
- `node_embedding`, which is a learned representation of individual atoms (nodes) in the graph
- `radial_embedding`, which contains Bessel functions (orthogonal functions in cylindrical coordinates, used to define message-passing directions between two neighbors; see [the DimeNet paper](https://arxiv.org/pdf/2003.03123.pdf) on which the approach is based) and a cutoff function
- `spherical_harmonics`, which contains spherical harmonics (orthogonal functions in spherical coordinates)
- `atomic_energies_fn`, which is a list of learned energies for each atom type that are used to compute the energy of an arrangement of atoms (molecule or crystal)
- `interactions`, which contains two `RealAgnosticResidualInteractionBlock` modules. These modules perform the graph convolutions.
- `products`, which contains two `EquivariantProductBasisBlock` modules that perform tensor products of the learned representations of the atoms, which are described in more detail in the appendix of the MACE paper
- `readouts`, which converts these tensor products into our energy and force predictions.
- `scale_shift`, which scales and shifts the energy predictions (I think).

In the next section, we'll go over how to freeze a subset of layers and train the remaining layers on a new dataset.

### Grabbing some data for fine-tuning

In [71]:
zeolite_data = Dataset.from_file("data/chabazite.pth.tar")
train, val, test = split_train_validation_test(
    zeolite_data, val_size=0.1, test_size=0.1
)

In [72]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, collate_fn=collate_dicts)

### Evaluate the model on the training data

Let's begin by assessing how well the model predicts energies and forces for these data before we do any fine-tuning.

In [73]:
def plot_hexbin(pred, targ, ax, key, scale="log", units: dict = UNITS):
    mae = mean_absolute_error(targ, pred)

    if scale == "log":
        pred = np.abs(pred) + 1e-8
        targ = np.abs(targ) + 1e-8

    lim_min = min(np.min(pred), np.min(targ)) * 1.1
    lim_max = max(np.max(pred), np.max(targ)) * 1.1

    extent = [lim_min, lim_max, lim_min, lim_max]

    hb = ax.hexbin(
        pred,
        targ,
        cmap="viridis",
        gridsize=60,
        bins="log",
        mincnt=1,
        edgecolors=None,
        linewidths=(0.1,),
        xscale=scale,
        yscale=scale,
        extent=extent,
    )

    ax.set_xlim(lim_min, lim_max)
    ax.set_ylim(lim_min, lim_max)
    ax.set_aspect("equal")

    ax.plot(
        (lim_min, lim_max),
        (lim_min, lim_max),
        color="#000000",
        zorder=-1,
        linewidth=0.5,
    )

    ax.set_xlabel("predicted / %s" % (units[key]), fontsize=12, fontweight="bold")
    ax.set_ylabel("target / %s" % (units[key]), fontsize=12, fontweight="bold")

    ax.annotate(
        "MAE: %.3f %s" % (mae, units[key]),
        (0.03, 0.95),
        xycoords="axes fraction",
        fontsize=12,
        fontweight="bold",
        fontstyle="italic",
    )

    return ax, hb


def stack_cat(item):
    try:
        out = torch.stack(item, dim=0)
    except RuntimeError:
        out = torch.cat(item, dim=0)
    return out

In [74]:
# loss_fn = loss.build_mse_loss(loss_coef={"energy": 0.01, "energy_grad": 1})
loss_fn = loss.build_mse_loss(loss_coef={"energy": 1})

In [75]:
# results, targets, val_loss = evaluate(nff_mace, test_loader, loss_fn, device=DEVICE)

In [76]:
# fig, ax_fig = plt.subplots(1, 2, figsize=(12, 6))

# for ax, key in zip(ax_fig, UNITS.keys()):
#     pred = stack_cat(results[key]).detach().cpu().numpy().reshape(-1)
#     targ = stack_cat(targets[key]).detach().cpu().numpy().reshape(-1)

#     plot_hexbin(pred, targ, ax, key, scale="linear")

#     ax.set_title("%s: %s" % ("MACE", key.upper()), fontsize=14)

# plt.show()

The performance here is pretty bad, but that's the whole reason we're fine-tuning the model. We'll tackle this fine-tuning in the next section.

### Doing the actual fine-tuning

We will freeze the early layers that comprise the learned representations and message passing: `node_embedding` and `interactions` (the `radial_embedding`, `spherical_harmonics`, and `atomic_energies_fn` have no learned parameters). Then, we'll train the remaining layers (`products`, `readouts`, and `scale_shift`) on new data.

In [77]:
freeze_layers = [
    mace_model.node_embedding,
    mace_model.interactions,
    mace_model.products,
]

for layer in freeze_layers:
    num_params = sum(p.numel() for p in layer.parameters())
    print(
        "Freezing params in layer: ",
        str(layer),
        " with ",
        num_params,
        " parameters.",
    )
    for param in layer.parameters():
        param.requires_grad = False

Freezing params in layer:  LinearNodeEmbeddingBlock(
  (linear): Linear(89x0e -> 128x0e | 11392 weights)
)  with  11392  parameters.
Freezing params in layer:  ModuleList(
  (0): RealAgnosticResidualInteractionBlock(
    (linear_up): Linear(128x0e -> 128x0e | 16384 weights)
    (conv_tp): TensorProduct(128x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+128x1o+128x2e+128x3o | 512 paths | 512 weights)
    (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 512]
    (linear): Linear(128x0e+128x1o+128x2e+128x3o -> 128x0e+128x1o+128x2e+128x3o | 65536 weights)
    (skip_tp): FullyConnectedTensorProduct(128x0e x 89x0e -> 128x0e+128x1o | 1458176 paths | 1458176 weights)
    (reshape): reshape_irreps()
  )
  (1): RealAgnosticResidualInteractionBlock(
    (linear_up): Linear(128x0e+128x1o -> 128x0e+128x1o | 32768 weights)
    (conv_tp): TensorProduct(128x0e+128x1o x 1x0e+1x1o+1x2e+1x3o -> 256x0e+384x1o+384x2e+256x3o | 1280 paths | 1280 weights)
    (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 1280]

First, we will need to set up all the parameters for training. We'll make variables to contain our training metrics, hooks for training, and the optimizer. Then, we'll train the model on the new data.

In [78]:
train_metrics = [
    metrics.MeanAbsoluteError("energy"),
    metrics.MeanAbsoluteError("energy_grad"),
]

In [79]:
trainable_params = filter(lambda p: p.requires_grad, mace_model.parameters())
optimizer = Adam(trainable_params, lr=3e-4)

In [80]:
# train_hooks = [
#     hooks.MaxEpochHook(10),
#     hooks.CSVHook(
#         OUTDIR_MACE,
#         metrics=train_metrics,
#     ),
#     hooks.PrintingHook(
#         OUTDIR_MACE,
#         metrics=train_metrics,
#         separator=" | ",
#         time_strf="%M:%S",
#         log_memory=False,
#     ),
#     hooks.ReduceLROnPlateauHook(
#         optimizer=optimizer,
#         patience=30,
#         factor=0.5,
#         min_lr=1e-7,
#         window_length=1,
#         stop_after_min=True,
#     ),
# ]

Next, we pass the model through the `NFFMACEWrapper` so that it can interface with the `nff` package. The wrapper has a `forward` method that takes inputs as `AtomsBatch` objects, translates them to `torch_geometric` objects, and passes them through the model.

In [81]:
print(loss_fn)

<function build_general_loss.<locals>.loss_fn at 0x7fb59c0824c0>


In [82]:
T = Trainer(
    model_path=OUTDIR_MACE,
    model=mace_model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    # hooks=train_hooks,
    # retain_graph=True,
)

In [83]:
T.train(device=DEVICE, n_epochs=10)

  0%|          | 0/1681 [00:00<?, ?it/s]


TypeError: cutoff of invalid type <class 'torch.Tensor'>

# Try doing the same tests with CHGNet

Here, we're going to load the NFF wrapper around the pre-trained CHGNet architecture and try training it the same way as we did with MACE above.

In [84]:
chgnet_nff = CHGNetNFF.load("0.3.0")

CHGNet v0.3.0 initialized with 412,525 parameters


In [85]:
chgnet_nff.modules

<bound method Module.modules of CHGNetNFF(
  (composition_model): AtomRef(
    (fc): Linear(in_features=94, out_features=1, bias=False)
  )
  (graph_converter): CrystalGraphConverter(algorithm='fast', atom_graph_cutoff=6, bond_graph_cutoff=3)
  (atom_embedding): AtomEmbedding(
    (embedding): Embedding(94, 64)
  )
  (bond_basis_expansion): BondEncoder(
    (rbf_expansion_ag): RadialBessel(
      (smooth_cutoff): CutoffPolynomial()
    )
    (rbf_expansion_bg): RadialBessel(
      (smooth_cutoff): CutoffPolynomial()
    )
  )
  (bond_embedding): Linear(in_features=31, out_features=64, bias=False)
  (bond_weights_ag): Linear(in_features=31, out_features=64, bias=False)
  (bond_weights_bg): Linear(in_features=31, out_features=64, bias=False)
  (angle_basis_expansion): AngleEncoder(
    (fourier_expansion): Fourier()
  )
  (angle_embedding): Linear(in_features=31, out_features=64, bias=False)
  (atom_conv_layers): ModuleList(
    (0-3): 4 x AtomConv(
      (activation): SiLU()
      (twoB

In [86]:
for layer in [
    chgnet_nff.atom_embedding,
    chgnet_nff.bond_embedding,
    chgnet_nff.angle_embedding,
    chgnet_nff.bond_basis_expansion,
    chgnet_nff.angle_basis_expansion,
    chgnet_nff.atom_conv_layers[:-1],
    chgnet_nff.bond_conv_layers,
    chgnet_nff.angle_layers,
]:
    for param in layer.parameters():
        param.requires_grad = False

In [87]:
loss_fn = loss.build_mse_loss(loss_coef={"energy": 0.01, "energy_grad": 1})

In [88]:
trainable_params = filter(lambda p: p.requires_grad, mace_model.parameters())
optimizer = Adam(trainable_params, lr=3e-4)

In [89]:
T = Trainer(
    model_path=OUTDIR_CHGNET,
    model=chgnet_nff,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
)

In [91]:
T.train(device=DEVICE, n_epochs=10)

  0%|          | 0/1681 [00:00<?, ?it/s]

1 structures imported


  0%|          | 0/1681 [00:00<?, ?it/s]


AttributeError: 'tuple' object has no attribute 'atomic_number'