In [5]:
import argparse

import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn.model_selection import KFold

In [6]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import GModel
from myutils import roc_auc, translate_result
from optimizer import Optimizer
from sampler import Sampler

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
class Args:
    def __init__(self):
        self.device = "cpu"  # cuda:number or cpu
        self.data = "gdsc2"  # Dataset{gdsc or ccle}


args = Args()
res, drug_feature, exprs, mut, cna, null_mask, pos_num = load_data(args)

load gdsc2


In [8]:
epochs = []
true_datas = pd.DataFrame()
predict_datas = pd.DataFrame()
k = 5
kfold = KFold(n_splits=k, shuffle=True, random_state=42)
device = "cpu"

for train_index, test_index in kfold.split(np.arange(pos_num)):
    sampler = Sampler(res, train_index, test_index, null_mask)
    model = GModel(
        adj_mat=sampler.train_data,
        gene=exprs,
        cna=cna,
        mutation=mut,
        sigma=2,
        k=11,
        iterates=3,
        feature_drug=drug_feature,
        n_hid1=192,
        n_hid2=36,
        alpha=5.74,
        device=device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        roc_auc,
        lr=5e-4,
        epochs=1000,
        device=device,
    ).to(device)
    epoch, true_data, predict_data = opt()
    epochs.append(epoch)
    true_datas = pd.concat([true_datas, translate_result(true_data)])
    predict_datas = pd.concat([predict_datas, translate_result(predict_data)])

epoch:   0 loss:0.876219 auc:0.6036
epoch:  20 loss:0.247919 auc:0.9490
epoch:  40 loss:0.183503 auc:0.9530
epoch:  60 loss:0.162410 auc:0.9561
epoch:  80 loss:0.152232 auc:0.9603
epoch: 100 loss:0.142288 auc:0.9654
epoch: 120 loss:0.136250 auc:0.9680
epoch: 140 loss:0.134553 auc:0.9688
epoch: 160 loss:0.131941 auc:0.9695
epoch: 180 loss:0.130492 auc:0.9701
epoch: 200 loss:0.129556 auc:0.9707
epoch: 220 loss:0.127514 auc:0.9713
epoch: 240 loss:0.127389 auc:0.9718
epoch: 260 loss:0.124547 auc:0.9727
epoch: 280 loss:0.136587 auc:0.9729
epoch: 300 loss:0.122504 auc:0.9731
epoch: 320 loss:0.120469 auc:0.9740
epoch: 340 loss:0.118593 auc:0.9744
epoch: 360 loss:0.117195 auc:0.9747
epoch: 380 loss:0.119901 auc:0.9749
epoch: 400 loss:0.115560 auc:0.9752
epoch: 420 loss:0.114478 auc:0.9753
epoch: 440 loss:0.113918 auc:0.9754
epoch: 460 loss:0.112774 auc:0.9761
epoch: 480 loss:0.111996 auc:0.9759
epoch: 500 loss:0.113120 auc:0.9751
epoch: 520 loss:0.111168 auc:0.9761
epoch: 540 loss:0.113101 auc

In [9]:
true_datas.reset_index(drop=True).to_csv("true_gdsc2.csv")
predict_datas.reset_index(drop=True).to_csv("pred_gdsc2.csv")