In [1]:
import os
from tqdm import tqdm
import shutil
import sys
import copy
import warnings
import numpy as np
import pandas as pd
from scipy import spatial as spa
from scipy.optimize import minimize, Bounds
from rdkit import Chem
import Bio
import scipy.spatial as spa
from Bio.PDB import PDBParser
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs
from rdkit.Geometry import Point3D
from scipy import spatial
from scipy.special import softmax

In [7]:
biopython_parser = PDBParser()


def safe_index(l, e):
    """ Return index of element e in list l. If e is not present, return the last index """
    try:
        return l.index(e)
    except:
        return len(l) - 1


def parse_receptor(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file):
    rec = parsePDB(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file)
    return rec


def parsePDB(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file):
    rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb')
    if not os.path.exists(rec_path) or use_full_size_file or use_original_protein_file:
        rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_obabel_reduce.pdb')
        if not os.path.exists(rec_path) or use_original_protein_file:
            rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein.pdb')


    return parse_pdb_from_path(rec_path)

def parse_pdb_from_path(path):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=PDBConstructionWarning)
        structure = biopython_parser.get_structure('random_id', path)
        rec = structure[0]
    return rec



def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
    if molecule_file.endswith('.mol2'):
        mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
    elif molecule_file.endswith('.sdf'):
        supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
        mol = supplier[0]
    elif molecule_file.endswith('.pdbqt'):
        with open(molecule_file) as file:
            pdbqt_data = file.readlines()
        pdb_block = ''
        for line in pdbqt_data:
            pdb_block += '{}\n'.format(line[:66])
        mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
    elif molecule_file.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
    else:
        return ValueError('Expect the format of the molecule_file to be '
                          'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))

    try:
        if sanitize or calc_charges:
            Chem.SanitizeMol(mol)

        if calc_charges:
            # Compute Gasteiger charges on the molecule.
            try:
                AllChem.ComputeGasteigerCharges(mol)
            except:
                warnings.warn('Unable to compute charges for the molecule.')

        if remove_hs:
            mol = Chem.RemoveHs(mol, sanitize=sanitize)
    except:
        return None

    return mol


def read_sdf_or_mol2(sdf_fileName):

    try:
        mol = read_molecule(sdf_fileName)
    except Exception as e:
        mol2_fileName = sdf_fileName[:-3] + "mol2"
        mol = read_molecule(mol2_fileName)
    return mol

def read_mols(pdbbind_dir, name, remove_hs=False):
    ligs = []
    for file in os.listdir(os.path.join(pdbbind_dir, name)):
        if file.endswith(".sdf") and 'rdkit' not in file:
            lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True)
            if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")):  # read mol2 file if sdf file cannot be sanitized
                #print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')
                lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), remove_hs=remove_hs, sanitize=True)
            if lig is not None:
                ligs.append(lig)
    return ligs

def extract_receptor_structure(rec, lig, cutoff=10000, lm_embedding_chains=None):
    conf = lig.GetConformer()
    lig_coords = conf.GetPositions()
    min_distances = []
    coords = []
    c_alpha_coords = []
    n_coords = []
    c_coords = []
    valid_chain_ids = []
    lengths = []
    for i, chain in enumerate(rec):
        chain_coords = []  # num_residues, num_atoms, 3
        chain_c_alpha_coords = []
        chain_n_coords = []
        chain_c_coords = []
        count = 0
        invalid_res_ids = []
        for res_idx, residue in enumerate(chain):
            if residue.get_resname() == 'HOH':
                invalid_res_ids.append(residue.get_id())
                continue
            residue_coords = []
            c_alpha, n, c = None, None, None
            for atom in residue:
                if atom.name == 'CA':
                    c_alpha = list(atom.get_vector())
                if atom.name == 'N':
                    n = list(atom.get_vector())
                if atom.name == 'C':
                    c = list(atom.get_vector())
                residue_coords.append(list(atom.get_vector()))

            if c_alpha != None and n != None and c != None:
                # only append residue if it is an amino acid and not some weird molecule that is part of the complex
                chain_c_alpha_coords.append(c_alpha)
                chain_n_coords.append(n)
                chain_c_coords.append(c)
                chain_coords.append(np.array(residue_coords))
                count += 1
            else:
                invalid_res_ids.append(residue.get_id())
        for res_id in invalid_res_ids:
            chain.detach_child(res_id)
        if len(chain_coords) > 0:
            all_chain_coords = np.concatenate(chain_coords, axis=0)
            distances = spatial.distance.cdist(lig_coords, all_chain_coords)
            min_distance = distances.min()
        else:
            min_distance = np.inf

        # this removes chains if they are not close enough to the ligand
        min_distances.append(min_distance)
        lengths.append(count)
        coords.append(chain_coords)
        c_alpha_coords.append(np.array(chain_c_alpha_coords))
        n_coords.append(np.array(chain_n_coords))
        c_coords.append(np.array(chain_c_coords))
        if min_distance < cutoff:
            valid_chain_ids.append(chain.get_id())
    min_distances = np.array(min_distances)
    if len(valid_chain_ids) == 0:
        valid_chain_ids.append(np.argmin(min_distances))
    valid_coords = []
    valid_c_alpha_coords = []
    valid_n_coords = []
    valid_c_coords = []
    valid_lengths = []
    invalid_chain_ids = []
    valid_lm_embeddings = []
    for i, chain in enumerate(rec):
        if chain.get_id() in valid_chain_ids:
            valid_coords.append(coords[i])
            valid_c_alpha_coords.append(c_alpha_coords[i])
            if lm_embedding_chains is not None:
                if i >= len(lm_embedding_chains):
                    raise ValueError('Encountered valid chain id that was not present in the LM embeddings')
                valid_lm_embeddings.append(lm_embedding_chains[i])
            valid_n_coords.append(n_coords[i])
            valid_c_coords.append(c_coords[i])
            valid_lengths.append(lengths[i])
        else:
            invalid_chain_ids.append(chain.get_id())
    coords = [item for sublist in valid_coords for item in sublist]  # list with n_residues arrays: [n_atoms, 3]

    c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0)  # [n_residues, 3]
    n_coords = np.concatenate(valid_n_coords, axis=0)  # [n_residues, 3]
    c_coords = np.concatenate(valid_c_coords, axis=0)  # [n_residues, 3]
    lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None
    for invalid_id in invalid_chain_ids:
        rec.detach_child(invalid_id)

    assert len(c_alpha_coords) == len(n_coords)
    assert len(c_alpha_coords) == len(c_coords)
    assert sum(valid_lengths) == len(c_alpha_coords)
    return rec, coords, c_alpha_coords, n_coords, c_coords, lm_embeddings


