In [1]:
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pandas as pd

from utils.utils import *
from utils.utils_train import *
from utils.utils_data import load_data_from_df, construct_loader

from models.CSG2A_net import CSG2A_net

In [2]:
def str2bool(v):
    return v.lower() in ('true', '1')

parser = argparse.ArgumentParser()

parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--batchsize', type=int, default=128)
parser.add_argument('--gene_hdim', type=int, default=64)

parser.add_argument('--valid_ratio', type=float, default=0.1)
parser.add_argument('--test_ratio', type=float, default=0.1)

parser.add_argument('--patience', type=int, default=20)

parser.add_argument('--lr', type=float, default=1e-4)

parser.add_argument('--data_dir', type=str, default='./data/Transcriptome_toy/')
parser.add_argument('--mat_pretrainf', type=str, default='./ckpts/mat_pretrained_weights.pt')

args = parser.parse_args([])

In [3]:
model_name = 'CSG2A_pretrain_test'

In [4]:
logger = Logger(model_name)
logger('='*50)
logger(f'start training {logger.date}_{logger.model_name}')

[07:39:49] start training 240131_CSG2A_pretrain_test


In [5]:
logger(args)
set_seed(args.seed,logger)

[07:39:49] Namespace(device='cuda:0', seed=42, dropout=0.1, batchsize=128, gene_hdim=64, valid_ratio=0.1, test_ratio=0.1, patience=20, lr=0.0001, data_dir='./data/Transcriptome_toy/', mat_pretrainf='./ckpts/mat_pretrained_weights.pt')
[07:39:49] random seed with 42


In [6]:
chemical_feat = load_data_from_df(args.data_dir+'condition_table.csv', smiles_column = 'canonical_smiles')

In [7]:
condition_df = pd.read_csv(args.data_dir+'condition_table.csv')
gex_DMSO = pd.read_csv(args.data_dir+'gex_DMSO.csv')
gex_comp = pd.read_csv(args.data_dir+'gex_comp.csv')

In [8]:
gex_comp

Unnamed: 0,DDR1,PAX8,RPS5,ABCF1,SPAG7,RHOA,RNPS1,SMNDC1,ATP6V0B,RPS6,...,P4HTM,SLC27A3,TBXA2R,RTN2,GFUS,PPARD,GNA11,WDTC1,PLSCR3,NPEPL1
0,0.659634,-0.903503,-0.580669,0.674749,-0.959752,-1.254497,1.197722,-0.388186,1.678402,-0.690094,...,0.821481,0.649951,0.466549,0.094080,0.929152,0.024015,0.398713,-2.429915,1.230694,1.087043
1,-0.524662,-1.023540,-0.271726,1.472793,-0.208879,-0.489726,-0.756635,-0.662601,0.227877,-0.409285,...,-0.156210,-0.029926,0.658519,-0.938777,-0.077835,0.050625,-0.066842,0.280632,-0.457598,0.755879
2,-0.448470,-0.727050,-0.811899,-1.164333,0.017085,0.001311,-0.045000,0.863134,1.418147,-0.044783,...,0.625126,0.012461,0.533003,-0.175997,0.578045,-1.149621,-0.046796,-0.525133,0.240329,-0.061210
3,0.824797,0.025431,-0.096312,-0.046504,0.210600,-0.696141,-0.530503,-0.796290,-0.361782,-0.056279,...,-0.080014,-1.100904,0.277633,1.750473,-0.199547,0.378748,-0.614527,-0.424776,0.492837,1.272370
4,2.075069,-0.155304,-1.843406,1.740102,-0.621561,-5.090036,-1.338589,1.908388,-0.400165,-0.960552,...,1.785844,3.662040,2.482540,2.262325,-2.178323,4.328634,-4.337384,-0.822997,-2.267418,4.080919
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,1.790350,0.459250,-0.363200,1.512900,-0.743400,-0.655950,-0.466650,0.797150,-1.821450,-0.000350,...,0.205750,0.708550,1.628100,1.175150,-0.758900,0.978050,-0.022400,-0.791700,-0.367250,1.837600
996,-0.721876,1.082948,0.279794,-0.595703,0.717329,0.310753,-0.230457,-0.384464,-0.199041,0.000000,...,-0.676923,-0.641107,0.125286,0.570545,0.215509,1.224879,0.023165,-0.220200,-0.099255,-0.061096
997,-0.320603,0.338305,-0.079068,-0.565042,-0.218407,0.058194,0.068526,0.066677,0.642542,0.000000,...,-0.581052,-0.268545,-0.499938,-0.482092,-0.021250,-0.719707,-0.191934,1.150662,-0.758694,-0.671634
998,0.824646,0.384784,0.211290,1.098780,0.100735,1.183465,0.662965,4.527091,1.008273,-0.000490,...,0.558384,-0.041820,1.988617,1.109972,-1.847901,1.202578,-1.801144,-2.210637,-1.776304,-2.893174


