In [1]:
import torch
import numpy as np

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.nn.functional as F

from model import CombNetRW
from dataset import CombinationDatasetRW
from MSI.load_msi_data import LoadData

from train_rw import EarlyStopping, train_contrastive, train_ce, evaluate, seed_everything

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, average_precision_score

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed_everything(42)

In [3]:
dataloader = LoadData()

# get_dict
drug_id2name, drug_name2id = dataloader.get_dict(type='drug')
ptn_id2name, ptn_name2id = dataloader.get_dict(type='protein')
ind_id2name, ind_name2id = dataloader.get_dict(type='indication')

In [4]:
# load drug pairs
dc_df = pd.read_csv('data/labels/DC_combined_msi.tsv', sep='\t') # drug combination
ddi_df = pd.read_csv('data/labels/TWOSIDES_msi.tsv', sep='\t') # drug drug interaction

db_df = pd.read_csv('database/drugbank_all_drug_links.csv')
db_id2name = dict(zip(db_df['DrugBank ID'], db_df['Name']))

# randomly sample 400 drug pairs for test set
case_dc_df = dc_df.sample(400, random_state=42)
case_ddi_df = ddi_df.sample(400, random_state=42)

# case study drug pairs
dc_more = pd.DataFrame([['DB01076', 'DB01039'], # atorvastatin & fenofibrate
                        ['DB01095', 'DB01039'], # fluvastatin & fenofibrate
                        ['DB01098', 'DB01039'], # rosuvastatin & fenofibrate
                        ['DB01126', 'DB00706'], # dutasteride & tamsulosin
                        ['DB00654', 'DB00373'], # lantanoprost & timolol
                        ['DB00235', 'DB00187'], # milrinone & esmolol
                        ], columns=['drug_1', 'drug_2'])

ddi_more = pd.DataFrame([['DB00252', 'DB00904'], # phenytoin & ondansetron
                         ['DB01026', 'DB00641'], # ketoconazole & simvastatin
                         ['DB00482', 'DB00575'], # celecoxib & clonidine
                         ['DB00745', 'DB00476'], # modafinil & duloxetine
                         ['DB00448', 'DB01033'], # lansoprazole & mercaptopurine
                         ['DB00277', 'DB00983'], # theophylline & formoterol
                         ], columns=['drug_1', 'drug_2'])

case_dc_df = pd.concat([case_dc_df, dc_more], axis=0)
case_ddi_df = pd.concat([case_ddi_df, ddi_more], axis=0)
case_dc_df.reset_index(drop=True, inplace=True)
case_ddi_df.reset_index(drop=True, inplace=True)

# add drug names
case_dc_df['drug1_name'] = case_dc_df['drug_1'].map(db_id2name)
case_dc_df['drug2_name'] = case_dc_df['drug_2'].map(db_id2name)
case_ddi_df['drug1_name'] = case_ddi_df['drug_1'].map(db_id2name)
case_ddi_df['drug2_name'] = case_ddi_df['drug_2'].map(db_id2name)

# get drug pairs
case_dc_pair = [set([drug1, drug2]) for drug1, drug2 in zip(case_dc_df['drug_1'], case_dc_df['drug_2'])]
case_ddi_pair = [set([drug1, drug2]) for drug1, drug2 in zip(case_ddi_df['drug_1'], case_ddi_df['drug_2'])]

