In [1]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils import data
from torch import nn 
import copy

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from time import time
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, roc_curve, confusion_matrix, precision_score, recall_score, auc
from sklearn.model_selection import KFold
torch.manual_seed(2)    # reproducible torch:2 np:3
np.random.seed(3)

from config import BIN_config_DBPE
from models import BIN_Interaction_Flat
from stream import BIN_Data_Encoder

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [2]:
def test(data_generator, model):
    y_pred = []
    y_label = []
    model.eval()
    loss_accumulate = 0.0
    count = 0.0
    for i, (d, p, d_mask, p_mask, label) in enumerate(data_generator):
        score, _ = model(d.long().cuda(), p.long().cuda(), d_mask.long().cuda(), p_mask.long().cuda())
        
        m = torch.nn.Sigmoid()
        logits = torch.squeeze(m(score))
        loss_fct = torch.nn.BCELoss()            
        
        label = Variable(torch.from_numpy(np.array(label)).float()).cuda()

        loss = loss_fct(logits, label)
        
        loss_accumulate += loss
        count += 1
        
        logits = logits.detach().cpu().numpy()
        
        label_ids = label.to('cpu').numpy()
        y_label = y_label + label_ids.flatten().tolist()
        y_pred = y_pred + logits.flatten().tolist()
        
    loss = loss_accumulate/count
    
    fpr, tpr, thresholds = roc_curve(y_label, y_pred)

    precision = tpr / (tpr + fpr)

    f1 = 2 * precision * tpr / (tpr + precision + 0.00001)

    thred_optim = thresholds[5:][np.argmax(f1[5:])]

    print("optimal threshold: " + str(thred_optim))

    y_pred_s = [1 if i else 0 for i in (y_pred >= thred_optim)]

    auc_k = auc(fpr, tpr)
    print("AUROC:" + str(auc_k))
    print("AUPRC: "+ str(average_precision_score(y_label, y_pred)))

    cm1 = confusion_matrix(y_label, y_pred_s)
    print('Confusion Matrix : \n', cm1)
    print('Recall : ', recall_score(y_label, y_pred_s))
    print('Precision : ', precision_score(y_label, y_pred_s))

    total1=sum(sum(cm1))
    #####from confusion matrix calculate accuracy
    accuracy1=(cm1[0,0]+cm1[1,1])/total1
    print ('Accuracy : ', accuracy1)

    sensitivity1 = cm1[0,0]/(cm1[0,0]+cm1[0,1])
    print('Sensitivity : ', sensitivity1 )

    specificity1 = cm1[1,1]/(cm1[1,0]+cm1[1,1])
    print('Specificity : ', specificity1)

    outputs = np.asarray([1 if i else 0 for i in (np.asarray(y_pred) >= 0.5)])
    return roc_auc_score(y_label, y_pred), average_precision_score(y_label, y_pred), f1_score(y_label, outputs), y_pred, loss.item()


def main(fold_n, lr):
    config = BIN_config_DBPE()
    
    lr = lr
    BATCH_SIZE = config['batch_size']
    train_epoch = 100
    
    loss_history = []
    
    model = BIN_Interaction_Flat(**config)
    
    model = model.cuda()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, dim = 0)
            
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    #opt = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
    
    print('--- Data Preparation ---')
    
    params = {'batch_size': BATCH_SIZE,
              'shuffle': True,
              'num_workers': 6, 
              'drop_last': True}

    dataFolder = './dataset/BindingDB'
    df_train = pd.read_csv(dataFolder + '/train.csv')
    df_val = pd.read_csv(dataFolder + '/val.csv')
    df_test = pd.read_csv(dataFolder + '/test.csv')
    
    training_set = BIN_Data_Encoder(df_train.index.values, df_train.Label.values, df_train)
    training_generator = data.DataLoader(training_set, **params)

    validation_set = BIN_Data_Encoder(df_val.index.values, df_val.Label.values, df_val)
    validation_generator = data.DataLoader(validation_set, **params)
    
    testing_set = BIN_Data_Encoder(df_test.index.values, df_test.Label.values, df_test)
    testing_generator = data.DataLoader(testing_set, **params)
    
    # early stopping
    max_auc = 0
    model_max = copy.deepcopy(model)
    
    print('--- Go for Training ---')
    torch.backends.cudnn.benchmark = True
    for epo in range(train_epoch):
        model.train()
        for i, (d, p, d_mask, p_mask, label) in enumerate(training_generator):
            score, _ = model(d.long().cuda(), p.long().cuda(), d_mask.long().cuda(), p_mask.long().cuda())

            label = Variable(torch.from_numpy(np.array(label)).float()).cuda()
            
            loss_fct = torch.nn.BCELoss()
            m = torch.nn.Sigmoid()
            n = torch.squeeze(m(score))
            
            loss = loss_fct(n, label)
            loss_history.append(loss)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            if (i % 100 == 0):
                print('Training at Epoch ' + str(epo + 1) + ' iteration ' + str(i) + ' with loss ' + str(loss.cpu().detach().numpy()))
            
        # every epoch test
        with torch.set_grad_enabled(False):
            auc, auprc, f1, logits, loss = test(validation_generator, model)
            if auc > max_auc:
                model_max = copy.deepcopy(model)
                max_auc = auc
            
            print('Validation at Epoch '+ str(epo + 1) + ' , AUROC: '+ str(auc) + ' , AUPRC: ' + str(auprc) + ' , F1: '+str(f1))
    
    print('--- Go for Testing ---')
    try:
        with torch.set_grad_enabled(False):
            auc, auprc, f1, logits, loss = test(testing_generator, model_max)
            print('Testing AUROC: ' + str(auc) + ' , AUPRC: ' + str(auprc) + ' , F1: '+str(f1) + ' , Test loss: '+str(loss))
    except:
        print('testing failed')
    return model_max, loss_history

