# Extended Subgraph Partitioning (orb-models)

In [6]:
import numpy as np
import torch

from tqdm import tqdm

from partitioner import part_graph_extended
import networkx as nx

## Partitioning Atoms

Loading a sample atomic dataset and converting it into a graph

In [7]:
from ase.io import read
from orb_models.forcefield.atomic_system import ase_atoms_to_atom_graphs
from ase.build import make_supercell
 
atoms = read("datasets/test.xyz")
atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])

# Instead of using neighborlist, I use the ase_atoms_to_atom_graphs provided by orb. Hopefully this will provide better results
atom_graph = ase_atoms_to_atom_graphs(atoms) # Keep this to compare results later

senders = atom_graph.senders
receivers = atom_graph.receivers
edge_feats = atom_graph.edge_features

G = nx.Graph()
G.add_nodes_from(range(len(atoms)))

for i, u in enumerate(senders):
    G.add_edge(u.item(), receivers[i].item(), weight=edge_feats['r'])

print("Number of atoms", len(atoms))

Number of atoms 3408


Partition the computational graph into the number of desired partitions with the specified neighborhood distance

In [8]:
desired_partitions = 20
neighborhood_distance = 4
partitions, extended_partitions = part_graph_extended(G, desired_partitions, neighborhood_distance)

num_partitions = len(partitions)

print(f"Created {num_partitions} partitions")
print(f"Average partition size: {sum(len(x) for x in extended_partitions) / num_partitions}")

Created 20 partitions
Average partition size: 2744.25


Create the ASE atoms object for each partition

In [9]:
from ase import Atoms

partitioned_atoms = []
indices_map = [] # Table mapping each atom in each partition back to its index in the original atoms object

for part in extended_partitions:
    current_partition = []
    current_indices_map = []
    for atom_index in part:
        current_partition.append(atoms[atom_index])
        current_indices_map.append(atoms[atom_index].index)

    partitioned_atoms.append(Atoms(current_partition, cell=atoms.cell, pbc=atoms.pbc)) # It's important to pass atoms.cell and atoms.pbc here
    indices_map.append(current_indices_map)

In [10]:
reconstructed_atoms = []
for atom_index in range(len(atoms)):
    reconstructed_atoms.append(atoms[atom_index])
reconstructed_atoms = Atoms(reconstructed_atoms, cell=atoms.cell, pbc=atoms.pbc)

reconstructed_atoms

Atoms(symbols='C880H2208Ga64S64Si192', pbc=True, cell=[[23.096664428710938, 13.334883689880371, 23.9624080657959], [-23.09648895263672, 13.334826469421387, 23.962270736694336], [3.831999856629409e-05, -26.669597625732422, 23.962278366088867]])

## Inference

In [11]:
from orb_models.forcefield import atomic_system, pretrained
from orb_models.forcefield import segment_ops

  from .autonotebook import tqdm as notebook_tqdm


Load model

In [12]:
from orb_models.forcefield.pretrained import get_base, load_model_for_inference
from orb_models.forcefield.graph_regressor import GraphRegressor, EnergyHead, NodeHead, GraphHead

device = "cpu"  # or device="cuda"

base = get_base(num_message_passing_steps=4)

model = GraphRegressor(
    graph_head=EnergyHead(
        latent_dim=256,
        num_mlp_layers=1,
        mlp_hidden_dim=256,
        target="energy",
        node_aggregation="mean",
        reference_energy_name="vasp-shifted",
        train_reference=True,
        predict_atom_avg=True,
    ),
    node_head=NodeHead(
        latent_dim=256,
        num_mlp_layers=1,
        mlp_hidden_dim=256,
        target="forces",
        remove_mean=True,
    ),
    stress_head=GraphHead(
        latent_dim=256,
        num_mlp_layers=1,
        mlp_hidden_dim=256,
        target="stress",
        compute_stress=True,
    ),
    model=base,
)

orbff = load_model_for_inference(model, 'https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v2-20241011.ckpt', device)

  state_dict = torch.load(local_path, map_location="cpu")


Run inference on each partition

In [13]:
aggregated_features = torch.zeros((len(atoms), 256), dtype=torch.float32, device=device)

for i, part in tqdm(enumerate(partitioned_atoms), total=num_partitions):
    input_graph = atomic_system.ase_atoms_to_atom_graphs(part)

    batch = orbff.model(input_graph)

    feat = batch.node_features["feat"]

    for j, node in enumerate(part):
        original_index = indices_map[i][j]
        if original_index in partitions[i]: # If the node is a root node of the partition
            aggregated_features[original_index] = feat[j]

