In [1]:
import MDAnalysis as mda
from pathlib import Path

def export_fixed_pdb(
    pdb_file: Path,
    exclude={"HOH", "WAT", "NA", "K", "CL", "MG", "CA"},
    charge_a=-1,
    charge_b=0
) -> Path:
    """
    Export a cleaned/fixed PDB that excludes solvent/ions and adds charge remarks.

    Parameters
    ----------
    pdb_file : Path
        Input PDB file.
    exclude : set[str]
        Residue names to exclude (e.g., solvents and ions).
    charge_a, charge_b : int
        Charges for molecule a (protein) and b (ligand).

    Returns
    -------
    fixed_pdb : Path
        Path to the cleaned PDB file.
    """
    u = mda.Universe(pdb_file)

    # Build a selection string excluding unwanted residues
    exclude_str = " ".join(exclude)
    sel = f"not resname {exclude_str}"

    # Select atoms excluding HOH, ions, etc.
    sel_atoms = u.select_atoms(sel)

    # Prepare output path
    fixed_pdb = pdb_file.with_name(pdb_file.stem + "_fixed.pdb")

    # Write new PDB
    sel_atoms.write(fixed_pdb)

    # Add REMARK lines for charge_a and charge_b
    with open(fixed_pdb, "r+") as f:
        content = f.read()
        f.seek(0, 0)
        f.write(f"REMARK charge_a {charge_a}\n")
        f.write(f"REMARK charge_b {charge_b}\n")
        f.write(content)

    print(f"Fixed PDB saved to: {fixed_pdb}")
    return fixed_pdb


In [2]:
from pathlib import Path

fixed = export_fixed_pdb(
    Path("6o6f/6o6f_complex.pdb"),
    exclude={"HOH", "WAT", "NA", "K", "CL", "MG", "CA"},
    charge_a=-1,
    charge_b=0
)


Fixed PDB saved to: 6o6f/6o6f_complex_fixed.pdb




In [3]:
from ase import Atoms
import pathlib
from openff.units import unit
import subprocess
import mace.calculators
from mace.calculators import mace_omol
from mace.calculators import mace_off
import MDAnalysis as mda
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import linregress
from typing import Literal
import torch

plt.style.use("ggplot")

STRUCTURES_DIR = pathlib.Path("6o6f")
# REF_ENERGIES_FILE = STRUCTURES_DIR / "reference_energies.txt"
EV_TO_KCALMOL = 23.0605

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [4]:
def read_remarks(pdbfile: pathlib.Path) -> dict[str, str]:
    """
    Get a dict of the REMARK lines in a PDB file.

    Parameters
    ----------
    pdbfile : str
        Path to the PDB file.

    Returns
    -------
    remarks : dict of str to str
        Dictionary mapping REMARK keys to their values.
    """
    remarks: dict[str, str] = {}
    with open(pdbfile) as f:
        for line in f:
            if line.startswith("REMARK"):
                parts = line.strip().split(maxsplit=2)
                if len(parts) == 3:
                    key = parts[1]
                    value = parts[2]
                    if "charge" in key:
                        value = int(value)
                    remarks[key] = value
    return remarks



def add_element_symbols(pdbfile: pathlib.Path) -> pathlib.Path:
    """
    Add element symbols to a PDB file using Open Babel.

    This function reads a PDB file, adds element symbols where missing,
    and writes the updated structure to a new PDB file.

    Parameters
    ----------
    pdbfile : pathlib.Path
        Path to the input PDB file.

    Returns
    -------
    new_pdbfile : pathlib.Path
        Path to the output PDB file with element symbols added.
    """
    new_pdbfile = pdbfile.with_name(pdbfile.stem + "_fixed.pdb")
    if not new_pdbfile.exists():
        subprocess.run(
            ["obabel", str(pdbfile), "-O", str(new_pdbfile), "--addelement"],
            check=True,
            stdout=subprocess.DEVNULL,  # suppress stdout
            stderr=subprocess.DEVNULL,  # suppress stderr
        )
    return new_pdbfile


def get_reference_energy(
    pdb_file: pathlib.Path, ref_energies_file: pathlib.Path
) -> unit.Quantity:
    """
    Retrieve the reference energy for a given PDB file from a reference energies file.

    Parameters
    ----------
    pdb_file : pathlib.Path
        Path to the PDB file for which to retrieve the reference energy.
    ref_energies_file : pathlib.Path
        Path to the file containing reference energies.

    Returns
    -------
    energy : unit.Quantity
        The reference energy in kilocalories per mole (kcal/mol).

    Raises
    ------
    ValueError
        If the PDB file is not found in the reference energies file.
    """
    # The reference file has energies like:
    # 01 10GS_01_bbn                                 -2.184
    # and the pdb_file is like 10GS_01.pdb.
    pdb_stem = pdb_file.stem  # e.g., "10GS_01"
    with open(ref_energies_file) as f:
        for line in f:
            if pdb_stem in line:
                parts = line.split()
                if len(parts) >= 2:
                    energy_value = float(parts[2])
                    return energy_value * unit.kilocalorie / unit.mole
    raise ValueError(
        f"Reference energy for {pdb_stem} not found in {ref_energies_file}"
    )


def mda_to_ase(mda_atoms: mda.core.groups.AtomGroup) -> Atoms:
    """Convert an MDAnalysis AtomGroup to an ASE Atoms object."""
    symbols = [atom.element for atom in mda_atoms]
    positions = mda_atoms.positions
    atoms = Atoms(symbols=symbols, positions=positions)
    return atoms


