In [2]:
import argparse
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from dataclasses import dataclass
from joblib import Parallel, delayed
from tqdm import tqdm

In [3]:
%load_ext autoreload
%autoreload 2

from model import GModel
from myutils import roc_auc, translate_result, filter_target
from load_data import load_data
from optimizer import Optimizer
from sampler import NewSampler

In [4]:
@dataclass
class Args:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    data: str = "ctrp"
    n_jobs: int = 20  # ← ここをコア数に応じて調整
    lr: float = 5e-4
    epochs: int = 1000

In [5]:
args = Args()
target_option = "cell"  # "cell" か "drug"

In [6]:
# データ読み込み
res, drug_feature, exprs, mut, cna, null_mask = load_data(args)

response matrix (res) shape: (470, 807)
exprs shape: (1089, 19851)
mut shape: (1089, 1667)
cna shape: (1090, 23316)
807
exprs shape: (807, 19851)
mut shape: (807, 1667)
cna shape: (807, 23316)
drug_feature shape: (470, 920)
response matrix (res) shape: (807, 470)
null_mask shape: (807, 470)


In [7]:
def run_single_model(exprs, cna, mut, drug_feature, res_mat, null_mask, target_dim, target_index, args, seed):
    sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
    val_labels = sampler.test_data[sampler.test_mask]

    model = GModel(
        adj_mat=sampler.train_data.float(),
        gene=exprs,
        cna=cna,
        mutation=mut,
        sigma=2,
        k=11,
        iterates=3,
        feature_drug=drug_feature,
        n_hid1=192,
        n_hid2=36,
        alpha=5.74,
        device=args.device,
    )
    opt = Optimizer(
        model=model,
        train_data=sampler.train_data,
        test_data=sampler.test_data,
        test_mask=sampler.test_mask,
        train_mask=sampler.train_mask,
        evaluate_fun=roc_auc,
        lr=args.lr,
        epochs=args.epochs,
        device=args.device,
    ).to(args.device)
    _, true_data, predict_data = opt()
    return true_data, predict_data

In [8]:
# ターゲット次元と統計
target_dim = 0 if target_option == "cell" else 1
samples = res.shape[target_dim]
cell_sum = np.sum(res.values, axis=1)
drug_sum = np.sum(res.values, axis=0)

In [9]:
# 並列実行用ラッパー
def process_target(seed, target_index):
    try:
        return run_single_model(
            exprs=exprs,
            cna=cna,
            mut=mut,
            drug_feature=drug_feature,
            res_mat=res.values,
            null_mask=null_mask,
            target_dim=target_dim,
            target_index=target_index,
            args=args,
            seed=seed,
        )
    except Exception as e:
        print(f"❌ Failed at target {target_index}: {e}")
        return None

In [9]:
# 結果格納用
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()
skipped_targets = []
passed_targets = []

# スキップチェック
for target_index in range(samples):
    label_vec = res.iloc[target_index] if target_dim == 0 else res.iloc[:, target_index]
    passed, reason, pos, neg, total = filter_target(label_vec)

    if passed:
        passed_targets.append(target_index)
    else:
        skipped_targets.append((target_index, reason, pos, neg, total))

# スキップ情報表示
print(f"\n🚫 Skipped Targets: {len(skipped_targets)}")
for idx, reason, pos, neg, total in skipped_targets:
    print(f"Target {idx}: skipped because {reason} (total={total}, pos={pos}, neg={neg})")

# 並列実行（max_njobs を args.n_jobs に変更可能）
results = Parallel(n_jobs=args.n_jobs)(
    delayed(process_target)(seed, target_index)
    for seed, target_index in enumerate(tqdm(passed_targets, desc=f"MOFGCN ({args.data} - {target_option})"))
)

# 結果の統合（None を除外）
for r in results:
    if r is not None:
        true_data, pred_data = r
        true_data_s = pd.concat([true_data_s, translate_result(true_data)], ignore_index=True)
        predict_data_s = pd.concat([predict_data_s, translate_result(pred_data)], ignore_index=True)