aggregated_features

100%|██████████| 20/20 [00:44<00:00,  2.23s/it]


tensor([[-0.1757, -0.3925,  0.3263,  ...,  0.2231,  0.4786,  0.5496],
        [-0.0559, -0.3711,  0.3010,  ...,  0.2893,  0.1709,  0.5445],
        [-0.5957, -0.4528,  0.1756,  ...,  0.2427,  0.4110,  0.5453],
        ...,
        [-0.1440,  0.5298,  0.0155,  ..., -0.1544,  0.0446,  0.5463],
        [ 0.1561,  0.5317, -0.0433,  ..., -0.1388,  0.0721,  0.4435],
        [ 0.3817,  0.5695,  0.1271,  ...,  0.0142, -0.4501,  0.3296]])

## Prediction

Run the prediction on the original, unpartitioned graph to obtain a benchmark for our results

In [14]:
result = orbff.predict(atom_graph)
benchmark_energy = result["graph_pred"]
benchmark_forces = result["node_pred"]
benchmark_stress = result["stress_pred"]

In [15]:
from orb_models.forcefield.graph_regressor import ScalarNormalizer, LinearReferenceEnergy
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

ref = REFERENCE_ENERGIES["vasp-shifted"]
reference = LinearReferenceEnergy(
    weight_init=ref.coefficients, trainable=True
)

n_node = torch.tensor([aggregated_features.shape[0]])

### Energy

In [16]:
input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean"
)

energy = orbff.graph_head.mlp(input)
energy = orbff.graph_head.normalizer.inverse(energy).squeeze(-1)
energy = energy * n_node
energy = energy + reference(atom_graph.atomic_numbers, atom_graph.n_node)
energy

tensor([-17974.6777], grad_fn=<AddBackward0>)

In [17]:
print(f"Absolute error: {torch.abs(benchmark_energy - energy).item()}")
print(f"Percent error: {torch.abs((benchmark_energy - energy) / benchmark_energy).item() * 100}%")
print(f"Maximum error: {torch.max(benchmark_energy - energy).item()}")

Absolute error: 0.0
Percent error: 0.0%
Maximum error: 0.0


### Forces

In [18]:
forces = orbff.node_head.mlp(aggregated_features)
system_means = segment_ops.aggregate_nodes(
    forces, n_node, reduction="mean"
)
node_broadcasted_means = torch.repeat_interleave(
    system_means, n_node, dim=0
)
forces = forces - node_broadcasted_means
forces = orbff.node_head.normalizer.inverse(forces)
forces

tensor([[-0.0688,  0.0882,  0.2663],
        [ 0.0824,  0.1349,  0.2663],
        [ 0.0129, -0.0198,  0.2584],
        ...,
        [ 0.0539, -0.0196,  0.1435],
        [-0.0231,  0.0474,  0.1652],
        [ 0.0963,  0.0796,  0.1533]])

In [19]:
mae = torch.mean(torch.abs(benchmark_forces - forces))
mape = 100 * torch.mean(torch.abs((benchmark_forces - forces) / benchmark_forces))
max = torch.max(torch.abs(benchmark_forces - forces))

print(f"Mean absolute error: {mae.item()}")
print(f"Mean absolute percent error: {mape}%")
print(f"Maximum error: {max}")

Mean absolute error: 3.5829430089506786e-08
Mean absolute percent error: 0.00018414562509860843%
Maximum error: 3.501772880554199e-07


### Stress

In [20]:
input = segment_ops.aggregate_nodes(
    aggregated_features,
    n_node,
    reduction="mean",
)
stress = orbff.stress_head.mlp(input)
stress = stress.squeeze(-1)
stress = orbff.stress_head.output_activation(stress)
stress = orbff.stress_head.normalizer.inverse(stress)
stress

tensor([[-1.4797e-02, -1.7188e-02, -1.2462e-02,  2.4029e-05,  3.1426e-05,
          5.2786e-07]])

In [21]:
mae = torch.mean(torch.abs(benchmark_stress - stress))
mape = 100 * torch.mean(torch.abs((benchmark_stress - stress) / benchmark_stress))
max = torch.max(torch.abs(benchmark_stress - stress))
print(f"Mean absolute error: {mae.item()}")
print(f"Mean absolute percent error: {mape}%")
print(f"Max error: {max}")

Mean absolute error: 3.152914873627566e-10
Mean absolute percent error: 2.1992624169797637e-05%
Max error: 1.862645149230957e-09
