In [1]:
from datetime import datetime
import random

import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch import optim
from torch.utils.data import DataLoader
from sklearn import metrics

import layers
import models
import os
import custom_loss
import time
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

In [3]:
TOTAL_ATOM_FEATS

55

In [4]:
df_ddi_train = pd.read_csv('/home/dwj/GTA-DDI/zhangddi-chchddi/data/ddi-data/ZhangDDI_train.csv')
df_ddi_val = pd.read_csv('/home/dwj/GTA-DDI/zhangddi-chchddi/data/ddi-data/ZhangDDI_test.csv')
df_ddi_test = pd.read_csv('/home/dwj/GTA-DDI/zhangddi-chchddi/data/ddi-data/ZhangDDI_valid.csv')

In [5]:
train_tup = [(h, t, r) for h, t, r in zip(df_ddi_train['drugbank_id_1'], df_ddi_train['drugbank_id_2'], df_ddi_train['label'])]
val_tup = [(h, t, r) for h, t, r in zip(df_ddi_test['drugbank_id_1'], df_ddi_test['drugbank_id_2'], df_ddi_test['label'])]
test_tup = [(h, t, r) for h, t, r in zip(df_ddi_val['drugbank_id_1'], df_ddi_val['drugbank_id_2'], df_ddi_val['label'])]

In [7]:
len(train_tup)

68383

In [8]:
len(val_tup)

22794

In [9]:
len(test_tup)

22795

In [10]:
total = len(val_tup) + len(train_tup) + len(test_tup)
len(train_tup) / total, len(test_tup)/total, len(val_tup)/total

(0.5999982451830274, 0.20000526445091776, 0.19999649036605482)

In [11]:
# Hyperparameters
n_atom_feats = TOTAL_ATOM_FEATS
n_atom_hid = 256
rel_total = 1
lr = 1e-4
weight_decay = 5e-4
n_epochs = 300
neg_samples = 1
batch_size = 1024
data_size_ratio = 1
kge_dim = 384


In [12]:
train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)

In [13]:
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

In [14]:
print(f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}")

Training with 68383 samples, validating with 22794, and testing with 22795


In [15]:
train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size *3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size *3)

In [16]:
def do_compute(batch, device, training=True):
        '''
            *batch: (pos_tri, neg_tri)
            *pos/neg_tri: (batch_h, batch_t, batch_r)
        '''
        probas_pred, ground_truth = [], []
        pos_tri = batch
        
        pos_tri = [tensor.to(device=device) for tensor in pos_tri]
        p_score ,rels= model(pos_tri)

        probas_pred = torch.sigmoid(p_score.detach()).cpu()
        ground_truth = rels.cpu().numpy()

        return p_score,probas_pred, ground_truth,rels


In [17]:
def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(int)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [18]:
def train(model, train_data_loader, val_data_loader, loss_fn,  optimizer, n_epochs, device, scheduler=None):
    print('Starting training at', datetime.today())
    for i in range(1, n_epochs+1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            model.train()
            p_score, probas_pred, ground_truth,rels = do_compute(batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            p_score = p_score.float()
            rels = rels.float()
            loss = loss_fn(p_score,rels )
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(train_probas_pred, train_ground_truth)

            for batch in val_data_loader:
                model.eval()
                p_score,  probas_pred, ground_truth,rels = do_compute(batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                p_score = p_score.float()
                rels = rels.float()
                loss= loss_fn(p_score,rels)
                val_loss += loss.item() * len(p_score)            

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(val_probas_pred, val_ground_truth)
               
        if scheduler:
            print('scheduling')
            scheduler.step()


        output_file = 'training_log.txt'

        with open(output_file, 'a') as f:
            # 写入第一行
            line1 = (f'Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},'
                    f' train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}')
            print(line1)
            f.write(line1 + '\n')
            
            # 写入第二行
            line2 = (f'\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}')
            print(line2)
            f.write(line2 + '\n')

In [19]:
device = 'cuda:5' if torch.cuda.is_available() else 'cpu'
model = models.GTA_DDI(n_atom_feats, n_atom_hid, kge_dim, rel_total, heads_out_feat_params=[64, 64, 64, 64], blocks_params=[6, 6,6, 6])
loss = custom_loss.SigmoidLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
model

GTA_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(384, affine=True, mode=graph)
  )
  (block0): GTA_DDI_Block(
    (conv): GATConv(55, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (block1): GTA_DDI_Block(
    (conv): GATConv(384, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (block2): GTA_DDI_Block(
    (conv): GATConv(384, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (block3): GTA_DDI_Block(
    (conv): GATConv(384, 64, heads=6)
    (readout): TopKPooling(384, ratio=0.5, multiplier=1.0)
  )
  (co_attention): CoAttentionLayer()
  (KGE): RESCAL(1, torch.Size([1, 147456]))
)

In [20]:
model.to(device=device);

In [21]:
train(model, train_data_loader, val_data_loader, loss, optimizer, n_epochs, device, scheduler)

Starting training at 2024-06-21 00:57:23.485165
scheduling
Epoch: 1 (115.4607s), train_loss: 0.6435, val_loss: 0.6049, train_acc: 0.6265, val_acc:0.6660
		train_roc: 0.6357, val_roc: 0.7221, train_auprc: 0.5287, val_auprc: 0.6314
scheduling
Epoch: 2 (126.3425s), train_loss: 0.5613, val_loss: 0.5193, train_acc: 0.7115, val_acc:0.7517
		train_roc: 0.7729, val_roc: 0.8205, train_auprc: 0.6901, val_auprc: 0.7627
scheduling
Epoch: 3 (125.9905s), train_loss: 0.4961, val_loss: 0.4712, train_acc: 0.7630, val_acc:0.7839
		train_roc: 0.8336, val_roc: 0.8566, train_auprc: 0.7749, val_auprc: 0.8087
scheduling
Epoch: 4 (126.3196s), train_loss: 0.4638, val_loss: 0.4467, train_acc: 0.7845, val_acc:0.7936
		train_roc: 0.8575, val_roc: 0.8764, train_auprc: 0.8079, val_auprc: 0.8366
scheduling
Epoch: 5 (126.3211s), train_loss: 0.4358, val_loss: 0.4290, train_acc: 0.8013, val_acc:0.8099
		train_roc: 0.8756, val_roc: 0.8807, train_auprc: 0.8336, val_auprc: 0.8452
scheduling
Epoch: 6 (126.6871s), train_los

: 