In [1]:
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pandas as pd

from utils.utils import *
from utils.utils_train import *
from utils.utils_data import load_data_from_df, construct_loader

from models.CSG2A_net import CSG2A_finetune

In [2]:
def str2bool(v):
    return v.lower() in ('true', '1')

parser = argparse.ArgumentParser()

parser.add_argument('--device', type=str, default='cuda:1')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--batchsize', type=int, default=128)
parser.add_argument('--gene_hdim', type=int, default=64)

parser.add_argument('--valid_ratio', type=float, default=0.1)
parser.add_argument('--test_ratio', type=float, default=0.1)

parser.add_argument('--patience', type=int, default=20)

parser.add_argument('--lr_init', type=float, default=1e-4)
parser.add_argument('--lr_final', type=float, default=1e-5)

parser.add_argument('--data_dir', type=str, default='./data/Cell_viability_toy/')
parser.add_argument('--mat_pretrainf', type=str, default=None)
parser.add_argument('--CSG2A_pretrainf', type=str, default='ckpts/CSG2A_pretrain_test.pt')
parser.add_argument('--Freeze_pretrained_params', type=str2bool, default=True)

parser.add_argument('--finetune_hdim1', type=int, default=512)
parser.add_argument('--finetune_hdim2', type=int, default=64)

args = parser.parse_args([])

In [3]:
model_name = 'CSG2A_finetune_GDSC_test'

In [4]:
logger = Logger(model_name)
logger('='*50)
logger(f'start training {logger.model_name}')

[08:30:30] start training CSG2A_finetune_GDSC_test


In [5]:
logger(args)
set_seed(args.seed,logger)

[08:30:30] Namespace(device='cuda:1', seed=42, dropout=0.1, batchsize=128, gene_hdim=64, valid_ratio=0.1, test_ratio=0.1, patience=20, lr_init=0.0001, lr_final=1e-05, data_dir='./data/Cell_viability_toy/', mat_pretrainf=None, CSG2A_pretrainf='ckpts/240131_CSG2A_pretrain_test.pt', Freeze_pretrained_params=True, finetune_hdim1=512, finetune_hdim2=64)
[08:30:30] random seed with 42


In [6]:
chemical_feat = load_data_from_df(args.data_dir+'label.csv', smiles_column = 'canonical_smiles')

In [7]:
label_df = pd.read_csv(args.data_dir+'label.csv')
gex_before = pd.read_csv(args.data_dir+'gex_before.csv')

In [8]:
gex_before

Unnamed: 0,DDR1,PAX8,RPS5,ABCF1,SPAG7,RHOA,RNPS1,SMNDC1,ATP6V0B,RPS6,...,P4HTM,SLC27A3,TBXA2R,RTN2,GFUS,PPARD,GNA11,WDTC1,PLSCR3,NPEPL1
0,-0.312002,-0.338365,-1.075805,0.670894,-0.723634,0.555828,-1.014549,-0.018791,0.754495,-0.563630,...,-0.335459,0.542933,2.744097,-0.461191,-0.444043,2.072120,-0.137510,-0.269562,-0.580998,-0.345085
1,0.058235,-0.414564,-0.800238,-0.078168,0.005980,-0.880170,1.071168,-1.330307,-1.760332,0.431784,...,1.499339,-0.224781,0.779428,2.140008,-1.558703,-1.591148,0.153446,-0.778635,0.497691,-1.011579
2,0.377200,0.422928,-0.507615,-0.122720,0.060485,0.985847,-0.232876,-0.240565,-0.198889,0.139608,...,-0.286461,-0.893090,1.820546,0.841572,0.112873,0.469826,-0.143019,-0.914002,0.951545,0.431429
3,-1.112858,-0.356872,0.920448,1.009038,-0.062330,-0.717800,0.539088,-1.252922,0.456988,0.656263,...,1.331777,-0.010417,-0.185400,-0.252697,0.511969,-0.402262,-1.937517,0.148372,0.212277,-0.579762
4,-1.244757,2.484428,-0.403377,-0.872755,-0.134601,0.430525,0.364445,0.516991,-0.388525,-0.083832,...,-0.734504,-0.902987,0.863945,0.366114,-0.085412,-0.142790,0.431350,-1.015167,-0.024406,-0.547580
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1223,0.451113,0.116617,-0.911532,0.439356,0.564247,-1.356667,0.938602,-0.407797,1.051078,-0.620631,...,0.549803,1.051863,-0.147007,1.048620,-1.246432,-1.110128,0.523925,0.946273,-0.226081,0.288084
1224,0.852975,-0.338365,-0.403882,0.435459,-0.076270,-0.214042,-1.148415,-1.010181,0.360228,0.202169,...,-0.347042,0.124023,0.586177,0.407158,-0.333841,0.781548,-0.515870,-0.426270,0.518706,-1.424627
1225,-1.368584,0.927844,1.199181,-1.130464,2.358451,0.097199,-0.036285,1.476893,0.526318,-0.099635,...,-0.331622,-0.091133,0.266691,-0.107206,0.117611,-0.591890,-0.836258,-1.571975,0.857795,1.165573
1226,-1.300526,-0.385298,0.782418,0.530788,0.536797,0.651323,0.000000,0.881306,-1.157269,1.016932,...,-0.873785,-0.052881,-0.325499,-1.436990,-0.681343,-0.201600,-2.320165,1.204619,-0.668314,1.147079


In [9]:
train_loader, valid_loader, test_loader = construct_loader(chemical_feat, gex_before.values, label_df['IC50'].values,
                                                           np.array([0.1]*len(gex_before)), # 10 uM
                                                           np.array([1]*len(gex_before)), # 72 hours
                                                           batch_size = args.batchsize,
                                                           valid_ratio = args.valid_ratio, 
                                                           test_ratio = args.test_ratio,
                                                           seed=args.seed)

