In [1]:
from __future__ import annotations

from pathlib import Path
from typing import Literal
import subprocess

import MDAnalysis as mda
from ase import Atoms
from ase.io import read as ase_read
from openff.units import unit
from tqdm import tqdm

import mace.calculators
# from mace.calculators import mace_omol
from mace.calculators import mace_off

# Conversion factor
EV_TO_KCALMOL = 23.0605


  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [2]:
def export_fixed_pdb(
    pdb_file: Path,
    exclude: set[str] | None = None,
    charge_a: int = 0,
    charge_b: int = 0,
) -> Path:
    """
    Export a cleaned/fixed PDB that excludes solvent/ions and adds charge remarks.

    Parameters
    ----------
    pdb_file : Path
        Input PDB file.
    exclude : set[str], optional
        Residue names to exclude (e.g., solvents and ions).
    charge_a, charge_b : int
        Charges for molecule a (receptor) and b (ligand).

    Returns
    -------
    fixed_pdb : Path
        Path to the cleaned PDB file with REMARK charge lines.
    """
    if exclude is None:
        exclude = {"HOH", "WAT", "NA", "K", "CL", "MG", "CA"}

    u = mda.Universe(pdb_file)

    # Build a selection string excluding unwanted residues
    exclude_str = " ".join(sorted(exclude))
    sel = f"not resname {exclude_str}"

    # Select atoms excluding HOH, ions, etc.
    sel_atoms = u.select_atoms(sel)

    # Prepare output path
    fixed_pdb = pdb_file.with_name(pdb_file.stem + "_fixed.pdb")

    # Write new PDB
    sel_atoms.write(fixed_pdb)

    # Add REMARK lines for charge_a and charge_b at the top
    with open(fixed_pdb, "r+") as f:
        content = f.read()
        f.seek(0, 0)
        f.write(f"REMARK charge_a {charge_a}\n")
        f.write(f"REMARK charge_b {charge_b}\n")
        f.write(content)

    print(f"[export_fixed_pdb] Fixed PDB saved to: {fixed_pdb}")
    return fixed_pdb


In [3]:
def compute_pdb_charge(pdb_path, include_n_term=False, include_c_term=True):
    """
    Compute the net charge of a protein from a PDB file.

    Defaults are chosen to match your manual convention:
      - sidechains only  -> net charge you counted (e.g. +3)
      - plus C-terminal  -> extra -1 for -COO- (e.g. +2)
      - N-terminus is NOT counted unless include_n_term=True
    """

    # Simple protonation model at ~pH 7
    SIDECHAIN_CHARGES = {
        "LYS": +1,
        "ARG": +1,
        "ASP": -1,
        "GLU": -1,
        "HIS":  0,
        "HID": 0,
        "HIE": 0,
        "HIP": +1,
    }

    # Neutral terminal caps
    NEUTRAL_CAPS = {"ACE", "NME", "BNC", "BCC", "BCB"}

    residues = {}

    # --- Parse PDB ---
    with open(pdb_path, "r") as f:
        for line in f:
            if line.startswith(("ATOM", "HETATM")):
                res_name = line[17:20].strip()
                chain_id = line[21].strip()
                res_seq = line[22:26].strip()
                i_code = line[26].strip()
                atom_name = line[12:16].strip()

                key = (chain_id, res_seq, i_code)
                residues.setdefault(key, {"name": res_name, "atoms": set()})
                residues[key]["atoms"].add(atom_name)

    # --- Sort residues for terminal detection ---
    def sort_key(k):
        chain_id, res_seq, i_code = k
        try:
            res_seq_int = int(res_seq)
        except ValueError:
            res_seq_int = 0
        return (chain_id, res_seq_int, i_code)

    sorted_keys = sorted(residues, key=sort_key)
    total = 0

    # --- Sidechain charges ---
    for key in sorted_keys:
        name = residues[key]["name"]
        total += SIDECHAIN_CHARGES.get(name, 0)

    # --- N-terminus (+1), optional ---
    if include_n_term:
        for key in sorted_keys:
            name = residues[key]["name"]
            atoms = residues[key]["atoms"]
            if name in NEUTRAL_CAPS:
                continue
            if "N" in atoms:  # first non-capped residue with an N
                total += 1
                break

    # --- C-terminus (−1), optional but ON by default ---
    if include_c_term:
        for key in reversed(sorted_keys):
            name = residues[key]["name"]
            atoms = residues[key]["atoms"]
            if name in NEUTRAL_CAPS:
                continue
            if "OXT" in atoms:  # free carboxylate terminus
                total -= 1
                break

    return total


