In [1]:
import argparse

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import RandomSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = "cuda:0"  # cuda:number or cpu
        self.data = "gdsc1"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)

load gdsc1


In [7]:
k = 5
kfold = KFold(n_splits=k, shuffle=True, random_state=42)

true_datas = pd.DataFrame()
predict_datas = pd.DataFrame()

for train_index, test_index in kfold.split(np.arange(pos_num)):
    sampler = RandomSampler(res, train_index, test_index, null_mask)
    model = nihgcn(
        adj_mat=sampler.train_data,
        cell_exprs=exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    ).to(args.device)
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        roc_auc,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    ).to(args.device)
    true_data, predict_data = opt()
    true_datas = pd.concat([true_datas, translate_result(true_data)], ignore_index=True)
    predict_datas = pd.concat(
        [predict_datas, translate_result(predict_data)], ignore_index=True
    )

epoch:   0 loss:0.691767 auc:0.5307
epoch:  20 loss:0.178715 auc:0.9590
epoch:  40 loss:0.162143 auc:0.9657
epoch:  60 loss:0.150074 auc:0.9695
epoch:  80 loss:0.140111 auc:0.9720
epoch: 100 loss:0.130622 auc:0.9740
epoch: 120 loss:0.121743 auc:0.9754
epoch: 140 loss:0.116201 auc:0.9754
epoch: 160 loss:0.110880 auc:0.9764
epoch: 180 loss:0.107448 auc:0.9768
epoch: 200 loss:0.106403 auc:0.9764
epoch: 220 loss:0.103569 auc:0.9771
epoch: 240 loss:0.102607 auc:0.9770
epoch: 260 loss:0.102776 auc:0.9764
epoch: 280 loss:0.100249 auc:0.9769
epoch: 300 loss:0.099409 auc:0.9771
epoch: 320 loss:0.099675 auc:0.9769
epoch: 340 loss:0.098431 auc:0.9770
epoch: 360 loss:0.098401 auc:0.9770
epoch: 380 loss:0.097726 auc:0.9770
epoch: 400 loss:0.098069 auc:0.9770
epoch: 420 loss:0.097509 auc:0.9769
epoch: 440 loss:0.096910 auc:0.9770
epoch: 460 loss:0.098060 auc:0.9766
epoch: 480 loss:0.096675 auc:0.9769
epoch: 500 loss:0.097106 auc:0.9768
epoch: 520 loss:0.096412 auc:0.9769
epoch: 540 loss:0.098140 auc

In [8]:
true_datas.to_csv("true_gdsc1.csv")
predict_datas.to_csv("pred_gdsc1.csv")