# Extended Subgraph Partitioning (mattersim)

In [1]:
import numpy as np
import torch

from tqdm import tqdm

from partitioner import part_graph_extended
import networkx as nx

## Partitioning Atoms

Loading a sample atomic dataset and converting it into a graph

In [2]:
from ase.io import read
from ase.build import make_supercell
from mattersim.datasets.utils.convertor import GraphConvertor
from torch_geometric.utils import to_networkx
 
atoms = read("datasets/test.xyz")
atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])

converter = GraphConvertor("m3gnet", 5.0, True, 4.0)

length = len(atoms)
atom_graph = converter.convert(atoms.copy(), None, None, None)

G = to_networkx(atom_graph)

print("Number of atoms", len(atoms))

Number of atoms 3408


Partition the computational graph into the number of desired partitions with the specified neighborhood distance

In [3]:
desired_partitions = 5
neighborhood_distance = 4
partitions, extended_partitions = part_graph_extended(G, desired_partitions, neighborhood_distance)

num_partitions = len(partitions)

print(f"Created {num_partitions} partitions")
print(f"Average partition size: {sum(len(x) for x in extended_partitions) / num_partitions}")

Created 5 partitions
Average partition size: 3408.0


Create the ASE atoms object for each partition

In [4]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition, cell=atoms.cell, pbc=atoms.pbc)) # It's important to pass atoms.cell and atoms.pbc here
    indices_map.append(current_indices_map)

In [5]:
reconstructed_atoms = []
for atom_index in range(len(atoms)):
    reconstructed_atoms.append(atoms[atom_index])
reconstructed_atoms = Atoms(reconstructed_atoms, cell=atoms.cell, pbc=atoms.pbc)

reconstructed_atoms

Atoms(symbols='C880H2208Ga64S64Si192', pbc=True, cell=[[23.096664428710938, 13.334883689880371, 23.9624080657959], [-23.09648895263672, 13.334826469421387, 23.962270736694336], [3.831999856629409e-05, -26.669597625732422, 23.962278366088867]])

## Inference

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

I need to modify the M3Gnet model so that the node features can be intercepted before they are passed into the final MLP

In [7]:
from typing import Dict

import os

from mattersim.forcefield.m3gnet.m3gnet import M3Gnet
from torch_runstats.scatter import scatter
from mattersim.utils.download_utils import download_checkpoint

class M3GnetModified(M3Gnet):
    def forward(
        self,
        input: Dict[str, torch.Tensor],
        dataset_idx: int = -1,
    ) -> torch.Tensor:
        # Exact data from input_dictionary
        pos = input["atom_pos"]
        cell = input["cell"]
        pbc_offsets = input["pbc_offsets"].float()
        atom_attr = input["atom_attr"]
        edge_index = input["edge_index"].long()
        three_body_indices = input["three_body_indices"].long()
        num_three_body = input["num_three_body"]
        num_bonds = input["num_bonds"]
        num_triple_ij = input["num_triple_ij"]
        num_atoms = input["num_atoms"]
        num_graphs = input["num_graphs"]
        batch = input["batch"]

        # -------------------------------------------------------------#
        cumsum = torch.cumsum(num_bonds, dim=0) - num_bonds
        index_bias = torch.repeat_interleave(  # noqa: F501
            cumsum, num_three_body, dim=0
        ).unsqueeze(-1)
        three_body_indices = three_body_indices + index_bias

        # === Refer to the implementation of M3GNet,        ===
        # === we should re-compute the following attributes ===
        # edge_length, edge_vector(optional), triple_edge_length, theta_jik
        atoms_batch = torch.repeat_interleave(repeats=num_atoms)
        edge_batch = atoms_batch[edge_index[0]]
        edge_vector = pos[edge_index[0]] - (
            pos[edge_index[1]]
            + torch.einsum("bi, bij->bj", pbc_offsets, cell[edge_batch])
        )
        edge_length = torch.linalg.norm(edge_vector, dim=1)
        vij = edge_vector[three_body_indices[:, 0].clone()]
        vik = edge_vector[three_body_indices[:, 1].clone()]
        rij = edge_length[three_body_indices[:, 0].clone()]
        rik = edge_length[three_body_indices[:, 1].clone()]
        cos_jik = torch.sum(vij * vik, dim=1) / (rij * rik)
        # eps = 1e-7 avoid nan in torch.acos function
        cos_jik = torch.clamp(cos_jik, min=-1.0 + 1e-7, max=1.0 - 1e-7)
        triple_edge_length = rik.view(-1)
        edge_length = edge_length.unsqueeze(-1)
        atomic_numbers = atom_attr.squeeze(1).long()

        # featurize
        atom_attr = self.atom_embedding(self.one_hot_atoms(atomic_numbers))
        edge_attr = self.rbf(edge_length.view(-1))
        edge_attr_zero = edge_attr  # e_ij^0
        edge_attr = self.edge_encoder(edge_attr)
        three_basis = self.sbf(triple_edge_length, torch.acos(cos_jik))

        # Main Loop
        for idx, conv in enumerate(self.graph_conv):
            atom_attr, edge_attr = conv(
                atom_attr,
                edge_attr,
                edge_attr_zero,
                edge_index,
                three_basis,
                three_body_indices,
                edge_length,
                num_bonds,
                num_triple_ij,
                num_atoms,
            )

        return atom_attr  # [batch_size]
    
