In [1]:
import argparse
from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # cuda:number or cpu
        self.data = "nci"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    0,  # Cell
    # 1  # Drug
]

load nci


In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [None]:
n_kfold = 1
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()
for dim in target_dim:
    for target_index in tqdm(np.arange(res.shape[dim])):
        if dim:
            if drug_sum[target_index] < 10:
                continue
        else:
            if cell_sum[target_index] < 10:
                continue
        epochs = []
        for fold in range(n_kfold):
            true_data, predict_data = nihgcn_new(
                cell_exprs=exprs,
                drug_finger=drug_finger,
                res_mat=res,
                null_mask=null_mask,
                target_dim=dim,
                target_index=target_index,
                evaluate_fun=roc_auc,
                args=args,
            )

        true_data_s = pd.concat(
            [true_data_s, translate_result(true_data)], ignore_index=True
        )
        predict_data_s = pd.concat(
            [predict_data_s, translate_result(predict_data)], ignore_index=True
        )

  0%|          | 0/60 [00:00<?, ?it/s]

epoch:   0 loss:0.702924 auc:0.4856
epoch:  20 loss:0.494176 auc:0.5104
epoch:  40 loss:0.404488 auc:0.5828
epoch:  60 loss:0.358112 auc:0.5828
epoch:  80 loss:0.332245 auc:0.5507
epoch: 100 loss:0.316429 auc:0.5771
epoch: 120 loss:0.303643 auc:0.5790
epoch: 140 loss:0.297601 auc:0.5946
epoch: 160 loss:0.291710 auc:0.5800
epoch: 180 loss:0.286714 auc:0.5515
epoch: 200 loss:0.284536 auc:0.5645
epoch: 220 loss:0.281118 auc:0.5598
epoch: 240 loss:0.277615 auc:0.5472
epoch: 260 loss:0.274232 auc:0.5692
epoch: 280 loss:0.277171 auc:0.5468
epoch: 300 loss:0.270490 auc:0.5641
epoch: 320 loss:0.268356 auc:0.5534
epoch: 340 loss:0.269411 auc:0.5462
epoch: 360 loss:0.268504 auc:0.5680
epoch: 380 loss:0.265549 auc:0.5273
epoch: 400 loss:0.261992 auc:0.5344
epoch: 420 loss:0.261335 auc:0.5251
epoch: 440 loss:0.258845 auc:0.5393
epoch: 460 loss:0.257846 auc:0.5319
epoch: 480 loss:0.257777 auc:0.4883
epoch: 500 loss:0.256361 auc:0.5156
epoch: 520 loss:0.257905 auc:0.4865
epoch: 540 loss:0.254992 auc

  2%|▏         | 1/60 [00:29<28:33, 29.03s/it]

Fit finished.
epoch:   0 loss:0.702065 auc:0.5346
epoch:  20 loss:0.494542 auc:0.6007
epoch:  40 loss:0.406006 auc:0.6116
epoch:  60 loss:0.357824 auc:0.5824
epoch:  80 loss:0.332256 auc:0.5863
epoch: 100 loss:0.321612 auc:0.5702
epoch: 120 loss:0.307360 auc:0.6007
epoch: 140 loss:0.296722 auc:0.6166
epoch: 160 loss:0.291787 auc:0.6115
epoch: 180 loss:0.286483 auc:0.6284
epoch: 200 loss:0.281385 auc:0.6380
epoch: 220 loss:0.278594 auc:0.6364
epoch: 240 loss:0.277260 auc:0.6312
epoch: 260 loss:0.276557 auc:0.6341
epoch: 280 loss:0.271583 auc:0.6417
epoch: 300 loss:0.270245 auc:0.6615
epoch: 320 loss:0.267662 auc:0.6402
epoch: 340 loss:0.267345 auc:0.6403
epoch: 360 loss:0.266338 auc:0.6271
epoch: 380 loss:0.263649 auc:0.6336
epoch: 400 loss:0.263488 auc:0.6625
epoch: 420 loss:0.259905 auc:0.6505
epoch: 440 loss:0.259842 auc:0.6664
epoch: 460 loss:0.259497 auc:0.6552
epoch: 480 loss:0.256126 auc:0.6559
epoch: 500 loss:0.257189 auc:0.6110
epoch: 520 loss:0.253669 auc:0.6637
epoch: 540 los

  3%|▎         | 2/60 [00:45<21:12, 21.93s/it]

