In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "gdsc1"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    0,  # Cell
    # 1  # Drug
]

load gdsc1


In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)

    val_labels = sampler.test_data[sampler.test_mask]

    if len(np.unique(val_labels)) < 2:
        print(f"Target {target_index} skipped: Validation set has only one class.")
        return None, None

    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 5  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 0:   0%|                                                         | 0/916 [00:00<?, ?it/s]

epoch:   0 loss:0.711293 auc:0.3309
epoch:   0 loss:0.700581 auc:0.4283
epoch:   0 loss:0.695801 auc:0.5562
epoch:   0 loss:0.704929 auc:0.5006
epoch:   0 loss:0.690704 auc:0.7174
epoch:  20 loss:0.198967 auc:0.9866
epoch:  20 loss:0.198290 auc:0.9699
epoch:  20 loss:0.198228 auc:0.9601
epoch:  20 loss:0.197443 auc:0.9763
epoch:  20 loss:0.197894 auc:0.9804
epoch:  40 loss:0.179765 auc:0.9879
epoch:  40 loss:0.178385 auc:0.9674
epoch:  40 loss:0.178610 auc:0.9716
epoch:  40 loss:0.178114 auc:0.9873
epoch:  40 loss:0.178280 auc:0.9865
epoch:  60 loss:0.166595 auc:0.9836
epoch:  60 loss:0.165170 auc:0.9688
epoch:  60 loss:0.165431 auc:0.9742
epoch:  60 loss:0.165646 auc:0.9854
epoch:  60 loss:0.165832 auc:0.9928
epoch:  80 loss:0.156780 auc:0.9786
epoch:  80 loss:0.155241 auc:0.9735
epoch:  80 loss:0.155983 auc:0.9752
epoch:  80 loss:0.156081 auc:0.9854
epoch:  80 loss:0.155722 auc:0.9931
epoch: 100 loss:0.147052 auc:0.9826
epoch: 100 loss:0.145945 auc:0.9757
epoch: 100 loss:0.146418 auc

Processing dim 0:   1%|▍                                           | 10/916 [30:40<46:19:16, 184.06s/it]

Fit finished.
epoch:   0 loss:0.698451 auc:0.6390
Fit finished.
epoch:   0 loss:0.691579 auc:0.3982
Fit finished.
epoch:   0 loss:0.705946 auc:0.4380
Fit finished.
epoch:   0 loss:0.700859 auc:0.4335
epoch:  20 loss:0.197323 auc:0.9567
Fit finished.
epoch:   0 loss:0.697270 auc:0.4112
epoch:  20 loss:0.201179 auc:0.9817
epoch:  20 loss:0.198990 auc:0.9523
epoch:  40 loss:0.177838 auc:0.9597
epoch:  20 loss:0.197019 auc:0.9573
epoch:  40 loss:0.179296 auc:0.9634
epoch:  20 loss:0.200953 auc:0.9494
epoch:  40 loss:0.180439 auc:0.9813
epoch:  60 loss:0.164879 auc:0.9699
epoch:  40 loss:0.176690 auc:0.9633
epoch:  60 loss:0.165977 auc:0.9665
epoch:  60 loss:0.166519 auc:0.9796
epoch:  40 loss:0.180023 auc:0.9492
epoch:  80 loss:0.155381 auc:0.9765
epoch:  80 loss:0.156000 auc:0.9691
epoch:  60 loss:0.163822 auc:0.9633
epoch:  80 loss:0.156339 auc:0.9817
epoch:  60 loss:0.166685 auc:0.9567
epoch: 100 loss:0.145599 auc:0.9769
epoch: 100 loss:0.146604 auc:0.9672
epoch: 100 loss:0.146692 auc:0

Processing dim 0:   2%|▋                                         | 15/916 [1:00:20<63:59:21, 255.67s/it]

Fit finished.
epoch:   0 loss:0.701279 auc:0.5428
epoch: 980 loss:0.105184 auc:0.9666
Fit finished.
epoch: 980 loss:0.105677 auc:0.9633
epoch:   0 loss:0.700854 auc:0.4455
epoch: 980 loss:0.103897 auc:0.9564
Fit finished.
epoch:  20 loss:0.200788 auc:0.9744
epoch:   0 loss:0.700693 auc:0.4752
Fit finished.
epoch:   0 loss:0.715470 auc:0.5389
Fit finished.
epoch:   0 loss:0.697245 auc:0.5013
epoch:  20 loss:0.197947 auc:0.9869
epoch:  40 loss:0.180338 auc:0.9758
epoch:  20 loss:0.200908 auc:0.9935
epoch:  20 loss:0.198170 auc:0.9566
epoch:  20 loss:0.200812 auc:0.9295
epoch:  60 loss:0.166460 auc:0.9779
epoch:  40 loss:0.180648 auc:0.9923
epoch:  40 loss:0.178578 auc:0.9696
epoch:  40 loss:0.178306 auc:0.9896
epoch:  40 loss:0.180063 auc:0.9391
epoch:  80 loss:0.155973 auc:0.9788
epoch:  60 loss:0.165094 auc:0.9882
epoch:  60 loss:0.166863 auc:0.9935
epoch: 100 loss:0.146954 auc:0.9797
epoch:  60 loss:0.166348 auc:0.9462
epoch:  60 loss:0.165275 auc:0.9700
epoch:  80 loss:0.155598 auc:0

Processing dim 0:   2%|▊                                      | 20/916 [11:05:14<714:01:07, 2868.82s/it]

Fit finished.
epoch:   0 loss:0.695390 auc:0.5217
epoch: 960 loss:0.103958 auc:0.9612
epoch: 980 loss:0.104043 auc:0.9768
epoch: 980 loss:0.103816 auc:0.9950
epoch:  20 loss:0.199015 auc:0.9279
Fit finished.
epoch: 980 loss:0.104747 auc:0.9654
epoch:   0 loss:0.701871 auc:0.5116
Fit finished.
epoch:   0 loss:0.687875 auc:0.5995
Fit finished.
epoch:   0 loss:0.707048 auc:0.4858
epoch:  40 loss:0.178863 auc:0.9399
epoch:  20 loss:0.197900 auc:0.9703
epoch:  20 loss:0.197884 auc:0.9654
Fit finished.
epoch:   0 loss:0.705134 auc:0.5005
epoch:  40 loss:0.178427 auc:0.9724
epoch:  20 loss:0.198491 auc:0.9256
epoch:  20 loss:0.198189 auc:0.9638
epoch:  40 loss:0.177576 auc:0.9698
epoch:  60 loss:0.165811 auc:0.9457
epoch:  60 loss:0.165576 auc:0.9719
epoch:  40 loss:0.178635 auc:0.9344
epoch:  80 loss:0.157287 auc:0.9493
epoch:  40 loss:0.177952 auc:0.9699
epoch:  60 loss:0.164473 auc:0.9674
epoch:  60 loss:0.165770 auc:0.9392
epoch:  60 loss:0.165194 auc:0.9702
epoch:  80 loss:0.156292 auc:0

KeyboardInterrupt: 

In [None]:
true_data_s.to_csv(f"new_cell_true_{args.data}.csv")
predict_data_s.to_csv(f"new_cell_pred_{args.data}.csv")