In [1]:
# === Fixed === #
import os
import random
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

from sklearn.metrics import mean_squared_error
from lifelines.utils import concordance_index
from scipy.stats import pearsonr,spearmanr
import scipy.stats as stats

# === Task-specific === #
import sys
sys.path.append('..')
from copy import deepcopy
from dataset import NPweightingDataSet
from utils import *
from models import *
from trainers import logging, train, validate, test

In [None]:
# ====== Argument Parsing ====== #
parser = argparse.ArgumentParser()
parser.add_argument('--WORKDIR_PATH', type=str, default="E:/Task/Task1_deep_learning_IC50")
#parser.add_argument('--DATASET_PATH', type=str, default='/data/project/minwoo/Drug_recommendation/NetGP/Data')
parser.add_argument('--inputdir', type=str, default='E:/Task/Task1_deep_learning_IC50/Drug Response/GDSC/Data/model_input')
parser.add_argument('--mut_input_fpath', type=str, default='E:/Task/Task1_deep_learning_IC50/Drug Response/GDSC/Data/model_input/GDSC/GDSC_mutation_input.csv')
parser.add_argument('--drug_input_fpath', type=str, default='E:/Task/Task1_deep_learning_IC50/Drug Response/GDSC/Data/model_input/GDSC/GDSC_SMILE_input.csv')
parser.add_argument('--exprs_input_fpath', type=str, default='E:/Task/Task1_deep_learning_IC50/Drug Response/GDSC/Data/model_input/GDSC/GDSC_ssgsea_input.csv')
parser.add_argument('--device', type=str, default='0')
parser.add_argument('--model_name', type=str, default='my')
parser.add_argument('--split_type', type=str, default='both', choices=['cell','drug','both','mix'])
parser.add_argument('--response_type', type=str, default='IC50', choices=['IC50', 'AUC'])

# === Train setting === #
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--testset_yes', type=bool, default=True)

args = parser.parse_known_args()[0]

# === Model Setting === #
args.code_dim = 30
args.drug_hidden_dims = [300, 100]
args.mut_hidden_dims = [300, 100]
args.code_dropout = True
args.code_dropout_rate = 0.2

args.forward_net_hidden_dim1 = 98
args.forward_net_hidden_dim2 = 98
args.forward_net_out_act = None

args.y_loss_weight=1.
args.drug_reconstruction_loss_weight=0.2
args.mut_reconstruction_loss_weight=0.3

device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print('torch version: ', torch.__version__)
print(device)

