In [26]:
from chgnet.trainer import Trainer

In [27]:
from chgnet.model.model import CHGNet
from typing import Literal
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import os
from datetime import datetime
import torch


class TrainerV2(Trainer):
    def __init__(self, model: CHGNet | None = None, targets: Literal['ef', 'efs', 'efsm'] = "ef", energy_loss_ratio: float = 1, force_loss_ratio: float = 1, stress_loss_ratio: float = 0.1, mag_loss_ratio: float = 0.1, optimizer: str = "Adam", scheduler: str = "CosLR", criterion: str = "MSE", epochs: int = 50, starting_epoch: int = 0, learning_rate: float = 0.001, print_freq: int = 100, torch_seed: int | None = None, data_seed: int | None = None, use_device: str | None = None, **kwargs) -> None:
        super().__init__(model, targets, energy_loss_ratio, force_loss_ratio, stress_loss_ratio, mag_loss_ratio, optimizer, scheduler, criterion, epochs, starting_epoch, learning_rate, print_freq, torch_seed, data_seed, use_device, **kwargs)
        
    def train(self, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader | None = None, save_dir: str | None = None, save_test_result: bool = False, train_composition_model: bool = False) -> None:
        """Train the model using torch data_loaders.

        Args:
            train_loader (DataLoader): train loader to update CHGNet weights
            val_loader (DataLoader): val loader to test accuracy after each epoch
            test_loader (DataLoader):  test loader to test accuracy at end of training.
                Can be None.
                Default = None
            save_dir (str): the dir name to save the trained weights
                Default = None
            save_test_result (bool): Whether to save the test set prediction in a JSON
                file. Default = False
            train_composition_model (bool): whether to train the composition model
                (AtomRef), this is suggested when the fine-tuning dataset has large
                elemental energy shift from the pretrained CHGNet, which typically comes
                from different DFT pseudo-potentials.
                Default = False
        """
        
        if self.model is None:
            raise ValueError("Model needs to be initialized")
        global best_checkpoint  # noqa: PLW0603
        if save_dir is None:
            save_dir = f"{datetime.now():%Y-%m-%d-%H%M%S}"
            self.save_dir=f"{datetime.now():%Y-%m-%d-%H%M%S}"
        os.makedirs(f"results/{save_dir}", exist_ok=True)
        writer = SummaryWriter(log_dir=f"log/{save_dir}")

        print(f"Begin Training: using {self.device} device")
        print(f"training targets: {self.targets}")
        self.model.to(self.device)

        # Turn composition model training on / off
        for param in self.model.composition_model.parameters():
            param.requires_grad = train_composition_model

        for epoch in range(self.starting_epoch, self.epochs):
            # train
            train_mae = self._train(train_loader, epoch)
            if "e" in train_mae and train_mae["e"] != train_mae["e"]:
                print("Exit due to NaN")
                break

            # val
            val_mae = self._validate(val_loader)
            for key in self.targets:
                self.training_history[key]["train"].append(train_mae[key])
                self.training_history[key]["val"].append(val_mae[key])

            if "e" in val_mae and val_mae["e"] != val_mae["e"]:
                print("Exit due to NaN")
                break

            self.save_checkpoint(epoch, val_mae, save_dir=save_dir)
            for key in self.targets:
                self.training_history[key]["train"].append(train_mae[key])
                self.training_history[key]["val"].append(val_mae[key])
                # writer.add_scalar(f"{key}_train_loss", train_loss / len(dataloader), epoch)
                writer.add_scalar(f"{key}_train_mae", train_mae[key], epoch)
                # writer.add_scalar(f"{key}_val_loss", train_loss / len(dataloader), epoch)
                writer.add_scalar(f"{key}_val_mae", val_mae[key], epoch)

        if test_loader is not None:
            # test best model
            print("---------Evaluate Model on Test Set---------------")
            for file in os.listdir(save_dir):
                if file.startswith("bestE_"):
                    test_file = file
                    best_checkpoint = torch.load(os.path.join(save_dir, test_file))

            self.model.load_state_dict(best_checkpoint["model"]["state_dict"])
            if save_test_result:
                test_mae = self._validate(
                    test_loader, is_test=True, test_result_save_path=save_dir
                )
            else:
                test_mae = self._validate(
                    test_loader, is_test=True, test_result_save_path=None
                )

            for key in self.targets:
                self.training_history[key]["test"] = test_mae[key]
            self.save(filename=os.path.join(save_dir, test_file))
        

