In [None]:
%cd ../

In [1]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2

In [65]:
%%capture
import numpy as np
import pandas as pd
import os
from collections import Counter

import rdkit as rd
from rdkit import DataStructs
from rdkit.Chem import AllChem

import sklearn as sk
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import seaborn as sns

import os
import time
import random
import joblib
import shutil

import torch
from torch.utils.data import Dataset, DataLoader
from toxicity.model import MTDNN

import tensorflow as tf
import datetime
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

In [91]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter('tensorboard/')

random_seed = 0
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

morgan_bits = 4096
morgan_radius = 2
train_epoch = 50
batch = 512

In [92]:
clintox_task = ['CT_TOX']
tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', 
               'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',
               'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']

all_tasks = tox21_tasks + clintox_task 

In [93]:
data_path = "./data/toxicity/datasets/tox21/split_data/seed_124/"
train_data=torch.load(data_path + 'train_data_tox21.pth')
test_data=torch.load(data_path + 'test_data_tox21.pth')
valid_data=torch.load(data_path + 'valid_data_tox21.pth')

clintox_data_path = "./data/toxicity/datasets/clintox/split_data/seed_124/"
train_data_clintox=torch.load(clintox_data_path + 'train_data_clintox.pth')
test_data_clintox=torch.load(clintox_data_path + 'test_data_clintox.pth')
valid_data_clintox=torch.load(clintox_data_path + 'valid_data_clintox.pth')

train_data = train_data.merge(train_data_clintox, how='outer', on='smiles')
test_data  = test_data.merge(test_data_clintox, how='outer', on='smiles')
valid_data  = valid_data.merge(valid_data_clintox, how='outer', on='smiles')


data = [train_data, test_data, valid_data]

In [94]:
smiles_embed = torch.load("./data/toxicity/smiles_embedding/smiles_embed_pretrain.pt")
for i in range(len(data)):
    data[i]['smiles_embed'] = data[i]['smiles'].apply(lambda x: smiles_embed.get(x))

In [95]:
data[0] = data[0].fillna(-1)
data[1] = data[1].fillna(-1)
data[2] = data[2].fillna(-1)

In [96]:
train_data = data[0]
test_data  = data[1]
valid_data = data[2]

In [97]:
x_train = []
for tensor in train_data['smiles_embed']:
    x_train.append(tensor)

x_train = torch.stack(x_train)
x_train = x_train.numpy()

y_train = train_data[all_tasks].values

In [98]:
x_test = []

for tensor in test_data['smiles_embed']:
    x_test.append(tensor)
x_test = torch.stack(x_test)
x_test = x_test.numpy()

y_test = test_data[all_tasks].values

In [99]:
x_valid = []
for tensor in valid_data['smiles_embed']:
    x_valid.append(tensor)
x_valid = torch.stack(x_valid)
x_valid = x_valid.numpy()

    
y_valid = valid_data[all_tasks].values

In [100]:
N_train = np.sum(y_train >= 0, 0)
N_test  = np.sum(y_test >= 0, 0)
N_valid  = np.sum(y_valid >= 0, 0)

In [110]:
x_train_torch = x_train.astype(np.float32)
y_train_torch = y_train.astype(np.float32)

x_test_torch = x_test.astype(np.float32)
y_test_torch = y_test.astype(np.float32)

x_valid_torch = x_valid.astype(np.float32)
y_valid_torch = y_valid.astype(np.float32)

input_shape = x_train_torch.shape[1]

In [111]:
class MTDNNData(Dataset):

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [112]:
training_set = MTDNNData(x_train_torch, y_train_torch)
training_generator = DataLoader(training_set, batch_size=batch, shuffle=True)

testing_set = MTDNNData(x_test_torch, y_test_torch)
testing_generator = DataLoader(testing_set, batch_size=len(testing_set), shuffle=False)

valid_set = MTDNNData(x_valid_torch, y_valid_torch)
valid_generator = DataLoader(valid_set, batch_size=len(valid_set), shuffle=False)

In [113]:
def save_ckp(state, is_best, checkpoint_path, best_model_path):
    # Method from : https://gist.github.com/vsay01/45dfced69687077be53dbdd4987b6b17
    f_path = checkpoint_path
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_path
        shutil.copyfile(f_path, best_fpath)
        
