# Extended Subgraph Partitioning Mattertune
Now I will use the partitioning algorithm for inference using Mattertune

In [42]:
import numpy as np
import torch

from tqdm import tqdm

### Extended Partitioning

In [43]:
import networkx as nx
import metis

from collections import deque

def part_graph_extended(G, desired_partitions, distance=None):
    def descendants_at_distance_multisource(G, sources, distance=None):
        if sources in G:
            sources = [sources]

        queue = deque(sources)
        depths = deque([0 for _ in queue])
        visited = set(sources)

        for source in queue:
            if source not in G:
                raise nx.NetworkXError(f"The node {source} is not in the graph.")

        while queue:
            node = queue[0]
            depth = depths[0]

            if distance is not None and depth > distance: return

            yield queue[0]

            queue.popleft()
            depths.popleft()

            for child in G[node]:
                if child not in visited:
                    visited.add(child)
                    queue.append(child)
                    depths.append(depth + 1)

    _, parts = metis.part_graph(G, desired_partitions, objtype="cut")
    partition_map = {node: parts[i] for i, node in enumerate(G.nodes())}
    num_partitions = desired_partitions

    # Find indices of nodes in each partition
    partitions = [set() for _ in range(desired_partitions)]

    for i, node in enumerate(G.nodes()):
        partitions[partition_map[i]].add(node)

    # Find boundary nodes (vertices adjacent to vertex not in partition)
    boundary_nodes = [set(map(lambda uv: uv[0], nx.edge_boundary(G, partitions[i]))) for i in range(num_partitions)]

    # Perform BFS on boundary_nodes to find extended neighbors up to a certain distance
    extended_neighbors = [set(descendants_at_distance_multisource(G, boundary_nodes[i], distance=distance)) for i in range(num_partitions)]

    extended_partitions = [p.union(a) for p, a in zip(partitions, extended_neighbors)]

    return partitions, extended_partitions

## Partitioning Atoms

Loading a sample atomic dataset and converting it into a graph

In [44]:
from ase.io import read
from orb_models.forcefield.atomic_system import ase_atoms_to_atom_graphs
from ase.build import make_supercell
 
atoms = read("datasets/test.xyz")
atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])

# Instead of using neighborlist, I use the ase_atoms_to_atom_graphs provided by orb. Hopefully this will provide better results
atom_graph = ase_atoms_to_atom_graphs(atoms) # Keep this to compare results later

senders = atom_graph.senders
receivers = atom_graph.receivers
edge_feats = atom_graph.edge_features

G = nx.Graph()
G.add_nodes_from(range(len(atoms)))

for i, u in enumerate(senders):
    G.add_edge(u.item(), receivers[i].item(), weight=edge_feats['r'])

print("Number of atoms", len(atoms))

Number of atoms 3408


Partition the computational graph into the number of desired partitions with the specified neighborhood distance

In [45]:
desired_partitions = 20
neighborhood_distance = 3
partitions, extended_partitions = part_graph_extended(G, desired_partitions, neighborhood_distance)

num_partitions = len(partitions)

Create the ASE atoms object for each partition

In [46]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition, cell=atoms.cell, pbc=atoms.pbc))
    indices_map.append(current_indices_map)


In [47]:
reconstructed_atoms = []
for atom_index in range(len(atoms)):
    reconstructed_atoms.append(atoms[atom_index])
reconstructed_atoms = Atoms(reconstructed_atoms, cell=atoms.cell, pbc=atoms.pbc)

reconstructed_atoms

Atoms(symbols='C880H2208Ga64S64Si192', pbc=True, cell=[[23.096664428710938, 13.334883689880371, 23.9624080657959], [-23.09648895263672, 13.334826469421387, 23.962270736694336], [3.831999856629409e-05, -26.669597625732422, 23.962278366088867]])

## Inference

In [48]:
from orb_models.forcefield import atomic_system, pretrained
from orb_models.forcefield import segment_ops

Load model

In [49]:
device = "cpu"  # or device="cuda"

orbff = pretrained.orb_v2(device=device)

  state_dict = torch.load(local_path, map_location="cpu")


Run inference on each partition

In [50]:
aggregated_features = torch.zeros((len(atoms), 256), dtype=torch.float32, device=device)