In [5]:
def experiment(args, dataset_partition, model, loss_fn, device):
    
    optimizer = optim.Adam([
            {'params': model.parameters()},
            {'params': loss_fn.parameters()}
        ], lr=args.learning_rate, weight_decay=args.weight_decay)


    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10, verbose=True)
    
    # ====== Cross Validation Best Performance Dict ====== #
    best_performances = {}
    best_performances['best_epoch'] = 0
    best_performances['best_train_loss'] = float('inf')
    best_performances['best_train_corr'] = 0.0
    best_performances['best_valid_loss'] = float('inf')
    best_performances['best_valid_corr'] = 0.0
    # ==================================================== #
    
    list_epoch = []
    list_train_epoch_loss = []
    list_epoch_rmse = []
    list_epoch_corr = []
    list_epoch_spearman = []
    list_epoch_ci = []

    list_val_epoch_loss = []
    list_val_epoch_rmse = []
    list_val_epoch_corr = []
    list_val_spearman = []
    list_val_ci = []
    
    counter = 0

    for epoch in range(args.epochs):
        list_epoch.append(epoch)
        
        # ====== TRAIN Epoch ====== #
       
        model, list_train_batch_loss, list_train_batch_out, list_train_batch_true = train(model, epoch, train_loader, optimizer, loss_fn, device)
        
        epoch_train_rmse = np.sqrt(mean_squared_error(np.array(list_train_batch_out).squeeze(1), np.array(list_train_batch_true).squeeze(1)))

        epoch_train_corr, _p = pearsonr(np.array(list_train_batch_out).squeeze(1), np.array(list_train_batch_true).squeeze(1))
   
        epoch_train_spearman, _p = spearmanr(np.array(list_train_batch_out).squeeze(1), np.array(list_train_batch_true).squeeze(1))
  
        epoch_train_ci = concordance_index(np.array(list_train_batch_out).squeeze(1), np.array(list_train_batch_true).squeeze(1))

        train_epoch_loss = sum(list_train_batch_loss) / len(list_train_batch_loss)
        
        list_train_epoch_loss.append(train_epoch_loss)
        list_epoch_rmse.append(epoch_train_rmse)
        list_epoch_corr.append(epoch_train_corr)
        list_epoch_spearman.append(epoch_train_spearman)
        list_epoch_ci.append(epoch_train_ci)
        
        # ====== VALID Epoch ====== #
        
        list_val_batch_loss, list_val_batch_out, list_val_batch_true = validate(model, valid_loader, loss_fn, device)
        
        epoch_val_rmse = np.sqrt(mean_squared_error(np.array(list_val_batch_out).squeeze(1), np.array(list_val_batch_true).squeeze(1)))
        epoch_val_corr, _p = pearsonr(np.array(list_val_batch_out).squeeze(1), np.array(list_val_batch_true).squeeze(1))
        epoch_val_spearman, _p = spearmanr(np.array(list_val_batch_out).squeeze(1), np.array(list_val_batch_true).squeeze(1))
        epoch_val_ci = concordance_index(np.array(list_val_batch_out).squeeze(1), np.array(list_val_batch_true).squeeze(1))
        
        val_epoch_loss = sum(list_val_batch_loss)/len(list_val_batch_loss)
        list_val_epoch_loss.append(val_epoch_loss)
        list_val_epoch_rmse.append(epoch_val_rmse)
        list_val_epoch_corr.append(epoch_val_corr)
        list_val_spearman.append(epoch_val_spearman)
        list_val_ci.append(epoch_val_ci)
        
        if val_epoch_loss < best_performances['best_valid_loss']:
            best_performances['best_epoch'] = epoch
            best_performances['best_train_loss'] = train_epoch_loss
            best_performances['best_train_corr'] = epoch_train_corr
            best_performances['best_valid_loss'] = val_epoch_loss
            best_performances['best_valid_corr'] = epoch_val_corr
            torch.save(model, os.path.join(args.outdir, args.exp_name + f'.model'))
            model_max = deepcopy(model) 
            counter = 0
        else:
            counter += 1
            logging(f'Early Stopping counter: {counter} out of {args.patience}', args.outdir, args.exp_name+'.log')

        logging(f'Epoch: {epoch:02d}, Train loss: {list_train_epoch_loss[-1]:.4f}, rmse: {epoch_train_rmse:.4f}, corr: {epoch_train_corr:.4f}, Valid loss: {list_val_epoch_loss[-1]:.4f}, rmse: {epoch_val_rmse:.4f}, pcc: {epoch_val_corr:.4f}', args.outdir, args.exp_name+'.log')

        if counter == args.patience:
            break
        scheduler.step(list_val_epoch_loss[-1])
    
    if args.testset_yes:
        test_loss,test_rmse, test_corr, test_spearman, test_ci, list_test_loss,list_test_output_loss,list_test_out, list_test_true = test(model_max, test_loader, loss_fn, device) 
        logging(f"Test:\tLoss: {test_loss}\tRMSE: {test_rmse}\tCORR: {test_corr}\tSPEARMAN: {test_spearman}\tCI: {test_ci}", args.outdir, f'{args.exp_name}_test.log')

        response_df = test_set.response_df
        response_df['test_loss'] = list_test_loss
        response_df['output_loss'] = list_test_output_loss
        response_df['test_pred'] = list_test_out
        response_df['test_true'] = list_test_true
 
        filename = os.path.join(args.outdir, f'{args.exp_name}_test.csv')
        response_df.to_csv(filename, sep=',', header=True, index=False)
    
    # ====== Add Result to Dictionary ====== #
    
    result = {}
    result['train_losses'] = list_train_epoch_loss
    result['val_losses'] = list_val_epoch_loss
    result['train_accs'] = list_epoch_corr
    result['val_accs'] = list_val_epoch_corr
    result['train_acc'] = epoch_train_corr
    result['val_acc'] = epoch_val_corr
    if args.testset_yes:
        result['test_acc'] = test_corr

    filename = os.path.join(args.outdir, f'{args.exp_name}_best_performances.json')
    with open(filename, 'w') as f:
        json.dump(best_performances, f)

    return vars(args), result, best_performances, model_max

In [None]:
# ====== Experiment  ====== #
total_results = defaultdict(list)  
best_best_epoch = 0
best_best_train_loss = 99.
best_best_train_metric = 0
best_best_valid_loss = 99.
best_best_valid_metric = 0

