# Introduction 

This notebook demonstrates the fitting of a TensorNet UMLIP using the MatPES v2025.1 PBE dataset. Fitting of other architectures in MatGL with either the PBE or r2SCAN datasets is similar.

**Important Note**: The training data sizes and maximum number of epochs chosen in the notebook are deliberately small so that the notebook will run within a reasonably short amount of time on a single CPU (~5-10 mins) for demonstration purposes. **The resulting model is not expected to be production quality**. When properly training a model, use the entire dataset with a much greater number of epochs.

In [24]:
from __future__ import annotations

import json
import os
import shutil
from functools import partial
import collections

import lightning as pl
import numpy as np
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from dgl.data.utils import split_dataset
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger
from matgl.config import DEFAULT_ELEMENTS
from matgl.ext.pymatgen import Structure2Graph
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_pes
from matgl.models import TensorNet
from matgl.utils.training import PotentialLightningModule, xavier_init
from monty.io import zopen
from pymatgen.core import Structure
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from matpes.data import get_data

In [2]:
data = get_data("PBE", download_atoms=True)

We need to load the atomic energies as the zero reference.

In [3]:
with zopen("MatPES-PBE-atoms.json.gz", "rt") as f:
    isolated_energies_pbe = json.load(f)
isolated_energies_pbe = {d["elements"][0]: d["energy"] for d in isolated_energies_pbe}
len(isolated_energies_pbe)

89

In [4]:
# Initialize the lists for storing structures with energies, forces, and optional stresses
structures = []
labels = collections.defaultdict(list)

for d in tqdm(data):
    structures.append(Structure.from_dict(d["structure"]))
    labels["energies"].append(d["energy"])
    labels["forces"].append(d["forces"])
    labels["stresses"].append((voigt_6_to_full_3x3_stress(np.array(d["stress"])) * -0.1).tolist())

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 434712/434712 [00:54<00:00, 8021.74it/s]


# Loading the data into the the MGLDataSet

In [5]:
# define the graph converter for periodic systems
element_types = DEFAULT_ELEMENTS
cry_graph = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(structures=structures, converter=cry_graph, labels=labels)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 434712/434712 [03:07<00:00, 2323.23it/s]


# Data split

For the purposes for demonstration, we are only going to use 2% of the data for trainining and 0.1% of the data for validation. In a real training, you should use a split such as 90%:10% or 95%:5%.

In [19]:
training_set, validation_set, _ = split_dataset(dataset, [0.02, 0.001, 0.979], random_state=42, shuffle=True)
# define the proper collate function for MGLDataLoader
collate_fn = partial(collate_fn_pes, include_line_graph=False, include_stress=True)
# initialize dataloader for training and validation
train_loader, val_loader = MGLDataLoader(
    train_data=training_set,
    val_data=validation_set,
    collate_fn=collate_fn,
    batch_size=32,
    num_workers=0,
)

# Model Setup

Here, we demonstrate the initialization of the TensorNet architecture. You may use any of the other architectures implemented in MatGL.

In [20]:
model = TensorNet(
    element_types=element_types,
    is_intensive=False,
    rbf_type="SphericalBessel",
    use_smooth=True,
    units=128,
)

In [21]:
# calculate scaling factor for training
train_graphs = []
energies = []
forces = []
for g, _lat, _attrs, lbs in training_set:
    train_graphs.append(g)
    energies.append(lbs["energies"])
    forces.append(lbs["forces"])
forces = torch.concatenate(forces)
rms_forces = torch.sqrt(torch.mean(torch.sum(forces**2, dim=1)))
# weight initialization
xavier_init(model)
# setup the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0e-3, weight_decay=1.0e-5, amsgrad=True)
scheduler = CosineAnnealingLR(optimizer, T_max=1000 * 10, eta_min=1.0e-2 * 1.0e-3)

# Setup the potential lightning module

Note that the max_epochs is set to 2 here for demonstration purposes. In a real fitting, this number should be much larger (probably > 1000).

In [22]:
# setup element_refs
energies_offsets = np.array([isolated_energies_pbe[element] for element in DEFAULT_ELEMENTS])
# initialize the potential lightning module
lit_model = PotentialLightningModule(
    model=model,
    element_refs=energies_offsets,
    data_std=rms_forces,
    optimizer=optimizer,
    scheduler=scheduler,
    loss="l1_loss",
    stress_weight=0.1,
    include_line_graph=False,
)
# setup loggers
path = os.getcwd()
logger = CSVLogger(save_dir=path)
# setup checkpoints
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_Total_Loss",
    mode="min",
    filename="{epoch:04d}-best_model",
)
# setup trainer
trainer = pl.Trainer(
    logger=logger,
    callbacks=[EarlyStopping(monitor="val_Total_Loss", mode="min", patience=200), checkpoint_callback],
    max_epochs=2,
    accelerator="cpu",  # you can use gpu instead
    gradient_clip_val=2.0,
    accumulate_grad_batches=4,
    profiler="simple",
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Run the fit

In [23]:
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)


  | Name  | Type              | Params | Mode 
----------------------------------------------------
0 | mae   | MeanAbsoluteError | 0      | train
1 | rmse  | MeanSquaredError  | 0      | train
2 | model | Potential         | 837 K  | train
----------------------------------------------------
837 K     Trainable params
0         Non-trainable params
837 K     Total params
3.352     Total estimated model params size (MB)
69        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
FIT Profiler Report

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                     	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                       

In [None]:
# save trained model
model_export_path = "./trained_model/"
lit_model.model.save(model_export_path)

# Cleanup

In [None]:
# This code just performs cleanup for this notebook.
shutil.rmtree("MGLDataset")
shutil.rmtree("trained_model")