In [9]:
train_loader, valid_loader, test_loader = construct_loader(chemical_feat, gex_DMSO.values, gex_comp.values,
                                                           (condition_df['dose']/100).values,
                                                           (condition_df['time']/72).values, # scale by longest time: GDSC
                                                           batch_size = args.batchsize,
                                                           valid_ratio = args.valid_ratio, 
                                                           test_ratio = args.test_ratio,
                                                           seed=args.seed)

In [10]:
# ppi adj processing
genes = gex_DMSO.columns

STRING_df = pd.read_csv(args.data_dir+'../STRING_edges.csv')

STRING_df = STRING_df[(STRING_df['source'].isin(genes)) & (STRING_df['target'].isin(genes))]
STRING_df['source_idx'] = STRING_df['source'].map(lambda x: genes.get_loc(x))
STRING_df['target_idx'] = STRING_df['target'].map(lambda x: genes.get_loc(x))

ppi_adj = torch.eye(len(genes))

for pair in STRING_df[['target_idx','source_idx']].values:
    ppi_adj[tuple(pair)] = 1
    ppi_adj[tuple(pair[::-1])] = 1

In [12]:
# generate ckpts folder
if not os.path.exists('ckpts'):
    os.makedirs('ckpts')

modelf = f'ckpts/{logger.date}_{logger.model_name}.pt'
early_stopper = EarlyStopper(path = modelf, printfunc = logger, patience = args.patience)

In [13]:
model = CSG2A_net(gex_dim = len(genes), hdim = args.gene_hdim, dropout = args.dropout, 
                  ppi_adj = ppi_adj.to(args.device), bias=False).to(args.device)

In [14]:
if args.mat_pretrainf is not None:
    pretrained_state_dict = torch.load(args.mat_pretrainf)

    model_state_dict = model.CCE.MAT.state_dict()
    logger(f'Loading pretrained MAT weights: {args.mat_pretrainf}')
    for name, param in pretrained_state_dict.items():
        if 'generator' in name:
                continue
        if isinstance(param, torch.nn.Parameter):
            param = param.data
        model_state_dict[name].copy_(param)
else:
    logger('No pretrained MAT weights loaded')

[07:39:52] Loading pretrained MAT weights: ./ckpts/mat_pretrained_weights.pt


In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.MSELoss()

In [19]:
epoch = 1
while True:
    train_loss = train(model, train_loader, optimizer, criterion, epoch, args, logger)
    valid_loss, valid_pcc = eval(model, valid_loader, criterion, args)
    logger(f'train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}, valid_pcc: {valid_pcc:.4f}')
    early_stopper(valid_loss, model)
    if early_stopper.early_stop:
        break
    epoch += 1

[07:40:42] - batch1/7 of epoch1, loss: 1.3418649435043335
[07:40:46] train_loss: 1.3619, valid_loss: 1.2475, valid_pcc: 0.0890
[07:40:46] Validation loss decreased (inf --> 1.2475).  Saving model ...
[07:40:50] train_loss: 1.3197, valid_loss: 1.2344, valid_pcc: 0.1199
[07:40:50] Validation loss decreased (1.2475 --> 1.2344).  Saving model ...
[07:41:04] train_loss: 1.3233, valid_loss: 1.2165, valid_pcc: 0.1499
[07:41:04] Validation loss decreased (1.2344 --> 1.2165).  Saving model ...
[07:41:19] train_loss: 1.3359, valid_loss: 1.1987, valid_pcc: 0.1751
[07:41:19] Validation loss decreased (1.2165 --> 1.1987).  Saving model ...
[07:41:36] train_loss: 1.2857, valid_loss: 1.1896, valid_pcc: 0.1901
[07:41:36] Validation loss decreased (1.1987 --> 1.1896).  Saving model ...
[07:41:50] train_loss: 1.2720, valid_loss: 1.1813, valid_pcc: 0.2033
[07:41:50] Validation loss decreased (1.1896 --> 1.1813).  Saving model ...
[07:42:05] train_loss: 1.2362, valid_loss: 1.1787, valid_pcc: 0.2106
[07:42

In [20]:
# load best model
model.load_state_dict(torch.load(modelf, map_location=args.device))

test_loss, test_pcc = eval(model, test_loader, criterion, args)
logger(f'test_loss: {test_loss:.4f}, test_pcc: {test_pcc:.4f}')

[07:44:05] test_loss: 1.2753, test_pcc: 0.2284