In [4]:
def read_remarks(pdbfile: Path) -> dict[str, int | str]:
    """
    Get a dict of the REMARK lines in a PDB file.

    Only lines starting with 'REMARK' and having at least 3 tokens are parsed.

    Parameters
    ----------
    pdbfile : Path
        Path to the PDB file.

    Returns
    -------
    remarks : dict[str, int | str]
        Dictionary mapping REMARK keys to their values.
    """
    remarks: dict[str, int | str] = {}
    with open(pdbfile) as f:
        for line in f:
            if line.startswith("REMARK"):
                parts = line.strip().split(maxsplit=2)
                if len(parts) == 3:
                    key = parts[1]
                    value: int | str = parts[2]
                    if "charge" in key:
                        value = int(value)
                    remarks[key] = value
    return remarks


In [5]:
def add_element_symbols(pdbfile: Path) -> Path:
    """
    Add element symbols to a PDB file using Open Babel.

    This function reads a PDB file, adds element symbols where missing,
    and writes the updated structure to a new PDB file.

    Parameters
    ----------
    pdbfile : pathlib.Path
        Path to the input PDB file.

    Returns
    -------
    new_pdbfile : pathlib.Path
        Path to the output PDB file with element symbols added.
    """
    new_pdbfile = pdbfile.with_name(pdbfile.stem + "_el.pdb")
    if not new_pdbfile.exists():
        subprocess.run(
            ["obabel", str(pdbfile), "-O", str(new_pdbfile), "--addelement"],
            check=True,
            stdout=subprocess.DEVNULL,  # suppress stdout
            stderr=subprocess.DEVNULL,  # suppress stderr
        )
        print(f"[add_element_symbols] Wrote {new_pdbfile}")
    else:
        print(f"[add_element_symbols] Using existing {new_pdbfile}")
    return new_pdbfile


In [6]:
def mda_to_ase(mda_atoms: mda.core.groups.AtomGroup) -> Atoms:
    """Convert an MDAnalysis AtomGroup to an ASE Atoms object."""
    symbols = [atom.element for atom in mda_atoms]
    positions = mda_atoms.positions
    atoms = Atoms(symbols=symbols, positions=positions)
    return atoms


In [7]:
def get_ase_atoms_from_files(
    receptor_pdb: Path,
    ligand_sdf: Path,
) -> dict[Literal["a", "b"], Atoms]:
    """
    Read separate receptor PDB and ligand SDF, and return ASE Atoms
    for each molecule.

    'a' = receptor
    'b' = ligand
    """
    # Receptor from PDB via MDAnalysis
    u_rec = mda.Universe(receptor_pdb)
    mol_a = u_rec.atoms
    ase_a = mda_to_ase(mol_a)

    # Ligand from SDF via ASE
    # If ligand.sdf has multiple conformers, this reads the first by default
    ase_b = ase_read(str(ligand_sdf))  # ASE guesses SDF from extension

    return {"a": ase_a, "b": ase_b}


In [8]:
def get_mlp_energy(
    atoms: Atoms,
    calc: mace.calculators.mace.MACECalculator,
    total_charge: int,
) -> unit.Quantity:
    """
    Compute the energy for given ASE Atoms using a MACE model.

    Parameters
    ----------
    atoms : Atoms
        An ASE Atoms object representing the molecular structure.
    calc : mace.calculators.mace.MACECalculator
        An instance of a MACECalculator.
    total_charge : int
        The total charge of the system.

    Returns
    -------
    energy : unit.Quantity
        The computed energy in kilocalories per mole (kcal/mol).
    """
    # Add the information to the atoms object
    atoms.info["charge"] = total_charge

    # Very simple spin rule: doublet if odd charge, singlet if even
#     atoms.info["spin_multiplicity"] = 1 if total_charge % 2 == 0 else 2

    # Set the calculator
    atoms.calc = calc

    # Get energy in eV and convert to kcal/mol
    energy_ev = atoms.get_potential_energy()
    energy_kcal = energy_ev * EV_TO_KCALMOL * unit.kilocalorie / unit.mole

    return energy_kcal.to(unit.kilocalorie / unit.mole)

