This notebook is used to demonstrate the training of TensorNet on MatPES-PBE dataset

In [None]:
from __future__ import annotations

import json
import os
import shutil
from functools import partial

import lightning as pl
import numpy as np
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from dgl.data.utils import split_dataset
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger
from matgl.config import DEFAULT_ELEMENTS
from matgl.ext.pymatgen import Structure2Graph
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_pes
from matgl.models import TensorNet
from matgl.utils.training import PotentialLightningModule, xavier_init
from pymatgen.core import Structure
from torch.optim.lr_scheduler import CosineAnnealingLR

In [None]:
# define the isolated atom energies for PBE
isolated_energies_pbe = {
    "Ac": -0.24210133,
    "Ag": -0.19840574,
    "Al": -0.21672837,
    "Ar": -0.0235315,
    "As": -1.70136472,
    "Au": -0.18479218,
    "B": -0.2911712,
    "Ba": -1.35978407,
    "Be": -0.04262353,
    "Bi": -1.32462383,
    "Br": -0.22687512,
    "C": -1.26281801,
    "Ca": -0.02596217,
    "Cd": -0.01374787,
    "Ce": -1.43642821,
    "Cl": -0.25828681,
    "Co": -1.66614587,
    "Cr": -5.44620624,
    "Cs": -0.13452777,
    "Cu": -0.23745594,
    "Dy": -5.51640166,
    "Er": -3.03880565,
    "Eu": -9.48606277,
    "F": -0.43794547,
    "Fe": -3.30583367,
    "Ga": -0.19854295,
    "Gd": -8.11540027,
    "Ge": -0.77924665,
    "H": -1.11723232,
    "He": -0.00045595,
    "Hf": -3.49292389,
    "Hg": -0.0105212,
    "Ho": -4.30111439,
    "I": -0.18858477,
    "In": -0.19672488,
    "Ir": -1.42793567,
    "K": -0.17827125,
    "Kr": -0.02265396,
    "La": -0.62794477,
    "Li": -0.29734917,
    "Lu": -0.25255978,
    "Mg": -0.00994627,
    "Mn": -5.14592659,
    "Mo": -4.60213279,
    "N": -3.12555634,
    "Na": -0.22858276,
    "Nb": -2.53481909,
    "Nd": -3.6389801,
    "Ne": -0.01216023,
    "Ni": -0.28412403,
    "Np": -7.30273499,
    "O": -1.54690765,
    "Os": -2.88280809,
    "P": -1.88719667,
    "Pa": -2.03239022,
    "Pb": -0.63069886,
    "Pd": -1.47521138,
    "Pm": -5.08859903,
    "Pr": -2.12706895,
    "Pt": -0.50244445,
    "Pu": -10.39244586,
    "Rb": -0.16194042,
    "Re": -4.63436797,
    "Rh": -1.44016062,
    "Ru": -1.68884293,
    "S": -0.89091719,
    "Sb": -1.4302063,
    "Sc": -2.12966897,
    "Se": -0.78345919,
    "Si": -0.82583191,
    "Sm": -6.9970228,
    "Sn": -0.67963499,
    "Sr": -0.02823145,
    "Ta": -3.5659314,
    "Tb": -6.45686224,
    "Tc": -3.40289704,
    "Te": -0.6573123,
    "Th": -1.04419147,
    "Ti": -2.40532262,
    "Tl": -0.17939998,
    "Tm": -2.10513872,
    "U": -4.6443113,
    "V": -3.61232779,
    "W": -4.57101127,
    "Xe": -0.01020284,
    "Y": -2.25679622,
    "Yb": -1.84040717,
    "Zn": -0.01098351,
    "Zr": -2.23742918,
}
# read the MatPES database using json
with open("MatPES-PBE-2025.1.json") as f:
    data = json.load(f)
# initialize the lists for storing structures with energies, forces, and optional stresses
structures = []
energies = []
forces = []
stresses = []
for i in range(len(data)):
    structures.append(Structure.from_dict(data[i]["structure"]))
    energies.append(data[i]["energy"])
    forces.append(data[i]["forces"])
    stresses.append(voigt_6_to_full_3x3_stress(np.array(data[i]["stress"])) * -0.1)
stresses = [stress.tolist() for stress in stresses]

In [None]:
# define the graph converter for periodic systems
element_types = DEFAULT_ELEMENTS
cry_graph = Structure2Graph(element_types=element_types, cutoff=5.0)

In [None]:
# save all PES properties into a dictionary
labels = {}
labels["energies"] = energies
labels["forces"] = forces
labels["stresses"] = stresses

dataset = MGLDataset(structures=structures, converter=cry_graph, labels=labels)

In [None]:
# splitting the dataset into training, validation and test
training_set, validation_set, test_set = split_dataset(dataset, [0.9, 0.05, 0.05], random_state=42, shuffle=True)
# define the proper collate function for MGLDataLoader
collate_fn = partial(collate_fn_pes, include_line_graph=False, include_stress=True)
# initialize dataloader for training and validation
train_loader, val_loader = MGLDataLoader(
    train_data=training_set,
    val_data=validation_set,
    collate_fn=collate_fn,
    batch_size=32,
    num_workers=0,
)

In [None]:
# initialize TensorNet architecture
model = TensorNet(
    element_types=element_types,
    is_intensive=False,
    rbf_type="SphericalBessel",
    use_smooth=True,
    units=128,
)

In [None]:
# calculate scaling factor for training
train_graphs = []
energies = []
forces = []
for g, _lat, _attrs, lbs in training_set:
    train_graphs.append(g)
    energies.append(lbs["energies"])
    forces.append(lbs["forces"])
forces = torch.concatenate(forces)
rms_forces = torch.sqrt(torch.mean(torch.sum(forces**2, dim=1)))
# weight initialization
xavier_init(model)
# setup the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0e-3, weight_decay=1.0e-5, amsgrad=True)
scheduler = CosineAnnealingLR(optimizer, T_max=1000 * 10, eta_min=1.0e-2 * 1.0e-3)

In [None]:
# setup element_refs
energies_offsets = np.array([isolated_energies_pbe[element] for element in DEFAULT_ELEMENTS])
# initialize the potential lightning module
lit_model = PotentialLightningModule(
    model=model,
    element_refs=energies_offsets,
    data_std=rms_forces,
    optimizer=optimizer,
    scheduler=scheduler,
    loss="l1_loss",
    stress_weight=0.1,
    include_line_graph=False,
)
# setup loggers
path = os.getcwd()
logger = CSVLogger(save_dir=path)
# setup checkpoints
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_Total_Loss",
    mode="min",
    filename="{epoch:04d}-best_model",
)
# setup trainer
trainer = pl.Trainer(
    logger=logger,
    callbacks=[EarlyStopping(monitor="val_Total_Loss", mode="min", patience=200), checkpoint_callback],
    max_epochs=1000,
    accelerator="cpu",  # you can use gpu instead
    gradient_clip_val=2.0,
    accumulate_grad_batches=4,
    profiler="simple",
)
# fit!
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
# save trained model
model_export_path = "./trained_model/"
lit_model.model.save(model_export_path)

In [None]:
# This code just performs cleanup for this notebook.
shutil.rmtree("MGLDataset")
shutil.rmtree("trained_model")