In [None]:
from ase import Atoms
import pathlib
from openff.units import unit
import subprocess
import mace.calculators
from mace.calculators import mace_omol
import MDAnalysis as mda
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import linregress
from typing import Literal
import torch

plt.style.use("ggplot")

STRUCTURES_DIR = pathlib.Path("ci9b01171_si_001")
CALC = mace_omol("extra_large", device="cpu")

REF_ENERGIES_FILE = STRUCTURES_DIR / "reference_energies.txt"
EV_TO_KCALMOL = 23.0605

## Utility Fns

In [None]:
def read_remarks(pdbfile: pathlib.Path) -> dict[str, str]:
    """
    Get a dict of the REMARK lines in a PDB file.

    Parameters
    ----------
    pdbfile : str
        Path to the PDB file.

    Returns
    -------
    remarks : dict of str to str
        Dictionary mapping REMARK keys to their values.
    """
    remarks: dict[str, str] = {}
    with open(pdbfile) as f:
        for line in f:
            if line.startswith("REMARK"):
                parts = line.strip().split(maxsplit=2)
                if len(parts) == 3:
                    key = parts[1]
                    value = parts[2]
                    if "charge" in key:
                        value = int(value)
                    remarks[key] = value
    return remarks


def add_element_symbols(pdbfile: pathlib.Path) -> pathlib.Path:
    """
    Add element symbols to a PDB file using Open Babel.

    This function reads a PDB file, adds element symbols where missing,
    and writes the updated structure to a new PDB file.

    Parameters
    ----------
    pdbfile : pathlib.Path
        Path to the input PDB file.

    Returns
    -------
    new_pdbfile : pathlib.Path
        Path to the output PDB file with element symbols added.
    """
    new_pdbfile = pdbfile.with_name(pdbfile.stem + "_fixed.pdb")
    if not new_pdbfile.exists():
        subprocess.run(
            ["obabel", str(pdbfile), "-O", str(new_pdbfile), "--addelement"],
            check=True,
            stdout=subprocess.DEVNULL,  # suppress stdout
            stderr=subprocess.DEVNULL,  # suppress stderr
        )
    return new_pdbfile


def get_reference_energy(
    pdb_file: pathlib.Path, ref_energies_file: pathlib.Path
) -> unit.Quantity:
    """
    Retrieve the reference energy for a given PDB file from a reference energies file.

    Parameters
    ----------
    pdb_file : pathlib.Path
        Path to the PDB file for which to retrieve the reference energy.
    ref_energies_file : pathlib.Path
        Path to the file containing reference energies.

    Returns
    -------
    energy : unit.Quantity
        The reference energy in kilocalories per mole (kcal/mol).

    Raises
    ------
    ValueError
        If the PDB file is not found in the reference energies file.
    """
    # The reference file has energies like:
    # 01 10GS_01_bbn                                 -2.184
    # and the pdb_file is like 10GS_01.pdb.
    pdb_stem = pdb_file.stem  # e.g., "10GS_01"
    with open(ref_energies_file) as f:
        for line in f:
            if pdb_stem in line:
                parts = line.split()
                if len(parts) >= 2:
                    energy_value = float(parts[2])
                    return energy_value * unit.kilocalorie / unit.mole
    raise ValueError(
        f"Reference energy for {pdb_stem} not found in {ref_energies_file}"
    )


def mda_to_ase(mda_atoms: mda.core.groups.AtomGroup) -> Atoms:
    """Convert an MDAnalysis AtomGroup to an ASE Atoms object."""
    symbols = [atom.element for atom in mda_atoms]
    positions = mda_atoms.positions
    return Atoms(symbols=symbols, positions=positions)


def get_ase_atoms(pdb_file: pathlib.Path) -> dict[Literal["a", "b"], Atoms]:
    """
    Read a PDB file and split it into two chains, returning ASE Atoms objects for each molecule.

    Parameters
    ----------
    pdb_file : pathlib.Path
        Path to the PDB file.

    Returns
    -------
    chains : dict
        A dictionary with keys 'a' and 'b' and values as ASE Atoms objects for each molecule.
    """
    remark_info = read_remarks(pdb_file)
    # Remove the ":" if present
    selection_b = remark_info.get("selection_b").replace(":", "")
    u = mda.Universe(pdb_file)
    mol_a = u.select_atoms(f"not resname {selection_b}")
    mol_b = u.select_atoms(f"resname {selection_b}")

    return {"a": mda_to_ase(mol_a), "b": mda_to_ase(mol_b)}


