In [1]:
import os
import pandas as pd

from rdkit import Chem # safe import before ccdc imports
from torch_geometric.loader import DataLoader
from torch.utils.data import ConcatDataset

from conf_ensemble_dataset_in_memory import ConfEnsembleDataset
from litschnet import LitSchNet
from molsize_model import MolSizeModel
from rmsd_predictor_evaluator import RMSDPredictorEvaluator
from tqdm import tqdm

# Data preparation

In [2]:
# run once to preprocess datasets and generate chunks
dataset = ConfEnsembleDataset()
# dataset = ConfEnsembleDataset(dataset='platinum') # 16G

In [3]:
def get_test_dataset(split, 
                    iteration, 
                    data_dir='data/',
                    ) :
    
    pdbbind_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('pdbbind')]
    pdbbind_n_chunks = len(pdbbind_chunks)
    
    test_datasets = []
    
    if split in ['random', 'scaffold'] :

        with open(os.path.join(data_dir, f'ligand_{split}_splits', f'train_smiles_{split}_split_{iteration}.txt'), 'r') as f :
            train_smiles = f.readlines()
            train_smiles = [smiles.strip() for smiles in train_smiles]
        
        with open(os.path.join(data_dir, f'ligand_{split}_splits', f'test_smiles_{split}_split_{iteration}.txt'), 'r') as f :
            test_smiles = f.readlines()
            test_smiles = [smiles.strip() for smiles in test_smiles]

        for chunk_number in tqdm(range(pdbbind_n_chunks)) :

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=test_smiles)
            test_datasets.append(dataset)
            
    else : #protein split

        with open(os.path.join(data_dir, 'protein_similarity_splits', f'train_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
            train_pdbs = f.readlines()
            train_pdbs = [pdb.strip() for pdb in train_pdbs]
        
        with open(os.path.join(data_dir, 'protein_similarity_splits', f'test_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
            test_pdbs = f.readlines()
            test_pdbs = [pdb.strip() for pdb in test_pdbs]

        smiles_df = pd.read_csv('data/smiles_df.csv')
        train_smiles = smiles_df[smiles_df['id'].isin(train_pdbs)]['smiles'].values
            
        for chunk_number in tqdm(range(pdbbind_n_chunks)) :

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          pdb_ids_list=test_pdbs)
            test_datasets.append(dataset)

    test_dataset = ConcatDataset(test_datasets)
    
    return test_dataset, train_smiles

In [4]:
def evaluate_model(experiment_name,
                   test_dataset,
                   #platinum_dataset,
                   training_smiles,
                   tasks = ['all', 'easy', 'hard']) :
    
    checkpoint_name = os.listdir(os.path.join('lightning_logs', experiment_name, 'checkpoints'))[0]
    checkpoint_path = os.path.join('lightning_logs', experiment_name, 'checkpoints', checkpoint_name)
    if 'molsize' in experiment_name :
        model = MolSizeModel.load_from_checkpoint(checkpoint_path=checkpoint_path)
    else :
        model = LitSchNet.load_from_checkpoint(checkpoint_path=checkpoint_path)
    
    evaluation_name = experiment_name + '_pdbbind'
    evaluator = RMSDPredictorEvaluator(model=model, 
                                       evaluation_name=evaluation_name,
                                       training_smiles=train_smiles)
    #evaluator.evaluate(test_dataset)
    evaluator.evaluate(test_dataset, overwrite=True)
    for task in tasks :
        evaluator.evaluation_report(task=task)

#     evaluation_name = experiment_name + '_platinum'
#     evaluator = RMSDPredictorEvaluator(model=model, 
#                                        evaluation_name=evaluation_name,
#                                        training_smiles=train_smiles)
#     evaluator.evaluate(platinum_dataset, overwrite=True)
#     for task in tasks :
#         evaluator.evaluation_report(task=task)

In [5]:
data_dir = 'data/'
# platinum_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('platinum')]
# platinum_n_chunks = len(platinum_chunks)

In [6]:
# platinum_datasets = []
# for chunk_number in tqdm(range(platinum_n_chunks)) :
#     dataset = ConfEnsembleDataset(dataset='platinum', loaded_chunk=chunk_number)
#     platinum_datasets.append(dataset)
# platinum_dataset = ConcatDataset(platinum_datasets)

In [7]:
splits = ['random', 'scaffold', 'protein']

In [None]:
%%time

#for split in splits :
for split in ['protein'] :

    for iteration in range(5) :

        test_dataset, train_smiles = get_test_dataset(split, iteration)
        
        experiment_name = f'{split}_split_{iteration}'
        # evaluate_model(experiment_name, test_dataset, platinum_dataset, train_smiles)
        evaluate_model(experiment_name, test_dataset, train_smiles)
        
        experiment_name = f'{split}_split_{iteration}_molsize'
        evaluate_model(experiment_name, test_dataset, train_smiles)

100%|█████████████████████████████████████████████| 3/3 [01:44<00:00, 34.88s/it]


Computing training set fingerprints
Grouping data by smiles


  0%|                                                  | 0/1412 [00:00<?, ?it/s]

Starting evaluation


100%|███████████████████████████████████████| 1412/1412 [01:24<00:00, 16.76it/s]


Computing training set fingerprints
Grouping data by smiles


  0%|                                          | 3/1412 [00:00<01:01, 22.96it/s]

Starting evaluation


100%|███████████████████████████████████████| 1412/1412 [01:03<00:00, 22.33it/s]
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
100%|█████████████████████████████████████████████| 3/3 [01:52<00:00, 37.51s/it]


Computing training set fingerprints
Grouping data by smiles


  0%|                                          | 2/1429 [00:00<01:27, 16.26it/s]

Starting evaluation


100%|███████████████████████████████████████| 1429/1429 [01:26<00:00, 16.59it/s]


Computing training set fingerprints
Grouping data by smiles


  0%|                                          | 3/1429 [00:00<01:10, 20.14it/s]

Starting evaluation


100%|███████████████████████████████████████| 1429/1429 [01:10<00:00, 20.20it/s]
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
100%|█████████████████████████████████████████████| 3/3 [01:55<00:00, 38.63s/it]


Computing training set fingerprints
Grouping data by smiles


  0%|                                          | 2/1373 [00:00<01:54, 11.93it/s]

Starting evaluation


100%|███████████████████████████████████████| 1373/1373 [01:31<00:00, 14.99it/s]


In [10]:
len(test_dataset)

101506

In [12]:
import importlib
import rmsd_predictor_evaluator
importlib.reload(rmsd_predictor_evaluator)
RMSDPredictorEvaluator = rmsd_predictor_evaluator.RMSDPredictorEvaluator

In [13]:
experiment_name = f'{split}_split_{iteration}'
# evaluate_model(experiment_name, test_dataset, platinum_dataset, train_smiles)
evaluate_model(experiment_name, test_dataset, train_smiles)

Computing training set fingerprints
Grouping data by smiles


  0%|                                          | 1/1418 [00:00<03:16,  7.22it/s]

Starting evaluation


  3%|█▏                                       | 40/1418 [00:03<01:42, 13.38it/s]

> [0;32m/home/benoit/bioactive_conformation_predictor/rmsd_predictor_evaluator.py[0m(260)[0;36mmol_evaluation[0;34m()[0m
[0;32m    258 [0;31m[0;34m[0m[0m
[0m[0;32m    259 [0;31m            [0;31m# Generated stats[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 260 [0;31m            [0mgenerated_targets[0m [0;34m=[0m [0mmol_targets[0m[0;34m[[0m[0mis_generated[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    261 [0;31m            [0mgenerated_preds[0m [0;34m=[0m [0mmol_preds[0m[0;34m[[0m[0mis_generated[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    262 [0;31m            [0mgenerated_energies[0m [0;34m=[0m [0mmol_energies[0m[0;34m[[0m[0mis_generated[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  bioactive_preds


array([], dtype=float32)


ipdb>  smiles


'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1'


  3%|█▏                                       | 40/1418 [00:17<01:42, 13.38it/s]

ipdb>  smiles_data_list


[Data(x=[43, 1], edge_index=[2, 92], pos=[43, 3], data_id='Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_0', energy=[1], n_heavy_atoms=[1], n_rotatable_bonds=[1], rmsd=[1]), Data(x=[43, 1], edge_index=[2, 92], pos=[43, 3], data_id='Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_1', energy=[1], n_heavy_atoms=[1], n_rotatable_bonds=[1], rmsd=[1]), Data(x=[43, 1], edge_index=[2, 92], pos=[43, 3], data_id='Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_2', energy=[1], n_heavy_atoms=[1], n_rotatable_bonds=[1], rmsd=[1]), Data(x=[43, 1], edge_index=[2, 92], pos=[43, 3], data_id='Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_3', energy=[1], n_heavy_atoms=[1], n_rotatable_bonds=[1], rmsd=[1]), Data(x=[43, 1], edge_index=[2, 92], pos=[43, 3], data_id='Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_4', energy=[1], n_heavy_atoms=[1], n_rotata

ipdb>  [d.data_id for d in smiles_data_list]


['Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_0', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_1', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_2', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_3', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_4', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_5', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_6', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_7', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_8', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_9', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_10', 'Cc1ccccc1S(=O)(=O)N(CCN(Cc1cncn1C)c1ccc(C#N)cc1)CC1CCN(C(=O)OC(C)(C)C)CC1_Gen_11', '

ipdb>  len([d.data_id for d in smiles_data_list])


100


ipdb>  exit


  3%|█                                     | 40/1418 [25:02<14:22:36, 37.56s/it]


BdbQuit: 