# Extended Subgraph Partitioning (mattersim)

In [107]:
# Papermill parameters

supercell_size = 3
supercell_scaling = [[supercell_size, 0, 0], [0, supercell_size, 0], [0, 0, supercell_size]]
desired_partitions = 20
num_message_passing = 3

In [108]:
# Imports

import scrapbook as sb

import numpy as np
import torch

from tqdm import tqdm

from partitioner import part_graph_extended
import networkx as nx

## Partitioning Atoms

Loading a sample atomic dataset and converting it into a graph

In [109]:
from ase.io import read
from ase.build import make_supercell
from mattersim.datasets.utils.convertor import GraphConvertor
from torch_geometric.utils import to_networkx
 
atoms = read("datasets/H2O.xyz")
atoms = make_supercell(atoms, supercell_scaling)

converter = GraphConvertor("m3gnet", 5.0, True, 4.0)

length = len(atoms)
atom_graph = converter.convert(atoms.copy(), None, None, None)

G = to_networkx(atom_graph)

sb.glue("num_atoms", len(atoms))
print("Number of atoms", len(atoms))

Number of atoms 426


Partition the computational graph into the number of desired partitions with the specified neighborhood distance

In [110]:
partitions, extended_partitions = part_graph_extended(G, desired_partitions, num_message_passing)

num_partitions = len(partitions)

print(f"Created {num_partitions} partitions")
print(f"Average partition size: {sum(len(x) for x in extended_partitions) / num_partitions}")

Created 20 partitions
Average partition size: 426.0


Create the ASE atoms object for each partition

In [111]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition, cell=atoms.cell, pbc=atoms.pbc)) # It's important to pass atoms.cell and atoms.pbc here
    indices_map.append(current_indices_map)

In [22]:
reconstructed_atoms = []
for atom_index in range(len(atoms)):
    reconstructed_atoms.append(atoms[atom_index])
reconstructed_atoms = Atoms(reconstructed_atoms, cell=atoms.cell, pbc=atoms.pbc)

reconstructed_atoms

Atoms(symbols='H128O64', pbc=True, cell=[8.65320864, 15.05202152, 14.13541336])

## Inference

In [113]:
device = "cuda" if torch.cuda.is_available() else "cpu"

I need to modify the M3Gnet model so that the node features can be intercepted before they are passed into the final MLP

In [114]:
from typing import Dict

import os

from mattersim.forcefield.m3gnet.m3gnet import M3Gnet
from torch_runstats.scatter import scatter
from mattersim.utils.download_utils import download_checkpoint

class M3GnetModified(M3Gnet):
    def forward(
        self,
        input: Dict[str, torch.Tensor],
        dataset_idx: int = -1,
    ) -> torch.Tensor:
        # Exact data from input_dictionary
        pos = input["atom_pos"]
        cell = input["cell"]
        pbc_offsets = input["pbc_offsets"].float()
        atom_attr = input["atom_attr"]
        edge_index = input["edge_index"].long()
        three_body_indices = input["three_body_indices"].long()
        num_three_body = input["num_three_body"]
        num_bonds = input["num_bonds"]
        num_triple_ij = input["num_triple_ij"]
        num_atoms = input["num_atoms"]
        num_graphs = input["num_graphs"]
        batch = input["batch"]

        # -------------------------------------------------------------#
        cumsum = torch.cumsum(num_bonds, dim=0) - num_bonds
        index_bias = torch.repeat_interleave(  # noqa: F501
            cumsum, num_three_body, dim=0
        ).unsqueeze(-1)
        three_body_indices = three_body_indices + index_bias

        # === Refer to the implementation of M3GNet,        ===
        # === we should re-compute the following attributes ===
        # edge_length, edge_vector(optional), triple_edge_length, theta_jik
        atoms_batch = torch.repeat_interleave(repeats=num_atoms)
        edge_batch = atoms_batch[edge_index[0]]
        edge_vector = pos[edge_index[0]] - (
            pos[edge_index[1]]
            + torch.einsum("bi, bij->bj", pbc_offsets, cell[edge_batch])
        )
        edge_length = torch.linalg.norm(edge_vector, dim=1)
        vij = edge_vector[three_body_indices[:, 0].clone()]
        vik = edge_vector[three_body_indices[:, 1].clone()]
        rij = edge_length[three_body_indices[:, 0].clone()]
        rik = edge_length[three_body_indices[:, 1].clone()]
        cos_jik = torch.sum(vij * vik, dim=1) / (rij * rik)
        # eps = 1e-7 avoid nan in torch.acos function
        cos_jik = torch.clamp(cos_jik, min=-1.0 + 1e-7, max=1.0 - 1e-7)
        triple_edge_length = rik.view(-1)
        edge_length = edge_length.unsqueeze(-1)
        atomic_numbers = atom_attr.squeeze(1).long()

        # featurize
        atom_attr = self.atom_embedding(self.one_hot_atoms(atomic_numbers))
        edge_attr = self.rbf(edge_length.view(-1))
        edge_attr_zero = edge_attr  # e_ij^0
        edge_attr = self.edge_encoder(edge_attr)
        three_basis = self.sbf(triple_edge_length, torch.acos(cos_jik))

        # Main Loop
        for idx, conv in enumerate(self.graph_conv):
            atom_attr, edge_attr = conv(
                atom_attr,
                edge_attr,
                edge_attr_zero,
                edge_index,
                three_basis,
                three_body_indices,
                edge_length,
                num_bonds,
                num_triple_ij,
                num_atoms,
            )

        return atom_attr  # [batch_size]
    
