In [1]:
import torch
import numpy as np

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.nn.functional as F

from model import CombNetRW
from dataset import CombinationDatasetRW
from MSI.load_msi_data import LoadData

from train_rw import EarlyStopping, train_contrastive, train_ce, evaluate, seed_everything

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, average_precision_score

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataloader = LoadData()

# get_dict
drug_id2name, drug_name2id = dataloader.get_dict(type='drug')

In [3]:
seed_everything(42)

In [4]:
dc_df = pd.read_csv('data/labels/DC_combined_msi.tsv', sep='\t')
ddi_df = pd.read_csv('data/labels/TWOSIDES_msi.tsv', sep='\t')

db_df = pd.read_csv('database/drugbank_all_drug_links.csv')
db_id2name = dict(zip(db_df['DrugBank ID'], db_df['Name']))

# randomly sample 400 drug pairs
case_dc_df = dc_df.sample(400, random_state=42)
case_ddi_df = ddi_df.sample(400, random_state=42)

dc_more = pd.DataFrame([['DB01076', 'DB01039'], # atorvastatin & fenofibrate
                        ['DB01095', 'DB01039'], # fluvastatin & fenofibrate
                        ['DB01098', 'DB01039'], # rosuvastatin & fenofibrate
                        ['DB01126', 'DB00706'], # dutasteride & tamsulosin
                        ['DB00654', 'DB00373'], # lantanoprost & timolol
                        ['DB00235', 'DB00187'], # milrinone & esmolol
                        ], columns=['drug_1', 'drug_2'])

ddi_more = pd.DataFrame([['DB00252', 'DB00904'], # phenytoin & ondansetron
                         ['DB01026', 'DB00641'], # ketoconazole & simvastatin
                         ['DB00482', 'DB00575'], # celecoxib & clonidine
                         ['DB00745', 'DB00476'], # modafinil & duloxetine
                         ['DB00448', 'DB01033'], # lansoprazole & mercaptopurine
                         ['DB00277', 'DB00983'], # theophylline & formoterol
                         ], columns=['drug_1', 'drug_2'])

case_dc_df = pd.concat([case_dc_df, dc_more], axis=0)
case_ddi_df = pd.concat([case_ddi_df, ddi_more], axis=0)
case_dc_df.reset_index(drop=True, inplace=True)
case_ddi_df.reset_index(drop=True, inplace=True)

case_dc_df['drug1_name'] = case_dc_df['drug_1'].map(db_id2name)
case_dc_df['drug2_name'] = case_dc_df['drug_2'].map(db_id2name)
case_ddi_df['drug1_name'] = case_ddi_df['drug_1'].map(db_id2name)
case_ddi_df['drug2_name'] = case_ddi_df['drug_2'].map(db_id2name)

In [5]:
case_dc_pair = [set([drug1, drug2]) for drug1, drug2 in zip(case_dc_df['drug_1'], case_dc_df['drug_2'])]
case_ddi_pair = [set([drug1, drug2]) for drug1, drug2 in zip(case_ddi_df['drug_1'], case_ddi_df['drug_2'])]