In [5]:
def get_dc_score(kgfeat, neg_dataset, exclude_list, seed):
    '''
    args:
        kgfeat: [None, 'node2vec', 'edge2vec', 'res2vec_homo', 'res2vec_hetero', 'DREAMwalk']
        neg_dataset: ['random', 'TWOSIDES']
        exclude_list: list of drug pairs to exclude (case study drug pairs)

    '''
    dataset = CombinationDatasetRW(database='DC_combined', kgfeat=kgfeat, neg_ratio=1, neg_dataset=neg_dataset, seed=seed, exclude_list=exclude_list)
    train_dataset, valid_dataset, test_dataset = dataset['train'], dataset['valid'], dataset['test']
    valid_dataset = ConcatDataset([valid_dataset, test_dataset])

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)

    input_dim = train_dataset[0][0].shape[0] // 2
    hidden_dim = input_dim
    output_dim = 1

    epochs = 100
    ce_lr = 1e-3
    contra_lr = 1e-1
    weight_decay = 1e-5

    ckpt_name = f'ckpt/blackbox_DREAMwalk_prod_fc_{seed}'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # SCL pretraining stage
    contra_model = CombNetRW(input_dim, hidden_dim, output_dim, comb_type='prod_fc').to(device)
    contra_optimizer = torch.optim.Adam(contra_model.parameters(), lr=contra_lr, weight_decay=weight_decay)
    contra_early_stopping = EarlyStopping(patience=20, verbose=True, path=f"{ckpt_name}_contra.pt")

    for epoch in range(epochs):
        train_loss = train_contrastive(contra_model, device, train_loader, contra_optimizer)
        print(f'Contra Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f}')
        contra_early_stopping(train_loss, contra_model)
        if contra_early_stopping.early_stop:
            print("Early stopping")
            break

    del contra_model
    del contra_optimizer
    del contra_early_stopping
    torch.cuda.empty_cache()

    # CE finetuning stage
    model = CombNetRW(input_dim, hidden_dim, output_dim, comb_type='prod_fc').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=ce_lr, weight_decay=weight_decay)

    model.load_state_dict(torch.load(f"{ckpt_name}_contra.pt")) # load SCL pretrained weight

    criterion = nn.BCEWithLogitsLoss()
    early_stopping = EarlyStopping(patience=20, verbose=True, path=f"{ckpt_name}.pt")
    metric_list = [accuracy_score, roc_auc_score, f1_score, average_precision_score, precision_score, recall_score]

    for epoch in range(epochs):
        train_loss, train_scores = train_ce(model, device, train_loader, criterion, optimizer, metric_list)
        valid_loss, valid_scores = evaluate(model, device, valid_loader, criterion, metric_list)
        print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train Acc: {train_scores[0]*100:.2f}% | Train Precision: {train_scores[4]:.4f} | Train Recall: {train_scores[5]:.4f} || Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_scores[0]*100:.2f}% | Valid Precision: {valid_scores[4]:.4f} | Valid Recall: {valid_scores[5]:.4f}')
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    
    with open(f'data/embedding/embeddings_{kgfeat}_msi_seed0.pkl', 'rb') as f:
        kgfeat_dict = pickle.load(f)
    combs = []
    for pair in exclude_list:
        combs.append(torch.cat([torch.tensor(kgfeat_dict[drug], dtype=torch.float) for drug in pair], dim=0).unsqueeze(0))
    combs = torch.cat(combs, dim=0).to(device)
    print(f"Pairs shape: {combs.shape}")

    model.load_state_dict(torch.load(f"{ckpt_name}.pt"))
    model.eval()
    with torch.no_grad():
        scores = torch.sigmoid(model(combs)).cpu().numpy()
    return scores

In [6]:
seed_lst = [42]

In [7]:
for seed in seed_lst:
    # random negative, dc case study
    # rand_dc_score = get_dc_score(kgfeat='DREAMwalk', neg_dataset='random', exclude_list=case_dc_pair, seed=seed)
    # random negative, ddi case study
    # rand_ddi_score = get_dc_score(kgfeat='DREAMwalk', neg_dataset='random', exclude_list=case_ddi_pair, seed=seed)
    # TWOSIDES negative, dc case study
    twosides_dc_score = get_dc_score(kgfeat='DREAMwalk', neg_dataset='TWOSIDES', exclude_list=case_dc_pair, seed=seed)
    # TWOSIDES negative, ddi case study
    # twosides_ddi_score = get_dc_score(kgfeat='DREAMwalk', neg_dataset='TWOSIDES', exclude_list=case_ddi_pair, seed=seed)
    
    # case_dc_df[f'score (random{seed})'] = rand_dc_score
    # case_ddi_df[f'score (random{seed})'] = rand_ddi_score
    case_dc_df[f'score (TWOSIDES{seed})'] = twosides_dc_score
    # case_ddi_df[f'score (TWOSIDES{seed})'] = twosides_ddi_score