for i, part in tqdm(enumerate(partitioned_atoms), total=num_partitions):
    input_graph = atomic_system.ase_atoms_to_atom_graphs(part)

    batch = orbff.model(input_graph)

    feat = batch.node_features["feat"]

    for j, node in enumerate(part):
        original_index = indices_map[i][j]
        if original_index in partitions[i]: # If the node is a root node of the partition
            aggregated_features[original_index] = feat[j]

aggregated_features

100%|██████████| 20/20 [01:51<00:00,  5.59s/it]


tensor([[-0.3705,  0.6076, -0.2275,  ..., -0.1806,  0.5268,  0.2838],
        [-0.1514,  0.8388,  0.1421,  ..., -0.0490,  0.2647, -0.3448],
        [-0.5362,  0.6659, -0.2096,  ...,  0.0278,  0.3660, -0.8769],
        ...,
        [ 0.4537, -0.0585,  0.0067,  ...,  0.0984, -0.1770, -0.0481],
        [ 0.7884,  0.0315,  0.2478,  ..., -0.0237, -0.0325, -0.4072],
        [ 0.8343, -0.0955,  0.0849,  ...,  0.1475, -0.5810, -0.3088]])

## Prediction

Run the prediction on the original, unpartitioned graph to obtain a benchmark for our results

In [60]:
result = orbff.predict(atom_graph)
benchmark_energy = result["graph_pred"]
benchmark_forces = result["node_pred"]
benchmark_stress = result["stress_pred"]

In [52]:
from orb_models.forcefield.graph_regressor import ScalarNormalizer, LinearReferenceEnergy
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

ref = REFERENCE_ENERGIES["vasp-shifted"]
reference = LinearReferenceEnergy(
    weight_init=ref.coefficients, trainable=True
)

n_node = torch.tensor([aggregated_features.shape[0]])

### Energy

In [53]:
input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean"
)

energy = orbff.graph_head.mlp(input)
energy = orbff.graph_head.normalizer.inverse(energy).squeeze(-1)
energy = energy * n_node
energy = energy + reference(atom_graph.atomic_numbers, atom_graph.n_node)
energy

tensor([-17688.4375], grad_fn=<AddBackward0>)

In [54]:
print(f"Absolute error: {torch.abs(benchmark_energy - energy).item()}")
print(f"Percent error: {torch.abs((benchmark_energy - energy) / benchmark_energy).item() * 100}%")

Absolute error: 0.20703125
Percent error: 0.0011704463759087957%


### Forces

In [55]:
forces = orbff.node_head.mlp(aggregated_features)
system_means = segment_ops.aggregate_nodes(
    forces, n_node, reduction="mean"
)
node_broadcasted_means = torch.repeat_interleave(
    system_means, n_node, dim=0
)
forces = forces - node_broadcasted_means
forces = orbff.node_head.normalizer.inverse(forces)
forces

tensor([[ 0.0061, -0.0097,  0.0199],
        [-0.0710, -0.0580, -0.0067],
        [-0.0459, -0.0113,  0.0353],
        ...,
        [ 0.0153, -0.1068,  0.0393],
        [-0.1074,  0.0386,  0.0392],
        [ 0.0803,  0.0699,  0.0097]])

In [56]:
mae = torch.mean(torch.abs(benchmark_forces - forces))
mape = 100 * torch.mean(torch.abs((benchmark_forces - forces) / benchmark_forces))

print(f"Mean absolute error: {mae.item()}")
print(f"Mean absolute percent error: {mape}%")

Mean absolute error: 0.0007215619552880526
Mean absolute percent error: 2.961912155151367%


### Stress

In [64]:
input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean",
)
stress = orbff.stress_head.mlp(input)
stress = stress.squeeze(-1)
stress = orbff.stress_head.output_activation(stress)
stress = orbff.stress_head.normalizer.inverse(stress)
stress

tensor([[ 1.1176e-02,  1.1250e-02,  1.2496e-02,  8.0223e-06, -1.8117e-06,
         -2.5154e-06]])

In [66]:
mae = torch.mean(torch.abs(benchmark_stress - stress))
mape = 100 * torch.mean(torch.abs((benchmark_stress - stress) / benchmark_stress))

print(f"Mean absolute error: {mae.item()}")
print(f"Mean absolute percent error: {mape}%")

Mean absolute error: 4.527105193119496e-05
Mean absolute percent error: 0.839568018913269%