def load_ckp(checkpoint_fpath, input_model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    input_model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    train_loss_min = checkpoint['train_loss_min']
    return model, optimizer, checkpoint['epoch'], train_loss_min.item()

In [114]:
ckpt_path = "./checkpoints/toxicity"

if not os.path.exists(ckpt_path):
    os.mkdir(ckpt_path)

checkpoint_file= ckpt_path + '/current_checkpoint.pt'
bestmodel_file = ckpt_path + '/best_model.pt'  
bestmodel_byvalid_file = ckpt_path + '/best_model_by_valid.pt' 
bestmodel_byvalid_crossed_file = ckpt_path + '/best_model_by_valid-crossed.pt'   

In [115]:
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
model = MTDNN(input_shape, all_tasks).to(device)

In [116]:
loss_history=[]  
correct_history=[]  
val_loss_history=[]  
val_correct_history=[] 
train_loss_min = np.Inf
val_loss_min = np.Inf

for e in range(train_epoch):
    
    model.train()

    running_train_loss = 0
    running_valid_loss = 0
    running_train_correct = 0
    running_val_correct = 0
    y_train_true = []
    y_train_pred = []
    y_valid_true = []
    y_valid_pred = []
    batch = 0
    for x_batch, y_batch in training_generator:
        batch += 1
        if torch.cuda.is_available():
            x_batch, y_batch = x_batch.cuda(), y_batch.cuda() 
        
        y_pred = model(x_batch)
        
        # Compute loss over all tasks
        loss = 0
        correct = 0
        y_train_true_task = []
        y_train_pred_task = []
        for i in range(len(all_tasks)):
            y_batch_task = y_batch[:,i]
            y_pred_task  = y_pred[i][:,0]
            
            indice_valid = y_batch_task >= 0
            loss_task = criterion(y_pred_task[indice_valid], y_batch_task[indice_valid]) / N_train[i]
            
            loss += loss_task

            pred_train = np.round(y_pred_task[indice_valid].detach().cpu().numpy())
            target_train = y_batch_task[indice_valid].float()
            y_train_true.extend(target_train.tolist()) 
            y_train_pred.extend(pred_train.reshape(-1).tolist())

        writer.add_scalar("Accuracy/train", loss, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        running_train_loss += loss.item()
        writer.add_scalar("Loss/train", running_train_loss, e)
        
    else:
        with torch.no_grad():    
        ## evaluation part 
            model.eval()
            for val_x_batch, val_y_batch in valid_generator:
                
                if torch.cuda.is_available():
                    val_x_batch, val_y_batch = val_x_batch.cuda(), val_y_batch.cuda() 
                
                val_output = model(val_x_batch)

                ## 2. loss calculation over all tasks 
                val_loss = 0
                val_correct = 0
                y_valid_true_task = []
                y_valid_pred_task = []
                for i in range(len(all_tasks)):
                    val_y_batch_task = val_y_batch[:,i]
                    val_output_task  = val_output[i][:,0]

                    # compute loss for labels that are not NA
                    indice_valid = val_y_batch_task >= 0
                    val_loss_task = criterion(val_output_task[indice_valid], val_y_batch_task[indice_valid]) / N_valid[i]

                    val_loss += val_loss_task
                    
                    pred_valid = np.round(val_output_task[indice_valid].detach().cpu().numpy())
                    target_valid = val_y_batch_task[indice_valid].float()
                    y_valid_true.extend(target_valid.tolist()) 
                    y_valid_pred.extend(pred_valid.reshape(-1).tolist())
                

                running_valid_loss+=val_loss.item()
                writer.add_scalar("Loss/valid", running_valid_loss, e)
        
        #epoch loss
        train_epoch_loss=np.mean(running_train_loss)
        val_epoch_loss=np.mean(running_valid_loss)  
       
        #epoch accuracy     
        train_epoch_acc = accuracy_score(y_train_true,y_train_pred)
        val_epoch_acc = accuracy_score(y_valid_true,y_valid_pred)
        
        #history
        loss_history.append(train_epoch_loss)  
        correct_history.append(train_epoch_acc)
        val_loss_history.append(val_epoch_loss)  
        val_correct_history.append(val_epoch_acc)  
        
        print(f"Epoch: {e}, Training loss: {train_epoch_loss:.4f}, Valid loss: {val_epoch_loss:.4f}")
        
        checkpoint = {
            'epoch': e + 1,
            'train_loss_min': train_epoch_loss,
            'val_loss_min': val_epoch_loss, 
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        
        save_ckp(checkpoint, False, checkpoint_file, bestmodel_file)
        
        if train_epoch_loss <= train_loss_min:
            save_ckp(checkpoint, True, checkpoint_file, bestmodel_file)
            train_loss_min = train_epoch_loss
            
        if train_epoch_loss >= val_epoch_loss:
            save_ckp(checkpoint, True, checkpoint_file, bestmodel_byvalid_crossed_file)
            train_loss_min = train_epoch_loss
            
        if val_epoch_loss <= val_loss_min:
            save_ckp(checkpoint, True, checkpoint_file, bestmodel_file)
            val_loss_min = val_epoch_loss

KeyboardInterrupt: 

In [None]:
loaded_model, optimizer, start_epoch, train_loss_min = load_ckp(bestmodel_file, model, optimizer)

In [90]:
print('Task'.ljust(10), '\t', '  AUC ', ' ACC ', ' BACC ', ' TN  ', ' TP  ', ' PR  ', ' RC  ', ' F1  ')
for task, auc in results.items():
    print(task.ljust(10), '\t', np.round(auc,3))

Task       	   AUC   ACC   BACC   TN    TP    PR    RC    F1  
NR-AR      	 [0.552 0.469 0.516 0.464 0.567 0.514 0.567 0.079]
NR-Aromatase 	 [0.569 0.352 0.546 0.327 0.765 0.532 0.765 0.119]
NR-PPAR-gamma 	 [0.491 0.559 0.53  0.56  0.5   0.532 0.5   0.045]
SR-HSE     	 [0.575 0.698 0.526 0.718 0.333 0.542 0.333 0.105]
NR-AR-LBD  	 [0.42  0.522 0.411 0.528 0.294 0.384 0.294 0.029]
NR-ER      	 [0.537 0.603 0.537 0.631 0.443 0.546 0.443 0.249]
SR-ARE     	 [0.472 0.434 0.469 0.417 0.521 0.472 0.521 0.231]
SR-MMP     	 [0.472 0.56  0.453 0.611 0.295 0.431 0.295 0.177]
NR-AhR     	 [0.453 0.583 0.493 0.611 0.375 0.491 0.375 0.175]
NR-ER-LBD  	 [0.543 0.289 0.526 0.256 0.795 0.517 0.795 0.12 ]
SR-ATAD5   	 [0.587 0.48  0.57  0.473 0.667 0.559 0.667 0.077]
SR-p53     	 [0.51  0.492 0.511 0.489 0.533 0.511 0.533 0.12 ]
CT_TOX     	 [0.268 0.304 0.22  0.314 0.125 0.154 0.125 0.019]


In [None]:
# print test loss
for x_valid_torch, y_valid_torch in valid_generator:
    y_valid_pred = model.eval().to(device).cpu()(x_valid_torch)
    
    # Compute loss over all tasks
    loss = 0
    for i in range(len(all_tasks)):
        y_test_task = y_valid_torch[:,i]
        y_pred_task  = y_valid_pred[i][:,0]

        # compute loss for labels that are not NA
        indice_valid = y_test_task >= 0
        loss_task = criterion(y_pred_task[indice_valid], y_test_task[indice_valid]) / N_test[i]

        loss += loss_task
    
print(loss.item())

In [None]:
results_valid = {}
# Collects performance metrics for all tasks on Valid set
for i in range(len(all_tasks)):
    
    valid_datapoints = y_valid[:,i] >= 0
    y_valid_task = y_valid[valid_datapoints,i] 
    y_valid_pred_task = y_valid_pred[i].detach().numpy()[valid_datapoints,0]
    
    
    acc = accuracy_score(y_valid_task, np.round(y_valid_pred_task))
    print('Accuracy for deepnn on Morgan Fingerprint:', acc)
    
    bacc = sk.metrics.balanced_accuracy_score(y_valid_task, np.round(y_valid_pred_task))

    f1 = f1_score(y_valid_task, np.round(y_valid_pred_task), pos_label=1)
    print('F1 for deepnn on Morgan Fingerprint:', f1)

    cfm = sk.metrics.confusion_matrix(y_valid_task, np.round(y_valid_pred_task))
    cfm = cfm.astype('float') / cfm.sum(axis=1)[:, np.newaxis]

    print('Confusion Matrix for deepnn on Morgan Fingerprint:\n', cfm)

    tn, fp, fn, tp = cfm.ravel()
    pr = tp / (tp + fp)
    rc = tp / (tp + fn)
    print(' True Positive:', tp)
    print(' True Negative:', tn)
    print('False Positive:', fp)
    print('False Negative:', fn)
    
    
    auc = roc_auc_score(y_valid_task, y_valid_pred_task)
    print('Test ROC AUC ({}):'.format(all_tasks[i]), auc)
    
    results_valid[all_tasks[i]] = [auc, acc, bacc, tn, tp, pr, rc, f1]

    fpr, tpr, threshold = sk.metrics.roc_curve(y_valid_task, y_valid_pred_task)
    plt.plot(fpr, tpr, 'b', label = 'AUC')
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

In [None]:
print('Task'.ljust(10), '\t', '  AUC ', ' ACC ', ' BACC ', ' TN  ', ' TP  ', ' PR  ', ' RC  ', ' F1  ')
for task, auc in results_valid.items():
    print(task.ljust(10), '\t', np.round(auc,3))