In [1]:
from datetime import datetime
import random

import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch import optim
from torch.utils.data import DataLoader
from sklearn import metrics

import layers
import models
import custom_loss
import time

In [2]:
from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

  return undirected_edge_list.T, features


In [3]:
TOTAL_ATOM_FEATS

55

In [4]:
df_ddi_train = pd.read_csv('data/ddi_training.csv')
df_ddi_val = pd.read_csv('data/ddi_test.csv')
df_ddi_test = pd.read_csv('data/ddi_validation.csv')

In [5]:
train_tup = [(h, t, r) for h, t, r in zip(df_ddi_train['d1'], df_ddi_train['d2'], df_ddi_train['type'])]
val_tup = [(h, t, r) for h, t, r in zip(df_ddi_val['d1'], df_ddi_val['d2'], df_ddi_val['type'])]
test_tup = [(h, t, r) for h, t, r in zip(df_ddi_test['d1'], df_ddi_test['d2'], df_ddi_test['type'])]

In [6]:
len(train_tup)

115185

In [7]:
len(val_tup)

38337

In [8]:
len(test_tup)

38348

In [9]:
total = len(val_tup) + len(train_tup) + len(test_tup)
len(train_tup) / total, len(test_tup)/total, len(val_tup)/total

(0.6003283473184969, 0.19986449158284256, 0.19980716109866056)

In [10]:
# Hyperparameters
n_atom_feats = TOTAL_ATOM_FEATS
n_atom_hid = 256
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
n_epochs = 300
neg_samples = 1
batch_size = 1024
data_size_ratio = 1
kge_dim = 384


In [11]:
train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)

In [12]:
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

In [13]:
print(f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}")

Training with 115185 samples, validating with 38337, and testing with 38348


In [14]:
train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size *3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size *3)

In [15]:
def do_compute(batch, device, training=True):
        '''
            *batch: (pos_tri, neg_tri)
            *pos/neg_tri: (batch_h, batch_t, batch_r)
        '''
        probas_pred, ground_truth = [], []
        pos_tri, neg_tri = batch
        
        pos_tri = [tensor.to(device=device) for tensor in pos_tri]
        p_score = model(pos_tri)
        probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
        ground_truth.append(np.ones(len(p_score)))

        neg_tri = [tensor.to(device=device) for tensor in neg_tri]
        n_score = model(neg_tri)
        probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
        ground_truth.append(np.zeros(len(n_score)))

        probas_pred = np.concatenate(probas_pred)
        ground_truth = np.concatenate(ground_truth)

        return p_score, n_score, probas_pred, ground_truth


In [16]:
def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(int)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [17]:
def train(model, train_data_loader, val_data_loader, loss_fn,  optimizer, n_epochs, device, scheduler=None):
    print('Starting training at', datetime.today())
    for i in range(1, n_epochs+1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(train_probas_pred, train_ground_truth)

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)            

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(val_probas_pred, val_ground_truth)
               
        if scheduler:
            print('scheduling')
            scheduler.step()


        output_file = 'training_log.txt'

        with open(output_file, 'a') as f:
            # 写入第一行
            line1 = (f'Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},'
                    f' train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}')
            print(line1)
            f.write(line1 + '\n')
            
            # 写入第二行
            line2 = (f'\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}')
            print(line2)
            f.write(line2 + '\n')

In [18]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = models.GTA_DDI(n_atom_feats, n_atom_hid, kge_dim, rel_total, heads_out_feat_params=[64, 64, 64, 64], blocks_params=[6, 6,6, 6])
loss = custom_loss.SigmoidLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
model

GTA_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(384, affine=True, mode=graph)
  )
  (block0): GTA_DDI_Block(
    (conv): GATConv(55, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (block1): GTA_DDI_Block(
    (conv): GATConv(384, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (block2): GTA_DDI_Block(
    (conv): GATConv(384, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (block3): GTA_DDI_Block(
    (conv): GATConv(384, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (co_attention): CoAttentionLayer()
  (KGE): RESCAL(86, torch.Size([86, 147456]))
)

In [19]:
model.to(device=device);

In [20]:
train(model, train_data_loader, val_data_loader, loss, optimizer, n_epochs, device, scheduler)

Starting training at 2024-06-20 13:35:18.412732
scheduling
Epoch: 1 (49.5570s), train_loss: 0.6736, val_loss: 0.6663, train_acc: 0.5722, val_acc:0.5825
		train_roc: 0.6062, val_roc: 0.6204, train_auprc: 0.5985, val_auprc: 0.6124
scheduling
Epoch: 2 (49.0150s), train_loss: 0.6513, val_loss: 0.6367, train_acc: 0.6129, val_acc:0.6366
		train_roc: 0.6589, val_roc: 0.6889, train_auprc: 0.6471, val_auprc: 0.6810
scheduling
Epoch: 3 (48.9937s), train_loss: 0.6199, val_loss: 0.5878, train_acc: 0.6504, val_acc:0.6820
		train_roc: 0.7104, val_roc: 0.7591, train_auprc: 0.6966, val_auprc: 0.7452
scheduling
Epoch: 4 (48.9288s), train_loss: 0.5743, val_loss: 0.6187, train_acc: 0.6951, val_acc:0.6554
		train_roc: 0.7661, val_roc: 0.7176, train_auprc: 0.7517, val_auprc: 0.7046
scheduling
Epoch: 5 (48.9575s), train_loss: 0.5613, val_loss: 0.5177, train_acc: 0.7054, val_acc:0.7386
		train_roc: 0.7800, val_roc: 0.8413, train_auprc: 0.7670, val_auprc: 0.8359
scheduling
Epoch: 6 (49.0997s), train_loss: 0.5