In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [10]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "nci"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 10  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    0,  # Cell
    # 1  # Drug
]

load nci


In [5]:
drug_sum

array([38., 35., 34., ..., 23., 24., 25.], dtype=float32)

In [6]:
cell_sum

array([456., 593., 453., 532., 323., 567., 618., 499., 383., 172., 239.,
       654., 345., 641., 452., 586., 412., 528., 432., 434., 499., 704.,
       497., 408., 715., 258., 466.,  82., 620., 252., 355., 445., 241.,
       613., 576., 318., 221., 204., 360., 356., 444., 592., 375., 586.,
       615., 297., 223., 539., 389., 486., 280., 666., 701., 503., 451.,
       328., 525., 252., 622., 554.], dtype=float32)

In [11]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)
    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [12]:
n_kfold = 1
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()
for dim in target_dim:
    for seed, target_index in enumerate(tqdm(np.arange(res.shape[dim]))):
        if dim:
            if drug_sum[target_index] < 10:
                continue
        else:
            if cell_sum[target_index] < 10:
                continue
        epochs = []
        for fold in range(n_kfold):
            true_data, predict_data = nihgcn_new(
                cell_exprs=exprs,
                drug_finger=drug_finger,
                res_mat=res,
                null_mask=null_mask,
                target_dim=dim,
                target_index=target_index,
                evaluate_fun=roc_auc,
                args=args,
                seed=seed
            )

        true_data_s = pd.concat(
            [true_data_s, translate_result(true_data)], ignore_index=True
        )
        predict_data_s = pd.concat(
            [predict_data_s, translate_result(predict_data)], ignore_index=True
        )

  2%|▏         | 1/60 [00:00<00:19,  3.09it/s]

epoch:   0 loss:0.699949 auc:0.4734
Fit finished.


  3%|▎         | 2/60 [00:00<00:18,  3.11it/s]

epoch:   0 loss:0.702829 auc:0.5612
Fit finished.


  5%|▌         | 3/60 [00:01<00:21,  2.61it/s]

epoch:   0 loss:0.701169 auc:0.5483
Fit finished.


  7%|▋         | 4/60 [00:01<00:22,  2.48it/s]

epoch:   0 loss:0.703291 auc:0.4655
Fit finished.


  8%|▊         | 5/60 [00:02<00:24,  2.27it/s]

epoch:   0 loss:0.703208 auc:0.4805
Fit finished.


 10%|█         | 6/60 [00:02<00:22,  2.45it/s]

epoch:   0 loss:0.698892 auc:0.5115
Fit finished.


 12%|█▏        | 7/60 [00:02<00:20,  2.65it/s]

epoch:   0 loss:0.702504 auc:0.5509
Fit finished.


 13%|█▎        | 8/60 [00:03<00:18,  2.79it/s]

epoch:   0 loss:0.700220 auc:0.4700
Fit finished.


 15%|█▌        | 9/60 [00:03<00:17,  2.89it/s]

epoch:   0 loss:0.698411 auc:0.5536
Fit finished.


 17%|█▋        | 10/60 [00:03<00:16,  2.95it/s]

epoch:   0 loss:0.700790 auc:0.4765
Fit finished.


 18%|█▊        | 11/60 [00:03<00:16,  2.99it/s]

epoch:   0 loss:0.699751 auc:0.4431
Fit finished.


 20%|██        | 12/60 [00:04<00:17,  2.80it/s]

epoch:   0 loss:0.703000 auc:0.5206
Fit finished.


 22%|██▏       | 13/60 [00:04<00:16,  2.90it/s]

epoch:   0 loss:0.699796 auc:0.5476
Fit finished.


 23%|██▎       | 14/60 [00:05<00:17,  2.63it/s]

epoch:   0 loss:0.699736 auc:0.4758
Fit finished.


 25%|██▌       | 15/60 [00:05<00:16,  2.66it/s]

epoch:   0 loss:0.701971 auc:0.5292
Fit finished.
epoch:   0 loss:0.700045 auc:0.5055
Fit finished.


 28%|██▊       | 17/60 [00:06<00:16,  2.55it/s]

epoch:   0 loss:0.703530 auc:0.5306
Fit finished.


 30%|███       | 18/60 [00:06<00:15,  2.71it/s]

epoch:   0 loss:0.698117 auc:0.5409
Fit finished.


 32%|███▏      | 19/60 [00:07<00:14,  2.83it/s]

epoch:   0 loss:0.697627 auc:0.5097
Fit finished.


 33%|███▎      | 20/60 [00:07<00:13,  2.92it/s]

epoch:   0 loss:0.700100 auc:0.5449
Fit finished.


 35%|███▌      | 21/60 [00:07<00:13,  2.98it/s]

epoch:   0 loss:0.699539 auc:0.5231
Fit finished.


 37%|███▋      | 22/60 [00:07<00:12,  3.05it/s]

epoch:   0 loss:0.698798 auc:0.4893
Fit finished.


 38%|███▊      | 23/60 [00:08<00:11,  3.11it/s]

epoch:   0 loss:0.699165 auc:0.5288
Fit finished.


 40%|████      | 24/60 [00:08<00:11,  3.12it/s]

epoch:   0 loss:0.700981 auc:0.5563
Fit finished.


 42%|████▏     | 25/60 [00:08<00:11,  3.09it/s]

epoch:   0 loss:0.702652 auc:0.5383
Fit finished.


 43%|████▎     | 26/60 [00:09<00:12,  2.75it/s]

