## UFF implementation in rdkit for modeling ensemble

In [None]:
# This cell contains a python script that uses RDKit to generate conformational ensemble a small molecule bound to a protein
# A colleague wrote this for small molecules.
# I'm considering to use this to generate side chain conformational ensemble for AF2 models
# Not sure if this is the best way to do this.

from io import StringIO
from pathlib import Path
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdFMCS
from Bio.PDB import PDBParser, PDBIO, Structure, Model, Chain, Residue, Atom
from tqdm import tqdm


def trim_protein(protein, ligand, thresh=8.0, lig_conf_id=0):
    # only keep immediately interacting residues for efficiency

    def get_res_id(atom):
        try:
            chain_id = atom.GetPDBResidueInfo().GetChainId()
            resnum = atom.GetPDBResidueInfo().GetResidueNumber()
        except AttributeError as e:
            return None
        return (chain_id, resnum)

    residue_distances = {}
        
    # for each residue, compute distance to ligand
    for idx in range(protein.GetNumAtoms()):
        atom_coord = protein.GetConformer().GetAtomPosition(idx)
        res_id = get_res_id(protein.GetAtomWithIdx(idx))
        if res_id is None:
            continue

        dists = [
            atom_coord.Distance(ligand.GetConformer(lig_conf_id).GetAtomPosition(lig_idx))
            for lig_idx in range(ligand.GetNumAtoms())
        ]
        dist_to_ligand = min(dists)

        if res_id in residue_distances:
            residue_distances[res_id] = min(residue_distances[res_id], dist_to_ligand)
        else:
            residue_distances[res_id] = dist_to_ligand

    # remove atoms
    trimmed_pocket = Chem.EditableMol(protein)
    for idx in reversed(range(protein.GetNumAtoms())):
        res_id = get_res_id(protein.GetAtomWithIdx(idx))
        if res_id is None or residue_distances[res_id] > thresh:
            trimmed_pocket.RemoveAtom(idx)
            
    return trimmed_pocket.GetMol()


# def assign_coords(mol, ref_mol):
#     AllChem.EmbedMolecule(mol)    
#     match = mol.GetSubstructMatch(ref_mol)
#     conf = mol.GetConformer()
#     ref_conf = ref_mol.GetConformer()

#     # initialize all coordinates in centroid
#     centroid = ref_mol.GetConformer().GetPositions().mean(axis=0)
#     for j in range(mol.GetNumAtoms()):
#         conf.SetAtomPosition(j, centroid)

#     # Transfer coordinates
#     for i, j in enumerate(match):
#         pos = ref_conf.GetAtomPosition(i)
#         conf.SetAtomPosition(j, pos)

#     return mol


def minimize_with_constraints(mol, fix_atoms, protein=None, max_iter=1000, fix_bb=True, fix_side_chains=True):
    mol = Chem.AddHs(mol, addCoords=True)

    if protein is not None:
        protein = Chem.AddHs(protein, addCoords=True)
        ligand = Chem.Mol(mol)
        mol = Chem.CombineMols(mol, protein)
        Chem.SanitizeMol(mol)
        if fix_bb:
            bb_atoms = [a for a in protein.GetAtoms() if a.GetSymbol() != 'H' and a.GetPDBResidueInfo().GetName().strip() in {'N', 'CA', 'C', 'O', 'OXT'}]
            fix_atoms.extend([ligand.GetNumAtoms() + a.GetIdx() for a in bb_atoms])
        if fix_side_chains:
            side_chain_atoms = [a for a in protein.GetAtoms() if a.GetSymbol() != 'H' and a.GetPDBResidueInfo().GetName().strip() not in {'N', 'CA', 'C', 'O', 'OXT'}]
            fix_atoms.extend([ligand.GetNumAtoms() + a.GetIdx() for a in side_chain_atoms])

    energies = []
    is_converged = []
    for conf_id in tqdm(range(mol.GetNumConformers())):
        ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id, ignoreInterfragInteractions=False)
        for idx in fix_atoms:
            atom = mol.GetAtomWithIdx(idx)
            if atom.GetAtomicNum() > 1: # constrain heavy atoms
                # ff.UFFAddPositionConstraint(atom.GetIdx(), 0.0, 1.e4)
                ff.AddFixedPoint(atom.GetIdx())
        return_flag = ff.Minimize(maxIts=max_iter)
        # print("Success?", return_flag == 0)
        energies.append(ff.CalcEnergy())
        is_converged.append(return_flag == 0)

    if protein is not None:
        frags = Chem.GetMolFrags(mol, asMols=True)
        mol = frags[0]
        protein = frags[1]
        for frag in frags[2:]:
            protein = Chem.CombineMols(protein, frag)
        protein = Chem.RemoveHs(protein)
    mol = Chem.RemoveHs(mol)

    # Assign energies as properties
    for c, energy, conv in zip(mol.GetConformers(), energies, is_converged):
        c.SetDoubleProp("energy", energy)
        c.SetProp("converged", str(conv))
    
    return mol if protein is None else (mol, protein)