def load_modified_from_checkpoint(
    load_path: str = None,
    *,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    checkpoint_folder = os.path.expanduser("~/.local/mattersim/pretrained_models")
    os.makedirs(checkpoint_folder, exist_ok=True)
    if (
        load_path is None
        or load_path.lower() == "mattersim-v1.0.0-1m.pth"
        or load_path.lower() == "mattersim-v1.0.0-1m"
    ):
        load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-1M.pth")
        if not os.path.exists(load_path):
            print(
                "The pre-trained model is not found locally, "
                "attempting to download it from the server."
            )
            download_checkpoint(
                "mattersim-v1.0.0-1M.pth", save_folder=checkpoint_folder
            )
        print(f"Loading the pre-trained {os.path.basename(load_path)} model")
    elif (
        load_path.lower() == "mattersim-v1.0.0-5m.pth"
        or load_path.lower() == "mattersim-v1.0.0-5m"
    ):
        load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-5M.pth")
        if not os.path.exists(load_path):
            print(
                "The pre-trained model is not found locally, "
                "attempting to download it from the server."
            )
            download_checkpoint(
                "mattersim-v1.0.0-5M.pth", save_folder=checkpoint_folder
            )
        print(f"Loading the pre-trained {os.path.basename(load_path)} model")
    else:
        print("Loading the model from %s" % load_path)
    assert os.path.exists(load_path), f"Model file {load_path} not found"

    checkpoint = torch.load(load_path, map_location=device)

    model = M3GnetModified(device=device, **checkpoint["model_args"]).to(device)
    model.load_state_dict(checkpoint["model"], strict=False)

    model.eval()

    return model

In [8]:
model = load_modified_from_checkpoint(device=device)

Loading the pre-trained mattersim-v1.0.0-1M.pth model


  checkpoint = torch.load(load_path, map_location=device)


Run inference on each partition

In [9]:
from mattersim.forcefield.potential import batch_to_dict
from torch_geometric.loader import DataLoader

aggregated_features = torch.zeros((len(atoms), 128), dtype=torch.float32, device=device)

dataloader = DataLoader([converter.convert(part.copy(), None, None, None) for part in partitioned_atoms])

for part_idx, input_graph in tqdm(enumerate(dataloader), total=num_partitions):
    input_graph = input_graph.to(device)
    input = batch_to_dict(input_graph)

    feat = model.forward(input)

    part = partitioned_atoms[part_idx]
    for j, node in enumerate(part):
        original_index = indices_map[part_idx][j]
        if original_index in partitions[part_idx]: # If the node is a root node of the partition
            aggregated_features[original_index] = feat[j]

aggregated_features

100%|██████████| 5/5 [00:01<00:00,  3.04it/s]


tensor([[ 0.0671,  0.1683,  0.1333,  ...,  0.1727,  0.0219,  0.0545],
        [ 0.0671,  0.1682,  0.1331,  ...,  0.1725,  0.0219,  0.0545],
        [ 0.0671,  0.1683,  0.1332,  ...,  0.1726,  0.0219,  0.0545],
        ...,
        [-0.0388, -0.3789, -0.0089,  ...,  0.0009, -0.0034,  0.1832],
        [-0.0388, -0.3789, -0.0089,  ...,  0.0009, -0.0034,  0.1832],
        [-0.0388, -0.3789, -0.0089,  ...,  0.0009, -0.0034,  0.1832]],
       device='cuda:0', grad_fn=<CopySlices>)

## Prediction

Run the prediction on the original, unpartitioned graph to obtain a benchmark for our results

In [10]:
result = orbff.predict(atom_graph)
benchmark_energy = result["graph_pred"]
benchmark_forces = result["node_pred"]
benchmark_stress = result["stress_pred"]

NameError: name 'orbff' is not defined

In [None]:
from orb_models.forcefield.graph_regressor import ScalarNormalizer, LinearReferenceEnergy
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

ref = REFERENCE_ENERGIES["vasp-shifted"]
reference = LinearReferenceEnergy(
    weight_init=ref.coefficients, trainable=True
)

n_node = torch.tensor([aggregated_features.shape[0]])

### Energy

In [None]:
input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean"
)

energy = orbff.graph_head.mlp(input)
energy = orbff.graph_head.normalizer.inverse(energy).squeeze(-1)
energy = energy * n_node
energy = energy + reference(atom_graph.atomic_numbers, atom_graph.n_node)
energy

In [None]:
print(f"Absolute error: {torch.abs(benchmark_energy - energy).item()}")
print(f"Percent error: {torch.abs((benchmark_energy - energy) / benchmark_energy).item() * 100}%")
print(f"Maximum error: {torch.max(benchmark_energy - energy).item()}")

### Forces

In [None]:
forces = orbff.node_head.mlp(aggregated_features)
system_means = segment_ops.aggregate_nodes(
    forces, n_node, reduction="mean"
)
node_broadcasted_means = torch.repeat_interleave(
    system_means, n_node, dim=0
)
forces = forces - node_broadcasted_means
forces = orbff.node_head.normalizer.inverse(forces)
forces

In [None]:
mae = torch.mean(torch.abs(benchmark_forces - forces))
mape = 100 * torch.mean(torch.abs((benchmark_forces - forces) / benchmark_forces))
max = torch.max(torch.abs(benchmark_forces - forces))

print(f"Mean absolute error: {mae.item()}")
print(f"Mean absolute percent error: {mape}%")
print(f"Maximum error: {max}")

### Stress

In [None]:
input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean",
)
stress = orbff.stress_head.mlp(input)
stress = stress.squeeze(-1)
stress = orbff.stress_head.output_activation(stress)
stress = orbff.stress_head.normalizer.inverse(stress)
stress

In [None]:
mae = torch.mean(torch.abs(benchmark_stress - stress))
mape = 100 * torch.mean(torch.abs((benchmark_stress - stress) / benchmark_stress))
max = torch.max(torch.abs(benchmark_stress - stress))
print(f"Mean absolute error: {mae.item()}")
print(f"Mean absolute percent error: {mape}%")
print(f"Max error: {max}")