In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
import pandas as pd
import copy

from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset, ConcatDataset
from rdkit import Chem
from conf_ensemble_dataset_in_memory import ConfEnsembleDataset
from rdkit.Chem import AllChem #needed for rdForceFieldHelpers
from collections import defaultdict
from litschnet import LitSchNet
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from rmsd_predictor_evaluator import RMSDPredictorEvaluator
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [2]:
pl.seed_everything(42, workers=True)

Global seed set to 42


42

# Data preparation

In [3]:
data_dir = 'data/'

In [4]:
# run once to preprocess datasets and generate chunks
# dataset = ConfEnsembleDataset()
# dataset = ConfEnsembleDataset(dataset='platinum') # 16G

In [5]:
platinum_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('platinum')]
platinum_n_chunks = len(platinum_chunks)

In [6]:
platinum_datasets = []
for chunk_number in tqdm(range(platinum_n_chunks)) :
    dataset = ConfEnsembleDataset(dataset='platinum', loaded_chunk=chunk_number)
    platinum_datasets.append(dataset)
platinum_dataset = ConcatDataset(platinum_datasets)

100%|█████████████████████████████████████████████| 2/2 [00:07<00:00,  3.92s/it]


In [7]:
pdbbind_chunks = [filename for filename in os.listdir(os.path.join(data_dir, 'processed')) if filename.startswith('pdbbind')]
pdbbind_n_chunks = len(pdbbind_chunks)

In [8]:
splits = ['random', 'scaffold']
tasks = ['all', 'easy', 'hard']

In [9]:
for split in splits :

    for iteration in range(5) :

        with open(os.path.join(data_dir, f'{split}_splits', f'test_smiles_{split}_split_{iteration}.txt'), 'r') as f :
            test_smiles = f.readlines()
            test_smiles = [smiles.strip() for smiles in test_smiles]

        test_datasets = []

        for chunk_number in tqdm(range(pdbbind_n_chunks)) :

            dataset = ConfEnsembleDataset(loaded_chunk=chunk_number,
                                          smiles_list=test_smiles)
            test_datasets.append(dataset)

        test_dataset = ConcatDataset(test_datasets)

        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

        experiment_name = f'{split}_split_{iteration}_new'
        if experiment_name in os.listdir('lightning_logs') :
            checkpoint_name = os.listdir(os.path.join('lightning_logs', experiment_name, 'checkpoints'))[0]
            checkpoint_path = os.path.join('lightning_logs', experiment_name, 'checkpoints', checkpoint_name)
            litschnet = LitSchNet.load_from_checkpoint(checkpoint_path=checkpoint_path)

        evaluation_name = experiment_name + '_pdbbind'
        evaluator = RMSDPredictorEvaluator(model=litschnet, evaluation_name=evaluation_name)
        evaluator.evaluate(test_dataset)
        for task in tasks :
            evaluator.evaluation_report(task=task)

        evaluation_name = experiment_name + '_platinum'
        evaluator = RMSDPredictorEvaluator(model=litschnet, evaluation_name=evaluation_name)
        evaluator.evaluate(platinum_dataset)
        for task in tasks :
            evaluator.evaluation_report(task=task)

100%|█████████████████████████████████████████████| 3/3 [01:03<00:00, 21.27s/it]


Grouping data by smiles


  0%|                                                  | 0/1085 [00:00<?, ?it/s]

Starting evaluation


100%|███████████████████████████████████████| 1085/1085 [00:21<00:00, 49.89it/s]


Grouping data by smiles


  0%|                                          | 4/4548 [00:00<01:55, 39.40it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:26<00:00, 52.34it/s]
100%|█████████████████████████████████████████████| 3/3 [01:06<00:00, 22.16s/it]


Grouping data by smiles


  0%|▏                                         | 5/1085 [00:00<00:25, 41.72it/s]

Starting evaluation


100%|███████████████████████████████████████| 1085/1085 [00:21<00:00, 49.38it/s]


