In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "ctrp"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    # 0,  # Cell
    1  # Drug
]

In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)

    val_labels = sampler.test_data[sampler.test_mask]

    if len(np.unique(val_labels)) < 2:
        print(f"Target {target_index} skipped: Validation set has only one class.")
        return None, None

    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 5  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 1:   0%|                                                         | 0/460 [00:00<?, ?it/s]

epoch:   0 loss:0.700125 auc:0.6650
epoch:   0 loss:0.701826 auc:0.5306
epoch:   0 loss:0.698057 auc:0.5485
epoch:   0 loss:0.699401 auc:0.5259
epoch:   0 loss:0.699864 auc:0.4961
epoch:  20 loss:0.352026 auc:0.5825
epoch:  20 loss:0.353488 auc:0.6670
epoch:  20 loss:0.354796 auc:0.6603
epoch:  20 loss:0.353556 auc:0.7383
epoch:  20 loss:0.352843 auc:0.7519
epoch:  40 loss:0.330609 auc:0.7109
epoch:  40 loss:0.328146 auc:0.5850
epoch:  40 loss:0.329086 auc:0.8906
epoch:  40 loss:0.327884 auc:0.6685
epoch:  40 loss:0.327367 auc:0.7348
epoch:  60 loss:0.307274 auc:0.5750
epoch:  60 loss:0.310535 auc:0.9570
epoch:  60 loss:0.304176 auc:0.6364
epoch:  60 loss:0.304537 auc:0.7469
epoch:  60 loss:0.311164 auc:0.7107
epoch:  80 loss:0.282168 auc:0.9492
epoch:  80 loss:0.282012 auc:0.6325
epoch:  80 loss:0.279761 auc:0.7534
epoch:  80 loss:0.281367 auc:0.6582
epoch: 100 loss:0.262552 auc:0.9258
epoch:  80 loss:0.286870 auc:0.7139
epoch: 100 loss:0.261070 auc:0.7711
epoch: 100 loss:0.262578 auc

Processing dim 1:   2%|▉                                           | 10/460 [28:40<21:30:20, 172.05s/it]

Fit finished.
epoch:   0 loss:0.698007 auc:0.5826
Fit finished.
epoch:   0 loss:0.700253 auc:0.5228
Fit finished.
epoch:   0 loss:0.702347 auc:0.5690
Fit finished.
epoch:   0 loss:0.700224 auc:0.5402
epoch: 960 loss:0.212442 auc:0.6907
epoch:  20 loss:0.351194 auc:0.5969
epoch:  20 loss:0.352918 auc:0.5063
epoch:  20 loss:0.355307 auc:0.4215
epoch:  20 loss:0.355148 auc:0.7016
epoch: 980 loss:0.214205 auc:0.6942
epoch:  40 loss:0.325584 auc:0.4019
epoch:  40 loss:0.328320 auc:0.5427
epoch:  40 loss:0.332843 auc:0.4469
epoch:  40 loss:0.330807 auc:0.7850
Fit finished.
epoch:   0 loss:0.698598 auc:0.4822
epoch:  60 loss:0.301058 auc:0.3793
epoch:  60 loss:0.305320 auc:0.5914
epoch:  60 loss:0.310816 auc:0.4791
epoch:  60 loss:0.309043 auc:0.7315
epoch:  20 loss:0.352489 auc:0.5557
epoch:  80 loss:0.283209 auc:0.6356
epoch:  80 loss:0.281258 auc:0.4126
epoch:  80 loss:0.289426 auc:0.5588
epoch:  80 loss:0.285413 auc:0.7228
epoch:  40 loss:0.331582 auc:0.5412
epoch: 100 loss:0.260632 auc:0

Processing dim 1:   3%|█▍                                          | 15/460 [57:53<30:28:03, 246.48s/it]

Fit finished.
epoch:   0 loss:0.701839 auc:0.5317
epoch: 980 loss:0.212978 auc:0.6428
epoch: 960 loss:0.216967 auc:0.6126
Fit finished.
epoch:   0 loss:0.700716 auc:0.4205
Fit finished.
epoch:   0 loss:0.698572 auc:0.5032
epoch:  20 loss:0.352854 auc:0.6151
Fit finished.
epoch:   0 loss:0.705282 auc:0.5144
epoch: 980 loss:0.215792 auc:0.6138
epoch:  20 loss:0.351720 auc:0.4863
epoch:  20 loss:0.349718 auc:0.5793
epoch:  40 loss:0.327901 auc:0.6156
Fit finished.
epoch:  20 loss:0.353210 auc:0.6417
epoch:   0 loss:0.702539 auc:0.4390
epoch:  40 loss:0.327999 auc:0.5841
epoch:  40 loss:0.324548 auc:0.6338
epoch:  60 loss:0.304902 auc:0.6407
epoch:  60 loss:0.302851 auc:0.5607
epoch:  20 loss:0.352641 auc:0.4498
epoch:  40 loss:0.328681 auc:0.6446
epoch:  60 loss:0.300610 auc:0.6768
epoch:  80 loss:0.280829 auc:0.6157
epoch:  40 loss:0.328887 auc:0.5545
epoch:  60 loss:0.306758 auc:0.6493
epoch:  80 loss:0.276883 auc:0.6705
epoch:  80 loss:0.280914 auc:0.5101
epoch: 100 loss:0.263732 auc:0

