In [6]:
%cd ~/REVIVAL2
%load_ext autoreload
%autoreload 2
%load_ext blackcellmagic

/disk2/fli/REVIVAL2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
import os
import numpy as np
import tempfile
from Bio.PDB import PDBParser, MMCIFParser
from rdkit import Chem
from rdkit.Chem import AllChem

from REVIVAL.util import get_chain_structure


def get_hydrogen_position(structure_file, chain_id, atom):
    """
    Extract a specific chain, process it with RDKit, and get the hydrogen position.

    Args:
        structure_file (str): Path to the structure file.
        chain_id (str): Chain ID to extract for RDKit processing.
        atom: Biopython atom object for which the hydrogen position is needed.

    Returns:
        np.ndarray: Position of the hydrogen atom attached to the given atom.

    Raises:
        ValueError: If RDKit fails to parse the file or if no hydrogen is found.
    """
    # Extract the chain to a temporary PDB file
    with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_pdb:
        temp_pdb_name = temp_pdb.name
    get_chain_structure(structure_file, chain_id, temp_pdb_name)

    # Generate RDKit molecule
    mol = Chem.MolFromPDBFile(temp_pdb_name, sanitize=False)
    if mol is None:
        raise ValueError(f"RDKit failed to parse the extracted chain: {temp_pdb_name}")

    mol_with_h = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol_with_h)

    # Get hydrogen position
    atom_idx = atom.serial_number - 1  # Match RDKit indexing
    conf = mol_with_h.GetConformer()
    for neighbor in mol_with_h.GetAtomWithIdx(atom_idx).GetNeighbors():
        if neighbor.GetSymbol() == "H":
            hydrogen_idx = neighbor.GetIdx()
            return np.array(conf.GetAtomPosition(hydrogen_idx))

    raise ValueError(f"No hydrogen found for atom {atom.get_id()}.")


def measure_bond_distance(
    structure_file, 
    chain_id_1, 
    res_id_1, 
    atom_name_1, 
    chain_id_2, 
    res_id_2, 
    atom_name_2, 
    add_hydrogen_to_1=False, 
    add_hydrogen_to_2=False):
    """
    Measure the bond distance between two atoms or hydrogens attached to them in a PDB or CIF file.

    Args:
        structure_file (str): Path to the PDB or CIF file.
        chain_id_1 (str): Chain ID where the first atom is located.
        res_id_1 (tuple): Tuple of (residue sequence number, insertion code) for the first atom.
        atom_name_1 (str): Name of the first atom.
        chain_id_2 (str): Chain ID where the second atom is located.
        res_id_2 (tuple): Tuple of (residue sequence number, insertion code) for the second atom.
        atom_name_2 (str): Name of the second atom.
        add_hydrogen_to_1 (bool): Add a hydrogen to atom_1 for distance calculation.
        add_hydrogen_to_2 (bool): Add a hydrogen to atom_2 for distance calculation.

    Returns:
        float: Distance between the specified atoms (or hydrogen atoms) in angstroms.
    """

    file_format = os.path.splitext(structure_file)[1][1:]

    if file_format.lower() == 'pdb':
        parser = PDBParser(QUIET=True)
    elif file_format.lower() == 'cif':
        parser = MMCIFParser(QUIET=True)
    else:
        raise ValueError("Unsupported file format. Use 'pdb' or 'cif'.")

    structure = parser.get_structure("protein", structure_file)

    # # Ensure res_id_1 and res_id_2 are tuples
    # def format_res_id(res_id):
    #     if isinstance(res_id, int):  # If only sequence number is provided
    #         return (' ', res_id, ' ')
    #     return res_id


    # Locate the first atom
    chain_1 = structure[0][chain_id_1]
    chain_2 = structure[0][chain_id_2]
    print(chain_1, chain_2)
    # Match residues dynamically
    def find_residue_by_id(chain, target_res_id):
        for residue in chain.get_residues():
            if residue.id[1] == target_res_id:
                return residue
        raise ValueError(f"Residue with ID {target_res_id} not found in chain.")

    residue_1 = find_residue_by_id(chain_1, res_id_1)
    residue_2 = find_residue_by_id(chain_2, res_id_2)
    print(residue_1, residue_2)

    def get_atom_with_variations(residue, atom_name):
        """
        Attempt to retrieve an atom from a residue, trying multiple variations of the atom name.

        Args:
            residue: A Bio.PDB.Residue object.
            atom_name (str): The atom name to search for.

        Returns:
            Atom object if found.

        Raises:
            KeyError: If no matching atom name is found.
        """
        variations = [
            atom_name,
            atom_name.replace("_", ""),  # Remove underscores
            f"{atom_name[0]}_{atom_name[1:]}",  # Add underscore after the first character
        ]
        for variation in variations:
            try:
                return residue[variation]
            except KeyError:
                continue
        raise KeyError(f"Atom {atom_name} or its variations not found in residue {residue}.")

    atom_1 = get_atom_with_variations(residue_1, atom_name_1)
    atom_2 = get_atom_with_variations(residue_2, atom_name_2)
    print(atom_1, atom_2)

    # # Generate RDKit molecule
    # mol = Chem.MolFromPDBFile(structure_file, sanitize=False)
    # mol_with_h = Chem.AddHs(mol)
    # AllChem.EmbedMolecule(mol_with_h)

    # def get_hydrogen_position(atom):
    #     """
    #     Find the position of a hydrogen atom attached to the given atom.
    #     """
    #     atom_idx = atom.serial_number - 1  # Match RDKit indexing
    #     conf = mol_with_h.GetConformer()
    #     for neighbor in mol_with_h.GetAtomWithIdx(atom_idx).GetNeighbors():
    #         if neighbor.GetSymbol() == "H":
    #             hydrogen_idx = neighbor.GetIdx()
    #             return np.array(conf.GetAtomPosition(hydrogen_idx))
    #     raise ValueError(f"No hydrogen found for atom {atom.get_id()}.")

    # Determine coordinates for atom_1 and atom_2
    if add_hydrogen_to_1:
        coord_1 = get_hydrogen_position(atom_1)
    else:
        coord_1 = atom_1.coord

    if add_hydrogen_to_2:
        coord_2 = get_hydrogen_position(atom_2)
    else:
        coord_2 = atom_2.coord

    # Calculate the distance
    distance = np.linalg.norm(coord_1 - coord_2)
    return distance