def argsort_conformers(mol, by, ascending=True):

    conformers = list(mol.GetConformers())
    conformers_sorted = sorted(conformers, key=lambda conf: float(conf.GetProp(by)), reverse=not ascending)
    conformer_ids_sorted = [c.GetId() for c in conformers_sorted]

    return conformer_ids_sorted


def permute_conformers(mol, order):
    assert len(order) == mol.GetNumConformers()

    # Make a copy
    out_mol = Chem.Mol(mol)
    
    # Remove all existing conformers
    out_mol.RemoveAllConformers()
    
    # Add conformers back in sorted order
    for cid in order:
        conf = mol.GetConformer(cid)
        out_mol.AddConformer(conf, assignId=True)

    return out_mol


def unique_conformers(mol, rmsd_thresh=1.0):
    """Remove conformers from `mol` with RMSD < `rmsd_thresh` to any previous conformer."""
    unique = []
    for i in range(mol.GetNumConformers()):
        keep = True
        for j in unique:
            rmsd = AllChem.GetConformerRMS(mol, i, j, prealigned=True)
            if rmsd < rmsd_thresh:
                keep = False
                break
        if keep:
            unique.append(i)

    return unique


# def deduplicate_conformers(mol, rmsd_thresh=1.0, energies=None):
#     """Remove conformers from `mol` with RMSD < `rmsd_thresh` to any previous conformer."""
#     unique = unique_conformers(mol, rmsd_thresh)

#     # Remove non-unique conformers (must be done in reverse order!)
#     all_ids = list(range(mol.GetNumConformers()))
#     to_delete = sorted(set(all_ids) - set(unique), reverse=True)
#     for cid in to_delete:
#         mol.RemoveConformer(cid)
#         if energies is not None:
#             energies.pop(cid)

#     # reset ids
#     for i in range(len(unique)):
#         mol.GetConformer(unique[i]).SetId(i)

#     return mol


class AtomNameResolver:
    def __init__(self):
        self.current_idx = {}

    def __call__(self, atom_type):
        if not atom_type in self.current_idx:
            self.current_idx[atom_type] = 0
        self.current_idx[atom_type] += 1
        return f'{atom_type}{self.current_idx[atom_type]}'


# def merge_pdb_sdf(pdb_file, sdf_file, output_pdb_file):
#     # --- Load PDB structure ---
#     parser = PDBParser(QUIET=True)
#     structure = parser.get_structure("protein", pdb_file)
    
#     # --- Load SDF ligand using RDKit ---
#     ligand = Chem.SDMolSupplier(sdf_file, removeHs=False)[0]
#     conf = ligand.GetConformer()
    
#     # --- Convert RDKit ligand to Biopython Residue ---
#     ligand_chain = chr(ord(max(chain.id for chain in structure.get_chains())) + 1)
#     model = structure[0]
#     chain = Chain.Chain(ligand_chain)  # Use new chain for ligand
#     residue = Residue.Residue(("H_UNL", 1, " "), "UNL", "")

