In [1]:
# Papermill parameters

supercell_size = 2
supercell_scaling = [[supercell_size, 0, 0], [0, supercell_size, 0], [0, 0, supercell_size]]
desired_partitions = 20
num_message_passing = 3

In [2]:
# Imports

import scrapbook as sb

import numpy as np
import torch

from tqdm import tqdm

from partitioner import part_graph_extended
import networkx as nx

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
from ase.io import read
from ase.build import make_supercell
from mattersim.datasets.utils.convertor import GraphConvertor
from torch_geometric.utils import to_networkx
 
atoms = read("datasets/test.xyz")
atoms = make_supercell(atoms, supercell_scaling)

converter = GraphConvertor("m3gnet", 5.0, True, 4.0)

length = len(atoms)
atom_graph = converter.convert(atoms.copy(), None, None, None)

G = to_networkx(atom_graph)

sb.glue("num_atoms", len(atoms))
print("Number of atoms", len(atoms))

Number of atoms 3408


In [4]:
partitions, extended_partitions = part_graph_extended(G, desired_partitions, num_message_passing)

num_partitions = len(partitions)

print(f"Created {num_partitions} partitions")
print(f"Average partition size: {sum(len(x) for x in extended_partitions) / num_partitions}")

Created 20 partitions
Average partition size: 2541.85


In [5]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition, cell=atoms.cell, pbc=atoms.pbc)) # It's important to pass atoms.cell and atoms.pbc here
    indices_map.append(current_indices_map)

In [6]:
reconstructed_atoms = []
for atom_index in range(len(atoms)):
    reconstructed_atoms.append(atoms[atom_index])
reconstructed_atoms = Atoms(reconstructed_atoms, cell=atoms.cell, pbc=atoms.pbc)

reconstructed_atoms

Atoms(symbols='C880H2208Ga64S64Si192', pbc=True, cell=[[23.096664428710938, 13.334883689880371, 23.9624080657959], [-23.09648895263672, 13.334826469421387, 23.962270736694336], [3.831999856629409e-05, -26.669597625732422, 23.962278366088867]])

In [7]:
from mattersim.forcefield.potential import Potential
from ase.units import GPa

class PartitionPotential(Potential):
    def forward(
        self,
        input: dict[str, torch.Tensor],
        include_forces: bool = True,
        include_stresses: bool = True,
        dataset_idx: int = -1,
    ):
        output = {}
        if self.model_name == "graphormer" or self.model_name == "geomformer":
            raise NotImplementedError
        else:
            strain = torch.zeros_like(input["cell"], device=self.device)
            volume = torch.linalg.det(input["cell"])
            if include_forces is True:
                input["atom_pos"].requires_grad_(True)
            if include_stresses is True:
                strain.requires_grad_(True)
                input["cell"] = torch.matmul(
                    input["cell"],
                    (torch.eye(3, device=self.device)[None, ...] + strain),
                )
                strain_augment = torch.repeat_interleave(
                    strain, input["num_atoms"], dim=0
                )
                input["atom_pos"] = torch.einsum(
                    "bi, bij -> bj",
                    input["atom_pos"],
                    (torch.eye(3, device=self.device)[None, ...] + strain_augment),
                )
                volume = torch.linalg.det(input["cell"])

            energies, energies_i = self.model.forward(input, dataset_idx, return_energies_per_atom=True)
            output["energies"] = energies
            output["energies_i"] = energies_i

            # Only take first derivative if only force is required
            if include_forces is True and include_stresses is False:
                grad_outputs: list[torch.Tensor] = [
                    torch.ones_like(
                        energies,
                    )
                ]
                grad = torch.autograd.grad(
                    outputs=[
                        energies,
                    ],
                    inputs=[input["atom_pos"]],
                    grad_outputs=grad_outputs,
                    create_graph=self.model.training,
                )

                # Dump out gradient for forces
                force_grad = grad[0]
                if force_grad is not None:
                    forces = torch.neg(force_grad)
                    output["forces"] = forces

            # Take derivatives up to second order
            # if both forces and stresses are required
            if include_forces is True and include_stresses is True:
                grad_outputs: list[torch.Tensor] = [
                    torch.ones_like(
                        energies,
                    )
                ]
                grad = torch.autograd.grad(
                    outputs=[
                        energies,
                    ],
                    inputs=[input["atom_pos"], strain],
                    grad_outputs=grad_outputs,
                    create_graph=self.model.training,
                )

                # Dump out gradient for forces and stresses
                force_grad = grad[0]
                stress_grad = grad[1]

                if force_grad is not None:
                    forces = torch.neg(force_grad)
                    output["forces"] = forces

                if stress_grad is not None:
                    stresses = (
                        1 / volume[:, None, None] * stress_grad / GPa
                    )  # 1/GPa = 160.21766208
                    output["stresses"] = stresses

        return output

