# Extended Subgraph Partitioning Mattertune
Now I will use the partitioning algorithm for inference using Mattertune

In [19]:
import numpy as np
import torch

import ase

### Extended Partitioning

In [20]:
import networkx as nx
import metis

from collections import deque

def part_graph_extended(G, desired_partitions, distance=None):
    def descendants_at_distance_multisource(G, sources, distance=None):
        if sources in G:
            sources = [sources]

        queue = deque(sources)
        depths = deque([0 for _ in queue])
        visited = set(sources)

        for source in queue:
            if source not in G:
                raise nx.NetworkXError(f"The node {source} is not in the graph.")

        while queue:
            node = queue[0]
            depth = depths[0]

            if distance is not None and depth > distance: return

            yield queue[0]

            queue.popleft()
            depths.popleft()

            for child in G[node]:
                if child not in visited:
                    visited.add(child)
                    queue.append(child)
                    depths.append(depth + 1)

    _, parts = metis.part_graph(G, desired_partitions, objtype="cut")
    partition_map = {node: parts[i] for i, node in enumerate(G.nodes())}
    num_partitions = desired_partitions

    # Find indices of nodes in each partition
    partitions = [set() for _ in range(desired_partitions)]

    for i, node in enumerate(G.nodes()):
        partitions[partition_map[i]].add(node)

    # Find boundary nodes (vertices adjacent to vertex not in partition)
    boundary_nodes = [set(map(lambda uv: uv[0], nx.edge_boundary(G, partitions[i]))) for i in range(num_partitions)]

    # Perform BFS on boundary_nodes to find extended neighbors up to a certain distance
    extended_neighbors = [set(descendants_at_distance_multisource(G, boundary_nodes[i], distance=distance)) for i in range(num_partitions)]

    extended_partitions = [p.union(a) for p, a in zip(partitions, extended_neighbors)]

    return partitions, extended_partitions

## Partitioning Atoms

Loading a sample atomic dataset and converting it into a graph

In [21]:
from ase.io import read
from orb_models.forcefield.atomic_system import ase_atoms_to_atom_graphs
from ase.build import make_supercell
 
atoms = read("datasets/test.xyz")
atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])

# Instead of using neighborlist, I use the ase_atoms_to_atom_graphs provided by orb. Hopefully this will provide better results
atom_graph = ase_atoms_to_atom_graphs(atoms) # Keep this to compare results later

senders = atom_graph.senders
receivers = atom_graph.receivers
edge_feats = atom_graph.edge_features

G = nx.Graph()
G.add_nodes_from(range(len(atoms)))

for i, u in enumerate(senders):
    G.add_edge(u.item(), receivers[i].item(), weight=edge_feats['r'])

print("Number of atoms", len(atoms))

Number of atoms 3408


Partition the computational graph into the number of desired partitions with the specified neighborhood distance

In [22]:
desired_partitions = 2
neighborhood_distance = 10
partitions, extended_partitions = part_graph_extended(G, desired_partitions, neighborhood_distance)

num_partitions = len(partitions)

Create the ASE atoms object for each partition

In [23]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition, cell=atoms.cell, pbc=atoms.pbc))
    indices_map.append(current_indices_map)


In [24]:
reconstructed_atoms = []
for atom_index in range(len(atoms)):
    reconstructed_atoms.append(atoms[atom_index])
reconstructed_atoms = Atoms(reconstructed_atoms, cell=atoms.cell, pbc=atoms.pbc)

reconstructed_atoms

Atoms(symbols='C880H2208Ga64S64Si192', pbc=True, cell=[[23.096664428710938, 13.334883689880371, 23.9624080657959], [-23.09648895263672, 13.334826469421387, 23.962270736694336], [3.831999856629409e-05, -26.669597625732422, 23.962278366088867]])

## Inference

In [25]:
from orb_models.forcefield import atomic_system, pretrained
from orb_models.forcefield import segment_ops

Load model

In [26]:
device = "cpu"  # or device="cuda"

orbff = pretrained.orb_v2(device=device)

  state_dict = torch.load(local_path, map_location="cpu")


Run inference on each partition

In [27]:
aggregated_features = np.zeros((len(atoms), 256), dtype=np.float32)

for i, part in enumerate(partitioned_atoms):
    input_graph = atomic_system.ase_atoms_to_atom_graphs(part)

    batch = orbff.model(input_graph)

    feat = batch.node_features["feat"]

    for j, node in enumerate(part):
        original_index = indices_map[i][j]
        if original_index in partitions[i]: # If the node is a root node of the partition
            aggregated_features[original_index] = feat[j]

aggregated_features = torch.from_numpy(aggregated_features)
aggregated_features

  aggregated_features[original_index] = feat[j]


tensor([[-0.3629,  0.6121, -0.2159,  ..., -0.1818,  0.5214,  0.2927],
        [-0.1475,  0.8231,  0.1479,  ..., -0.0515,  0.2717, -0.3499],
        [-0.5161,  0.6599, -0.2023,  ...,  0.0230,  0.3217, -0.8838],
        ...,
        [ 0.4505, -0.0583,  0.0022,  ...,  0.0957, -0.1776, -0.0512],
        [ 0.7821,  0.0331,  0.2452,  ..., -0.0250, -0.0294, -0.4117],
        [ 0.8261, -0.0955,  0.0826,  ...,  0.1450, -0.5848, -0.3101]])

In [28]:
# Build the MLP
from orb_models.forcefield.nn_util import build_mlp
from orb_models.forcefield.graph_regressor import ScalarNormalizer, LinearReferenceEnergy
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

normalizer = ScalarNormalizer()

ref = REFERENCE_ENERGIES["vasp-shifted"]
reference = LinearReferenceEnergy(
    weight_init=ref.coefficients, trainable=True
)

n_node = torch.tensor([aggregated_features.shape[0]])

mlp = build_mlp(
    input_size=256,
    hidden_layer_sizes=[256] * 1,
    output_size=1,
)

input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean"
)

pred = orbff.graph_head.mlp(input)
pred = orbff.graph_head.normalizer.inverse(pred).squeeze(-1)
pred = pred * n_node
pred = pred + reference(atom_graph.atomic_numbers, atom_graph.n_node)
pred

tensor([-17688.2305], grad_fn=<AddBackward0>)

In [32]:
result = orbff.predict(atom_graph)
print(result)

{'graph_pred': tensor([-17688.2305]), 'stress_pred': tensor([[ 1.1084e-02,  1.1161e-02,  1.2405e-02,  8.0175e-06, -1.8519e-06,
         -2.5269e-06]]), 'node_pred': tensor([[ 0.0349, -0.0312,  0.0005],
        [-0.0958, -0.0764, -0.0108],
        [-0.0572,  0.0509,  0.0085],
        ...,
        [ 0.0123, -0.1079,  0.0407],
        [-0.1089,  0.0325,  0.0415],
        [ 0.0791,  0.0703,  0.0156]])}


In [30]:
real_feats = batch.node_features["feat"]
torch.mean((aggregated_features - real_feats) / real_feats)

tensor(0.)

In [31]:
mse = torch.mean((forces_from_partition - forces_from_original) ** 2)
mae = torch.mean(abs(forces_from_partition - forces_from_original))
mape = 100 * torch.mean(abs(forces_from_partition - forces_from_original) / forces_from_original)
mse, mae, mape

NameError: name 'forces_from_partition' is not defined