def get_ase_atoms(pdb_file: pathlib.Path) -> dict[Literal["a", "b"], Atoms]:
    """
    Read a PDB file and split it into two chains, returning ASE Atoms objects for each molecule.

    Parameters
    ----------
    pdb_file : pathlib.Path
        Path to the PDB file.

    Returns
    -------
    chains : dict
        A dictionary with keys 'a' and 'b' and values as ASE Atoms objects for each molecule.
    """
    remark_info = read_remarks(pdb_file)
    u = mda.Universe(pdb_file)
    resname_b = "LIG"
    print(resname_b)
    mol_a = u.select_atoms(f"not resname {resname_b}")
    mol_b = u.select_atoms(f"resname {resname_b}")

    return {"a": mda_to_ase(mol_a), "b": mda_to_ase(mol_b)}




def get_mlp_energy(
    atoms: Atoms, calc: mace.calculators.mace.MACECalculator, total_charge: int
) -> unit.Quantity:
    """
    Compute the energy for given ASE Atoms using a MACE model.

    Parameters
    ----------
    atoms : Atoms
        An ASE Atoms object representing the molecular structure.

    calc : mace.calculators.mace.MACECalculator
        An instance of a MACECalculator.

    total_charge : int
        The total charge of the system.

    Returns
    -------
    energy : unit.Quantity
        The computed energy in kilocalories per mole (kcal/mol).
    """
    # Add the information to the atoms object
    atoms.info["charge"] = total_charge

    # Assume spin multiplicity based on total charge
    atoms.info["spin_multiplicity"] = 1 if total_charge % 2 == 0 else 2

    # Set the calculator
    atoms.calc = calc

    # Return the energy in kcal/mol
    energy = atoms.get_potential_energy() * EV_TO_KCALMOL * unit.kilocalorie / unit.mole

    return energy.to(unit.kilocalorie / unit.mole)






def get_mlp_interaction_energy(
    pdb_file: pathlib.Path, calc: mace.calculators.mace.MACECalculator
) -> unit.Quantity:
    """
    Compute the interaction energy using the MACE-OMOL model.

    Parameters
    ----------
    pdb_file : pathlib.Path
        Path to the PDB file for which to compute the interaction energy.

    calc : mace.calculators.mace.MACECalculator
        An instance of a MACECalculator.

    Returns
    -------
    interaction_energy : unit.Quantity
        The computed interaction energy in kilocalories per mole (kcal/mol).
    """
    # Get the fixed pdb file with element symbols
    fixed_pdb_file = add_element_symbols(pdb_file)

    # Get the ASE Atoms for each molecule
    mols = get_ase_atoms(fixed_pdb_file)
    print(mols["b"].positions)

    # Get the energies for each molecule individually, and for the complex
    pdb_remarks = read_remarks(fixed_pdb_file)
    energies = {}

    for key in ["a", "b"]:
        energies[key] = get_mlp_energy(
            atoms=mols[key], calc=calc, total_charge=pdb_remarks[f"charge_{key}"]
        )

    energies["complex"] = get_mlp_energy(
        atoms=mols["a"] + mols["b"],
        calc=calc,
        total_charge=pdb_remarks["charge_a"] + pdb_remarks["charge_b"],
    )

    # Interaction energy = E(complex) - (E(a) + E(b))
    interaction_energy = energies["complex"] - (energies["a"] + energies["b"])

    return interaction_energy

In [5]:
torch.cuda.empty_cache()

In [None]:
# Get all the pdb files. Make sure to exclude anything with "fixed" in the name.
calc = mace_omol("extra_large", device="cpu"
                 )
# calc = mace_off("medium", device="cuda")
pdb_files = [f for f in STRUCTURES_DIR.glob("*.pdb") if "fixed" not in f.name]
# ref_energies = [get_reference_energy(f, REF_ENERGIES_FILE) for f in pdb_files]
mlp_energies = [get_mlp_interaction_energy(f, calc) for f in tqdm(pdb_files)]
mlp_energies

Using float64 for MACECalculator, recommended for geometry optimization.


  torch.load(f=model_path, map_location=device)


Using head omol out of ['omol']


  0%|          | 0/1 [00:00<?, ?it/s]

LIG
[[-34.71500015 -11.65200043 -16.63100052]
 [-22.93400002  -6.34000015 -20.04100037]
 [-24.33099937  -5.34000015 -21.4810009 ]
 [-28.00900078 -10.00199986 -19.11700058]
 [-28.83300018  -7.50500011 -20.46299934]
 [-28.91799927 -10.39299965 -20.13999939]
 [-29.22200012  -6.4289999  -21.36800003]
 [-29.93199921  -8.14000034 -19.75200081]
 [-30.08499908  -3.25900006 -20.08399963]
 [-32.19599915 -10.82199955 -16.96199989]
 [-25.78899956  -9.21700001 -19.07799911]
 [-30.01799965  -4.70499992 -19.49799919]
 [-29.63199997  -3.87400007 -21.44599915]
 [-24.77499962  -8.32800007 -19.43000031]
 [-31.09700012 -10.34200001 -17.69499969]
 [-32.47000122  -9.93500042 -21.11100006]
 [-33.61600113 -10.73499966 -18.96199989]
 [-26.43499947  -6.94700003 -20.57699966]
 [-30.96199989  -9.81799984 -21.4090004 ]
 [-24.01099968  -6.2670002  -20.54500008]
 [-33.42100143 -11.01299953 -17.59300041]
 [-27.12400055  -8.9989996  -19.46500015]
 [-27.48699951  -7.83300018 -20.2159996 ]
 [-29.04599953  -5.06099987 -2