epoch:   0 loss:0.703327 auc:0.4973
Fit finished.


 45%|████▌     | 27/60 [00:09<00:14,  2.31it/s]

epoch:   0 loss:0.699150 auc:0.5338
Fit finished.


 47%|████▋     | 28/60 [00:10<00:12,  2.50it/s]

epoch:   0 loss:0.703100 auc:0.4297
Fit finished.


 48%|████▊     | 29/60 [00:10<00:11,  2.67it/s]

epoch:   0 loss:0.702883 auc:0.5329
Fit finished.


 50%|█████     | 30/60 [00:10<00:10,  2.80it/s]

epoch:   0 loss:0.701304 auc:0.5276
Fit finished.


 52%|█████▏    | 31/60 [00:11<00:10,  2.77it/s]

epoch:   0 loss:0.702848 auc:0.4704
Fit finished.


 53%|█████▎    | 32/60 [00:11<00:09,  2.87it/s]

epoch:   0 loss:0.699564 auc:0.5174
Fit finished.


 55%|█████▌    | 33/60 [00:11<00:09,  2.96it/s]

epoch:   0 loss:0.699627 auc:0.5313
Fit finished.


 57%|█████▋    | 34/60 [00:12<00:08,  3.02it/s]

epoch:   0 loss:0.702314 auc:0.5045
Fit finished.


 58%|█████▊    | 35/60 [00:12<00:09,  2.69it/s]

epoch:   0 loss:0.698395 auc:0.4821
Fit finished.


 60%|██████    | 36/60 [00:13<00:08,  2.82it/s]

epoch:   0 loss:0.704409 auc:0.5613
Fit finished.
epoch:   0 loss:0.698394 auc:0.4270


 62%|██████▏   | 37/60 [00:13<00:09,  2.51it/s]

Fit finished.


 63%|██████▎   | 38/60 [00:13<00:08,  2.69it/s]

epoch:   0 loss:0.700943 auc:0.4447
Fit finished.


 65%|██████▌   | 39/60 [00:14<00:07,  2.82it/s]

epoch:   0 loss:0.702217 auc:0.4761
Fit finished.


 67%|██████▋   | 40/60 [00:14<00:06,  2.91it/s]

epoch:   0 loss:0.699293 auc:0.5071
Fit finished.


 68%|██████▊   | 41/60 [00:14<00:06,  2.98it/s]

epoch:   0 loss:0.699145 auc:0.5084
Fit finished.


 70%|███████   | 42/60 [00:15<00:06,  2.72it/s]

epoch:   0 loss:0.703877 auc:0.5335
Fit finished.


 72%|███████▏  | 43/60 [00:15<00:06,  2.83it/s]

epoch:   0 loss:0.703917 auc:0.5832
Fit finished.


 73%|███████▎  | 44/60 [00:15<00:05,  2.81it/s]

epoch:   0 loss:0.702709 auc:0.4976
Fit finished.


 75%|███████▌  | 45/60 [00:16<00:05,  2.74it/s]

epoch:   0 loss:0.705776 auc:0.5451
Fit finished.


 77%|███████▋  | 46/60 [00:16<00:04,  2.82it/s]

epoch:   0 loss:0.700485 auc:0.5373
Fit finished.


 78%|███████▊  | 47/60 [00:17<00:04,  2.66it/s]

epoch:   0 loss:0.702660 auc:0.5587
Fit finished.


 80%|████████  | 48/60 [00:17<00:04,  2.43it/s]

epoch:   0 loss:0.701861 auc:0.5130
Fit finished.


 82%|████████▏ | 49/60 [00:17<00:04,  2.36it/s]

epoch:   0 loss:0.702447 auc:0.5387
Fit finished.


 83%|████████▎ | 50/60 [00:18<00:03,  2.66it/s]

epoch:   0 loss:0.699252 auc:0.4880
Fit finished.


 85%|████████▌ | 51/60 [00:18<00:02,  3.07it/s]

epoch:   0 loss:0.700002 auc:0.4847
Fit finished.


 87%|████████▋ | 52/60 [00:18<00:02,  3.44it/s]

epoch:   0 loss:0.701480 auc:0.5265
Fit finished.


 88%|████████▊ | 53/60 [00:18<00:01,  3.76it/s]

epoch:   0 loss:0.700112 auc:0.4654
Fit finished.


 90%|█████████ | 54/60 [00:19<00:01,  4.01it/s]

epoch:   0 loss:0.704135 auc:0.4959
Fit finished.


 92%|█████████▏| 55/60 [00:19<00:01,  4.19it/s]

epoch:   0 loss:0.701705 auc:0.4952
Fit finished.


 93%|█████████▎| 56/60 [00:19<00:00,  4.29it/s]

epoch:   0 loss:0.699987 auc:0.5741
Fit finished.


 95%|█████████▌| 57/60 [00:19<00:00,  4.43it/s]

epoch:   0 loss:0.699343 auc:0.4964
Fit finished.


 97%|█████████▋| 58/60 [00:19<00:00,  4.50it/s]

epoch:   0 loss:0.702465 auc:0.4839
Fit finished.


 98%|█████████▊| 59/60 [00:20<00:00,  4.45it/s]

epoch:   0 loss:0.703517 auc:0.4030
Fit finished.


100%|██████████| 60/60 [00:20<00:00,  2.93it/s]

epoch:   0 loss:0.701695 auc:0.5129
Fit finished.





In [13]:
true_data_s.to_csv(f"new_cell_true_{args.data}.csv")
predict_data_s.to_csv(f"new_cell_pred_{args.data}.csv")

In [None]:
true_data_s