data/processed/casestudy_DC_combined_kgfeat(DREAMwalk)_chemfeat(None)_neg(TWOSIDES_1)_seed42.pt already exists in processed/ directory.
Loading dataset...data/processed/casestudy_DC_combined_kgfeat(DREAMwalk)_chemfeat(None)_neg(TWOSIDES_1)_seed42.pt
Dictionary of {train, valid, test, whole} dataset is loaded.


Contra Epoch 001: | Train Loss: 0.5122
Validation loss decreased (inf --> 0.512247).  Saving model ...
Contra Epoch 002: | Train Loss: 0.4722
Validation loss decreased (0.512247 --> 0.472213).  Saving model ...
Contra Epoch 003: | Train Loss: 0.4602
Validation loss decreased (0.472213 --> 0.460177).  Saving model ...
Contra Epoch 004: | Train Loss: 0.4539
Validation loss decreased (0.460177 --> 0.453910).  Saving model ...
Contra Epoch 005: | Train Loss: 0.4451
Validation loss decreased (0.453910 --> 0.445118).  Saving model ...
Contra Epoch 006: | Train Loss: 0.4451
Validation loss decreased (0.445118 --> 0.445095).  Saving model ...
Contra Epoch 007: | Train Loss: 0.4466
EarlyStopping counter: 1 out of 20
Contra Epoch 008: | Train Loss: 0.4382
Validation loss decreased (0.445095 --> 0.438164).  Saving model ...
Contra Epoch 009: | Train Loss: 0.4365
Validation loss decreased (0.438164 --> 0.436510).  Saving model ...
Contra Epoch 010: | Train Loss: 0.4254
Validation loss decreased (0

In [8]:
case_dc_df.tail(10)

Unnamed: 0,drug_1,drug_2,drug1_name,drug2_name,score (TWOSIDES42)
396,DB01132,DB00912,Pioglitazone,Repaglinide,0.978839
397,DB00421,DB00381,Spironolactone,Amlodipine,0.986407
398,DB00302,DB09526,Tranexamic acid,Hydroquinone,0.999938
399,DB01248,DB01211,Docetaxel,Clarithromycin,0.999068
400,DB01076,DB01039,Atorvastatin,Fenofibrate,0.999926
401,DB01095,DB01039,Fluvastatin,Fenofibrate,0.998363
402,DB01098,DB01039,Rosuvastatin,Fenofibrate,0.998648
403,DB01126,DB00706,Dutasteride,Tamsulosin,0.9909
404,DB00654,DB00373,Latanoprost,Timolol,0.989096
405,DB00235,DB00187,Milrinone,Esmolol,0.992646


In [9]:
# case_ddi_df.tail(10)

## Load model and project protein or indication

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open('data/embedding/embeddings_DREAMwalk_msi_seed0.pkl', 'rb') as f:
    kgfeat_dict = pickle.load(f)

In [8]:
model = CombNetRW(128, 128, 1, comb_type='prod_fc').to(device)
model.load_state_dict(torch.load(f"ckpt/blackbox_DREAMwalk_prod_fc_42.pt"))
model.eval()

CombNetRW(
  (tr): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (cosine_similarity): CosineSimilarity()
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [9]:
# project all proteins
ptn_lst = list(ptn_id2name.keys())
ptn_emb = torch.tensor([kgfeat_dict[ptn] for ptn in ptn_lst], dtype=torch.float).to(device)

  ptn_emb = torch.tensor([kgfeat_dict[ptn] for ptn in ptn_lst], dtype=torch.float).to(device)


In [10]:
ptn_emb.shape

torch.Size([17660, 128])

In [11]:
ptn_proj = model.project_single_entity(ptn_emb).cpu().detach().numpy()

In [12]:
ptn_proj.shape

(17660, 128)

## Compute similarity between drug pair and proteins

In [13]:
combs = []
for pair in case_dc_pair[-6:]:
    combs.append(torch.cat([torch.tensor(kgfeat_dict[drug], dtype=torch.float) for drug in pair], dim=0).unsqueeze(0))
combs = torch.cat(combs, dim=0).to(device)

In [14]:
combs.shape

torch.Size([6, 256])

In [15]:
comb_proj = model.project_pair_entity(combs).cpu().detach().numpy()

In [16]:
comb_proj.shape

(6, 128)