In [1]:
import os
import pytorch_lightning as pl
import pandas as pd

from rdkit import Chem # safe import before ccdc imports
from torch_geometric.loader import DataLoader
from torch.utils.data import ConcatDataset

from conf_ensemble_dataset_in_memory import ConfEnsembleDataset
from litschnet import LitSchNet
from molsize_model import MolSizeModel
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from tqdm import tqdm

In [2]:
pl.seed_everything(42, workers=True)

Global seed set to 42


42

# Data preparation

In [3]:
# run once to preprocess datasets and generate chunks
dataset = ConfEnsembleDataset()
# dataset = ConfEnsembleDataset(dataset='platinum')

In [4]:
splits = ['random', 'scaffold', 'protein']

In [5]:
def get_loaders(split, 
                iteration, 
                data_dir='data/',
                ) :
    
    pdbbind_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('pdbbind')]
    pdbbind_n_chunks = len(pdbbind_chunks)
    
    train_datasets = []
    val_datasets = []
    test_datasets = []
    
    if split in ['random', 'scaffold'] :
        
        with open(os.path.join(data_dir, f'ligand_{split}_splits', f'train_smiles_{split}_split_{iteration}.txt'), 'r') as f :
            train_smiles = f.readlines()
            train_smiles = [smiles.strip() for smiles in train_smiles]

        with open(os.path.join(data_dir, f'ligand_{split}_splits', f'val_smiles_{split}_split_{iteration}.txt'), 'r') as f :
            val_smiles = f.readlines()
            val_smiles = [smiles.strip() for smiles in val_smiles]

        with open(os.path.join(data_dir, f'ligand_{split}_splits', f'test_smiles_{split}_split_{iteration}.txt'), 'r') as f :
            test_smiles = f.readlines()
            test_smiles = [smiles.strip() for smiles in test_smiles]

        for chunk_number in tqdm(range(pdbbind_n_chunks)) :

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=train_smiles)
            train_datasets.append(dataset)

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=val_smiles)
            val_datasets.append(dataset)

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=test_smiles)
            test_datasets.append(dataset)
            
    elif split == 'ecfp' :
        
        with open(os.path.join(data_dir, f'ecfp_similarity_splits', f'train_smiles_{split}_similarity_split_{iteration}.txt'), 'r') as f :
            train_smiles = f.readlines()
            train_smiles = [smiles.strip() for smiles in train_smiles]

        with open(os.path.join(data_dir, f'ecfp_similarity_splits', f'val_smiles_{split}_similarity_split_{iteration}.txt'), 'r') as f :
            val_smiles = f.readlines()
            val_smiles = [smiles.strip() for smiles in val_smiles]

        with open(os.path.join(data_dir, f'ecfp_similarity_splits', f'test_smiles_{split}_similarity_split_{iteration}.txt'), 'r') as f :
            test_smiles = f.readlines()
            test_smiles = [smiles.strip() for smiles in test_smiles]

        for chunk_number in tqdm(range(pdbbind_n_chunks)) :

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=train_smiles)
            train_datasets.append(dataset)

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=val_smiles)
            val_datasets.append(dataset)

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=test_smiles)
            test_datasets.append(dataset)
        
            
    elif split == 'protein' : #protein split
        
        with open(os.path.join(data_dir, 'protein_similarity_splits', f'train_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
            train_pdbs = f.readlines()
            train_pdbs = [pdb.strip() for pdb in train_pdbs]

        with open(os.path.join(data_dir, 'protein_similarity_splits', f'val_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
            val_pdbs = f.readlines()
            val_pdbs = [pdb.strip() for pdb in val_pdbs]

        with open(os.path.join(data_dir, 'protein_similarity_splits', f'test_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
            test_pdbs = f.readlines()
            test_pdbs = [pdb.strip() for pdb in test_pdbs]

        for chunk_number in tqdm(range(pdbbind_n_chunks)) :

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          pdb_ids_list=train_pdbs)
            train_datasets.append(dataset)

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          pdb_ids_list=val_pdbs)
            val_datasets.append(dataset)

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          pdb_ids_list=test_pdbs)
            test_datasets.append(dataset)

    train_dataset = ConcatDataset(train_datasets)
    val_dataset = ConcatDataset(val_datasets)
    test_dataset = ConcatDataset(test_datasets)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader

In [6]:
# for split in splits :
# for split in ['protein', 'scaffold', 'random'] :
for split in ['ecfp'] :
    
    for iteration in range(5) :
    
        train_loader, val_loader, test_loader = get_loaders(split, iteration)
        
        experiment_name = f'{split}_split_{iteration}'
        if not experiment_name in os.listdir('lightning_logs') :
            litschnet = LitSchNet()
            logger = TensorBoardLogger(save_dir=os.getcwd(), version=experiment_name, name="lightning_logs")
            trainer = pl.Trainer(logger=logger, max_epochs=20, gpus=1)
            trainer.fit(litschnet, train_loader, val_loader)
            trainer.test(litschnet, test_loader)
            
        experiment_name = f'{split}_split_{iteration}_molsize'
        if not experiment_name in os.listdir('lightning_logs') :
            molsize_model = MolSizeModel()
            logger = TensorBoardLogger(save_dir=os.getcwd(), version=experiment_name, name="lightning_logs")
            trainer = pl.Trainer(logger=logger, max_epochs=20, gpus=1)
            trainer.fit(molsize_model, train_loader, val_loader)
            trainer.test(molsize_model, test_loader)

  0%|                                                     | 0/3 [00:32<?, ?it/s]


KeyboardInterrupt: 

In [7]:
from collections import defaultdict
n_bioactive_conformations = defaultdict(list)
n_conformations = defaultdict(list)
data_dir = 'data/'

for split in splits :
    
    for iteration in range(5) :
    
        pdbbind_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('pdbbind')]
        pdbbind_n_chunks = len(pdbbind_chunks)

        train_datasets = []
        val_datasets = []
        test_datasets = []

        if split in ['random', 'scaffold'] :

            with open(os.path.join(data_dir, f'ligand_{split}_splits', f'train_smiles_{split}_split_{iteration}.txt'), 'r') as f :
                train_smiles = f.readlines()
                train_smiles = [smiles.strip() for smiles in train_smiles]

            with open(os.path.join(data_dir, f'ligand_{split}_splits', f'val_smiles_{split}_split_{iteration}.txt'), 'r') as f :
                val_smiles = f.readlines()
                val_smiles = [smiles.strip() for smiles in val_smiles]

            with open(os.path.join(data_dir, f'ligand_{split}_splits', f'test_smiles_{split}_split_{iteration}.txt'), 'r') as f :
                test_smiles = f.readlines()
                test_smiles = [smiles.strip() for smiles in test_smiles]

            for chunk_number in tqdm(range(pdbbind_n_chunks)) :

                dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                              smiles_list=train_smiles)
                train_datasets.append(dataset)

                dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                              smiles_list=val_smiles)
                val_datasets.append(dataset)

                dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                              smiles_list=test_smiles)
                test_datasets.append(dataset)

        else : #protein split

            with open(os.path.join(data_dir, 'protein_similarity_splits', f'train_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
                train_pdbs = f.readlines()
                train_pdbs = [pdb.strip() for pdb in train_pdbs]

            with open(os.path.join(data_dir, 'protein_similarity_splits', f'val_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
                val_pdbs = f.readlines()
                val_pdbs = [pdb.strip() for pdb in val_pdbs]

            with open(os.path.join(data_dir, 'protein_similarity_splits', f'test_pdb_protein_similarity_split_{iteration}.txt'), 'r') as f :
                test_pdbs = f.readlines()
                test_pdbs = [pdb.strip() for pdb in test_pdbs]

            for chunk_number in tqdm(range(pdbbind_n_chunks)) :

                dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                              pdb_ids_list=train_pdbs)
                train_datasets.append(dataset)

                dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                              pdb_ids_list=val_pdbs)
                val_datasets.append(dataset)

                dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                              pdb_ids_list=test_pdbs)
                test_datasets.append(dataset)

        train_dataset = ConcatDataset(train_datasets)
        val_dataset = ConcatDataset(val_datasets)
        test_dataset = ConcatDataset(test_datasets)
        
        d = {
            'train' : train_dataset,
            'val' : val_dataset,
            'test' : test_dataset
        }
        for s, dataset in d.items() :
            n_conformations[s].append(len(dataset))
            n_bio = 0
            for data in dataset :
                if data.rmsd == 0 :
                    n_bio = n_bio + 1
            n_bioactive_conformations[s].append(n_bio)

100%|█████████████████████████████████████████████| 3/3 [04:39<00:00, 93.22s/it]
100%|█████████████████████████████████████████████| 3/3 [04:36<00:00, 92.04s/it]
100%|█████████████████████████████████████████████| 3/3 [04:44<00:00, 94.87s/it]
100%|█████████████████████████████████████████████| 3/3 [04:43<00:00, 94.44s/it]
100%|█████████████████████████████████████████████| 3/3 [04:44<00:00, 94.71s/it]
100%|████████████████████████████████████████████| 3/3 [05:35<00:00, 111.82s/it]
100%|█████████████████████████████████████████████| 3/3 [04:53<00:00, 97.93s/it]
100%|█████████████████████████████████████████████| 3/3 [04:52<00:00, 97.57s/it]
100%|█████████████████████████████████████████████| 3/3 [04:56<00:00, 98.79s/it]
100%|█████████████████████████████████████████████| 3/3 [04:53<00:00, 97.87s/it]
100%|████████████████████████████████████████████| 3/3 [08:15<00:00, 165.05s/it]
100%|████████████████████████████████████████████| 3/3 [08:09<00:00, 163.14s/it]
100%|███████████████████████

In [8]:
data_dir='data/'
pdbbind_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('pdbbind')]
pdbbind_n_chunks = len(pdbbind_chunks)
n_confs = 0
n_bio = 0