In [50]:
from REVIVAL.global_param import LIB_INFO_DICT

In [51]:
LIB_INFO_DICT["PfTrpB-4bromo"]["cofactor-distances"]

{'C-C': (('B', 1, 'LIG', 'C_5', False), ('B', 1, 'LIG', 'C_14', False)),
 'GLU-NH_1': (('A', 104, 'GLU', 'OE1', False), ('B', 1, 'LIG', 'N_1', True)),
 'GLU-NH_2': (('A', 104, 'GLU', 'OE2', False), ('B', 1, 'LIG', 'N_1', True))}

In [52]:
atom1_info, atom2_info = LIB_INFO_DICT["PfTrpB-4bromo"]["cofactor-distances"]["C-C"]
chain_id_1, res_id_1,res_name_1, atom_name_1, atom_h_1 = atom1_info
chain_id_2, res_id_2, res_name_2, atom_name_2, atom_h_2 = atom2_info

In [57]:
measure_bond_distance(
    structure_file="/disk2/fli/af3_inference/outputs/pftrpb-4bromo_joint-all/pftrpb-4bromo_joint-all_model.cif", 
    chain_id_1=chain_id_1, 
    res_id_1=res_id_1, 
    atom_name_1=atom_name_1, 
    chain_id_2=chain_id_2, 
    res_id_2=res_id_2, 
    atom_name_2=atom_name_2, 
    add_hydrogen_to_1=atom_h_1,
    add_hydrogen_to_2=atom_h_2
    )

<Chain id=B> <Chain id=B>
<Residue LIG_B het=H_LIG_B resseq=1 icode= > <Residue LIG_B het=H_LIG_B resseq=1 icode= >
<Atom C5> <Atom C14>


3.1300173

In [14]:
import os
import numpy as np
from Bio.PDB import PDBParser, MMCIFParser
from rdkit import Chem
from rdkit.Chem import AllChem
from difflib import get_close_matches