Processing dim 1:   4%|█▋                                      | 20/460 [5:49:24<178:25:04, 1459.78s/it]

Fit finished.
epoch:   0 loss:0.702162 auc:0.6678
Fit finished.
epoch: 980 loss:0.215886 auc:0.6056
epoch:   0 loss:0.699643 auc:0.5261
epoch: 980 loss:0.254496 auc:0.6583
epoch: 960 loss:0.335440 auc:0.5085
epoch:  20 loss:0.350058 auc:0.3499
epoch:  20 loss:0.351832 auc:0.4400
Fit finished.
epoch:   0 loss:0.700535 auc:0.4838


Processing dim 1:   5%|██▏                                      | 25/460 [5:50:20<115:31:32, 956.07s/it]

Fit finished.
epoch:   0 loss:0.703238 auc:0.6118
epoch: 980 loss:0.329262 auc:0.5508
epoch:  40 loss:0.324255 auc:0.4537
epoch:  40 loss:0.327756 auc:0.5599
epoch:  20 loss:0.351004 auc:0.4766
Fit finished.
epoch:   0 loss:0.698233 auc:0.4682
epoch:  20 loss:0.354858 auc:0.5110
epoch:  60 loss:0.299142 auc:0.4856
epoch:  60 loss:0.305137 auc:0.6670
epoch:  40 loss:0.325496 auc:0.4974
epoch:  20 loss:0.353770 auc:0.5990
epoch:  80 loss:0.274814 auc:0.4371
epoch:  40 loss:0.328880 auc:0.5281
epoch:  80 loss:0.280385 auc:0.7678
epoch:  40 loss:0.329153 auc:0.6493
epoch:  60 loss:0.302042 auc:0.5266
epoch:  60 loss:0.305867 auc:0.5495
epoch: 100 loss:0.258006 auc:0.4712
epoch: 100 loss:0.262496 auc:0.7950
epoch:  60 loss:0.306526 auc:0.6610
epoch:  80 loss:0.277599 auc:0.5328
epoch: 120 loss:0.246315 auc:0.4687
epoch:  80 loss:0.282268 auc:0.6613
epoch:  80 loss:0.281955 auc:0.5379
epoch: 120 loss:0.249647 auc:0.8307
epoch: 100 loss:0.260391 auc:0.5359
epoch: 100 loss:0.263630 auc:0.6635


Processing dim 1:   7%|██▌                                    | 30/460 [15:18:33<346:35:30, 2901.70s/it]

Fit finished.
epoch:   0 loss:0.699739 auc:0.3625
epoch:  20 loss:0.352869 auc:0.6303
epoch: 960 loss:0.212866 auc:0.5657
epoch:  60 loss:0.306526 auc:0.6224
epoch:  20 loss:0.354727 auc:0.5909
epoch:  20 loss:0.354638 auc:0.7475
epoch: 980 loss:0.216063 auc:0.5573
epoch:  40 loss:0.328757 auc:0.6646
epoch:  80 loss:0.282363 auc:0.6272
epoch:  40 loss:0.330350 auc:0.6715
epoch:  40 loss:0.329587 auc:0.7000
Fit finished.
epoch:   0 loss:0.700788 auc:0.5484
epoch:  60 loss:0.306403 auc:0.6822
epoch: 100 loss:0.262733 auc:0.6330
epoch:  60 loss:0.307566 auc:0.7851
epoch:  60 loss:0.308039 auc:0.7350
epoch:  20 loss:0.353832 auc:0.4134
epoch:  80 loss:0.283601 auc:0.7519
epoch: 120 loss:0.250151 auc:0.6384


KeyboardInterrupt: 

In [None]:
true_data_s.to_csv(f"new_drug_true_{args.data}.csv")
predict_data_s.to_csv(f"new_drug_pred_{args.data}.csv")