In [1]:
import tempfile
import os
import pickle
import torch

from molecule_featurizer import MoleculeFeaturizer
from litschnet import LitSchNet
from ccdc_rdkit_connector import CcdcRdkitConnector
from ccdc.pharmacophore import Pharmacophore
from ccdc.descriptors import GeometricDescriptors
from ccdc.utilities import Colour
from ccdc import io
from torch_geometric.data import Batch

In [2]:
protein_file = '/home/benoit/bioactive_conformation_predictor/gold_docking/3_all/5jh6_protein.pdb'
native_ligand_file = '/home/benoit/bioactive_conformation_predictor/gold_docking/3_all/5jh6_ligand.mol2'
ligand_files = [f'/home/benoit/bioactive_conformation_predictor/gold_docking/3_all/3_all_{i}.mol2' for i in range(100)]
all_ligand_files = '5jh6_generated_confs.mol2'

In [3]:
with io.MoleculeWriter(all_ligand_files) as mol_writer:
    mol = io.MoleculeReader(native_ligand_file)[0]
    mol.identifier = f'native_ligand'
    mol_writer.write(mol)
    for i, ligand_file in enumerate(ligand_files):
        mol = io.MoleculeReader(ligand_file)[0]
        mol.identifier = f'ligand_{i}'
        mol_writer.write(mol)

In [4]:
Pharmacophore.read_feature_definitions()
feature_definitions = [
                fd for fd in Pharmacophore.feature_definitions.values()
                if fd.identifier != 'exit_vector' and
                fd.identifier != 'heavy_atom' and
                fd.identifier != 'hydrophobe'
            ]

In [5]:
native_ligand = io.CrystalReader(native_ligand_file)[0]

In [6]:
ring_feature_def = Pharmacophore.feature_definitions['ring']
ring_features = ring_feature_def.detect_features(native_ligand)
print(len(ring_features))

4


In [7]:
donor_proj_def = Pharmacophore.feature_definitions['donor_projected']
donor_proj_features = donor_proj_def.detect_features(native_ligand)
print(len(donor_proj_features))

3


In [8]:
query = Pharmacophore.Query(ring_features + donor_proj_features)

In [9]:
mol2_info = Pharmacophore.FeatureDatabase.DatabaseInfo(all_ligand_files, 0, Colour(0, 255, 0, 255))

In [10]:
csdsqlx = os.path.join(all_ligand_files.replace('.mol2', '.csdsqlx'))
mol2_sdb = Pharmacophore.FeatureDatabase.Creator.StructureDatabase(mol2_info, use_crystal_symmetry=False, structure_database_path=csdsqlx)

In [11]:
creator = Pharmacophore.FeatureDatabase.Creator()

In [12]:
db = creator.create([mol2_sdb])

In [13]:
print(len(db))

101


In [14]:
searcher = Pharmacophore.Search()

In [15]:
hits = searcher.search(query, database=db)

In [16]:
hits

[<ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eece550>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eecebd0>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eece910>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eecec90>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eececd0>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eeced10>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eeced50>,
 <ccdc.pharmacophore.Pharmacophore.SearchHit at 0x7fc65eecec50>]

In [17]:
for hit in hits :
    print(hit.identifier)

native_ligand
ligand_51
ligand_5
ligand_64
ligand_4
ligand_63
ligand_52
ligand_36


In [52]:
indexes = [int(s.split('_')[1]) for s in [hit.identifier for hit in hits][1:]]

In [18]:
data_dir = 'data/'
encoder_path = os.path.join(data_dir, 'molecule_encoders.p')
if os.path.exists(encoder_path) : # Load existing encoder
    with open(encoder_path, 'rb') as f:
        mol_encoders = pickle.load(f)
mol_featurizer = MoleculeFeaturizer(mol_encoders)

experiment_name = f'random_split_0_new'
if experiment_name in os.listdir('lightning_logs') :
    checkpoint_name = os.listdir(os.path.join('lightning_logs', experiment_name, 'checkpoints'))[0]
    checkpoint_path = os.path.join('lightning_logs', experiment_name, 'checkpoints', checkpoint_name)
    litschnet = LitSchNet.load_from_checkpoint(checkpoint_path=checkpoint_path)
litschnet.eval()

LitSchNet(
  (schnet): SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0)
  (leaky_relu): LeakyReLU(negative_slope=0.01)
  (sigmoid): Sigmoid()
)

In [19]:
# dummy_mol = copy.deepcopy(ce.mol)
# dummy_mol.RemoveAllConformers()
# for conf_id in generated_ids :
#     dummy_mol.AddConformer(ce.mol.GetConformer(conf_id), assignId=True)

In [45]:
connector = CcdcRdkitConnector()
mol = io.MoleculeReader(all_ligand_files)[1]
rdkit_mol = connector.ccdc_mol_to_rdkit_mol(mol)
ccdc_mols = [mol for mol in io.MoleculeReader(all_ligand_files)][2:]
connector.ccdc_mols_to_rdkit_mol_conformers(ccdc_mols, rdkit_mol)
data_list = mol_featurizer.featurize_mol(rdkit_mol)
batch = Batch.from_data_list(data_list)

In [46]:
from mol_viewer import MolViewer
MolViewer().view(rdkit_mol)

interactive(children=(IntSlider(value=0, description='conf_id', max=99), Output()), _dom_classes=('widget-inte…

In [54]:
indexes

[51, 5, 64, 4, 63, 52, 36]

In [56]:
set(indexes).intersection(set(range(20)))

{4, 5}

In [63]:
with torch.no_grad() :
    preds = litschnet(batch).cpu().numpy()
preds = preds.reshape(-1)
top20_index = preds.argsort()[:20]
# sorted_ccdc_mols = [ccdc_mol for i, ccdc_mol in enumerate(ccdc_mols) if i in top20_index]

In [65]:
set(indexes).intersection(set(top20_index))

{4, 5, 51, 52}

In [66]:
rmsds = mol_featurizer.get_bioactive_rmsds(rdkit_mol)
top20_index = rmsds.argsort()[:20]

In [68]:
set(indexes).intersection(set(top20_index))

set()

In [71]:
import numpy as np
top20_index = np.array([data.energy for data in data_list]).argsort()[:20]

In [72]:
set(indexes).intersection(set(top20_index))

set()