In [6]:
def get_dc_score(neg_dataset, exclude_list, embeddings, seed):
    dataset = CombinationDatasetRW(database='DC_combined', embeddingf='DREAMwalk', neg_ratio=1, neg_dataset=neg_dataset, seed=42, exclude_list=exclude_list)
    train_dataset, valid_dataset, test_dataset = dataset['train'], dataset['valid'], dataset['test']
    valid_dataset = ConcatDataset([valid_dataset, test_dataset])

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)

    input_dim = train_dataset[0][0].shape[0] // 2
    hidden_dim = input_dim
    output_dim = 1

    epochs = 100
    ce_lr = 1e-3
    contra_lr = 1e-1
    weight_decay = 1e-5

    ckpt_name = f'ckpt/casestudy_DREAMwalk_prod_fc_{seed}'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    contra_model = CombNetRW(input_dim, hidden_dim, output_dim, comb_type='prod_fc').to(device)
    contra_optimizer = torch.optim.Adam(contra_model.parameters(), lr=contra_lr, weight_decay=weight_decay)
    contra_early_stopping = EarlyStopping(patience=20, verbose=True, path=f"{ckpt_name}_contra.pt")

    for epoch in range(epochs):
        train_loss = train_contrastive(contra_model, device, train_loader, contra_optimizer)
        print(f'Contra Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f}')
        contra_early_stopping(train_loss, contra_model)
        if contra_early_stopping.early_stop:
            print("Early stopping")
            break

    del contra_model
    del contra_optimizer
    del contra_early_stopping
    torch.cuda.empty_cache()

    model = CombNetRW(input_dim, hidden_dim, output_dim, comb_type='prod_fc').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=ce_lr, weight_decay=weight_decay)

    model.load_state_dict(torch.load(f"{ckpt_name}_contra.pt"))

    criterion = nn.BCEWithLogitsLoss()
    early_stopping = EarlyStopping(patience=20, verbose=True, path=f"{ckpt_name}.pt")
    metric_list = [accuracy_score, roc_auc_score, f1_score, average_precision_score, precision_score, recall_score]

    for epoch in range(epochs):
        train_loss, train_scores = train_ce(model, device, train_loader, criterion, optimizer, metric_list)
        valid_loss, valid_scores = evaluate(model, device, valid_loader, criterion, metric_list)
        print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train Acc: {train_scores[0]*100:.2f}% | Train Precision: {train_scores[4]:.4f} | Train Recall: {train_scores[5]:.4f} || Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_scores[0]*100:.2f}% | Valid Precision: {valid_scores[4]:.4f} | Valid Recall: {valid_scores[5]:.4f}')
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    
    combs = []
    for pair in exclude_list:
        combs.append(torch.cat([torch.tensor(embeddings[drug], dtype=torch.float) for drug in pair], dim=0).unsqueeze(0))
    combs = torch.cat(combs, dim=0).to(device)
    print(f"Pairs shape: {combs.shape}")

    model.load_state_dict(torch.load(f"{ckpt_name}.pt"))
    model.eval()
    with torch.no_grad():
        scores = torch.sigmoid(model(combs)).cpu().numpy()
    return scores

In [7]:
with open('data/embedding/embeddings_DREAMwalk_msi_seed0.pkl', 'rb') as f:
    embeddings = pickle.load(f)

In [8]:
seed_lst = [42]

In [9]:
for seed in seed_lst:
    # random negative, dc case study
    rand_dc_score = get_dc_score(neg_dataset='random', exclude_list=case_dc_pair, embeddings=embeddings, seed=seed)
    # random negative, ddi case study
    rand_ddi_score = get_dc_score(neg_dataset='random', exclude_list=case_ddi_pair, embeddings=embeddings, seed=seed)

    # ddi negative, dc case study
    twosides_dc_score = get_dc_score(neg_dataset='TWOSIDES', exclude_list=case_dc_pair, embeddings=embeddings, seed=seed)
    # ddi negative, ddi case study
    twosides_ddi_score = get_dc_score(neg_dataset='TWOSIDES', exclude_list=case_ddi_pair, embeddings=embeddings, seed=seed)

    case_dc_df[f'score (random{seed})'] = rand_dc_score
    case_dc_df[f'score (twosides{seed})'] = twosides_dc_score
    case_ddi_df[f'score (random{seed})'] = rand_ddi_score
    case_ddi_df[f'score (twosides{seed})'] = twosides_ddi_score