In [3]:
# fold 1
#biosnap interaction times 1e-6, flat, batch size 64, len 205, channel 3, epoch 50
s = time()
model_max, loss_history = main(1, 5e-6)
e = time()
print(e-s)
lh = list(filter(lambda x: x < 1, loss_history))
plt.plot(lh)

Let's use 8 GPUs!
--- Data Preparation ---
--- Go for Training ---


    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


Training at Epoch 1 iteration 0 with loss 0.67082864
Training at Epoch 1 iteration 100 with loss 0.69454414
Training at Epoch 1 iteration 200 with loss 0.68584555
Training at Epoch 1 iteration 300 with loss 0.65992177
Training at Epoch 1 iteration 400 with loss 0.7050383
Training at Epoch 1 iteration 500 with loss 0.68235034
Training at Epoch 1 iteration 600 with loss 0.66636235
Training at Epoch 1 iteration 700 with loss 0.7001598
optimal threshold: 0.07315666973590851
AUROC:0.5709169097763743
AUPRC: 0.17622596167517443
Confusion Matrix : 
 [[   0 5714]
 [   0  926]]
Recall :  1.0
Precision :  0.1394578313253012
Accuracy :  0.1394578313253012
Sensitivity :  0.0
Specificity :  1.0
Validation at Epoch 1 , AUROC: 0.5709169097763743 , AUPRC: 0.17622596167517443 , F1: 0.2542313117066291




Training at Epoch 2 iteration 0 with loss 0.7192378
Training at Epoch 2 iteration 100 with loss 0.7564486
Training at Epoch 2 iteration 200 with loss 0.68028253
Training at Epoch 2 iteration 300 with loss 0.6746302
Training at Epoch 2 iteration 400 with loss 0.6865299
Training at Epoch 2 iteration 500 with loss 0.71370286
Training at Epoch 2 iteration 600 with loss 0.68855023
Training at Epoch 2 iteration 700 with loss 0.7352923
optimal threshold: 3.7331603380152956e-05
AUROC:0.5987548297501268
AUPRC: 0.19717487997919692
Confusion Matrix : 
 [[   0 5714]
 [   0  926]]
Recall :  1.0
Precision :  0.1394578313253012
Accuracy :  0.1394578313253012
Sensitivity :  0.0
Specificity :  1.0
Validation at Epoch 2 , AUROC: 0.5987548297501268 , AUPRC: 0.19717487997919692 , F1: 0.2620274914089347




Training at Epoch 3 iteration 0 with loss 0.6567126
Training at Epoch 3 iteration 100 with loss 0.67624635
Training at Epoch 3 iteration 200 with loss 0.71628416
Training at Epoch 3 iteration 300 with loss 0.6761807
Training at Epoch 3 iteration 400 with loss 0.6982836
Training at Epoch 3 iteration 500 with loss 0.71921015
Training at Epoch 3 iteration 600 with loss 0.7202301
Training at Epoch 3 iteration 700 with loss 0.6918887
optimal threshold: 3.6369015438131314e-11
AUROC:0.6178995047348437
AUPRC: 0.21095644156146942
Confusion Matrix : 
 [[   1 5712]
 [   0  927]]
Recall :  1.0
Precision :  0.1396294622684139
Accuracy :  0.13975903614457832
Sensitivity :  0.00017503938386136881
Specificity :  1.0
Validation at Epoch 3 , AUROC: 0.6178995047348437 , AUPRC: 0.21095644156146942 , F1: 0.27842333105490996




Training at Epoch 4 iteration 0 with loss 0.7050623
Training at Epoch 4 iteration 100 with loss 0.6907485
Training at Epoch 4 iteration 200 with loss 0.6963889
Training at Epoch 4 iteration 300 with loss 0.6957656
Training at Epoch 4 iteration 400 with loss 0.6684682
Training at Epoch 4 iteration 500 with loss 0.68866754
Training at Epoch 4 iteration 600 with loss 0.72337866
Training at Epoch 4 iteration 700 with loss 0.72578484
optimal threshold: 0.4688915014266968
AUROC:0.620866959609087
AUPRC: 0.19764261937762512
Confusion Matrix : 
 [[1245 4471]
 [  83  841]]