def get_mlp_energy(
    atoms: Atoms, calc: mace.calculators.mace.MACECalculator, total_charge: int
) -> unit.Quantity:
    """
    Compute the energy for given ASE Atoms using a MACE model.

    Parameters
    ----------
    atoms : Atoms
        An ASE Atoms object representing the molecular structure.

    calc : mace.calculators.mace.MACECalculator
        An instance of a MACECalculator.

    total_charge : int
        The total charge of the system.

    Returns
    -------
    energy : unit.Quantity
        The computed energy in kilocalories per mole (kcal/mol).
    """
    # Add the information to the atoms object
    atoms.info["charge"] = total_charge

    # Assume spin multiplicity based on total charge
    atoms.info["spin_multiplicity"] = 1 if total_charge % 2 == 0 else 2

    # Set the calculator
    atoms.calc = calc

    # Return the energy in kcal/mol
    energy = atoms.get_potential_energy() * EV_TO_KCALMOL * unit.kilocalorie / unit.mole
    
    torch.cuda.empty_cache()

    return energy.to(unit.kilocalorie / unit.mole)


def get_mlp_interaction_energy(
    pdb_file: pathlib.Path, calc: mace.calculators.mace.MACECalculator
) -> unit.Quantity:
    """
    Compute the interaction energy using the MACE-OMOL model.

    Parameters
    ----------
    pdb_file : pathlib.Path
        Path to the PDB file for which to compute the interaction energy.

    calc : mace.calculators.mace.MACECalculator
        An instance of a MACECalculator.

    Returns
    -------
    interaction_energy : unit.Quantity
        The computed interaction energy in kilocalories per mole (kcal/mol).
    """
    # Get the fixed pdb file with element symbols
    fixed_pdb_file = add_element_symbols(pdb_file)

    # Get the ASE Atoms for each molecule
    mols = get_ase_atoms(fixed_pdb_file)

    # Get the energies for each molecule individually, and for the complex
    pdb_remarks = read_remarks(pdb_file)
    energies = {}

    for key in ["a", "b"]:
        energies[key] = get_mlp_energy(
            atoms=mols[key], calc=calc, total_charge=pdb_remarks[f"charge_{key}"]
        )

    energies["complex"] = get_mlp_energy(
        atoms=mols["a"] + mols["b"],
        calc=calc,
        total_charge=pdb_remarks["charge_a"] + pdb_remarks["charge_b"],
    )

    # Interaction energy = E(complex) - (E(a) + E(b))
    interaction_energy = energies["complex"] - (energies["a"] + energies["b"])

    del mols  # free memory

    return interaction_energy

## Calculate the energies and plot

In [None]:
# Get all the pdb files. Make sure to exclude anything with "fixed" in the name.
pdb_files = [f for f in STRUCTURES_DIR.glob("*.pdb") if "fixed" not in f.name]
ref_energies = [get_reference_energy(f, REF_ENERGIES_FILE) for f in pdb_files]
mlp_energies = [get_mlp_interaction_energy(f, CALC) for f in tqdm(pdb_files)]
# Clear GPU memory
torch.cuda.empty_cache()

In [None]:
# Plot the predicted vs reference energies (both should be in kcal/mol)
fig, ax = plt.subplots(figsize=(6, 6))
ref_energies_array = np.array(
    [e.m_as(unit.kilocalorie / unit.mole) for e in ref_energies]
)
mlp_energies_array = np.array(
    [e.m_as(unit.kilocalorie / unit.mole) for e in mlp_energies]
)
sns.scatterplot(x=ref_energies_array, y=mlp_energies_array, ax=ax)
max_energy = max(ref_energies_array.max(), mlp_energies_array.max())
min_energy = min(ref_energies_array.min(), mlp_energies_array.min())
ax.plot([min_energy, max_energy], [min_energy, max_energy], "k--", alpha=0.5)
ax.set_xlabel("Reference energy (kcal/mol)")
ax.set_ylabel("MLP energy (kcal/mol)")
ax.set_title("MACE-OMOL Interaction Energies vs Reference")

# Calculate RMSE and R^2 and show on the plot
rmse = np.sqrt(np.mean((ref_energies_array - mlp_energies_array) ** 2))
slope, intercept, r_value, p_value, std_err = linregress(
    ref_energies_array, mlp_energies_array
)
r_squared = r_value**2
ax.text(
    0.05,
    0.95,
    f"RMSE: {rmse:.2f} kcal/mol\nR²: {r_squared:.3f}",
    transform=ax.transAxes,
    verticalalignment="top",
    bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)