data/processed/casestudy_DC_combined_embf(DREAMwalk)_neg(random_1)_seed42.pt already exists in processed/ directory.
Loading dataset...data/processed/casestudy_DC_combined_embf(DREAMwalk)_neg(random_1)_seed42.pt
Dictionary of {train, valid, test, whole} dataset is loaded.
Contra Epoch 001: | Train Loss: 0.5307
Validation loss decreased (inf --> 0.530664).  Saving model ...
Contra Epoch 002: | Train Loss: 0.4899
Validation loss decreased (0.530664 --> 0.489889).  Saving model ...
Contra Epoch 003: | Train Loss: 0.4800
Validation loss decreased (0.489889 --> 0.479976).  Saving model ...
Contra Epoch 004: | Train Loss: 0.4732
Validation loss decreased (0.479976 --> 0.473236).  Saving model ...
Contra Epoch 005: | Train Loss: 0.4723
Validation loss decreased (0.473236 --> 0.472347).  Saving model ...
Contra Epoch 006: | Train Loss: 0.4710
Validation loss decreased (0.472347 --> 0.471011).  Saving model ...
Contra Epoch 007: | Train Loss: 0.4707
Validation loss decreased (0.471011 --> 0.470

In [10]:
# case_dc_df.to_csv('case_study_dc_final.csv', index=False)
# case_ddi_df.to_csv('case_study_ddi_final.csv', index=False)

In [13]:
np.mean(case_dc_df['score (random42)']), np.std(case_dc_df['score (random42)'])

(0.7963714, 0.2746458947658539)

In [14]:
np.mean(case_dc_df['score (twosides42)']), np.std(case_dc_df['score (twosides42)'])

(0.90615875, 0.2390493005514145)

In [15]:
np.mean(case_ddi_df['score (random42)']), np.std(case_ddi_df['score (random42)'])

(0.44574577, 0.3382812440395355)

In [16]:
np.mean(case_ddi_df['score (twosides42)']), np.std(case_ddi_df['score (twosides42)'])

(0.07809865, 0.14183305203914642)

In [17]:
case_dc_df.tail(10)

Unnamed: 0,drug_1,drug_2,drug1_name,drug2_name,score (random42),score (twosides42)
396,DB01132,DB00912,Pioglitazone,Repaglinide,0.998631,0.922784
397,DB00421,DB00381,Spironolactone,Amlodipine,0.99718,0.976938
398,DB00302,DB09526,Tranexamic acid,Hydroquinone,0.091462,0.99811
399,DB01248,DB01211,Docetaxel,Clarithromycin,0.971726,0.994592
400,DB01076,DB01039,Atorvastatin,Fenofibrate,0.968693,0.999525
401,DB01095,DB01039,Fluvastatin,Fenofibrate,0.522131,0.995555
402,DB01098,DB01039,Rosuvastatin,Fenofibrate,0.921302,0.99576
403,DB01126,DB00706,Dutasteride,Tamsulosin,0.329342,0.986248
404,DB00654,DB00373,Latanoprost,Timolol,0.311942,0.958969
405,DB00235,DB00187,Milrinone,Esmolol,0.520863,0.957944


In [18]:
case_ddi_df.tail(10)

Unnamed: 0,drug_1,drug_2,drug1_name,drug2_name,score (random42),score (twosides42)
396,DB00690,DB00213,Flurazepam,Pantoprazole,0.182107,0.003466
397,DB00181,DB00745,Baclofen,Modafinil,0.949188,0.332245
398,DB00787,DB01611,Acyclovir,Hydroxychloroquine,0.880163,0.011064
399,DB00338,DB01238,Omeprazole,Aripiprazole,0.284575,0.127058
400,DB00252,DB00904,Phenytoin,Ondansetron,0.778398,0.003543
401,DB01026,DB00641,Ketoconazole,Simvastatin,0.954853,0.164367
402,DB00482,DB00575,Celecoxib,Clonidine,0.963811,0.252595
403,DB00745,DB00476,Modafinil,Duloxetine,0.994005,0.158354
404,DB00448,DB01033,Lansoprazole,Mercaptopurine,0.848539,0.004
405,DB00277,DB00983,Theophylline,Formoterol,0.904183,0.040363