#     get_atom_name = AtomNameResolver()
#     for atom in ligand.GetAtoms():
#         pos = conf.GetAtomPosition(atom.GetIdx())
#         atom_name = get_atom_name(atom.GetSymbol())
#         bp_atom = Atom.Atom(
#             atom_name, [pos.x, pos.y, pos.z], 0.0, 1.0, ' ', atom_name, atom.GetIdx() + 1, atom.GetSymbol()
#         )
#         residue.add(bp_atom)
    
#     chain.add(residue)
#     model.add(chain)
    
#     # --- Write merged structure to PDB ---
#     io = PDBIO()
#     io.set_structure(structure)
#     io.save(str(output_pdb_file))

def merge_protein_model_and_ligand(pdb_model, ligand_rdmol):

    # --- Convert RDKit ligand to Biopython Residue ---
    ligand_chain = chr(ord(max(chain.id for chain in pdb_model.get_chains())) + 1)
    chain = Chain.Chain(ligand_chain)  # Use new chain for ligand
    residue = Residue.Residue(("H_UNL", 1, " "), "UNL", "")
    conf = ligand_rdmol.GetConformer()

    get_atom_name = AtomNameResolver()
    for atom in ligand_rdmol.GetAtoms():
        pos = conf.GetAtomPosition(atom.GetIdx())
        atom_name = get_atom_name(atom.GetSymbol())
        bp_atom = Atom.Atom(
            atom_name, [pos.x, pos.y, pos.z], 0.0, 1.0, ' ', atom_name, atom.GetIdx() + 1, atom.GetSymbol().upper()
        )
        residue.add(bp_atom)
    
    chain.add(residue)
    pdb_model.add(chain)

    return pdb_model


def combine_pdb_models(list_of_models):
    structure = Structure.Structure("combined")
    for model in list_of_models:
        model.id = len(structure)
        structure.add(model)
    return structure


