In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
from rdkit import Chem
from data.utils.similarity_search import SimilaritySearch
from conf_ensemble import ConfEnsembleLibrary
from rankers.tfd_ranker_sim import TFD2SimRefMCSRanker
from data.utils.enzyme_connector import ENZYMEConnector
from data.utils.pdbbind import PDBbindMetadataProcessor
from data.utils.chembl_connector import ChEMBLConnector
from collections import defaultdict
from rdkit.Chem.rdFMCS import FindMCS
from rdkit.Chem.TorsionFingerprints import GetTFDMatrix
from rdkit.Chem.rdchem import Mol, EditableMol

In [None]:
figures_dir = '../hdd/pdbbind_bioactive/figures/'

In [None]:
root = '../hdd/pdbbind_bioactive/data'


cel_df = pd.read_csv(os.path.join(root, 'pdb_conf_ensembles', 'ensemble_names.csv'))
pdb_df = pd.read_csv(os.path.join(root, 'pdb_conf_ensembles', 'pdb_df.csv'))
pdb_df = pdb_df.merge(cel_df, left_on='ligand_name', right_on='ensemble_name')
pdbbind_df = PDBbindMetadataProcessor().get_master_dataframe()
pdbbind_df = pdbbind_df.merge(pdb_df, left_on='PDB code', right_on='pdb_id')

In [None]:
len(pdbbind_df['PDB code'].unique())

In [None]:
cc = ChEMBLConnector()
chembl_table = cc.get_target_table(level=1)
pdbbind_df = pdbbind_df.merge(chembl_table, left_on='Uniprot ID', right_on='accession')

In [None]:
len(pdbbind_df['PDB code'].unique())

In [None]:
9428 / 13460

In [None]:
pdbbind_df['level1'].value_counts()

In [None]:
7322/9428

In [None]:
ec = ENZYMEConnector()
enzyme_table = ec.get_table()
pdbbind_df = pdbbind_df.merge(enzyme_table, left_on='Uniprot ID', right_on='uniprot_id')

In [None]:
len(pdbbind_df['PDB code'].unique())

In [None]:
cel = ConfEnsembleLibrary()

In [None]:
ss = SimilaritySearch(cel_df['smiles'].values)

In [None]:
def get_editable_mol_match(mol, match):
    edit_mol = EditableMol(mol)
    idx_to_remove = []
    for a in mol.GetAtoms():
        atom_idx = a.GetIdx()
        if not atom_idx in match:
            idx_to_remove.append(atom_idx)
    for idx in reversed(idx_to_remove):
        edit_mol.RemoveAtom(idx)
    return edit_mol

def get_full_matrix_from_tril(tril_matrix, n):
    matrix = np.zeros((n, n))
    i=1
    j=0
    for v in tril_matrix:
        matrix[i, j] = matrix[j, i] = v
        j = j + 1
        if j == i:
            i = i + 1
            j = 0
    return matrix

