In [1]:
from datetime import datetime
import random

import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch import optim
from torch.utils.data import DataLoader
from sklearn import metrics

import layers
import models
import custom_loss
import time

In [2]:
from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

In [3]:
TOTAL_ATOM_FEATS

55

In [4]:
df_ddi_train = pd.read_csv('data/ddi_training.csv')
df_ddi_val = pd.read_csv('data/ddi_validation.csv')
df_ddi_test = pd.read_csv('data/ddi_test.csv')

In [5]:
train_tup = [(h, t, r) for h, t, r in zip(df_ddi_train['d1'], df_ddi_train['d2'], df_ddi_train['type'])]
val_tup = [(h, t, r) for h, t, r in zip(df_ddi_val['d1'], df_ddi_val['d2'], df_ddi_val['type'])]
test_tup = [(h, t, r) for h, t, r in zip(df_ddi_test['d1'], df_ddi_test['d2'], df_ddi_test['type'])]

In [6]:
len(train_tup)

115185

In [7]:
len(val_tup)

38348

In [8]:
len(test_tup)

38337

In [9]:
total = len(val_tup) + len(train_tup) + len(test_tup)
len(train_tup) / total, len(test_tup)/total, len(val_tup)/total

(0.6003283473184969, 0.19980716109866056, 0.19986449158284256)

In [10]:
# Hyperparameters
n_atom_feats = TOTAL_ATOM_FEATS
n_atom_hid = 64
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
n_epochs = 300
neg_samples = 1
batch_size = 1024
data_size_ratio = 1
kge_dim = 64


In [11]:
train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)

In [12]:
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

In [13]:
print(f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}")

Training with 115185 samples, validating with 38348, and testing with 38337


In [14]:
train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size *3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size *3)

In [15]:
def do_compute(batch, device, training=True):
        '''
            *batch: (pos_tri, neg_tri)
            *pos/neg_tri: (batch_h, batch_t, batch_r)
        '''
        probas_pred, ground_truth = [], []
        pos_tri, neg_tri = batch
        
        pos_tri = [tensor.to(device=device) for tensor in pos_tri]
        p_score = model(pos_tri)
        probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
        ground_truth.append(np.ones(len(p_score)))

        neg_tri = [tensor.to(device=device) for tensor in neg_tri]
        n_score = model(neg_tri)
        probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
        ground_truth.append(np.zeros(len(n_score)))

        probas_pred = np.concatenate(probas_pred)
        ground_truth = np.concatenate(ground_truth)

        return p_score, n_score, probas_pred, ground_truth


In [16]:
def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [17]:
def train(model, train_data_loader, val_data_loader, loss_fn,  optimizer, n_epochs, device, scheduler=None):
    print('Starting training at', datetime.today())
    for i in range(1, n_epochs+1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(train_probas_pred, train_ground_truth)

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)            

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(val_probas_pred, val_ground_truth)
               
        if scheduler:
            print('scheduling')
            scheduler.step()


        print(f'Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},'
        f' train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}')
        print(f'\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}')

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.SSI_DDI(n_atom_feats, n_atom_hid, kge_dim, rel_total, heads_out_feat_params=[32, 32, 32, 32], blocks_params=[2, 2, 2, 2])
loss = custom_loss.SigmoidLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
model

SSI_DDI(
  (initial_norm): LayerNorm(55)
  (net_norms): ModuleList(
    (0): LayerNorm(64)
    (1): LayerNorm(64)
    (2): LayerNorm(64)
    (3): LayerNorm(64)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)
  )
  (co_attention): CoAttentionLayer()
  (KGE): RESCAL(86, torch.Size([86, 4096]))
)

In [19]:
model.to(device=device);

In [20]:
train(model, train_data_loader, val_data_loader, loss, optimizer, n_epochs, device, scheduler)

ss: 0.4736, val_loss: 0.4750, train_acc: 0.7712, val_acc:0.7714
		train_roc: 0.8506, val_roc: 0.8497, train_auprc: 0.8303, val_auprc: 0.8305
scheduling
Epoch: 12 (56.1362s), train_loss: 0.4580, val_loss: 0.4600, train_acc: 0.7818, val_acc:0.7807
		train_roc: 0.8612, val_roc: 0.8615, train_auprc: 0.8416, val_auprc: 0.8438
scheduling
Epoch: 13 (54.5477s), train_loss: 0.4470, val_loss: 0.4521, train_acc: 0.7888, val_acc:0.7863
		train_roc: 0.8685, val_roc: 0.8662, train_auprc: 0.8493, val_auprc: 0.8471
scheduling
Epoch: 14 (53.8652s), train_loss: 0.4397, val_loss: 0.4519, train_acc: 0.7948, val_acc:0.7853
		train_roc: 0.8734, val_roc: 0.8663, train_auprc: 0.8537, val_auprc: 0.8445
scheduling
Epoch: 15 (59.7325s), train_loss: 0.4283, val_loss: 0.4822, train_acc: 0.8018, val_acc:0.7563
		train_roc: 0.8806, val_roc: 0.8657, train_auprc: 0.8619, val_auprc: 0.8499
scheduling
Epoch: 16 (56.8572s), train_loss: 0.4327, val_loss: 0.4286, train_acc: 0.7996, val_acc:0.8034
		train_roc: 0.8782, val_r