🚫 Skipped Targets: 42
Target 8: skipped because low_negative_ratio (total=165, pos=163, neg=2)
Target 22: skipped because low_negative_ratio (total=118, pos=116, neg=2)
Target 29: skipped because low_negative_ratio (total=87, pos=86, neg=1)
Target 30: skipped because few_total_samples (total=9, pos=9, neg=0)
Target 39: skipped because low_negative_ratio (total=14, pos=14, neg=0)
Target 84: skipped because low_negative_ratio (total=10, pos=10, neg=0)
Target 117: skipped because low_negative_ratio (total=83, pos=82, neg=1)
Target 128: skipped because low_negative_ratio (total=64, pos=64, neg=0)
Target 130: skipped because low_negative_ratio (total=98, pos=97, neg=1)
Target 142: skipped because low_positive_ratio (total=92, pos=1, neg=91)
Target 146: skipped because low_negative_ratio (total=105, pos=103, neg=2)
Target 148: skipped because low_negative_ratio (total=235, pos=234, neg=1)
Target 162: skipped because low_negative_ratio (total=189, pos=187, neg=2)
Target 228: skipped because 

  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
MOFGCN (ctrp - cell):  50%|█████████████████████▎                     | 380/765 [06:02<06:14,  1.03it/s]

epoch:   0 loss:0.766382 auc:0.4388
epoch:  20 loss:0.458135 auc:0.6277
epoch:  40 loss:0.381996 auc:0.6649
epoch:  60 loss:0.365382 auc:0.6543
epoch:  80 loss:0.354101 auc:0.6489
epoch: 100 loss:0.338953 auc:0.6596
epoch: 120 loss:0.314324 auc:0.6782
epoch: 140 loss:0.278165 auc:0.6303
epoch: 160 loss:0.235463 auc:0.5798
epoch: 180 loss:0.207634 auc:0.5293
Fit finished.
epoch:   0 loss:0.737039 auc:0.4889
epoch:  20 loss:0.411070 auc:0.8111
epoch:  40 loss:0.375623 auc:0.8722
epoch:  60 loss:0.354230 auc:0.8667
epoch:  80 loss:0.331404 auc:0.9111
epoch: 100 loss:0.304240 auc:0.8722
epoch: 120 loss:0.261988 auc:0.8667
epoch: 140 loss:0.226666 auc:0.8444
epoch: 160 loss:0.203897 auc:0.8222
epoch: 180 loss:0.190434 auc:0.7889
Fit finished.
epoch:   0 loss:0.755379 auc:0.6198
epoch:  20 loss:0.407901 auc:0.6698
epoch:  40 loss:0.370564 auc:0.7115
epoch:  60 loss:0.354403 auc:0.7396
epoch:  80 loss:0.335397 auc:0.7417
epoch: 100 loss:0.307692 auc:0.7479
epoch: 120 loss:0.267612 auc:0.8042


MOFGCN (ctrp - cell):  58%|████████████████████████▋                  | 440/765 [06:56<04:58,  1.09it/s]

epoch:   0 loss:0.754745 auc:0.5309
epoch:  20 loss:0.455796 auc:0.6605
epoch:  40 loss:0.380165 auc:0.7037
epoch:  60 loss:0.365047 auc:0.7253
epoch:  80 loss:0.354274 auc:0.7346
epoch: 100 loss:0.340864 auc:0.7716
epoch: 120 loss:0.321411 auc:0.8117
epoch: 140 loss:0.296070 auc:0.8117
epoch: 160 loss:0.259891 auc:0.8302
epoch: 180 loss:0.226088 auc:0.8333
epoch: 200 loss:0.204146 auc:0.8302
epoch: 220 loss:0.189647 auc:0.8179
epoch: 240 loss:0.180139 auc:0.8117
epoch: 260 loss:0.172899 auc:0.8086
epoch: 280 loss:0.169145 auc:0.8056
Fit finished.
epoch:   0 loss:0.795336 auc:0.0000
epoch:  20 loss:0.448282 auc:0.5882
epoch:  40 loss:0.378948 auc:0.7353
epoch:  60 loss:0.362717 auc:0.7353
epoch:  80 loss:0.348022 auc:0.7941
epoch: 100 loss:0.326705 auc:0.7353
epoch: 120 loss:0.299728 auc:0.5588
epoch: 140 loss:0.263470 auc:0.3529
epoch: 160 loss:0.228284 auc:0.3235
epoch: 180 loss:0.204936 auc:0.5000
Fit finished.
epoch:   0 loss:0.776982 auc:0.4124
epoch:  20 loss:0.412150 auc:0.7026