Grouping data by smiles


  0%|                                          | 3/4548 [00:00<02:35, 29.16it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:28<00:00, 51.22it/s]
100%|█████████████████████████████████████████████| 3/3 [01:08<00:00, 22.82s/it]


Grouping data by smiles


  1%|▍                                        | 10/1089 [00:00<00:13, 80.12it/s]

Starting evaluation


100%|███████████████████████████████████████| 1089/1089 [00:22<00:00, 49.14it/s]


Grouping data by smiles


  0%|                                          | 3/4548 [00:00<02:36, 28.99it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:27<00:00, 51.80it/s]
100%|█████████████████████████████████████████████| 3/3 [01:07<00:00, 22.49s/it]


Grouping data by smiles


  0%|▏                                         | 4/1088 [00:00<00:32, 33.68it/s]

Starting evaluation


100%|███████████████████████████████████████| 1088/1088 [00:21<00:00, 50.00it/s]


Grouping data by smiles


  0%|                                          | 2/4548 [00:00<03:50, 19.76it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:27<00:00, 51.93it/s]
100%|█████████████████████████████████████████████| 3/3 [01:06<00:00, 22.15s/it]


Grouping data by smiles


  1%|▎                                         | 7/1088 [00:00<00:18, 57.01it/s]

Starting evaluation


100%|███████████████████████████████████████| 1088/1088 [00:21<00:00, 50.55it/s]


Grouping data by smiles


  0%|                                          | 2/4548 [00:00<03:57, 19.17it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:29<00:00, 50.88it/s]
100%|█████████████████████████████████████████████| 3/3 [01:07<00:00, 22.56s/it]


Grouping data by smiles


  0%|▏                                         | 5/1098 [00:00<00:23, 46.94it/s]

Starting evaluation


100%|███████████████████████████████████████| 1098/1098 [00:22<00:00, 49.86it/s]


Grouping data by smiles


  0%|                                          | 2/4548 [00:00<04:11, 18.09it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:27<00:00, 51.79it/s]
100%|█████████████████████████████████████████████| 3/3 [01:08<00:00, 22.88s/it]


Grouping data by smiles


  0%|▏                                         | 5/1093 [00:00<00:25, 42.26it/s]

Starting evaluation


100%|███████████████████████████████████████| 1093/1093 [00:21<00:00, 50.75it/s]


Grouping data by smiles


  0%|                                          | 2/4548 [00:00<03:53, 19.43it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:28<00:00, 51.63it/s]
100%|█████████████████████████████████████████████| 3/3 [01:06<00:00, 22.17s/it]


Grouping data by smiles


  0%|▏                                         | 5/1100 [00:00<00:24, 44.96it/s]

Starting evaluation


100%|███████████████████████████████████████| 1100/1100 [00:23<00:00, 47.73it/s]


Grouping data by smiles


  0%|                                          | 2/4548 [00:00<03:56, 19.21it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:28<00:00, 51.52it/s]
100%|█████████████████████████████████████████████| 3/3 [01:07<00:00, 22.48s/it]


Grouping data by smiles


  1%|▏                                         | 6/1069 [00:00<00:21, 50.40it/s]

Starting evaluation


100%|███████████████████████████████████████| 1069/1069 [00:19<00:00, 54.85it/s]


Grouping data by smiles


  0%|                                          | 2/4548 [00:00<03:57, 19.14it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:27<00:00, 52.22it/s]
100%|█████████████████████████████████████████████| 3/3 [01:08<00:00, 22.94s/it]


Grouping data by smiles


  1%|▏                                         | 6/1091 [00:00<00:19, 55.18it/s]

Starting evaluation


100%|███████████████████████████████████████| 1091/1091 [00:21<00:00, 50.02it/s]


Grouping data by smiles


  0%|                                          | 3/4548 [00:00<02:38, 28.68it/s]

Starting evaluation


100%|███████████████████████████████████████| 4548/4548 [01:25<00:00, 53.28it/s]
