In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "gdsc2"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    0,  # Cell
    # 1  # Drug
]

load gdsc2


In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)
    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 5  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 0:   0%|                                                         | 0/910 [00:00<?, ?it/s]

epoch:   0 loss:0.706324 auc:0.3721
epoch:   0 loss:0.692680 auc:0.5224
epoch:   0 loss:0.700092 auc:0.6296
epoch:   0 loss:0.695419 auc:0.4744
epoch:   0 loss:0.715908 auc:0.5283
epoch:  20 loss:0.153835 auc:0.9592
epoch:  20 loss:0.154698 auc:0.8980
epoch:  20 loss:0.154487 auc:0.9929
epoch:  20 loss:0.153625 auc:0.9923
epoch:  20 loss:0.153382 auc:0.9512
epoch:  40 loss:0.138309 auc:0.9964
epoch:  40 loss:0.137671 auc:0.9961
epoch:  40 loss:0.139144 auc:0.9229
epoch:  40 loss:0.137829 auc:0.9698
epoch:  40 loss:0.137113 auc:0.9609
epoch:  60 loss:0.124536 auc:0.9929
epoch:  60 loss:0.123757 auc:0.9954
epoch:  60 loss:0.125224 auc:0.9524
epoch:  60 loss:0.124250 auc:0.9722
epoch:  60 loss:0.123228 auc:0.9697
epoch:  80 loss:0.115868 auc:0.9592
epoch:  80 loss:0.115453 auc:0.9747
epoch:  80 loss:0.115605 auc:0.9893
epoch:  80 loss:0.114624 auc:0.9775
epoch:  80 loss:0.114848 auc:0.9977
epoch: 100 loss:0.109326 auc:0.9780
epoch: 100 loss:0.109732 auc:0.9705
epoch: 100 loss:0.108615 auc

Processing dim 0:   1%|▍                                           | 10/910 [30:19<45:28:47, 181.92s/it]

Fit finished.
epoch:   0 loss:0.697176 auc:0.5245
epoch: 980 loss:0.074296 auc:0.9731
Fit finished.
epoch: 980 loss:0.074602 auc:0.9660
epoch:   0 loss:0.689477 auc:0.4170
Fit finished.
epoch:   0 loss:0.707917 auc:0.3464
Fit finished.
epoch:  20 loss:0.154152 auc:0.9990
epoch:   0 loss:0.699970 auc:0.3318
Fit finished.
epoch:  20 loss:0.155054 auc:0.9974
epoch:   0 loss:0.697837 auc:0.5533
epoch:  20 loss:0.155242 auc:0.9557
epoch:  40 loss:0.138036 auc:1.0000
epoch:  20 loss:0.152727 auc:0.9884
epoch:  20 loss:0.156705 auc:0.9917
epoch:  40 loss:0.139142 auc:0.9948
epoch:  40 loss:0.139156 auc:0.9640
epoch:  60 loss:0.124337 auc:1.0000
epoch:  40 loss:0.140999 auc:1.0000
epoch:  60 loss:0.125279 auc:0.9948
epoch:  40 loss:0.136096 auc:0.9907
epoch:  60 loss:0.125194 auc:0.9685
epoch:  80 loss:0.115922 auc:0.9931
epoch:  60 loss:0.127121 auc:1.0000
epoch:  80 loss:0.115583 auc:1.0000
epoch:  60 loss:0.122614 auc:0.9900
epoch:  80 loss:0.116139 auc:0.9719
epoch: 100 loss:0.109777 auc:0

Processing dim 0:   2%|▋                                           | 15/910 [59:22<62:30:26, 251.43s/it]

Fit finished.
epoch:   0 loss:0.699021 auc:0.5235
Fit finished.
epoch:   0 loss:0.703062 auc:0.3648
epoch: 980 loss:0.072156 auc:0.9856
epoch: 960 loss:0.072583 auc:0.9952
epoch:  20 loss:0.155755 auc:0.9882
Fit finished.
epoch:   0 loss:0.699281 auc:0.5608
Fit finished.
epoch:  20 loss:0.153629 auc:0.9920
epoch:   0 loss:0.718660 auc:0.3535
epoch: 980 loss:0.072369 auc:0.9952
epoch:  40 loss:0.139467 auc:0.9882
epoch:  20 loss:0.155884 auc:0.9809
epoch:  40 loss:0.137127 auc:0.9968
epoch:  20 loss:0.155045 auc:0.9825
epoch:  40 loss:0.139616 auc:0.9826
Fit finished.
epoch:  60 loss:0.125101 auc:0.9896
epoch:   0 loss:0.700324 auc:0.6379
epoch:  60 loss:0.125170 auc:0.9826
epoch:  40 loss:0.138388 auc:0.9854
epoch:  60 loss:0.123179 auc:0.9968
epoch:  20 loss:0.153440 auc:1.0000
epoch:  80 loss:0.115881 auc:0.9917
epoch:  80 loss:0.114658 auc:0.9968
epoch:  60 loss:0.124606 auc:0.9883
epoch:  80 loss:0.115681 auc:0.9792
epoch:  40 loss:0.137069 auc:1.0000
epoch: 100 loss:0.109629 auc:0

Processing dim 0:   2%|▉                                       | 20/910 [7:19:37<459:52:44, 1860.18s/it]

Fit finished.
epoch:   0 loss:0.695019 auc:0.4601
epoch: 980 loss:0.076088 auc:0.9878
Fit finished.
epoch:   0 loss:0.701276 auc:0.4974
Fit finished.
epoch:   0 loss:0.684412 auc:0.6173
epoch: 960 loss:0.072443 auc:1.0000
Fit finished.
epoch:   0 loss:0.709253 auc:0.3398
epoch:  20 loss:0.153788 auc:0.9302
epoch:  20 loss:0.153966 auc:0.9803
epoch: 980 loss:0.072310 auc:1.0000
epoch:  20 loss:0.153532 auc:0.9694
epoch:  40 loss:0.137788 auc:0.9825
epoch:  20 loss:0.153448 auc:0.9695
epoch:  40 loss:0.137616 auc:0.9366
Fit finished.
epoch:   0 loss:0.704015 auc:0.4819
epoch:  60 loss:0.123815 auc:0.9861
epoch:  40 loss:0.137043 auc:0.9783
epoch:  40 loss:0.136924 auc:0.9707
epoch:  60 loss:0.124026 auc:0.9431
epoch:  20 loss:0.153099 auc:0.9564
epoch:  60 loss:0.123437 auc:0.9744
epoch:  60 loss:0.123011 auc:0.9719
epoch:  80 loss:0.115495 auc:0.9532
epoch:  80 loss:0.115328 auc:0.9876
epoch:  40 loss:0.137177 auc:0.9603
epoch: 100 loss:0.109430 auc:0.9624
epoch:  80 loss:0.115000 auc:0

Processing dim 0:   3%|█                                      | 25/910 [15:36:20<807:47:08, 3285.91s/it]

Fit finished.



KeyboardInterrupt



epoch:   0 loss:0.704263 auc:0.5798
epoch: 980 loss:0.072430 auc:0.9876
epoch: 940 loss:0.072244 auc:0.9606
epoch: 960 loss:0.072440 auc:0.9796


In [None]:
true_data_s.to_csv(f"new_cell_true_{args.data}.csv")
predict_data_s.to_csv(f"new_cell_pred_{args.data}.csv")