In [None]:
import torch
from ase.io import read
from torch.utils.data import Dataset, DataLoader
from pymatgen.io.ase import AseAtomsAdaptor
from chgnet.model import CHGNet
from chgnet.graph import CrystalGraphConverter
from chgnet.data.dataset import GraphData
from chgnet.trainer import Trainer

#class GraphDataset(Dataset):
    # def __init__(self, graphs, device):
    #     self.graphs = graphs
    #     self.device = device
    # def __len__(self):
    #     return len(self.graphs)
    # def __getitem__(self, idx):
    #     return self.graphs[idx].to(self.device)

# ===== File paths =====
train_extxyz = "/home/phanim/harshitrawat/summer/final_work/T1_chgnet_labeled.extxyz"
valid_extxyz = "/home/phanim/harshitrawat/summer/final_work/T2_chgnet_labeled.extxyz"

# ===== Hyperparameters =====
batch_size = 16
epochs = 30
lr = 1e-4
output_dir = "./chgnet_finetuned"
device = "cuda"

# ===== ASE → pymatgen Structure → CrystalGraph =====
adaptor = AseAtomsAdaptor()
converter = CrystalGraphConverter()

def load_extxyz_to_graphs(path):
    atoms_list = read(path, index=":")
    graphs = []
    for atoms in atoms_list:
        structure = adaptor.get_structure(atoms)
        graph = converter(structure)
        graphs.append(graph)
    return graphs

train_graphs = load_extxyz_to_graphs(train_extxyz)
val_graphs = load_extxyz_to_graphs(valid_extxyz)



In [33]:
class GraphDataset(Dataset):
    def __init__(self, graphs, device):
        self.graphs = graphs
        self.device = device

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx].to(self.device)


In [40]:
train_dataset = GraphDataset(train_graphs, device="cuda:0")
val_dataset = GraphDataset(val_graphs, device="cuda:0")
def collate_graphs(graphs):
    batch = {}
    for key in graphs[0].__dict__:
        batch[key] = [getattr(g, key) for g in graphs]
    return type(graphs[0]).collate(batch)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_graphs)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_graphs)


In [41]:
from chgnet.data.dataset import GraphData

def collate_graphs(graph_list):
    return GraphData.collate(graph_list)

train_loader = DataLoader(
    GraphDataset(train_graphs, device),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_graphs,
)

val_loader = DataLoader(
    GraphDataset(val_graphs, device),
    batch_size=batch_size,
    collate_fn=collate_graphs,
)


In [42]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # ← Use the first GPU
device = "cuda:0"
# ===== Model =====
model = CHGNet.load(use_device="cpu")
model.to("cuda:0")


# ===== Trainer =====
trainer = Trainer(
    model=model,
    targets="efs",  # energy + force + stress
    energy_loss_ratio=1.0,
    force_loss_ratio=30.0,
    stress_loss_ratio=0.0,  # set >0 if stress labels are present
    optimizer="Adam",
    scheduler="CosLR",
    criterion="MSE",
    epochs=epochs,
    learning_rate=lr,
    use_device=device,
    print_freq=1  # <== Print loss every batch
)

# ===== Train! =====
trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    save_dir=output_dir,
    save_test_result=False,
)

# ===== Save final model =====
trainer.save(f"{output_dir}/final_chgnet_model.pth.tar")
print("✅ Fine-tuning complete.")

CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu
Begin Training: using cuda:0 device
training targets: efs


AttributeError: type object 'GraphData' has no attribute 'collate'

In [39]:
from chgnet.graph.converter import collate_crystal_graphs


ImportError: cannot import name 'collate_crystal_graphs' from 'chgnet.graph.converter' (/home/phanim/harshitrawat/miniconda3/envs/mace_0.3.8/lib/python3.10/site-packages/chgnet/graph/converter.py)

In [7]:
import chgnet
print(chgnet.__version__)
print(chgnet.__file__)


0.3.8
/home/phanim/harshitrawat/miniconda3/envs/mace_0.3.8/lib/python3.10/site-packages/chgnet/__init__.py


In [8]:
ls $(python -c "import chgnet; print(chgnet.__path__[0])")