for chunk_number in range(pdbbind_n_chunks) :
    dataset = ConfEnsembleDataset(loaded_chunk=chunk_number)
    n_confs = n_confs + len(dataset)
    for data in dataset :
        if data.rmsd == 0 :
            n_bio = n_bio + 1
print(n_confs)
print(n_bio)

952880
15514


In [9]:
n_conformations

defaultdict(list,
            {'train': [766598,
              762302,
              761320,
              759701,
              762338,
              762491,
              765350,
              762653,
              758790,
              764836,
              767698,
              765616,
              773956,
              771798,
              774111],
             'val': [92863,
              95435,
              94898,
              95784,
              95187,
              92999,
              89046,
              96698,
              96347,
              87620,
              111143,
              104200,
              102109,
              104759,
              102038],
             'test': [93419,
              95143,
              96662,
              97395,
              95355,
              96219,
              97313,
              92358,
              96572,
              99253,
              101283,
              108352,
              102417,
              105115,
        

In [10]:
n_bioactive_conformations

defaultdict(list,
            {'train': [12429,
              12479,
              12424,
              12492,
              12411,
              12377,
              12284,
              12570,
              12445,
              12457,
              12341,
              12307,
              12459,
              12456,
              12399],
             'val': [1545,
              1550,
              1549,
              1498,
              1508,
              1570,
              1622,
              1588,
              1549,
              1588,
              1582,
              1618,
              1536,
              1493,
              1597],
             'test': [1540,
              1485,
              1541,
              1524,
              1595,
              1550,
              1591,
              1339,
              1503,
              1452,
              1610,
              1600,
              1532,
              1582,
              1533]})