In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "nci"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    # 0,  # Cell
    1  # Drug
]

load nci


In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)
    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 5  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 1:   0%|                                                        | 0/1005 [00:00<?, ?it/s]

epoch:   0 loss:0.700431 auc:0.5034
epoch:   0 loss:0.699659 auc:0.5074
epoch:   0 loss:0.702862 auc:0.5325
epoch:   0 loss:0.702356 auc:0.6352
epoch:   0 loss:0.700747 auc:0.4793
epoch:  20 loss:0.496713 auc:0.7474
epoch:  20 loss:0.495762 auc:0.6914
epoch:  20 loss:0.493439 auc:0.7274
epoch:  20 loss:0.499161 auc:0.8210
epoch:  20 loss:0.499788 auc:0.8047
epoch:  40 loss:0.406495 auc:0.7817
epoch:  40 loss:0.405870 auc:0.7452
epoch:  40 loss:0.408184 auc:0.7352
epoch:  40 loss:0.402569 auc:0.7998
epoch:  40 loss:0.407719 auc:0.8126
epoch:  60 loss:0.356062 auc:0.8066
epoch:  60 loss:0.360415 auc:0.8183
epoch:  60 loss:0.362171 auc:0.7488
epoch:  60 loss:0.361891 auc:0.7515
epoch:  60 loss:0.361250 auc:0.7189
epoch:  80 loss:0.332291 auc:0.7201
epoch:  80 loss:0.334286 auc:0.7886
epoch:  80 loss:0.331467 auc:0.7817
epoch:  80 loss:0.338685 auc:0.7101
epoch:  80 loss:0.334664 auc:0.7115
epoch: 100 loss:0.317021 auc:0.7165
epoch: 100 loss:0.316727 auc:0.8023
epoch: 100 loss:0.317534 auc

Processing dim 1:   1%|▍                                          | 10/1005 [18:06<30:00:58, 108.60s/it]

Fit finished.
epoch: 980 loss:0.241333 auc:0.5089
epoch:   0 loss:0.699013 auc:0.3893
Fit finished.
epoch:   0 loss:0.702149 auc:0.3762
Fit finished.
Fit finished.
epoch:   0 loss:0.702765 auc:0.4082
epoch:   0 loss:0.701891 auc:0.6133
Fit finished.
epoch:   0 loss:0.700711 auc:0.4972
epoch:  20 loss:0.496643 auc:0.4720
epoch:  20 loss:0.489977 auc:0.8469
epoch:  20 loss:0.489109 auc:0.5556
epoch:  20 loss:0.497614 auc:0.7197
epoch:  20 loss:0.493986 auc:0.8715
epoch:  40 loss:0.405228 auc:0.4953
epoch:  40 loss:0.399464 auc:0.4386
epoch:  40 loss:0.401455 auc:0.5167
epoch:  40 loss:0.406684 auc:0.3003
epoch:  40 loss:0.402096 auc:0.6975
epoch:  60 loss:0.358532 auc:0.5256
epoch:  60 loss:0.354703 auc:0.4556
epoch:  60 loss:0.357064 auc:0.5056
epoch:  60 loss:0.359330 auc:0.2770
epoch:  60 loss:0.356283 auc:0.7297
epoch:  80 loss:0.332499 auc:0.5501
epoch:  80 loss:0.330204 auc:0.5104
epoch:  80 loss:0.330058 auc:0.4789
epoch:  80 loss:0.336371 auc:0.2992
epoch:  80 loss:0.337277 auc:0

Processing dim 1:   1%|▋                                          | 15/1005 [40:02<47:35:11, 173.04s/it]

Fit finished.
epoch:   0 loss:0.699396 auc:0.6780
epoch: 980 loss:0.243332 auc:0.8147
Fit finished.
epoch:   0 loss:0.702708 auc:0.3025
Fit finished.
Fit finished.
epoch:   0 loss:0.699598 auc:0.6580
epoch:   0 loss:0.699544 auc:0.5393
Fit finished.
epoch:  20 loss:0.500498 auc:0.9297
epoch:   0 loss:0.702056 auc:0.5565
epoch:  20 loss:0.488066 auc:0.9527
epoch:  20 loss:0.491881 auc:0.9913
epoch:  20 loss:0.492624 auc:0.7169
epoch:  40 loss:0.409456 auc:0.9569
epoch:  20 loss:0.497412 auc:0.6159
epoch:  40 loss:0.400890 auc:0.9641
epoch:  40 loss:0.402561 auc:0.9878
epoch:  40 loss:0.401404 auc:0.6529
epoch:  60 loss:0.364105 auc:0.9388
epoch:  40 loss:0.406764 auc:0.5755
epoch:  60 loss:0.354966 auc:0.9603
epoch:  60 loss:0.356987 auc:0.9878
epoch:  60 loss:0.362724 auc:0.6281
epoch:  80 loss:0.338245 auc:0.9297
epoch:  60 loss:0.365251 auc:0.5886
epoch:  80 loss:0.329658 auc:0.9565
epoch:  80 loss:0.333861 auc:0.9688
epoch:  80 loss:0.328947 auc:0.6529
epoch: 100 loss:0.320505 auc:0

Processing dim 1:   2%|▊                                        | 20/1005 [1:00:49<55:08:51, 201.55s/it]