def find_matching_residue(chain, input_res_name, input_atom_name):
    """
    Find the best-matching residue and atom in a chain based on input names.

    Args:
        chain: Bio.PDB.Chain object.
        input_res_name (str): Input residue name (e.g., LIG, LIG_B).
        input_atom_name (str): Input atom name (e.g., C_5, C5).

    Returns:
        tuple: Matching residue and atom objects.
    """
    for residue in chain.get_residues():
        res_name = residue.get_resname()
        atom_names = [atom.get_name() for atom in residue.get_atoms()]

        # Compare residue and atom names
        if input_res_name in res_name or get_close_matches(input_res_name, [res_name]):
            for atom_name in atom_names:
                if input_atom_name in atom_name or get_close_matches(input_atom_name, [atom_name]):
                    return residue, residue[atom_name]

    raise ValueError(f"Matching residue/atom not found for {input_res_name}/{input_atom_name}.")


def measure_bond_distance(
    structure_file, 
    chain_id_1, 
    input_res_name_1, 
    input_atom_name_1, 
    chain_id_2, 
    input_res_name_2, 
    input_atom_name_2, 
    add_hydrogen_to_1=False, 
    add_hydrogen_to_2=False):
    """
    Measure the bond distance between two atoms or hydrogens attached to them in a PDB or CIF file.

    Args:
        structure_file (str): Path to the PDB or CIF file.
        chain_id_1 (str): Chain ID where the first atom is located.
        input_res_name_1 (str): Residue name of the first atom.
        input_atom_name_1 (str): Name of the first atom.
        chain_id_2 (str): Chain ID where the second atom is located.
        input_res_name_2 (str): Residue name of the second atom.
        input_atom_name_2 (str): Name of the second atom.
        add_hydrogen_to_1 (bool): Add a hydrogen to atom_1 for distance calculation.
        add_hydrogen_to_2 (bool): Add a hydrogen to atom_2 for distance calculation.

    Returns:
        float: Distance between the specified atoms (or hydrogen atoms) in angstroms.
    """

    # Determine file format
    file_format = os.path.splitext(structure_file)[1][1:]

    if file_format.lower() == 'pdb':
        parser = PDBParser(QUIET=True)
    elif file_format.lower() == 'cif':
        parser = MMCIFParser(QUIET=True)
    else:
        raise ValueError("Unsupported file format. Use 'pdb' or 'cif'.")

    structure = parser.get_structure("protein", structure_file)

    # Locate chains
    chain_1 = structure[0][chain_id_1]
    chain_2 = structure[0][chain_id_2]

    # Match residues and atoms
    residue_1, atom_1 = find_matching_residue(chain_1, input_res_name_1, input_atom_name_1)
    residue_2, atom_2 = find_matching_residue(chain_2, input_res_name_2, input_atom_name_2)

    # Generate RDKit molecule
    mol = Chem.MolFromPDBFile(structure_file, sanitize=False)
    mol_with_h = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol_with_h)

    def get_hydrogen_position(atom):
        """
        Find the position of a hydrogen atom attached to the given atom.
        """
        atom_idx = atom.serial_number - 1  # Match RDKit indexing
        conf = mol_with_h.GetConformer()
        for neighbor in mol_with_h.GetAtomWithIdx(atom_idx).GetNeighbors():
            if neighbor.GetSymbol() == "H":
                hydrogen_idx = neighbor.GetIdx()
                return np.array(conf.GetAtomPosition(hydrogen_idx))
        raise ValueError(f"No hydrogen found for atom {atom.get_id()}.")

    # Determine coordinates for atom_1 and atom_2
    if add_hydrogen_to_1:
        coord_1 = get_hydrogen_position(atom_1)
    else:
        coord_1 = atom_1.coord

    if add_hydrogen_to_2:
        coord_2 = get_hydrogen_position(atom_2)
    else:
        coord_2 = atom_2.coord

    # Calculate the distance
    distance = np.linalg.norm(coord_1 - coord_2)
    return distance


In [None]:

# # Example usage
# pdb_file = "example.pdb"
# chain_id = "A"
# res_id_1 = (10, " ")  # Residue 10, no insertion code
# atom_name_1 = "CA"  # Alpha carbon
# res_id_2 = (20, " ")  # Residue 20, no insertion code
# atom_name_2 = "CB"  # Beta carbon

# bond_distance = measure_bond_distance(pdb_file, chain_id, res_id_1, atom_name_1, res_id_2, atom_name_2)
# print(f"Bond distance: {bond_distance:.2f} Å")