SyntaxError: incomplete input (1646131108.py, line 99)

In [15]:
from pymatgen.core.structure import Structure
import numpy as
try:
    from chgnet import ROOT

    lmo = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
except Exception:
    from urllib.request import urlopen

    url = "https://raw.githubusercontent.com/CederGroupHub/chgnet/main/examples/mp-18767-LiMnO2.cif"
    cif = urlopen(url).read().decode("utf-8")
    lmo = Structure.from_str(cif, fmt="cif")

structures, energies_per_atom, forces, stresses, magmoms = [], [], [], [], []
from chgnet.model import CHGNet
chgnet = CHGNet.load()

for _ in range(100):
    structure = lmo.copy()
    # stretch the cell by a small amount
    structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))
    # perturb all atom positions by a small amount
    structure.perturb(0.1)
    # print(structure)

    pred = chgnet.predict_structure(structure)

    structures.append(structure)
    energies_per_atom.append(pred["e"] + np.random.uniform(-0.1, 0.1, size=1))
    forces.append(pred["f"] + np.random.uniform(-0.01, 0.01, size=pred["f"].shape))
    stresses.append(
        pred["s"] * -10 + np.random.uniform(-0.05, 0.05, size=pred["s"].shape)
    )
    magmoms.append(pred["m"] + np.random.uniform(-0.03, 0.03, size=pred["m"].shape))

CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cuda:0


In [16]:
from chgnet.data.dataset import StructureData, get_train_val_test_loader
dataset = StructureData(
    structures=structures,
    energies=energies_per_atom,
    forces=forces,
    stresses=stresses,  # can be None
    magmoms=magmoms,  # can be None
)
train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset, batch_size=32, train_ratio=0.9, val_ratio=0.05
)

100 structures imported


In [17]:
# Define Trainer
trainer = TrainerV2(
    model=chgnet,
    targets="efsm",
    optimizer="Adam",
    scheduler="CosLR",
    criterion="MSE",
    epochs=10,
    learning_rate=1e-2,
    # use_device="cuda",
    print_freq=2,
)

In [18]:
trainer.train(train_loader, val_loader, test_loader)

Begin Training: using cuda:0 device
training targets: efsm
Epoch: [0][1/3] | Time (1.079)(0.158) | Loss 0.0039(0.0039) | MAE e 0.053(0.053)  f 0.005(0.005)  s 0.002(0.002)  m 0.014(0.014)  
Epoch: [0][2/3] | Time (0.943)(0.138) | Loss 0.9551(0.4795) | MAE e 0.388(0.220)  f 0.157(0.081)  s 1.347(0.675)  m 0.547(0.280)  
*   e_MAE (0.183) 	f_MAE (0.074) 	s_MAE (1.109) 	m_MAE (0.218) 	
Epoch: [1][1/3] | Time (0.725)(0.002) | Loss 0.4176(0.4176) | MAE e 0.207(0.207)  f 0.071(0.071)  s 1.101(1.101)  m 0.214(0.214)  
Epoch: [1][2/3] | Time (0.752)(0.001) | Loss 0.8245(0.6211) | MAE e 0.095(0.151)  f 0.116(0.094)  s 1.360(1.230)  m 0.375(0.295)  
*   e_MAE (0.101) 	f_MAE (0.150) 	s_MAE (1.135) 	m_MAE (0.193) 	
Epoch: [2][1/3] | Time (0.745)(0.001) | Loss 0.3914(0.3914) | MAE e 0.070(0.070)  f 0.115(0.115)  s 1.021(1.021)  m 0.167(0.167)  
Epoch: [2][2/3] | Time (0.801)(0.001) | Loss 1.0637(0.7276) | MAE e 0.108(0.089)  f 0.174(0.145)  s 1.237(1.129)  m 0.062(0.115)  
*   e_MAE (0.117) 	f_MAE 

In [24]:
# trainer.get_best_model()