In [None]:
min_tfds = {}
original_ec = defaultdict(list)
closest_ec = defaultdict(list)
mcs_sizes = {}
for ligand_name, ce in tqdm(cel.library.items()):
    try:
        mol = ce.mol
        smiles = Chem.MolToSmiles(mol)

        closest_smiles_list, sim = ss.find_closest_in_set(smiles)
        closest_smiles = closest_smiles_list[0]
        closest_name = cel_df[cel_df['smiles'] == closest_smiles]['ensemble_name'].values[0]
        closest_ce = cel.library[closest_name]
        closest_mol = closest_ce.mol

        ref_mol = closest_mol

        mcs = FindMCS([ref_mol, mol], 
                      timeout=5, 
                      matchChiralTag=True,)
                    #   ringMatchesRingOnly=True)
        smarts = mcs.smartsString
        mcs_mol = Chem.MolFromSmarts(smarts)

        ref_mol_match = ref_mol.GetSubstructMatch(mcs_mol)
        pdb_edit_mol = get_editable_mol_match(ref_mol, ref_mol_match)
        new_ref_mol = pdb_edit_mol.GetMol()
        new_pdb_match = new_ref_mol.GetSubstructMatch(mcs_mol)
        new_ref_mol = Chem.RenumberAtoms(new_ref_mol, new_pdb_match)

        mol_match = mol.GetSubstructMatch(mcs_mol)
        gen_edit_mol = get_editable_mol_match(mol, mol_match)
        new_mol = gen_edit_mol.GetMol()
        new_gen_match = new_mol.GetSubstructMatch(mcs_mol)
        new_mol = Chem.RenumberAtoms(new_mol, new_gen_match)

        bio_conf_idx = []
        for conf in new_ref_mol.GetConformers():
            conf_id = mcs_mol.AddConformer(conf, assignId=True)
            bio_conf_idx.append(conf_id)

        gen_conf_idx = []
        for conf in new_mol.GetConformers():
            conf_id = mcs_mol.AddConformer(conf, assignId=True)
            gen_conf_idx.append(conf_id)
    
        Chem.SanitizeMol(mcs_mol)
        tfd_matrix = GetTFDMatrix(mcs_mol)
        tfd_matrix = get_full_matrix_from_tril(tfd_matrix, 
                                                n=mcs_mol.GetNumConformers())

        n_ref_confs = len(bio_conf_idx)
        tfds = tfd_matrix[:n_ref_confs, n_ref_confs:]
        min_tfd = tfds.min(0)

        if len(min_tfd) == mol.GetNumConformers() :
            mcs_sizes[ligand_name] = mcs_mol.GetNumHeavyAtoms()
            min_tfds[ligand_name] = min_tfd.min()

            subset_pdbbind_df = pdbbind_df[pdbbind_df['smiles'] == smiles]
            ecs = subset_pdbbind_df['level_4'].values
            for ec in ecs:
                original_ec[ligand_name].append(ec)

            subset_pdbbind_df = pdbbind_df[pdbbind_df['smiles'] == closest_smiles]
            ecs = subset_pdbbind_df['level_4'].values
            for ec in ecs:
                closest_ec[ligand_name].append(ec)
    except:
        pass

In [None]:
equals = []
not_in_closest = []
no_closest = []
for ligand_name, ecs1 in original_ec.items():
    
    closest_is_original = False
    if ligand_name in closest_ec:
        
        ecs2 = closest_ec[ligand_name]
        for ec1 in ecs1:
            if ec1 in ecs2:
                closest_is_original = True
                break
                
        if closest_is_original:
            equals.append(ligand_name)
        else:
            not_in_closest.append(ligand_name)
            
    else:
        no_closest.append(ligand_name)

In [None]:
len(equals)

In [None]:
len(not_in_closest)

In [None]:
4494 / (4494 + 1503)

In [None]:
len(no_closest)

In [None]:
rows = []
for ligand_name, tfd in min_tfds.items():
    row = {}
    row['Ligand name'] = ligand_name
    row['TFD'] = tfd
    ec = None
    if ligand_name in equals:
        ec = 'Same as closest reference molecule'
    if ligand_name in not_in_closest:
        ec = 'Different to closest reference molecule'
    row['Enzyme class'] = ec
    row['MCS size'] = mcs_sizes[ligand_name]
    rows.append(row)

In [None]:
df = pd.DataFrame(rows)

In [None]:
def custom_agg(series) :
    result = ''
    
    mean = series.mean().round(2)
    if not np.isnan(mean) :
        result = result + str(mean)
        
    std = series.std().round(2)
    if not np.isnan(std) :
        result = result + ' ± ' + str(std)
    
    if result == '' :
        result = 'NA'
    return result

In [None]:
df.groupby('Enzyme class').agg(custom_agg)

In [None]:
df.groupby('Enzyme class').median()

In [None]:
sns.histplot(data=df,
            x='MCS size',
            hue='Enzyme class',
            stat='proportion',
            common_norm=False)
plt.xlabel('Size of the MCS to the closest reference molecule \n (number of heavy atoms)')
plt.savefig(os.path.join(figures_dir, 'TFD_MCS_distribution_hist.png'),
           dpi=300,
           bbox_inches='tight')

In [None]:
sns.histplot(data=df,
            x='TFD',
            hue='Enzyme class',
            stat='proportion',
            common_norm=False)
plt.xlabel('TFD of the MCS to the closest reference molecule')
plt.savefig(os.path.join(figures_dir, 'TFD_EC_distribution_hist.png'),
           dpi=300,
           bbox_inches='tight')

In [None]:
sns.kdeplot(data=df,
            x='TFD',
            hue='Enzyme class',
           common_norm=False)
plt.savefig(os.path.join(figures_dir, 'TFD_EC_distribution_hist.png'),
           dpi=300,
           bbox_inches='tight')

In [None]:
mt = [tfd for ligand_name, tfd in min_tfds.items()]

In [None]:
sns.ecdfplot(data=mt)

In [None]:
plt.hist(mt)