In [8]:
potential = PartitionPotential.from_checkpoint(load_training_state=False)
potential = potential.to(device)

[32m2025-02-26 14:30:58.941[0m | [1mINFO    [0m | [36mmattersim.forcefield.potential[0m:[36mfrom_checkpoint[0m:[36m884[0m - [1mLoading the pre-trained mattersim-v1.0.0-1M.pth model[0m


In [9]:
from mattersim.forcefield.potential import batch_to_dict
from torch_geometric.loader import DataLoader

dataloader = DataLoader([converter.convert(part.copy(), None, None, None) for part in partitioned_atoms], batch_size=1)

energies_parts = np.zeros(len(atoms))
forces_parts = np.zeros((len(atoms), 3))
stress_parts = np.zeros((len(partitions), 3, 3))

for part_idx, input_graph in tqdm(enumerate(dataloader), total=num_partitions):
    input_graph = input_graph.to(device)
    input_dict = batch_to_dict(input_graph)

    output = potential(input_dict, include_forces=True, include_stresses=True)

    energies_i = output["energies_i"].detach().cpu().numpy()
    forces = output["forces"].detach().cpu().numpy()
    stress = output["stresses"].detach().cpu().numpy()
    
    part = partitioned_atoms[part_idx]
    for j in range(len(part)):
        original_index = indices_map[part_idx][j]
        if original_index in partitions[part_idx]:
            energies_parts[original_index] = energies_i[j]
            forces_parts[original_index] = forces[j]
    stress_parts[part_idx] = stress
    
    torch.cuda.empty_cache()

100%|██████████| 20/20 [00:09<00:00,  2.02it/s]


In [10]:
from mattersim.forcefield.potential import batch_to_dict
from torch_geometric.loader import DataLoader

dataloader = DataLoader([converter.convert(atoms.copy(), None, None, None)], batch_size=1)

energy = 0
forces = np.zeros((len(atoms), 3))
stress = np.zeros((3, 3))

for input_graph in dataloader:
    input_graph = input_graph.to(device)
    input_dict = batch_to_dict(input_graph)

    output = potential(input_dict, include_forces=True, include_stresses=True)

    energy = output["energies"].detach().cpu().numpy()
    forces = output["forces"].detach().cpu().numpy()
    stress = output["stresses"].detach().cpu().numpy().reshape(3, 3)
    
    energy = energy[0]
    forces = forces.reshape(-1, 3)

In [11]:
energy_part = np.sum(energies_parts)
stress_part = np.sum(stress_parts, axis=0)

energy_mae = torch.nn.L1Loss()(torch.tensor(energy_part)/len(atoms), torch.tensor(energy)/len(atoms)).item()
forces_mae = torch.nn.L1Loss()(torch.tensor(forces_parts), torch.tensor(forces)).item()
stress_mae = torch.nn.L1Loss()(torch.tensor(stress_part), torch.tensor(stress)).item()

print(f"Energy MAE: {energy_mae} eV/atom")
print(f"Forces MAE: {forces_mae} eV/Å")
print(f"Stress MAE: {stress_mae} GPa")

Energy MAE: 1.0468380551742484e-05 eV/atom
Forces MAE: 8.11114269895377e-06 eV/Å
Stress MAE: 9.01010038473113 GPa