epoch: 980 loss:0.240925 auc:0.3722
Fit finished.
epoch:   0 loss:0.700055 auc:0.4605
Fit finished.
epoch:   0 loss:0.703550 auc:0.4688
Fit finished.
Fit finished.
epoch:   0 loss:0.698025 auc:0.4600
epoch:   0 loss:0.697468 auc:0.4059
Fit finished.
epoch:   0 loss:0.699852 auc:0.5975
epoch:  20 loss:0.492204 auc:0.6963
epoch:  20 loss:0.492373 auc:0.7146
epoch:  20 loss:0.491528 auc:0.6700
epoch:  20 loss:0.490087 auc:0.5155
epoch:  20 loss:0.494402 auc:0.6500
epoch:  40 loss:0.401455 auc:0.6007
epoch:  40 loss:0.402194 auc:0.6767
epoch:  40 loss:0.400411 auc:0.6900
epoch:  40 loss:0.400722 auc:0.5041
epoch:  40 loss:0.403234 auc:0.4550
epoch:  60 loss:0.359864 auc:0.5984
epoch:  60 loss:0.362677 auc:0.6730
epoch:  60 loss:0.360468 auc:0.6600
epoch:  60 loss:0.355519 auc:0.5025
epoch:  60 loss:0.358142 auc:0.5325
epoch:  80 loss:0.332000 auc:0.6484
epoch:  80 loss:0.331803 auc:0.5806
epoch:  80 loss:0.331030 auc:0.6800
epoch:  80 loss:0.331321 auc:0.4894
epoch:  80 loss:0.331954 auc:0

Processing dim 1:   2%|█                                        | 25/1005 [1:20:25<58:03:14, 213.26s/it]

Fit finished.
epoch:   0 loss:0.699125 auc:0.5148
Fit finished.
epoch:   0 loss:0.699217 auc:0.7387
Fit finished.
epoch:   0 loss:0.700735 auc:0.3984
Fit finished.
epoch:   0 loss:0.702572 auc:0.4333


Processing dim 1:   3%|█▏                                       | 30/1005 [1:20:50<39:01:41, 144.10s/it]

Fit finished.
epoch:   0 loss:0.703231 auc:0.5288
epoch:  20 loss:0.490664 auc:0.8950
epoch:  20 loss:0.493902 auc:0.8596
epoch:  20 loss:0.498933 auc:0.8272
epoch:  20 loss:0.491537 auc:0.5592
epoch:  20 loss:0.499001 auc:0.8646
epoch:  40 loss:0.400646 auc:0.6879
epoch:  40 loss:0.405312 auc:0.8523
epoch:  40 loss:0.407606 auc:0.8048
epoch:  40 loss:0.402075 auc:0.2489
epoch:  40 loss:0.405271 auc:0.8200
epoch:  60 loss:0.357037 auc:0.6686
epoch:  60 loss:0.365881 auc:0.8193
epoch:  60 loss:0.360937 auc:0.8112
epoch:  60 loss:0.357131 auc:0.3013
epoch:  60 loss:0.360183 auc:0.7570
epoch:  80 loss:0.331674 auc:0.6479
epoch:  80 loss:0.332882 auc:0.7863
epoch:  80 loss:0.333877 auc:0.8016
epoch:  80 loss:0.331291 auc:0.3703
epoch:  80 loss:0.332393 auc:0.6401
epoch: 100 loss:0.317389 auc:0.6169
epoch: 100 loss:0.315191 auc:0.8266
epoch: 100 loss:0.316412 auc:0.7968
epoch: 100 loss:0.315066 auc:0.3628
epoch: 100 loss:0.315816 auc:0.6679
epoch: 120 loss:0.303662 auc:0.6405
epoch: 120 los

Processing dim 1:   3%|█▎                                    | 35/1005 [11:14:54<644:07:20, 2390.56s/it]

Fit finished.
epoch:   0 loss:0.703036 auc:0.4707
epoch:  20 loss:0.499020 auc:0.8848
Fit finished.
epoch:   0 loss:0.699517 auc:0.5019
epoch:  20 loss:0.497301 auc:0.6389
epoch:  20 loss:0.496389 auc:0.6777
epoch:  20 loss:0.491347 auc:0.6238
epoch:  40 loss:0.409806 auc:0.8464
epoch:  20 loss:0.488487 auc:0.6983
epoch:  40 loss:0.406035 auc:0.6875
epoch:  40 loss:0.405737 auc:0.5289
epoch:  40 loss:0.403925 auc:0.6597
epoch:  60 loss:0.361167 auc:0.7776
epoch:  40 loss:0.398197 auc:0.6855
epoch:  60 loss:0.358002 auc:0.6944
epoch:  60 loss:0.357246 auc:0.7467
epoch:  60 loss:0.362005 auc:0.5041
epoch:  80 loss:0.333961 auc:0.7776
epoch:  60 loss:0.352809 auc:0.6611
epoch:  80 loss:0.336659 auc:0.6875
epoch:  80 loss:0.335460 auc:0.4959
epoch:  80 loss:0.332853 auc:0.7391
epoch: 100 loss:0.318900 auc:0.8304
epoch:  80 loss:0.327926 auc:0.6701
epoch: 100 loss:0.318324 auc:0.6806
epoch: 100 loss:0.319686 auc:0.5537
epoch: 100 loss:0.314745 auc:0.7240
epoch: 120 loss:0.307763 auc:0.7504


KeyboardInterrupt: 

In [None]:
true_data_s.to_csv(f"new_drug_true_{args.data}.csv")
predict_data_s.to_csv(f"new_drug_pred_{args.data}.csv")