[0m[01;34mdata[0m/   __init__.py  [01;34mpretrained[0m/   py.typed  [01;34mutils[0m/
[01;34mgraph[0m/  [01;34mmodel[0m/       [01;34m__pycache__[0m/  [01;34mtrainer[0m/


In [16]:
from chgnet.trainer import Trainer
help(Trainer)


Help on class Trainer in module chgnet.trainer.trainer:

class Trainer(builtins.object)
 |  Trainer(model: 'CHGNet | None' = None, *, targets: 'TrainTask' = 'ef', energy_loss_ratio: 'float' = 1, force_loss_ratio: 'float' = 1, stress_loss_ratio: 'float' = 0.1, mag_loss_ratio: 'float' = 0.1, optimizer: 'str' = 'Adam', scheduler: 'str' = 'CosLR', criterion: 'str' = 'MSE', epochs: 'int' = 50, starting_epoch: 'int' = 0, learning_rate: 'float' = 0.001, print_freq: 'int' = 100, torch_seed: 'int | None' = None, data_seed: 'int | None' = None, use_device: 'str | None' = None, check_cuda_mem: 'bool' = True, **kwargs) -> 'None'
 |  
 |  A trainer to train CHGNet using energy, force, stress and magmom.
 |  
 |  Methods defined here:
 |  
 |  __init__(self, model: 'CHGNet | None' = None, *, targets: 'TrainTask' = 'ef', energy_loss_ratio: 'float' = 1, force_loss_ratio: 'float' = 1, stress_loss_ratio: 'float' = 0.1, mag_loss_ratio: 'float' = 0.1, optimizer: 'str' = 'Adam', scheduler: 'str' = 'CosLR',

In [18]:
help(CrystalGraphConverter)


Help on class CrystalGraphConverter in module chgnet.graph.converter:

class CrystalGraphConverter(torch.nn.modules.module.Module)
 |  CrystalGraphConverter(*, atom_graph_cutoff: 'float' = 6, bond_graph_cutoff: 'float' = 3, algorithm: "Literal['legacy', 'fast']" = 'fast', on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error', verbose: 'bool' = False) -> 'None'
 |  
 |  Convert a pymatgen.core.Structure to a CrystalGraph
 |  The CrystalGraph dataclass stores essential field to make sure that
 |  gradients like force and stress can be calculated through back-propagation later.
 |  
 |  Method resolution order:
 |      CrystalGraphConverter
 |      torch.nn.modules.module.Module
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, *, atom_graph_cutoff: 'float' = 6, bond_graph_cutoff: 'float' = 3, algorithm: "Literal['legacy', 'fast']" = 'fast', on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error', verbose: 'bool' = False) -> 'None'
 |    

In [43]:
#!/usr/bin/env python3
import os
import argparse

import torch
from torch.utils.data import Dataset, DataLoader

# Make sure you have torch_geometric installed
from torch_geometric.data import Batch

from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor

from chgnet.graph import CrystalGraphConverter
from chgnet.model import CHGNet
from chgnet.trainer import Trainer


ModuleNotFoundError: No module named 'torch_geometric'

In [45]:
#!/usr/bin/env python3
import os
import argparse
from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor
import torch

from chgnet.model import CHGNet
from chgnet.graph.converter import CrystalGraphConverter
from chgnet.graph.dataset import GraphDataset
from chgnet.trainer import Trainer


def read_extxyz_to_graphs(path, converter):
    """Convert .extxyz to list of CrystalGraph Data objects with energy & force labels."""
    atoms_list = read(path, index=":")
    graphs = []

    for atoms in atoms_list:
        struct = AseAtomsAdaptor.get_structure(atoms)
        graph = converter(struct)

        energy = atoms.info.get("energy", None)
        forces = atoms.arrays.get("forces", None)

        if energy is None or forces is None:
            raise ValueError("Missing energy or forces in extxyz entry.")

        graph.y = torch.tensor([energy], dtype=torch.float)  # shape [1]
        graph.forces = torch.tensor(forces, dtype=torch.float)  # shape [n_atoms, 3]
        graphs.append(graph)

    return graphs


def main():
    parser = argparse.ArgumentParser(description="Fine-tune CHGNet from extxyz")
    parser.add_argument("--pretrained_model", type=str, required=True)
    parser.add_argument("--train_extxyz", type=str, required=True)
    parser.add_argument("--valid_extxyz", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--cutoff", type=float, default=5.0)
    parser.add_argument("--max_neighbors", type=int, default=12)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🚀 Using device: {device}")

    # Load model
    model = CHGNet().to(device)
    checkpoint = torch.load(args.pretrained_model, map_location=device)
    if "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)
    print("✅ Loaded pretrained model")

    # Set up converter and data
    converter = CrystalGraphConverter(
        cutoff=args.cutoff,
        max_neighbors=args.max_neighbors,
        use_canonize=True
    )

    train_graphs = read_extxyz_to_graphs(args.train_extxyz, converter)
    valid_graphs = read_extxyz_to_graphs(args.valid_extxyz, converter)
    print(f"✅ Loaded {len(train_graphs)} train and {len(valid_graphs)} val samples")

    train_dataset = GraphDataset(train_graphs)
    val_dataset = GraphDataset(valid_graphs)

    # Trainer
    trainer = Trainer(
        model=model,
        optimizer_args={"lr": args.lr},
        targets="efs",
        energy_weight=1.0,
        force_weight=30.0,
        stress_weight=0.0,
        device=device
    )

    # Training loop
    for epoch in range(1, args.epochs + 1):
        train_loss = trainer.train_epoch(train_dataset, batch_size=args.batch_size)
        val_loss = trainer.validate_epoch(val_dataset, batch_size=args.batch_size)
        print(f"[Epoch {epoch}/{args.epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save model
    torch.save(model.state_dict(), "chgnet_finetuned.pth")
    print("💾 Saved fine-tuned model to chgnet_finetuned.pth")


if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'chgnet.graph.dataset'

In [44]:
!pip install torch_geometric

[0m^C