In [10]:
# ppi adj processing
genes = gex_before.columns

STRING_df = pd.read_csv(args.data_dir+'../STRING_edges.csv')

STRING_df = STRING_df[(STRING_df['source'].isin(genes)) & (STRING_df['target'].isin(genes))]
STRING_df['source_idx'] = STRING_df['source'].map(lambda x: genes.get_loc(x))
STRING_df['target_idx'] = STRING_df['target'].map(lambda x: genes.get_loc(x))

ppi_adj = torch.eye(len(genes))

for pair in STRING_df[['target_idx','source_idx']].values:
    ppi_adj[tuple(pair)] = 1
    ppi_adj[tuple(pair[::-1])] = 1

# Build Model

In [12]:
CSG2A_params = {'gex_dim': len(genes),
                'hdim': args.gene_hdim,
                'dropout': args.dropout,
                'ppi_adj': ppi_adj.to(args.device),
                'bias': False}

In [13]:
model = CSG2A_finetune(gex_dim = len(genes), CSG2A_params = CSG2A_params, 
                       finetune_hdim1 = args.finetune_hdim1, finetune_hdim2 = args.finetune_hdim2,
                       dropout = args.dropout).to(args.device)

In [14]:
if args.CSG2A_pretrainf != None:
    pretrained_state_dict = torch.load(args.CSG2A_pretrainf, map_location = args.device)
    model_state_dict = model.CSG2A.state_dict()
    logger(f'Loading pretrained CSG2A weights: {args.CSG2A_pretrainf}')
    for name, param in pretrained_state_dict.items():
        if isinstance(param, torch.nn.Parameter):
            param = param.data
        model_state_dict[name].copy_(param)

elif args.mat_pretrainf != None:
    pretrained_state_dict = torch.load(args.mat_pretrainf, map_location = args.device)
    model_state_dict = model.CSG2A.CCE.MAT.state_dict()
    logger(f'Loading pretrained MAT weights: {args.mat_pretrainf}')
    for name, param in pretrained_state_dict.items():
        if 'generator' in name:
            continue
        if isinstance(param, torch.nn.Parameter):
            param = param.data
        model_state_dict[name].copy_(param)

else:
    logger('No pretrained weights loaded')

[08:30:31] Loading pretrained CSG2A weights: ckpts/240131_CSG2A_pretrain_test.pt


In [15]:
# Freeze pretrained CSG2A layers
if args.Freeze_pretrained_params:
    logger('Freezing pretrained CSG2A parameters')
    for param in model.CSG2A.parameters():
        param.requires_grad = False
else:
    logger('Did not freeze CSG2A parameters')

[08:30:31] Freezing pretrained CSG2A parameters


# Train & Evaluate model

In [None]:
# generate ckpts folder
if not os.path.exists('ckpts'):
    os.makedirs('ckpts')

modelf = f'ckpts/{logger.date}_{logger.model_name}.pt'
early_stopper = EarlyStopper(path = modelf, printfunc = logger, patience = args.patience)

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_init
                             ,weight_decay=1e-5)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                        T_max=args.patience,
                                        eta_min=args.lr_final,
                                        last_epoch=-1,
                                        verbose=True)
criterion = nn.MSELoss()

Adjusting learning rate of group 0 to 1.0000e-04.


In [17]:
epoch = 1
while True:
    train_loss = train(model, train_loader, optimizer, criterion, epoch, args, logger)
    if epoch <= args.patience:
        if args.lr_final:
            lr_scheduler.step()
    valid_loss, valid_pcc = eval(model, valid_loader, criterion, args)
    logger(f'train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}, valid_pcc: {valid_pcc:.4f}')
    early_stopper(valid_loss, model)
    if early_stopper.early_stop:
        break
    epoch += 1

[08:30:32] - batch1/8 of epoch1, loss: 18.159637451171875
Adjusting learning rate of group 0 to 9.9446e-05.
[08:30:33] train_loss: 14.6508, valid_loss: 14.9583, valid_pcc: 0.3883
[08:30:33] Validation loss decreased (inf --> 14.9583).  Saving model ...
Adjusting learning rate of group 0 to 9.7798e-05.
[08:30:34] train_loss: 12.8514, valid_loss: 14.4359, valid_pcc: 0.4794
[08:30:34] Validation loss decreased (14.9583 --> 14.4359).  Saving model ...
Adjusting learning rate of group 0 to 9.5095e-05.
[08:30:35] train_loss: 11.9911, valid_loss: 13.5835, valid_pcc: 0.5572
[08:30:35] Validation loss decreased (14.4359 --> 13.5835).  Saving model ...
Adjusting learning rate of group 0 to 9.1406e-05.
[08:30:36] train_loss: 11.1716, valid_loss: 12.8282, valid_pcc: 0.6067
[08:30:36] Validation loss decreased (13.5835 --> 12.8282).  Saving model ...
Adjusting learning rate of group 0 to 8.6820e-05.
[08:30:37] train_loss: 10.6858, valid_loss: 12.1921, valid_pcc: 0.6373
[08:30:37] Validation loss de

In [18]:
# load best model
model.load_state_dict(torch.load(modelf, map_location=args.device))

test_loss, test_pcc = eval(model, test_loader, criterion, args)
logger(f'test_loss: {test_loss:.4f}, test_pcc: {test_pcc:.4f}')

[08:35:50] test_loss: 7.0239, test_pcc: 0.7547