MOFGCN (ctrp - cell):  76%|████████████████████████████████▌          | 580/765 [09:13<03:06,  1.01s/it]

epoch:   0 loss:0.788994 auc:0.5587
epoch:  20 loss:0.401950 auc:0.5043
epoch:  40 loss:0.369459 auc:0.6543
epoch:  60 loss:0.351375 auc:0.6565
epoch:  80 loss:0.328268 auc:0.6978
epoch: 100 loss:0.298313 auc:0.7587
epoch: 120 loss:0.257021 auc:0.7587
epoch: 140 loss:0.221240 auc:0.7565
epoch: 160 loss:0.199461 auc:0.7457
epoch: 180 loss:0.185342 auc:0.7326
epoch: 200 loss:0.176533 auc:0.7304
epoch: 220 loss:0.171575 auc:0.7326
epoch: 240 loss:0.165458 auc:0.7522
epoch: 260 loss:0.161057 auc:0.7500
Fit finished.
epoch:   0 loss:0.924151 auc:0.5238
epoch:  20 loss:0.499443 auc:0.6270
epoch:  40 loss:0.386909 auc:0.6476
epoch:  60 loss:0.366440 auc:0.6365
epoch:  80 loss:0.351282 auc:0.6365
epoch: 100 loss:0.333351 auc:0.6540
epoch: 120 loss:0.315943 auc:0.6254
epoch: 140 loss:0.295244 auc:0.6000
epoch: 160 loss:0.261746 auc:0.5825
epoch: 180 loss:0.227303 auc:0.5794
Fit finished.
epoch:   0 loss:0.745071 auc:0.5884
epoch:  20 loss:0.439399 auc:0.4852
epoch:  40 loss:0.376534 auc:0.4684


Process LokyProcess-15:
Process LokyProcess-14:
Traceback (most recent call last):
  File "/Users/inouey2/miniconda3/envs/torch/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/inouey2/miniconda3/envs/torch/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/inouey2/miniconda3/envs/torch/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 501, in _process_worker
    with worker_exit_lock:
  File "/Users/inouey2/miniconda3/envs/torch/lib/python3.10/site-packages/joblib/externals/loky/backend/synchronize.py", line 119, in __enter__
    return self._semlock.acquire()
KeyboardInterrupt
Traceback (most recent call last):
  File "/Users/inouey2/miniconda3/envs/torch/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/inouey2/miniconda3/envs/torch/lib/python3.10/multiprocessing/process.py", line 108, in r

epoch:   0 loss:0.785896 auc:0.4292
epoch:  20 loss:0.390026 auc:0.2750
epoch:  40 loss:0.365097 auc:0.3250
epoch:  60 loss:0.343100 auc:0.3208
epoch:  80 loss:0.316076 auc:0.4125
epoch: 100 loss:0.282524 auc:0.4708
epoch: 120 loss:0.244863 auc:0.4750
epoch: 140 loss:0.215964 auc:0.4708
epoch: 160 loss:0.195875 auc:0.4708
epoch: 180 loss:0.182616 auc:0.4875
epoch: 200 loss:0.175462 auc:0.4750
epoch: 220 loss:0.167854 auc:0.4833
epoch: 240 loss:0.162813 auc:0.4875
epoch: 260 loss:0.159771 auc:0.4917
epoch: 280 loss:0.156523 auc:0.4833
epoch: 300 loss:0.154064 auc:0.5000
epoch: 320 loss:0.151755 auc:0.4875
epoch: 340 loss:0.154959 auc:0.4583
Fit finished.
epoch:   0 loss:0.824332 auc:0.6232
epoch:  20 loss:0.489091 auc:0.7705
epoch:  40 loss:0.389111 auc:0.7464
epoch:  60 loss:0.369017 auc:0.7343
epoch:  80 loss:0.355906 auc:0.7271
epoch: 100 loss:0.337768 auc:0.7126
epoch: 120 loss:0.315378 auc:0.7198
epoch: 140 loss:0.287438 auc:0.7343
epoch: 160 loss:0.254107 auc:0.7367
epoch: 180 los


KeyboardInterrupt

