# Extended Subgraph Partitioning Mattertune
Now I will use the partitioning algorithm for inference using Mattertune

## Functions

### Extended Partitioning

In [14]:
import networkx as nx
import metis

from collections import deque

def part_graph_extended(G, desired_partitions, distance=None):
    def descendants_at_distance_multisource(G, sources, distance=None):
        if sources in G:
            sources = [sources]

        queue = deque(sources)
        depths = deque([0 for _ in queue])
        visited = set(sources)

        for source in queue:
            if source not in G:
                raise nx.NetworkXError(f"The node {source} is not in the graph.")

        while queue:
            node = queue[0]
            depth = depths[0]

            if distance is not None and depth > distance: return

            yield queue[0]

            queue.popleft()
            depths.popleft()

            for child in G[node]:
                if child not in visited:
                    visited.add(child)
                    queue.append(child)
                    depths.append(depth + 1)

    _, parts = metis.part_graph(G, desired_partitions, objtype="cut")
    partition_map = {node: parts[i] for i, node in enumerate(G.nodes())}
    num_partitions = desired_partitions

    # Find indices of nodes in each partition
    partitions = [set() for _ in range(desired_partitions)]

    for i, node in enumerate(G.nodes()):
        partitions[partition_map[i]].add(node)

    # Find boundary nodes (vertices adjacent to vertex not in partition)
    boundary_nodes = [set(map(lambda uv: uv[0], nx.edge_boundary(G, partitions[i]))) for i in range(num_partitions)]

    # Perform BFS on boundary_nodes to find extended neighbors up to a certain distance
    extended_neighbors = [set(descendants_at_distance_multisource(G, boundary_nodes[i], distance=distance)) for i in range(num_partitions)]

    extended_partitions = [p.union(a) for p, a in zip(partitions, extended_neighbors)]

    return partitions, extended_partitions

## Partitioning Atoms

Loading a sample atomic dataset and converting it into a graph

In [15]:
if False:
    from mattertune.data.mptraj import MPTrajDatasetConfig, MPTrajDataset
    from ase import io

    mptraj = MPTrajDataset(MPTrajDatasetConfig(min_num_atoms=300,split="val"))
    atoms_list = mptraj.atoms_list
    atoms_list.sort(key=lambda x: len(x))
    io.write("test.xyz", atoms_list[-1])

In [75]:
from ase.io import read
from orb_models.forcefield.atomic_system import ase_atoms_to_atom_graphs
from ase.build import make_supercell
 
atoms = read("test.xyz")
atoms = make_supercell(atoms, [[3, 0, 0], [0, 3, 0], [0, 0, 3]])

# Instead of using neighborlist, I use the ase_atoms_to_atom_graphs provided by orb. Hopefully this will provide better results
atom_graph = ase_atoms_to_atom_graphs(atoms)
senders = atom_graph.senders
receivers = atom_graph.receivers
edge_feats = atom_graph.edge_features

G = nx.Graph()
G.add_nodes_from(range(len(atoms)))

for i, u in enumerate(senders):
    G.add_edge(u.item(), receivers[i].item(), weight=edge_feats['r'])

# layout = nx.random_layout(G, seed=1)
# nx.draw(G, pos=layout, with_labels=True)

print("Number of atoms", len(atoms))
print("Diameter of graph:", nx.diameter(G))

Number of atoms 11502


KeyboardInterrupt: 

Performing the partitioning

In [69]:
desired_partitions = 20colors = cm.get_cmap("Accent", lut=num_partitions)
# color_map = [None for _ in range(len(G.nodes()))]
# for i, part in enumerate(partitions):
#     for u in part:
#         color_map[u] = colors(i)
# nx.draw(G, pos=layout, with_labels=True, node_color=color_map)
neighborhood_distance = 4
partitions, extended_partitions = part_graph_extended(G, desired_partitions, neighborhood_distance)

num_partitions = len(partitions)

# Visualization
import matplotlib.cm as cm