Fit finished.
epoch:   0 loss:0.699565 auc:0.4961
epoch:  20 loss:0.493130 auc:0.4854
epoch:  40 loss:0.404072 auc:0.4678
epoch:  60 loss:0.358187 auc:0.4809
epoch:  80 loss:0.330035 auc:0.4818
epoch: 100 loss:0.316520 auc:0.4808
epoch: 120 loss:0.305527 auc:0.4904
epoch: 140 loss:0.297007 auc:0.5139
epoch: 160 loss:0.290552 auc:0.5245
epoch: 180 loss:0.286242 auc:0.5297
epoch: 200 loss:0.281844 auc:0.5250
epoch: 220 loss:0.279268 auc:0.5211
epoch: 240 loss:0.277384 auc:0.5237
epoch: 260 loss:0.275130 auc:0.5230
epoch: 280 loss:0.272650 auc:0.5266
epoch: 300 loss:0.272026 auc:0.5247
epoch: 320 loss:0.268768 auc:0.5275
epoch: 340 loss:0.265510 auc:0.5246
epoch: 360 loss:0.264807 auc:0.5132
epoch: 380 loss:0.263533 auc:0.5262
epoch: 400 loss:0.261750 auc:0.5296
epoch: 420 loss:0.261134 auc:0.5211
epoch: 440 loss:0.260563 auc:0.4903
epoch: 460 loss:0.258719 auc:0.5069
epoch: 480 loss:0.257181 auc:0.5161
epoch: 500 loss:0.256754 auc:0.5112
epoch: 520 loss:0.255195 auc:0.4632
epoch: 540 los

  5%|▌         | 3/60 [01:02<18:39, 19.65s/it]

Fit finished.
epoch:   0 loss:0.704495 auc:0.4625
epoch:  20 loss:0.495648 auc:0.6801
epoch:  40 loss:0.403521 auc:0.7492
epoch:  60 loss:0.357355 auc:0.7651
epoch:  80 loss:0.330630 auc:0.7726
epoch: 100 loss:0.314171 auc:0.7751
epoch: 120 loss:0.310693 auc:0.7654
epoch: 140 loss:0.295709 auc:0.7761
epoch: 160 loss:0.291377 auc:0.7695
epoch: 180 loss:0.286653 auc:0.7756
epoch: 200 loss:0.282527 auc:0.7780
epoch: 220 loss:0.282251 auc:0.7677
epoch: 240 loss:0.275805 auc:0.7751
epoch: 260 loss:0.273481 auc:0.7714
epoch: 280 loss:0.272161 auc:0.7731
epoch: 300 loss:0.270099 auc:0.7670
epoch: 320 loss:0.273245 auc:0.7538
epoch: 340 loss:0.267717 auc:0.7622
epoch: 360 loss:0.264760 auc:0.7636
epoch: 380 loss:0.264251 auc:0.7672
epoch: 400 loss:0.262366 auc:0.7630
epoch: 420 loss:0.262922 auc:0.7754
epoch: 440 loss:0.259298 auc:0.7770
epoch: 460 loss:0.257094 auc:0.7651
epoch: 480 loss:0.256142 auc:0.7658
epoch: 500 loss:0.255448 auc:0.7637
epoch: 520 loss:0.255429 auc:0.7596
epoch: 540 los

  7%|▋         | 4/60 [01:19<17:17, 18.53s/it]

Fit finished.
epoch:   0 loss:0.700004 auc:0.4951
epoch:  20 loss:0.492128 auc:0.5757
epoch:  40 loss:0.402097 auc:0.6367
epoch:  60 loss:0.356386 auc:0.6690
epoch:  80 loss:0.330435 auc:0.6719
epoch: 100 loss:0.315265 auc:0.6774
epoch: 120 loss:0.305474 auc:0.6717
epoch: 140 loss:0.298176 auc:0.6755
epoch: 160 loss:0.292972 auc:0.6591
epoch: 180 loss:0.287491 auc:0.6768
epoch: 200 loss:0.282340 auc:0.6896
epoch: 220 loss:0.279889 auc:0.6925
epoch: 240 loss:0.277080 auc:0.6910
epoch: 260 loss:0.274959 auc:0.6996
epoch: 280 loss:0.275444 auc:0.6951
epoch: 300 loss:0.274214 auc:0.7115
epoch: 320 loss:0.271232 auc:0.7114
epoch: 340 loss:0.268449 auc:0.7222
epoch: 360 loss:0.265845 auc:0.7261
epoch: 380 loss:0.262936 auc:0.7276
epoch: 400 loss:0.262570 auc:0.7268
epoch: 420 loss:0.262254 auc:0.7166
epoch: 440 loss:0.260689 auc:0.7203
epoch: 460 loss:0.258205 auc:0.7292
epoch: 480 loss:0.259003 auc:0.7187
epoch: 500 loss:0.256625 auc:0.7178
epoch: 520 loss:0.254517 auc:0.7153


In [None]:
true_data_s.to_csv(f"new_cell_true_{args.data}.csv")
predict_data_s.to_csv(f"new_cell_pred_{args.data}.csv")