## MEPIN inference example

### Imports

In [1]:
import ase
import ase.io
import numpy as np
import torch

from mepin.model.modules import TripleCrossPaiNNModule
from mepin.tools.frechet import frechet_distance
from mepin.tools.inference import create_reaction_batch

### Model loading

`use_geodesic` controls whether to use geodesic path initialization (MEPIN-G) or not (MEPIN-L).

In [2]:
use_geodesic = True
dataset = "t1x"  # or "cyclo"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = f"ckpt/{dataset}_G.ckpt" if use_geodesic else f"ckpt/{dataset}_L.ckpt"
model = TripleCrossPaiNNModule.load_from_checkpoint(ckpt)
model = model.eval().to(device)

### Inference

Path prediction for given reaction ID (`rxn_id`)

In [3]:
# Load the reactant and product atoms
rxn_id = "C2H2N2O_rxn3122"

reactant = ase.io.read(f"data/{dataset}_xtb/xyz/{rxn_id}_R.xyz")
product = ase.io.read(f"data/{dataset}_xtb/xyz/{rxn_id}_P.xyz")
interp_traj = ase.io.read(f"data/{dataset}_xtb/geodesic/{rxn_id}.xyz", ":")

# Predict the reaction path
batch = create_reaction_batch(
    reactant, product, interp_traj, use_geodesic=use_geodesic, num_images=101
).to(device)
with torch.no_grad():
    output_positions = model(batch).reshape(batch.num_graphs, -1, 3).cpu().numpy()

# Construct the trajectory and save it
trajectory = []
for i in range(output_positions.shape[0]):
    atoms = ase.Atoms(
        numbers=reactant.get_atomic_numbers(),
        positions=output_positions[i],
        cell=reactant.cell,
        pbc=reactant.pbc,
    )
    trajectory.append(atoms)
ase.io.write("predicted_path.xyz", trajectory)

Compare to the ground truth IRC path

In [None]:
# Load the IRC trajectory
irc_trajectory = ase.io.read(f"data/{dataset}_xtb/irc/{rxn_id}.xyz", ":")
irc_positions = np.array([atoms.get_positions() for atoms in irc_trajectory])

# Calculate the Frechet distance
dist = frechet_distance(output_positions, irc_positions)
print(f"Frechet distance vs. IRC: {dist:.2f} Å")

Frechet distance vs. IRC: 0.23 Å
