# Introduction 

This notebook demonstrates the fitting of a TensorNet UMLIP using the MatPES v2025.1 PBE dataset. Fitting of other architectures in MatGL with either the PBE or r2SCAN datasets is similar.

In [None]:
from __future__ import annotations

import json
import os
import shutil
from functools import partial

import lightning as pl
import numpy as np
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from dgl.data.utils import split_dataset
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger
from matgl.config import DEFAULT_ELEMENTS
from matgl.ext.pymatgen import Structure2Graph
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_pes
from matgl.models import TensorNet
from matgl.utils.training import PotentialLightningModule, xavier_init
from monty.io import zopen
from pymatgen.core import Structure
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from matpes.data import get_data

In [None]:
data = get_data("PBE", download_atoms=True)

Downloading from https://s3.us-east-1.amazonaws.com/materialsproject-contribs/MatPES_2025_1/MatPES-PBE-2025.1.json.gz...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 431M/431M [01:39<00:00, 4.33MB/s]


Downloading from https://s3.us-east-1.amazonaws.com/materialsproject-contribs/MatPES_2025_1/MatPES-PBE-atoms.json.gz...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8.39k/8.39k [00:00<00:00, 871kB/s]


We need to load the atomic energies as the zero reference.

In [None]:
with zopen("MatPES-PBE-atoms.json.gz", "rt") as f:
    isolated_energies_pbe = json.load(f)
isolated_energies_pbe = {d["formula_pretty"]: d["energy"] for d in isolated_energies_pbe}
isolated_energies_pbe

{'Zn': -0.01098351,
 'Y': -2.25679622,
 'Tl': -0.17939998,
 'Ti': -2.40532262,
 'Sr': -0.02823145,
 'Ta': -3.5659314,
 'Te': -0.6573123,
 'V': -3.61232779,
 'W': -4.57101127,
 'Si': -0.82583191,
 'Sn': -0.67963499,
 'Se': -0.78345919,
 'Sb': -1.4302063,
 'Sc': -2.12966897,
 'Zr': -2.23742918,
 'Rh': -1.44016062,
 'Re': -4.63436797,
 'O2': -1.54690765,
 'Tc': -3.40289704,
 'Os': -2.88280809,
 'Rb': -0.16194042,
 'P': -1.88719667,
 'S': -0.89091719,
 'Th': -1.04419147,
 'Ni': -0.28412403,
 'Pb': -0.63069886,
 'Pd': -1.47521138,
 'Pt': -0.50244445,
 'Ru': -1.68884293,
 'Pa': -2.03239022,
 'Pu': -10.39244586,
 'Ne': -0.01216023,
 'Na': -0.22858276,
 'Nb': -2.53481909,
 'N2': -3.12555634,
 'Mn': -5.14592659,
 'Mg': -0.00994627,
 'U': -4.6443113,
 'Li': -0.29734917,
 'Np': -7.30273499,
 'Lu': -0.25255978,
 'Mo': -4.60213279,
 'Kr': -0.02265396,
 'In': -0.19672488,
 'I': -0.18858477,
 'He': -0.00045595,
 'Hf': -3.49292389,
 'Hg': -0.0105212,
 'K': -0.17827125,
 'Ir': -1.42793567,
 'Ge': -0.77

In [None]:
# initialize the lists for storing structures with energies, forces, and optional stresses
structures = []
energies = []
forces = []
stresses = []
for d in tqdm(data):
    structures.append(Structure.from_dict(d["structure"]))
    energies.append(d["energy"])
    forces.append(d["forces"])
    stresses.append(voigt_6_to_full_3x3_stress(np.array(d["stress"])) * -0.1)
stresses = [stress.tolist() for stress in stresses]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 434712/434712 [00:57<00:00, 7595.45it/s]


# Loading the data into the the MGLDataSet

In [None]:
# define the graph converter for periodic systems
element_types = DEFAULT_ELEMENTS
cry_graph = Structure2Graph(element_types=element_types, cutoff=5.0)

In [None]:
# save all PES properties into a dictionary
labels = {}
labels["energies"] = energies
labels["forces"] = forces
labels["stresses"] = stresses

dataset = MGLDataset(structures=structures, converter=cry_graph, labels=labels)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 434712/434712 [03:08<00:00, 2301.06it/s]


In [None]:
# splitting the dataset into training, validation and test
training_set, validation_set, test_set = split_dataset(dataset, [0.9, 0.05, 0.05], random_state=42, shuffle=True)
# define the proper collate function for MGLDataLoader
collate_fn = partial(collate_fn_pes, include_line_graph=False, include_stress=True)
# initialize dataloader for training and validation
train_loader, val_loader = MGLDataLoader(
    train_data=training_set,
    val_data=validation_set,
    collate_fn=collate_fn,
    batch_size=32,
    num_workers=0,
)

# Model Setup

Here, we demonstrate the initialization of the TensorNet architecture. You may use any of the other architectures implemented in MatGL.

In [None]:
model = TensorNet(
    element_types=element_types,
    is_intensive=False,
    rbf_type="SphericalBessel",
    use_smooth=True,
    units=128,
)

In [None]:
# calculate scaling factor for training
train_graphs = []
energies = []
forces = []
for g, _lat, _attrs, lbs in training_set:
    train_graphs.append(g)
    energies.append(lbs["energies"])
    forces.append(lbs["forces"])
forces = torch.concatenate(forces)
rms_forces = torch.sqrt(torch.mean(torch.sum(forces**2, dim=1)))
# weight initialization
xavier_init(model)
# setup the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0e-3, weight_decay=1.0e-5, amsgrad=True)
scheduler = CosineAnnealingLR(optimizer, T_max=1000 * 10, eta_min=1.0e-2 * 1.0e-3)

# Setup the potential lightning module

Note that the max_epochs is set to 2 here for demonstration purposes. In a real fitting, this number should be much larger (probably > 1000).

In [None]:
# setup element_refs
energies_offsets = np.array([isolated_energies_pbe[element] for element in DEFAULT_ELEMENTS])
# initialize the potential lightning module
lit_model = PotentialLightningModule(
    model=model,
    element_refs=energies_offsets,
    data_std=rms_forces,
    optimizer=optimizer,
    scheduler=scheduler,
    loss="l1_loss",
    stress_weight=0.1,
    include_line_graph=False,
)
# setup loggers
path = os.getcwd()
logger = CSVLogger(save_dir=path)
# setup checkpoints
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_Total_Loss",
    mode="min",
    filename="{epoch:04d}-best_model",
)
# setup trainer
trainer = pl.Trainer(
    logger=logger,
    callbacks=[EarlyStopping(monitor="val_Total_Loss", mode="min", patience=200), checkpoint_callback],
    max_epochs=2,
    accelerator="cpu",  # you can use gpu instead
    gradient_clip_val=2.0,
    accumulate_grad_batches=4,
    profiler="simple",
)

KeyError: 'H'

# Run the fit

In [None]:
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
# save trained model
model_export_path = "./trained_model/"
lit_model.model.save(model_export_path)

# Cleanup

In [None]:
# This code just performs cleanup for this notebook.
shutil.rmtree("MGLDataset")
shutil.rmtree("trained_model")