In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#! conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia
#! conda install -y -c conda-forge tqdm matplotlib

In [2]:
import sys, os
# import options_parser as op
import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader
import argparse
from torch.autograd import Variable
import pandas as pd


from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import Normalizer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from mlp import MLP_simple, MLP_batchnorm, MLP_rawcgc, MLP_w_mutation
from lantentDataset import LatentDataset
from tqdm import tqdm
SEED = 459
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Train an ANN to predict the log2-fold change vialbility for a given cell line+ mutations + drug

In [3]:
drug_latent = pd.read_csv('./data/drug_latent.csv')
encoded_cell_lines = pd.read_csv('./data/encoded_cell_lines.csv', index_col=0)
drug_resp = pd.read_csv('./data/primary-screen-replicate-collapsed-logfold-change_longFormat.csv', index_col=0)
mutations_df = pd.read_csv('../data/CCLE_muttion_final.csv', index_col=0)

  mask |= (ar1 == a)


In [4]:
cell_ids = drug_resp.cell_id.value_counts().index.values
train_cells, test_cells =  train_test_split(cell_ids, test_size=0.2)
train_cells, val_cells =  train_test_split(train_cells, test_size=0.1)

In [7]:
import pickle

data = {'train_cells':train_cells, 'val_cells':val_cells, 'test_cells':test_cells}
with open('./data/cells_split.p', 'wb') as f:
    pickle.dump(data, f)


In [9]:
trainDataset = LatentDataset(drug_resp[drug_resp.cell_id.isin(train_cells)].reset_index(drop=True),drug_latent,encoded_cell_lines, mutations_df, include_mutation=True  )
validationDataset = LatentDataset(drug_resp[drug_resp.cell_id.isin(val_cells)].reset_index(drop=True),drug_latent,encoded_cell_lines, mutations_df, include_mutation=True   )
testDataset = LatentDataset(drug_resp[drug_resp.cell_id.isin(test_cells)].reset_index(drop=True),drug_latent,encoded_cell_lines, mutations_df, include_mutation=True   )

trainLoader = DataLoader(trainDataset, batch_size=1000,num_workers=16, shuffle=True, drop_last=True)
validationLoader = DataLoader(validationDataset, batch_size=1000,num_workers=16, drop_last=True)
testLoader = DataLoader(testDataset, batch_size=1000,num_workers=16, drop_last=True)

In [10]:
len(testDataset),len(validationDataset), mutations_df.shape

(516123, 206399, (568, 18787))

In [28]:
drug_latent_size, ge_latent_size, mutat_size = 56, 1024, 18787
model = MLP_w_mutation(drug_latent_size, ge_latent_size, mutat_size)

In [29]:
learning_rate = 0.001

In [None]:
if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0005)


epoch = 0
bestR2 = -1
bestLoss = float("inf")
bestEpoch = 0
path = './trainedModels/'
num_epochs =100
pbar = tqdm(range(1,num_epochs+1))

for epoch in pbar:
    pbar.set_description(f"Epoch: {epoch}")
    train_loss = 0.
    rmse, r2 = 0,-1
    model.train()
    for batch in trainLoader:

        geLatentVec, dLatentVec, mutations, target = batch
        
        target = target.reshape(-1,1)
        # if geLatentVec.shape[0] != 50:
        #     continue

        if torch.cuda.is_available():
            geLatentVec = geLatentVec.cuda()
            dLatentVec = dLatentVec.cuda()
            mutations = mutations.cuda()
            target = target.cuda()
        else:
            geLatentVec = Variable(geLatentVec)
            dLatentVec = Variable(dLatentVec)
            mutations = Variable(mutations)
            target = Variable(target)
        out = model(geLatentVec, dLatentVec, mutations)
        loss = criterion(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().data.numpy() * len(target)
        
        pbar.set_postfix({'Train Loss': loss.cpu().data.numpy() * len(target), 'Validation Loss':rmse, 'R2_Score':r2, 'Best':bestLoss})
        
    train_loss = train_loss / len(trainLoader.dataset)
    pbar.set_postfix({'Train Loss': train_loss, 'Validation Loss':rmse, 'R2_Score':r2, 'Best':bestLoss})
    
    torch.save(model.state_dict(), path + f'MLP_w_mutation_{epoch}.pt')
    if epoch % 2 == 0:
        
        model.eval()

        for batch in validationLoader:
            geLatentVec, dLatentVec, mutations, target = batch
            target = target.reshape(-1,1)
            if torch.cuda.is_available():
                geLatentVec = geLatentVec.cuda()
                dLatentVec = dLatentVec.cuda()
                mutations = mutations.cuda()
                target = target.cuda()

            out = model(geLatentVec, dLatentVec, mutations)

            out = out.data.cpu().numpy().tolist()
            target = target.cpu().numpy().tolist()
            r2 += r2_score(target, out)
            rmse += mean_squared_error(target, out)**0.5
            # SS_tot = torch.std(target)
            # SS_res = evalLoss
            
            pbar.set_postfix({'Train Loss': train_loss, 'Validation Loss':rmse, 'R2_Score':r2, 'Best':bestLoss})
            
#             print('epoch: {}, Validation Loss: {:.6f}, R2_Score: {:.6f}'.format(epoch, rmse, r2))
        if (r2 > bestR2):
            bestLoss = rmse
            bestR2 = r2
            bestEpoch = epoch
            torch.save(model.state_dict(), path + 'MLP_w_mutation_best.pt')

Epoch: 11:  10%|█         | 10/100 [1:01:18<8:48:44, 352.49s/it, Train Loss=319, Validation Loss=0, R2_Score=-1, Best=117]      