if __name__ == '__main__':

    args.exp_name = f'{args.model_name}_{args.split_type}'
    args.outdir = os.path.join(args.WORKDIR_PATH, 'Results', args.exp_name)

    createFolder(args.outdir)
    
    # =============== #
    # === Dataset === #
    # =============== #

    train_set = NPweightingDataSet(
                response_fpath=os.path.join(args.inputdir,f'GDSC',f'cv_{args.split_type}', f'GDSC_train_IC50_by_{args.split_type}_cv00.csv'), 
                drug_input=args.drug_input_fpath, 
                exprs_input=args.exprs_input_fpath, 
                mut_input=args.mut_input_fpath,
                response_type=args.response_type)
    valid_set = NPweightingDataSet(
                response_fpath=os.path.join(args.inputdir,f'GDSC', f'cv_{args.split_type}', f'GDSC_valid_IC50_by_{args.split_type}_cv00.csv'), 
                drug_input=args.drug_input_fpath, 
                exprs_input=args.exprs_input_fpath, 
                mut_input=args.mut_input_fpath,
                response_type=args.response_type)
    test_set  = NPweightingDataSet(
                response_fpath=os.path.join(args.inputdir,f'GDSC',f'cv_{args.split_type}', f'GDSC_test_IC50_by_{args.split_type}_cv00.csv'), 
                drug_input=args.drug_input_fpath, 
                exprs_input=args.exprs_input_fpath, 
                mut_input=args.mut_input_fpath,
                response_type=args.response_type)
    
    # === input === #
    args.drug_in_dim = train_set.drug_fp_df.shape[1] - 2  
    args.cell_in_dim = train_set.cell_exprs_df.shape[1] - 2
    args.mut_score_dim = train_set.mut_score_df.shape[1] - 2
    
    print("-----------TRAIN DATASET-----------")
    print("NUMBER OF DATA:", train_set.__len__()) 
    print("-----------VALID DATASET-----------")
    print("NUMBER OF DATA:", valid_set.__len__())
    print("-----------TEST  DATASET-----------")
    print("NUMBER OF DATA:", test_set.__len__())

    # === Data Set/Loader === #
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=False,pin_memory = True,num_workers=12)
    valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=args.batch_size, shuffle=False, drop_last=False,pin_memory = True,num_workers=12)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, drop_last=False,pin_memory = True,num_workers=12)
    
    
    
    dataset_partition = {
        'train_loader': train_loader,
        'valid_loader': valid_loader,
        'test_loader' : test_loader
    }

    # ============= #
    # === Model === #
    # ============= #

    drug_autoencoder = DeepAutoencoderThreeHiddenLayers(input_dim=args.drug_in_dim, 
                                    hidden_dims=args.drug_hidden_dims, 
                                    code_dim=args.code_dim,
                                    code_activation=True, dropout=args.code_dropout, dropout_rate=args.code_dropout_rate)

    mut_autoencoder = DeepAutoencoderThreeHiddenLayers(input_dim=args.mut_score_dim, 
                                        hidden_dims=args.mut_hidden_dims, 
                                        code_dim=args.code_dim,
                                        code_activation=True, dropout=args.code_dropout, dropout_rate=args.code_dropout_rate)
                                        
    # Forward network
    net = ForwardNetworkTwoHiddenLayers((2 * args.code_dim +38), 
                                            args.forward_net_hidden_dim1,
                                            args.forward_net_hidden_dim2,
                                            out_activation=args.forward_net_out_act)
                                            
    # Make the model together
    model = DEERS_Concat(drug_autoencoder=drug_autoencoder, 
                            mut_line_autoencoder=mut_autoencoder, 
                            forward_network=net).to(device)

    # === Loss === #
    loss_fn = MergedLoss(y_loss_weight=args.y_loss_weight,
                            drug_reconstruction_loss_weight=args.drug_reconstruction_loss_weight,
                            mut_reconstruction_loss_weight=args.mut_reconstruction_loss_weight,
                            )
    
    # =============== #
    # ===== Run ===== #
    # =============== #
    setting, result, best_performances, model_max = experiment(args, dataset_partition, model, loss_fn, device)
    save_exp_result(setting, result, args.outdir)
    
    if best_performances['best_valid_corr'] >= best_best_valid_metric:
        best_best_epoch = best_performances['best_epoch']
        best_best_train_loss = best_performances['best_train_loss']
        best_best_train_metric = best_performances['best_train_corr']
        best_best_valid_loss = best_performances['best_valid_loss']
        best_best_valid_metric = best_performances['best_valid_corr']
        best_setting = setting
        best_result = result
    
    total_results['best_epoch'].append(best_performances['best_epoch'])
    total_results['best_train_loss'].append(best_performances['best_train_loss'])
    total_results['best_train_corr'].append(best_performances['best_train_corr'])
    total_results['best_valid_loss'].append(best_performances['best_valid_loss'])
    total_results['best_valid_corr'].append(best_performances['best_valid_corr'])
    pass
pass


        
print(f'Best Train Loss: {best_best_train_loss:.4f}')
print(f'Best Train Corr: {best_best_train_metric:.4f}')
print(f'Best Valid Loss: {best_best_valid_loss:.4f}')
print(f'Best Valid Corr: {best_best_valid_metric:.4f}')
