In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "ctrp"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    0,  # Cell
    # 1  # Drug
]

In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)

    val_labels = sampler.test_data[sampler.test_mask]

    if len(np.unique(val_labels)) < 2:
        print(f"Target {target_index} skipped: Validation set has only one class.")
        return None, None

    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 5  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 0:   0%|                                                         | 0/823 [00:00<?, ?it/s]

epoch:   0 loss:0.700400 auc:0.5486
epoch:   0 loss:0.700169 auc:0.4396
epoch:   0 loss:0.699614 auc:0.4705
epoch:   0 loss:0.701952 auc:0.4683
epoch:   0 loss:0.701382 auc:0.4739
epoch:  20 loss:0.351870 auc:0.9411
epoch:  20 loss:0.353710 auc:0.9096
epoch:  20 loss:0.351859 auc:0.9426
epoch:  20 loss:0.352427 auc:0.9389
epoch:  20 loss:0.353012 auc:0.8886
epoch:  40 loss:0.327984 auc:0.9286
epoch:  40 loss:0.326826 auc:0.9449
epoch:  40 loss:0.327702 auc:0.9410
epoch:  40 loss:0.329016 auc:0.8982
epoch:  40 loss:0.329561 auc:0.9137
epoch:  60 loss:0.303458 auc:0.9450
epoch:  60 loss:0.304093 auc:0.9451
epoch:  60 loss:0.306829 auc:0.9045
epoch:  60 loss:0.306551 auc:0.9352
epoch:  60 loss:0.307292 auc:0.9133
epoch:  80 loss:0.279289 auc:0.9510
epoch:  80 loss:0.279415 auc:0.9489
epoch:  80 loss:0.282388 auc:0.9091
epoch:  80 loss:0.280566 auc:0.9366
epoch:  80 loss:0.284722 auc:0.9094
epoch: 100 loss:0.260573 auc:0.9512
epoch: 100 loss:0.260552 auc:0.9540
epoch: 100 loss:0.262513 auc

Processing dim 0:   1%|▌                                           | 10/823 [27:34<37:22:07, 165.47s/it]

Fit finished.
epoch:   0 loss:0.697994 auc:0.4534
epoch: 980 loss:0.331591 auc:0.9386
epoch: 980 loss:0.212419 auc:0.9138
epoch: 960 loss:0.214862 auc:0.9408
epoch: 940 loss:0.212805 auc:0.9014
epoch:  20 loss:0.350272 auc:0.9403
Fit finished.
Fit finished.
epoch:   0 loss:0.700274 auc:0.4805
epoch:   0 loss:0.702258 auc:0.4984
epoch: 980 loss:0.213058 auc:0.9401
epoch: 960 loss:0.212582 auc:0.9014
epoch:  40 loss:0.325219 auc:0.9500
epoch:  20 loss:0.353828 auc:0.9191
epoch:  20 loss:0.355460 auc:0.9396
Fit finished.
epoch:   0 loss:0.700235 auc:0.4694
epoch: 980 loss:0.213871 auc:0.8999
epoch:  60 loss:0.302530 auc:0.9524
epoch:  40 loss:0.329070 auc:0.9261
epoch:  40 loss:0.331974 auc:0.9409
epoch:  20 loss:0.355041 auc:0.9562
Fit finished.
epoch:   0 loss:0.698635 auc:0.4555
epoch:  80 loss:0.276938 auc:0.9532
epoch:  60 loss:0.305862 auc:0.9269
epoch:  60 loss:0.310889 auc:0.9428
epoch:  40 loss:0.330995 auc:0.9619
epoch:  20 loss:0.353212 auc:0.9502
epoch: 100 loss:0.260940 auc:0

Processing dim 0:   2%|▊                                           | 15/823 [57:08<54:50:44, 244.36s/it]