In [9]:
def get_mlp_interaction_energy_from_files(
    receptor_pdb: Path,
    ligand_sdf: Path,
    calc: mace.calculators.mace.MACECalculator,
) -> unit.Quantity:
    """
    Compute the interaction energy using MACE, given separate receptor PDB and ligand SDF.

    Parameters
    ----------
    receptor_pdb : pathlib.Path
        Path to the (fixed) receptor PDB file (with REMARK charge_a / charge_b).
    ligand_sdf : pathlib.Path
        Path to the ligand SDF file.
    calc : mace.calculators.mace.MACECalculator
        An instance of a MACECalculator.

    Returns
    -------
    interaction_energy : unit.Quantity
        The computed interaction energy in kilocalories per mole (kcal/mol).
    """
    # Get ASE Atoms for receptor (a) and ligand (b)
    mols = get_ase_atoms_from_files(receptor_pdb, ligand_sdf)

    # Get charges from REMARKs in the receptor PDB
    pdb_remarks = read_remarks(receptor_pdb)
    charge_a = int(pdb_remarks["charge_a"])
    charge_b = int(pdb_remarks["charge_b"])

    energies: dict[str, unit.Quantity] = {}

    # E_a: receptor
    energies["a"] = get_mlp_energy(
        atoms=mols["a"],
        calc=calc,
        total_charge=charge_a,
    )

    # E_b: ligand
    energies["b"] = get_mlp_energy(
        atoms=mols["b"],
        calc=calc,
        total_charge=charge_b,
    )

    # E_complex: receptor + ligand
    complex_atoms = mols["a"] + mols["b"]
    total_charge = charge_a + charge_b

    energies["complex"] = get_mlp_energy(
        atoms=complex_atoms,
        calc=calc,
        total_charge=total_charge,
    )

    # Interaction energy = E(complex) - (E(a) + E(b))
    interaction_energy = energies["complex"] - (energies["a"] + energies["b"])
    return interaction_energy.to(unit.kilocalorie / unit.mole)


In [10]:
if __name__ == "__main__":
    # Directory setup
    STRUCTURES_DIR = Path("3QTU")
    RECEPTOR_PDB = STRUCTURES_DIR / "receptor.pdb"
    LIGANDS_DIR = STRUCTURES_DIR
    
    receptor_charge = compute_pdb_charge(str(RECEPTOR_PDB))
    print(f"Receptor net charge: {receptor_charge}")

    # 1) Ensure receptor has element symbols
    receptor_with_elements = add_element_symbols(RECEPTOR_PDB)

    # 2) Clean receptor and add REMARK charges
    receptor_fixed = export_fixed_pdb(
        receptor_with_elements,
#         exclude={"HOH", "WAT", "NA", "K", "CL", "MG", "CA"},
        exclude={"NA", "K", "CL", "MG", "CA"},
        charge_a=receptor_charge,  # receptor charge
        charge_b=0,   # ligand charge (adjust if needed)
    )

    # 3) Set up MACE model
#     calc = mace_omol("extra_large", device="cpu")
    # or:
    calc = mace_off("medium", device="cpu")

    # 4) Collect ligands
    ligand_files = sorted(LIGANDS_DIR.glob("*.sdf"))
    if not ligand_files:
        raise FileNotFoundError(f"No .sdf ligands found in {LIGANDS_DIR}")

    # 5) Compute interaction energies
    mlp_energies: list[unit.Quantity] = []
    print(f"Found {len(ligand_files)} ligands in {LIGANDS_DIR}\n")

    for lig in tqdm(ligand_files, desc="Computing interaction energies"):
        e_int = get_mlp_interaction_energy_from_files(
            receptor_fixed,
            lig,
            calc,
        )
        mlp_energies.append(e_int)

    # 6) Print results
    print("\nInteraction energies (kcal/mol):")
    for lig, e in zip(ligand_files, mlp_energies):
        print(f"{lig.name:30s}  {e}")

Receptor net charge: 2
[add_element_symbols] Using existing 3QTU/receptor_el.pdb


  torch.load(f=model_path, map_location=device)


[export_fixed_pdb] Fixed PDB saved to: 3QTU/receptor_el_fixed.pdb
Using MACE-OFF23 MODEL for MACECalculator with /home/campus.ncl.ac.uk/c2033567/.cache/mace/MACE-OFF23_medium.model
Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization.
Using head Default out of ['Default']
Found 1 ligands in 3QTU



Computing interaction energies: 100%|███████| 1/1 [00:19<00:00, 19.13s/it]


Interaction energies (kcal/mol):
ligand.sdf                      -93.13865050300956 kilocalorie / mole



