In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "gdsc1"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    # 0,  # Cell
    1  # Drug
]

load gdsc1


In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)

    val_labels = sampler.test_data[sampler.test_mask]

    if len(np.unique(val_labels)) < 2:
        print(f"Target {target_index} skipped: Validation set has only one class.")
        return None, None

    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 50  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 1:   0%|                                                         | 0/331 [00:00<?, ?it/s]

epoch:   0 loss:0.705151 auc:0.5048
epoch:   0 loss:0.706738 auc:0.6139
epoch:   0 loss:0.700720 auc:0.5860
epoch:   0 loss:0.700732 auc:0.4976
epoch:   0 loss:0.699030 auc:0.4847
epoch:   0 loss:0.697764 auc:0.5367
epoch:   0 loss:0.708055 auc:0.4216
epoch:   0 loss:0.700076 auc:0.5199
epoch:   0 loss:0.706589 auc:0.5703
epoch:   0 loss:0.703264 auc:0.5105
epoch:   0 loss:0.702710 auc:0.5261
epoch:   0 loss:0.702192 auc:0.5518
epoch:   0 loss:0.694649 auc:0.4563
epoch:   0 loss:0.702028 auc:0.7188
epoch:   0 loss:0.699541 auc:0.4418
epoch:   0 loss:0.699551 auc:0.4857
epoch:   0 loss:0.694415 auc:0.4373
epoch:   0 loss:0.692943 auc:0.4514
epoch:   0 loss:0.703838 auc:0.5818
epoch:   0 loss:0.707023 auc:0.5407
epoch:   0 loss:0.693480 auc:0.4261
epoch:   0 loss:0.706946 auc:0.5446
epoch:   0 loss:0.710978 auc:0.5428
epoch:   0 loss:0.697394 auc:0.5134
epoch:   0 loss:0.706673 auc:0.5207
epoch:   0 loss:0.700500 auc:0.4619
epoch:   0 loss:0.700831 auc:0.5207
epoch:   0 loss:0.697524 auc

Processing dim 1:  45%|███████████████████▍                       | 150/331 [1:37:30<1:57:39, 39.00s/it]

Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.704467 auc:0.3793
Fit finished.
epoch:   0 loss:0.697340 auc:0.5550
epoch:   0 loss:0.698616 auc:0.4653
epoch:   0 loss:0.688838 auc:0.6070
Fit finished.
Target 108 skipped: Validation set has only one class.
epoch:   0 loss:0.705563 auc:0.4587
epoch:   0 loss:0.700725 auc:0.5717
epoch:   0 loss:0.691911 auc:0.3414
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.701198 auc:0.4760
epoch:   0 loss:0.694533 auc:0.5008
Fit finished.
epoch:   0 loss:0.708314 auc:0.5474
Fit finished.
epoch:   0 loss:0.705121 auc:0.6756
epoch:   0 loss:0.708397 auc:0.5422
epoch:   0 loss:0.695588 auc:0.4663
epoch:   0 loss:0.717655 auc:0.5623
epoch:   0 loss:0.697426 auc:0.5243
epoch:   0 loss:0.710120 auc:0.4378
Fit finished.
epoch:   0 loss:0.713831 auc:0.4263
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.700073 auc:0.6311
epoch:   0 loss:0

Processing dim 1:  60%|█████████████████████████▉                 | 200/331 [3:28:18<2:31:20, 69.32s/it]

Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.708179 auc:0.5100
Fit finished.
epoch:   0 loss:0.700734 auc:0.3392
epoch:   0 loss:0.698153 auc:0.5073
epoch:   0 loss:0.705217 auc:0.5003
epoch:   0 loss:0.698287 auc:0.3802
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.699260 auc:0.4430
epoch:   0 loss:0.701007 auc:0.4592
epoch:   0 loss:0.699994 auc:0.4448
epoch:   0 loss:0.704574 auc:0.1361
epoch:   0 loss:0.707475 auc:0.6746
epoch:   0 loss:0.710047 auc:0.4090
epoch:   0 loss:0.705700 auc:0.4567
epoch:   0 loss:0.700383 auc:0.5194
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.703459 auc:0.5104
epoch:   0 loss:0.704387 auc:0.6069
Fit finished.


Processing dim 1:  76%|████████████████████████████████▍          | 250/331 [3:28:33<1:04:11, 47.54s/it]

Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.706091 auc:0.5181
Fit finished.
Fit finished.
Fit finished.
Fit finished.
epoch:   0 loss:0.702266 auc:0.5179
epoch:   0 loss:0.696561 auc:0.4888
epoch:   0 loss:0.699299 auc:0.4951
epoch:   0 loss:0.701523 auc:0.5113
epoch:   0 loss:0.696602 auc:0.4142
epoch:   0 loss:0.701279 auc:0.5648
epoch:   0 loss:0.707690 auc:0.6219
epoch:   0 loss:0.694575 auc:0.4958
Fit finished.
Fit finished.
epoch:   0 loss:0.707684 auc:0.3165
Fit finished.
Fit finished.
epoch:   0 loss:0.713080 auc:0.5038
epoch:   0 loss:0.708045 auc:0.5311
Fit finished.
Fit finished.
epoch:   0 loss:0.700413 auc:0.3633
epoch:   0 loss:0.698771 auc:0.4696
epoch:   0 loss:0.699316 auc:0.3793
epoch:   0 loss:0.700827 auc:0.5490
epoch:   0 loss:0.699838 auc:0.5661
epoch:   0 loss:0.695894 auc:0.5097
epoch:   0 loss:0.705773 auc:0.3802
epoch:   0 loss:0.690413 auc:0.5579
epoch:   0 loss:0.696338 auc

KeyboardInterrupt: 

In [None]:
true_data_s.to_csv(f"new_drug_true_{args.data}.csv")
predict_data_s.to_csv(f"new_drug_pred_{args.data}.csv")