Fit finished.
epoch:   0 loss:0.701811 auc:0.4729
epoch: 980 loss:0.215270 auc:0.9673
epoch: 980 loss:0.212832 auc:0.9274
Fit finished.
epoch:   0 loss:0.700705 auc:0.5363
epoch: 940 loss:0.237710 auc:0.8496
epoch:  20 loss:0.353085 auc:0.9332
Fit finished.
epoch:   0 loss:0.698575 auc:0.4909
Fit finished.
epoch:   0 loss:0.705277 auc:0.5339
epoch:  20 loss:0.351817 auc:0.8939
epoch: 960 loss:0.235486 auc:0.8005
epoch:  40 loss:0.327899 auc:0.9316
epoch:  20 loss:0.350552 auc:0.9530
epoch:  40 loss:0.328358 auc:0.8953
epoch:  20 loss:0.353988 auc:0.9226
epoch: 980 loss:0.232282 auc:0.8407
epoch:  60 loss:0.305217 auc:0.9320
epoch:  60 loss:0.302939 auc:0.8918
epoch:  40 loss:0.325177 auc:0.9573
epoch:  40 loss:0.329296 auc:0.9297
Fit finished.
epoch:   0 loss:0.702583 auc:0.5131
epoch:  80 loss:0.281211 auc:0.9326
epoch:  80 loss:0.279060 auc:0.8818
epoch:  60 loss:0.301843 auc:0.9588
epoch:  60 loss:0.307980 auc:0.9282
epoch: 100 loss:0.266880 auc:0.9262
epoch:  20 loss:0.352764 auc:0

Processing dim 0:   2%|▉                                       | 20/823 [5:48:13<324:53:35, 1456.56s/it]

Fit finished.
epoch:   0 loss:0.701981 auc:0.4889
epoch: 920 loss:0.370149 auc:0.9179
epoch: 980 loss:0.228737 auc:0.8663
Fit finished.
epoch: 980 loss:0.212634 auc:0.9161
epoch:   0 loss:0.699671 auc:0.5397
epoch:  20 loss:0.349586 auc:0.8812
epoch: 940 loss:0.355446 auc:0.9239
Fit finished.
epoch:  20 loss:0.351389 auc:0.9006
epoch:   0 loss:0.700496 auc:0.4492
Fit finished.
epoch:   0 loss:0.702849 auc:0.4571
epoch:  40 loss:0.323684 auc:0.8794
epoch: 960 loss:0.348702 auc:0.9240
epoch:  40 loss:0.327078 auc:0.9005
epoch:  20 loss:0.351868 auc:0.8382
epoch:  60 loss:0.298905 auc:0.8812
epoch:  20 loss:0.354049 auc:0.9658
epoch: 980 loss:0.342839 auc:0.9258
epoch:  40 loss:0.328428 auc:0.9728
epoch:  40 loss:0.326106 auc:0.8509
epoch:  80 loss:0.274673 auc:0.8842
epoch:  60 loss:0.304755 auc:0.8982
Fit finished.
epoch:   0 loss:0.703190 auc:0.5905
epoch:  80 loss:0.279818 auc:0.9047
epoch:  60 loss:0.304692 auc:0.9718
epoch: 100 loss:0.258008 auc:0.8733
epoch:  60 loss:0.302179 auc:0

Processing dim 0:   3%|█▏                                     | 25/823 [15:15:42<735:27:42, 3317.87s/it]

Fit finished.
epoch: 940 loss:0.212618 auc:0.8719
epoch:   0 loss:0.703344 auc:0.5797
epoch: 920 loss:0.213384 auc:0.9599
epoch: 960 loss:0.212442 auc:0.8368
epoch: 920 loss:0.213474 auc:0.8983
epoch: 960 loss:0.212407 auc:0.8706
epoch:  20 loss:0.354110 auc:0.9267
epoch: 980 loss:0.213883 auc:0.8239
epoch: 940 loss:0.212760 auc:0.9604
epoch: 940 loss:0.212654 auc:0.8984
epoch: 980 loss:0.217522 auc:0.8787
epoch:  40 loss:0.328125 auc:0.9248
Fit finished.
epoch:   0 loss:0.698148 auc:0.5309
epoch: 960 loss:0.215583 auc:0.9627
epoch: 960 loss:0.212432 auc:0.8978
Fit finished.
epoch:   0 loss:0.698732 auc:0.5769
epoch:  60 loss:0.305380 auc:0.9254
epoch:  20 loss:0.353742 auc:0.9130
epoch: 980 loss:0.213203 auc:0.9639
epoch: 980 loss:0.214913 auc:0.8887
epoch:  40 loss:0.329058 auc:0.9209
epoch:  80 loss:0.281107 auc:0.9290
epoch:  20 loss:0.352438 auc:0.9492
Fit finished.
epoch:   0 loss:0.699414 auc:0.4794
epoch:  60 loss:0.306899 auc:0.9220
epoch: 100 loss:0.263452 auc:0.9265
Fit fini

KeyboardInterrupt: 

In [None]:
true_data_s.to_csv(f"new_cell_true_{args.data}.csv")
predict_data_s.to_csv(f"new_cell_pred_{args.data}.csv")