Recall :  0.9101731601731602
Precision :  0.15832078313253012
Accuracy :  0.3141566265060241
Sensitivity :  0.21780965710286915
Specificity :  0.9101731601731602
Validation at Epoch 4 , AUROC: 0.620866959609087 , AUPRC: 0.19764261937762512 , F1: 0.28085310328300983




Training at Epoch 5 iteration 0 with loss 0.70766497
Training at Epoch 5 iteration 100 with loss 0.65859604
Training at Epoch 5 iteration 200 with loss 0.65823483
Training at Epoch 5 iteration 300 with loss 0.69431573
Training at Epoch 5 iteration 400 with loss 0.70292634
Training at Epoch 5 iteration 500 with loss 0.71530986
Training at Epoch 5 iteration 600 with loss 0.6983434
Training at Epoch 5 iteration 700 with loss 0.6781969
optimal threshold: 0.4660532474517822
AUROC:0.6296021243398967
AUPRC: 0.22324799154656727
Confusion Matrix : 
 [[ 592 5121]
 [  46  881]]
Recall :  0.9503775620280475
Precision :  0.14678440519826724
Accuracy :  0.22183734939759037
Sensitivity :  0.10362331524593034
Specificity :  0.9503775620280475
Validation at Epoch 5 , AUROC: 0.6296021243398967 , AUPRC: 0.22324799154656727 , F1: 0.2741742936729009




Training at Epoch 6 iteration 0 with loss 0.7039902
Training at Epoch 6 iteration 100 with loss 0.7233455
Training at Epoch 6 iteration 200 with loss 0.73803186
Training at Epoch 6 iteration 300 with loss 0.67067516
Training at Epoch 6 iteration 400 with loss 0.6885591
Training at Epoch 6 iteration 500 with loss 0.7188779
Training at Epoch 6 iteration 600 with loss 0.66856825
Training at Epoch 6 iteration 700 with loss 0.69552326
optimal threshold: 0.43776440620422363
AUROC:0.6239582843572381
AUPRC: 0.206099660428421
Confusion Matrix : 
 [[ 275 5438]
 [  14  913]]
Recall :  0.9848975188781014
Precision :  0.1437568886789482
Accuracy :  0.1789156626506024
Sensitivity :  0.048135830561876425
Specificity :  0.9848975188781014
Validation at Epoch 6 , AUROC: 0.6239582843572381 , AUPRC: 0.206099660428421 , F1: 0.28043282236248873




Training at Epoch 7 iteration 0 with loss 0.6819896
Training at Epoch 7 iteration 100 with loss 0.6983069
Training at Epoch 7 iteration 200 with loss 0.7219417
Training at Epoch 7 iteration 300 with loss 0.69254005
Training at Epoch 7 iteration 400 with loss 0.7024156
Training at Epoch 7 iteration 500 with loss 0.71235186
Training at Epoch 7 iteration 600 with loss 0.6606362
Training at Epoch 7 iteration 700 with loss 0.6493866
optimal threshold: 0.4446001946926117
AUROC:0.6171993471993982
AUPRC: 0.2271743894482957
Confusion Matrix : 
 [[ 321 5392]
 [  24  903]]
Recall :  0.9741100323624595
Precision :  0.14344718030182685
Accuracy :  0.18433734939759036
Sensitivity :  0.05618764221949939
Specificity :  0.9741100323624595
Validation at Epoch 7 , AUROC: 0.6171993471993982 , AUPRC: 0.2271743894482957 , F1: 0.2688836104513064




Training at Epoch 8 iteration 0 with loss 0.69071335
Training at Epoch 8 iteration 100 with loss 0.6409137
Training at Epoch 8 iteration 200 with loss 0.6385876
Training at Epoch 8 iteration 300 with loss 0.6805055
Training at Epoch 8 iteration 400 with loss 0.700977
Training at Epoch 8 iteration 500 with loss 0.6853907
Training at Epoch 8 iteration 600 with loss 0.676106
Training at Epoch 8 iteration 700 with loss 0.7196401
optimal threshold: 0.41210272908210754
AUROC:0.6235056218253677
AUPRC: 0.2235906071800078
Confusion Matrix : 
 [[  25 5689]
 [   0  926]]
Recall :  1.0
Precision :  0.1399848828420257
Accuracy :  0.14322289156626505
Sensitivity :  0.004375218760938047
Specificity :  1.0
Validation at Epoch 8 , AUROC: 0.6235056218253677 , AUPRC: 0.2235906071800078 , F1: 0.27739188886399285




Training at Epoch 9 iteration 0 with loss 0.65416056
Training at Epoch 9 iteration 100 with loss 0.71033823
Training at Epoch 9 iteration 200 with loss 0.7441087


KeyboardInterrupt: 