def load_modified_from_checkpoint(
    load_path: str = None,
    *,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    checkpoint_folder = os.path.expanduser("~/.local/mattersim/pretrained_models")
    os.makedirs(checkpoint_folder, exist_ok=True)
    if (
        load_path is None
        or load_path.lower() == "mattersim-v1.0.0-1m.pth"
        or load_path.lower() == "mattersim-v1.0.0-1m"
    ):
        load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-1M.pth")
        if not os.path.exists(load_path):
            print(
                "The pre-trained model is not found locally, "
                "attempting to download it from the server."
            )
            download_checkpoint(
                "mattersim-v1.0.0-1M.pth", save_folder=checkpoint_folder
            )
        print(f"Loading the pre-trained {os.path.basename(load_path)} model")
    elif (
        load_path.lower() == "mattersim-v1.0.0-5m.pth"
        or load_path.lower() == "mattersim-v1.0.0-5m"
    ):
        load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-5M.pth")
        if not os.path.exists(load_path):
            print(
                "The pre-trained model is not found locally, "
                "attempting to download it from the server."
            )
            download_checkpoint(
                "mattersim-v1.0.0-5M.pth", save_folder=checkpoint_folder
            )
        print(f"Loading the pre-trained {os.path.basename(load_path)} model")
    else:
        print("Loading the model from %s" % load_path)
    assert os.path.exists(load_path), f"Model file {load_path} not found"

    checkpoint = torch.load(load_path, map_location=device)

    model = M3GnetModified(device=device, **checkpoint["model_args"]).to(device)
    model.load_state_dict(checkpoint["model"], strict=False)

    model.eval()

    return model

In [25]:
model = load_modified_from_checkpoint(device=device)

Loading the pre-trained mattersim-v1.0.0-1M.pth model


  checkpoint = torch.load(load_path, map_location=device)


Run inference on each partition

In [26]:
from mattersim.forcefield.potential import batch_to_dict
from torch_geometric.loader import DataLoader

aggregated_atomic_numbers = torch.zeros((len(atoms), 1), dtype=torch.float32, device=device)
aggregated_features = torch.zeros((len(atoms), 128), dtype=torch.float32, device=device)

dataloader = DataLoader([converter.convert(part.copy(), None, None, None) for part in partitioned_atoms])

for part_idx, input_graph in tqdm(enumerate(dataloader), total=num_partitions):
    input_graph = input_graph.to(device)
    input_dict = batch_to_dict(input_graph)
    atomic_numbers = input_dict["atom_attr"]

    with torch.no_grad():
        feat = model.forward(input_dict)

    part = partitioned_atoms[part_idx]
    for j, node in enumerate(part):
        original_index = indices_map[part_idx][j]
        if original_index in partitions[part_idx]: # If the node is a root node of the partition
            aggregated_features[original_index] = feat[j]
            aggregated_atomic_numbers[original_index] = atomic_numbers[j]

    del input_graph, input_dict, atomic_numbers, feat
    torch.cuda.empty_cache()

100%|██████████| 20/20 [00:05<00:00,  3.77it/s]


In [117]:
atomic_numbers = aggregated_atomic_numbers.squeeze(1).long()
batch = torch.zeros((len(atoms)), dtype=torch.int64, device=device)

energy = model.final(aggregated_features).view(-1)
energy = model.normalizer(energy, atomic_numbers)
energy = scatter(energy, batch, dim=0, dim_size=1)

sb.glue("partition_energy", energy.item())
energy

tensor([-952.7580], grad_fn=<ScatterAddBackward0>)

## Prediction

Run the prediction on the original, unpartitioned graph to obtain a benchmark for our results

In [118]:
from mattersim.forcefield import MatterSimCalculator

atoms.calc = MatterSimCalculator(device=device)
benchmark_energy = atoms.get_potential_energy()
benchmark_forces = atoms.get_forces()

sb.glue("benchmark_energy", benchmark_energy.item())
benchmark_energy, benchmark_forces

(-952.758,
 array([[-1.66473910e-08, -1.85820535e-02, -4.08356488e-01],
        [-3.58969828e-07,  1.85802449e-02, -4.08347845e-01],
        [-7.40692485e-09, -4.23084080e-01,  1.11127883e-01],
        [-1.27110980e-07,  4.23082590e-01,  1.11129329e-01],
        [ 3.46074581e-01, -2.11748719e-01, -1.41425997e-01],
        [-3.46071154e-01, -2.11747557e-01, -1.41425401e-01],
        [-3.46075475e-01,  2.11747527e-01, -1.41425848e-01],
        [ 3.46075416e-01,  2.11747408e-01, -1.41425774e-01],
        [ 7.04312697e-09, -1.85810234e-02, -4.08356190e-01],
        [ 1.10594556e-08,  1.85800772e-02, -4.08348352e-01],
        [-5.51226549e-08, -4.23080891e-01,  1.11128017e-01],
        [ 1.34969014e-07,  4.23084199e-01,  1.11130156e-01],
        [ 3.46076488e-01, -2.11750761e-01, -1.41427398e-01],
        [-3.46075267e-01, -2.11750418e-01, -1.41427100e-01],
        [-3.46069723e-01,  2.11745262e-01, -1.41423717e-01],
        [ 3.46073300e-01,  2.11746454e-01, -1.41424656e-01],
        [ 1.4

### Energy

In [16]:
energy_error_abs = torch.abs(benchmark_energy - energy).item()
energy_error_pct = torch.abs((benchmark_energy - energy) / benchmark_energy).item() * 100
energy_error_max = torch.max(benchmark_energy - energy).item()

sb.glue("energy_error_abs", energy_error_abs)
sb.glue("energy_error_pct", energy_error_pct)
sb.glue("energy_error_max", energy_error_max)

print(f"Absolute error: {energy_error_abs}")
print(f"Percent error: {energy_error_pct}%")
print(f"Maximum error: {energy_error_max}")