if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument('--original_mol_sdf', type=Path)
    parser.add_argument('--new_mol_smiles', type=str)
    parser.add_argument('--allow_partial_match', action='store_true')
    parser.add_argument('--protein_pdb', type=Path)
    parser.add_argument('--outdir', type=Path, default=None, help='Provide this to write separate PDB and SDF files for each model.')
    parser.add_argument('--out_sdf', type=Path, default=None)
    parser.add_argument('--out_pdb', type=Path, default=None)
    parser.add_argument('--num_confs', type=int, default=50)
    parser.add_argument('--rmsd_thresh', type=float, default=1.0)
    parser.add_argument('--pocket_cutout', type=float, default=8.0)
    parser.add_argument('--max_iter', type=int, default=1000)
    parser.add_argument('--relax_side_chains', action='store_true')
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    # assert not (args.relax_side_chains and args.outdir is None), "Cannot write output PDB structures if outdir is not provided."
    if args.relax_side_chains and args.outdir is None and args.out_pdb:
        print("[WARNING] Cannot write output PDB structures if outdir (or out_pdb) is not provided.")

    # Load inputs
    ref_mol = Chem.SDMolSupplier(args.original_mol_sdf)[0]
    mol_to_model = Chem.MolFromSmiles(args.new_mol_smiles)
    protein = Chem.MolFromPDBFile(str(args.protein_pdb), sanitize=False)


    # Generate random conformers as starting points
    rdmol = Chem.AddHs(mol_to_model, addCoords=True)
    AllChem.EmbedMultipleConfs(rdmol, numConfs=args.num_confs, randomSeed=args.seed)
    # converged = AllChem.UFFOptimizeMoleculeConfs(rdmol, maxIters=1000)
    # # print(converged)
    rdmol = Chem.RemoveHs(rdmol)
    print(f"Created {args.num_confs} conformers.")


    # Align each with the reference and assign original coordinates
    if args.allow_partial_match:
        mcs_res = rdFMCS.FindMCS([ref_mol, rdmol], ringMatchesRingOnly=True, completeRingsOnly=True)
        assert not mcs_res.canceled
        mcs = Chem.MolFromSmarts(mcs_res.smartsString)
        match_mol_inds = rdmol.GetSubstructMatch(mcs)
        match_ref_inds = ref_mol.GetSubstructMatch(mcs)
    else:
        match_mol_inds = rdmol.GetSubstructMatch(ref_mol)
        match_ref_inds = list(range(len(match_mol_inds)))
    
    for cid in range(rdmol.GetNumConformers()):
        # rms = AlignMol(rdmol, ref_mol, prbCid=i, atomMap=list(enumerate(match)))
        rms = AllChem.AlignMol(rdmol, ref_mol, prbCid=cid, atomMap=[(j, i) for i, j in zip(match_ref_inds, match_mol_inds)])

        for i, j in zip(match_ref_inds, match_mol_inds):
            pos = ref_mol.GetConformer().GetAtomPosition(i)
            rdmol.GetConformer(cid).SetAtomPosition(j, pos)
    print(f"Assigned fixed coordinates.")


    # Relax starting point while keeping known substructure fixed
    rdmol = minimize_with_constraints(rdmol, fix_atoms=list(match_mol_inds))
    print(f"Starting positions are ready.")


    # Remove unnecessary parts from protein
    pocket = Chem.RemoveHs(protein)
    if not ((args.pocket_cutout is None) or np.isinf(args.pocket_cutout)):
        pocket = trim_protein(pocket, ref_mol, thresh=args.pocket_cutout)
        print("Protein successfully trimmed.")


    # Repeat conformers for the protein
    for i in range(rdmol.GetNumConformers() - 1):
        # pocket.AddConformer(pocket.GetConformer(0), assignId=True)
        pocket.AddConformer(Chem.Conformer(pocket.GetConformer(0)), assignId=True)  # create copy


    # Run energy minimization
    rdmol, pocket = minimize_with_constraints(
        rdmol, 
        fix_atoms=list(match_mol_inds), 
        protein=pocket, 
        max_iter=args.max_iter, 
        fix_side_chains=not args.relax_side_chains,
    )
    print("Relaxation done.")


    # Sort conformers
    conformer_order = argsort_conformers(rdmol, by='energy')
    rdmol = permute_conformers(rdmol, order=conformer_order)
    if args.relax_side_chains:
        pocket = permute_conformers(pocket, order=conformer_order)


    # Remove close duplicates
    # rdmol = deduplicate_conformers(rdmol, rmsd_thresh=2.0, energies=energies)
    unique_inds = unique_conformers(rdmol, rmsd_thresh=args.rmsd_thresh)
    print(f"{len(unique_inds)} conformer(s) left after deduplication.\n")


    # Write the output files
    list_of_rdmols = []
    for i, cid in enumerate(unique_inds):
        new_mol = Chem.Mol(rdmol)
        new_mol.RemoveAllConformers()
        # copy conformer properties to mol for saving
        conformer = rdmol.GetConformer(cid)
        for prop, val in conformer.GetPropsAsDict().items():
            new_mol.SetProp(prop, str(val))
        new_mol.AddConformer(conformer)
        list_of_rdmols.append(new_mol)

    list_of_pdb_models = []
    for i, cid in enumerate(unique_inds):  
        parser = PDBParser(QUIET=True)
        pdb_stream = StringIO(Chem.MolToPDBBlock(pocket, confId=cid))
        pdb_model = parser.get_structure("protein", pdb_stream)[0]

        combined_model = merge_protein_model_and_ligand(pdb_model, list_of_rdmols[i])
        list_of_pdb_models.append(combined_model)

    # Save a single SDF file
    if args.out_sdf is not None:
        with Chem.SDWriter(args.out_sdf) as w:
            for mol in list_of_rdmols:
                w.write(mol)

    # Save a single PDB file
    if args.out_pdb is not None:
        combined_structure = combine_pdb_models(list_of_pdb_models)
        io = PDBIO()
        io.set_structure(combined_structure)
        io.save(str(args.out_pdb))

    # Save separate SDF files as well as combined PDB files (format required for MaSIF)
    if args.outdir is not None:
        
        assert len(list_of_rdmols) == len(list_of_pdb_models)

        args.outdir.mkdir(exist_ok=True)
        for i, (mol, pdb_model) in enumerate(zip(list_of_rdmols, list_of_pdb_models)):

            file_stem = Path(args.outdir, f'model_{i}')

            # Save SDF
            out_sdf = file_stem.with_suffix('.sdf')
            with Chem.SDWriter(out_sdf) as w:    
                w.write(mol)

            # Save PDB
            out_pdb = file_stem.with_suffix('.pdb')
            io = PDBIO()
            io.set_structure(pdb_model)
            io.save(str(out_pdb))