In [8]:
def align_prediction(smoothing_factor, pdbbind_calpha_coords, omegafold_calpha_coords, pdbbind_ligand_coords, return_rotation=False):
    pdbbind_dists = spa.distance.cdist(pdbbind_calpha_coords, pdbbind_ligand_coords)
    weights = np.exp(-1 * smoothing_factor * np.amin(pdbbind_dists, axis=1))
    
    pdbbind_calpha_centroid = np.sum(np.expand_dims(weights, axis=1) * pdbbind_calpha_coords, axis=0) / np.sum(weights)
    omegafold_calpha_centroid = np.sum(np.expand_dims(weights, axis=1) * omegafold_calpha_coords, axis=0) / np.sum(weights)
    centered_pdbbind_calpha_coords = pdbbind_calpha_coords - pdbbind_calpha_centroid
    centered_omegafold_calpha_coords = omegafold_calpha_coords - omegafold_calpha_centroid
    centered_pdbbind_ligand_coords = pdbbind_ligand_coords - pdbbind_calpha_centroid
    
    rotation, rec_weighted_rmsd = spa.transform.Rotation.align_vectors(centered_pdbbind_calpha_coords, centered_omegafold_calpha_coords, weights)
    if return_rotation:
        return rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid
    
    aligned_omegafold_calpha_coords = rotation.apply(centered_omegafold_calpha_coords)
    aligned_omegafold_pdbbind_dists = spa.distance.cdist(aligned_omegafold_calpha_coords, centered_pdbbind_ligand_coords)
    inv_r_rmse = np.sqrt(np.mean(((1 / pdbbind_dists) - (1 / aligned_omegafold_pdbbind_dists)) ** 2))
    return inv_r_rmse


In [9]:
def get_alignment_rotation(pdb_id, pdbbind_protein_path, omegafold_protein_path, pdbbind_path):
    pdbbind_rec = parse_pdb_from_path(pdbbind_protein_path)
    omegafold_rec = parse_pdb_from_path(omegafold_protein_path)
    pdbbind_ligand = read_mols(pdbbind_path, pdb_id, remove_hs=True)[0]
    
    pdbbind_calpha_coords = extract_receptor_structure(pdbbind_rec, pdbbind_ligand)[2]
    omegafold_calpha_coords = extract_receptor_structure(omegafold_rec, pdbbind_ligand)[2]
    pdbbind_ligand_coords = pdbbind_ligand.GetConformer().GetPositions()

    if pdbbind_calpha_coords.shape != omegafold_calpha_coords.shape:
        print(f'Receptor structures differ for PDB ID {pdb_id} - Skipping', pdbbind_calpha_coords.shape, omegafold_calpha_coords.shape)
        return None, None, None

    res = minimize(
        align_prediction,
        [0.1],
        bounds=Bounds([0.0],[1.0]),
        args=(
            pdbbind_calpha_coords,
            omegafold_calpha_coords,
            pdbbind_ligand_coords
        ),
        tol=1e-8
    )

    smoothing_factor = res.x
    inv_r_rmse = res.fun
    rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid = align_prediction(
        smoothing_factor,
        pdbbind_calpha_coords,
        omegafold_calpha_coords,
        pdbbind_ligand_coords,
        True
    )

    return rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid

In [None]:
from biopandas.pdb import PandasPdb

for f in tqdm(os.listdir("data/esmfold_structures")):
    pdb_id = f.split("_")[0]
    
    omega_protein_filename = f"data/esmfold_structures/{pdb_id}_protein_esmfold.pdb"
    omega_protein_output_filename = f"data/PDBBind_processed/{pdb_id}/{pdb_id}_protein_esmfold_aligned_tr.pdb"
    
    rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid = get_alignment_rotation(pdb_id, f"data/PDBBind_processed/{pdb_id}/{pdb_id}_protein_processed.pdb", 
                           omega_protein_filename, "data/PDBBind_processed/")
    
    if rotation is None:
        continue
    
    ppdb_omegafold = PandasPdb().read_pdb(omega_protein_filename)
    ppdb_omegafold_pre_rot = ppdb_omegafold.df['ATOM'][['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
    ppdb_omegafold_aligned = rotation.apply(ppdb_omegafold_pre_rot - omegafold_calpha_centroid) + pdbbind_calpha_centroid
    
    
    ppdb_omegafold.df['ATOM'][['x_coord', 'y_coord', 'z_coord']] = ppdb_omegafold_aligned
    ppdb_omegafold.to_pdb(path=omega_protein_output_filename, records=['ATOM'], gz=False)
    