# colors = cm.get_cmap("Accent", lut=num_partitions)
# color_map = [None for _ in range(len(G.nodes()))]
# for i, part in enumerate(partitions):
#     for u in part:
#         color_map[u] = colors(i)
# nx.draw(G, pos=layout, with_labels=True, node_color=color_map)

for x in extended_partitions:
    print(len(x))

2867
2670
2781
2808
2775
2640
2827
2653
2813
2873
2736
2703
2693
2597
2785
2666
2651
2815
2624
2908


Create the ASE atoms object for each partition

In [70]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition))
    indices_map.append(current_indices_map)

partitioned_atoms


[Atoms(symbols='C752H1850Ga52S53Si160', pbc=False),
 Atoms(symbols='C700H1714Ga52S50Si154', pbc=False),
 Atoms(symbols='C728H1787Ga53S52Si161', pbc=False),
 Atoms(symbols='C732H1806Ga55S56Si159', pbc=False),
 Atoms(symbols='C717H1791Ga54S54Si159', pbc=False),
 Atoms(symbols='C694H1690Ga51S49Si156', pbc=False),
 Atoms(symbols='C735H1824Ga53S55Si160', pbc=False),
 Atoms(symbols='C695H1707Ga52S50Si149', pbc=False),
 Atoms(symbols='C732H1812Ga55S58Si156', pbc=False),
 Atoms(symbols='C745H1858Ga54S56Si160', pbc=False),
 Atoms(symbols='C709H1769Ga55S54Si149', pbc=False),
 Atoms(symbols='C704H1745Ga52S53Si149', pbc=False),
 Atoms(symbols='C706H1743Ga44S45Si155', pbc=False),
 Atoms(symbols='C680H1668Ga51S49Si149', pbc=False),
 Atoms(symbols='C726H1798Ga52S52Si157', pbc=False),
 Atoms(symbols='C699H1712Ga53S51Si151', pbc=False),
 Atoms(symbols='C691H1706Ga52S53Si149', pbc=False),
 Atoms(symbols='C735H1821Ga52S52Si155', pbc=False),
 Atoms(symbols='C687H1681Ga51S49Si156', pbc=False),
 Atoms(symbo

In [71]:
whole_graph = []
for atom_index in range(len(atoms)):
    whole_graph.append(atoms[atom_index])
whole_graph = Atoms(whole_graph)

whole_graph

Atoms(symbols='C880H2208Ga64S64Si192', pbc=False)

## Inference

In [51]:
import ase
from ase.build import bulk

from orb_models.forcefield import atomic_system, pretrained
from orb_models.forcefield.base import batch_graphs

import torch

Load model

In [52]:
device = "cpu"  # or device="cuda"
orbff = pretrained.orb_v2(device=device)

  state_dict = torch.load(local_path, map_location="cpu")


Run inference on each partition

In [72]:
forces_from_partition = [None for _ in range(len(atoms))]

for i, part in enumerate(partitioned_atoms):
    input_graph = atomic_system.ase_atoms_to_atom_graphs(part)
    result = orbff.predict(input_graph)
    for j, node in enumerate(part):
        original_index = indices_map[i][j]
        if original_index in partitions[i]: # If the node is a root node of the partition
            forces_from_partition[original_index] = result["node_pred"][j]

forces_from_partition = torch.stack(forces_from_partition)

Run inference on the whole, original atom graph

In [73]:
input_graph = atomic_system.ase_atoms_to_atom_graphs(whole_graph)
result = orbff.predict(input_graph)
forces_from_original = result["node_pred"]

In [79]:
mse = torch.mean((forces_from_partition - forces_from_original) ** 2)
mae = torch.mean(abs(forces_from_partition - forces_from_original))
mape = 100 * torch.mean(abs(forces_from_partition - forces_from_original) / forces_from_original)
mse, mae, mape

(tensor(7.8128e-07), tensor(0.0007), tensor(0.7463))