ModuleNotFoundError: No module named 'rdkit'

In [None]:
def get_protein_ensemble(
    struct,
    resi_start,
    resi_end,
    out_cif,
    num_confs=50,
    rmsd_thresh=1.0,
    max_iter=1000,
    seed=42
):
    """
    Generate conformational ensemble for a flexible protein region.
    
    Parameters:
    -----------
    struct : str or Path
        Path to PDB file containing the protein structure
    resi_start : int
        Start residue index for the flexible region
    resi_end : int
        End residue index for the flexible region
    out_cif : str or Path
        Output path for mmCIF file containing all conformers
    num_confs : int, default=50
        Number of conformers to generate
    rmsd_thresh : float, default=1.0
        RMSD threshold for deduplication
    max_iter : int, default=1000
        Maximum iterations for UFF minimization
    seed : int, default=42
        Random seed for conformer generation
        
    Returns:
    --------
    Path
        Full path to the output mmCIF file
    """
    from Bio.PDB import MMCIFIO
    
    # Load protein structure
    protein = Chem.MolFromPDBFile(str(struct), sanitize=False)
    print(f"Loaded protein with {protein.GetNumAtoms()} atoms")
    
    # Identify flexible and static region atoms
    flexible_atoms = []
    static_atoms = []
    
    for atom in protein.GetAtoms():
        try:
            resnum = atom.GetPDBResidueInfo().GetResidueNumber()
            if resi_start <= resnum <= resi_end:
                flexible_atoms.append(atom.GetIdx())
            else:
                static_atoms.append(atom.GetIdx())
        except AttributeError:
            # If residue info is not available, treat as static
            static_atoms.append(atom.GetIdx())
    
    print(f"Flexible region: {len(flexible_atoms)} atoms (residues {resi_start}-{resi_end})")
    print(f"Static region: {len(static_atoms)} atoms")
    
    # Generate multiple conformers
    protein_with_hs = Chem.AddHs(protein, addCoords=True)
    AllChem.EmbedMultipleConfs(protein_with_hs, numConfs=num_confs, randomSeed=seed)
    protein_multi_conf = Chem.RemoveHs(protein_with_hs)
    print(f"Generated {num_confs} initial conformers")
    
    # Minimize with constraints (keep static region fixed)
    # Note: minimize_with_constraints expects protein=None for protein-only minimization
    protein_minimized = minimize_with_constraints(
        protein_multi_conf,
        fix_atoms=static_atoms.copy(),
        protein=None,
        max_iter=max_iter,
        fix_bb=False,  # We're manually specifying atoms to fix
        fix_side_chains=False
    )
    print("Energy minimization completed")
    
    # Sort conformers by energy
    conformer_order = argsort_conformers(protein_minimized, by='energy')
    protein_sorted = permute_conformers(protein_minimized, order=conformer_order)
    
    # Remove duplicate conformers
    unique_inds = unique_conformers(protein_sorted, rmsd_thresh=rmsd_thresh)
    print(f"{len(unique_inds)} unique conformer(s) after deduplication")
    
    # Convert each conformer to a Biopython model
    list_of_pdb_models = []
    for i, cid in enumerate(unique_inds):
        parser = PDBParser(QUIET=True)
        pdb_stream = StringIO(Chem.MolToPDBBlock(protein_sorted, confId=cid))
        pdb_model = parser.get_structure("protein", pdb_stream)[0]
        pdb_model.id = i  # Set model ID
        list_of_pdb_models.append(pdb_model)
    
    # Combine all models into a single structure
    combined_structure = combine_pdb_models(list_of_pdb_models)
    
    # Write to mmCIF format
    io = MMCIFIO()
    io.set_structure(combined_structure)
    io.save(str(out_cif))
    print(f"Saved {len(list_of_pdb_models)} conformers to {out_cif}")
    
    return Path(out_cif)


In [None]:
get_protein_ensemble