In [1]:
import argparse
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from dataclasses import dataclass
from joblib import Parallel, delayed
from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from model import GModel
from myutils import roc_auc, translate_result, filter_target
from load_data import load_data
from optimizer import Optimizer
from sampler import NewSampler

In [3]:
@dataclass
class Args:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    data: str = "ctrp"
    n_jobs: int = 6 # ← ここをコア数に応じて調整
    lr: float = 5e-4
    epochs: int = 1000

In [4]:
args = Args()
target_option = "drug"  # "cell" か "drug"

In [5]:
# データ読み込み
res, drug_feature, exprs, mut, cna, null_mask = load_data(args)

response matrix (res) shape: (470, 807)
exprs shape: (1089, 19851)
mut shape: (1089, 1667)
cna shape: (1090, 23316)
807
exprs shape: (807, 19851)
mut shape: (807, 1667)
cna shape: (807, 23316)
drug_feature shape: (470, 920)
response matrix (res) shape: (807, 470)
null_mask shape: (807, 470)


In [6]:
def run_single_model(exprs, cna, mut, drug_feature, res_mat, null_mask, target_dim, target_index, args, seed):
    sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
    val_labels = sampler.test_data[sampler.test_mask]

    model = GModel(
        adj_mat=sampler.train_data.float(),
        gene=exprs,
        cna=cna,
        mutation=mut,
        sigma=2,
        k=11,
        iterates=3,
        feature_drug=drug_feature,
        n_hid1=192,
        n_hid2=36,
        alpha=5.74,
        device=args.device,
    )
    opt = Optimizer(
        model=model,
        train_data=sampler.train_data,
        test_data=sampler.test_data,
        test_mask=sampler.test_mask,
        train_mask=sampler.train_mask,
        evaluate_fun=roc_auc,
        lr=args.lr,
        epochs=args.epochs,
        device=args.device,
    ).to(args.device)
    _, true_data, predict_data = opt()
    return true_data, predict_data

In [7]:
# ターゲット次元と統計
target_dim = 0 if target_option == "cell" else 1
samples = res.shape[target_dim]
cell_sum = np.sum(res.values, axis=1)
drug_sum = np.sum(res.values, axis=0)

In [8]:
# 並列実行用ラッパー
def process_target(seed, target_index):
    try:
        return run_single_model(
            exprs=exprs,
            cna=cna,
            mut=mut,
            drug_feature=drug_feature,
            res_mat=res.values,
            null_mask=null_mask,
            target_dim=target_dim,
            target_index=target_index,
            args=args,
            seed=seed,
        )
    except Exception as e:
        print(f"❌ Failed at target {target_index}: {e}")
        return None

In [9]:
# 結果格納用
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()
skipped_targets = []
passed_targets = []

# スキップチェック
for target_index in range(samples):
    label_vec = res.iloc[target_index] if target_dim == 0 else res.iloc[:, target_index]
    passed, reason, pos, neg, total = filter_target(label_vec)

    if passed:
        passed_targets.append(target_index)
    else:
        skipped_targets.append((target_index, reason, pos, neg, total))

# スキップ情報表示
print(f"\n🚫 Skipped Targets: {len(skipped_targets)}")
for idx, reason, pos, neg, total in skipped_targets:
    print(f"Target {idx}: skipped because {reason} (total={total}, pos={pos}, neg={neg})")

# 並列実行（max_njobs を args.n_jobs に変更可能）
results = Parallel(n_jobs=args.n_jobs)(
    delayed(process_target)(seed, target_index)
    for seed, target_index in enumerate(tqdm(passed_targets, desc=f"MOFGCN ({args.data} - {target_option})"))
)

# 結果の統合（None を除外）
for r in results:
    if r is not None:
        true_data, pred_data = r
        true_data_s = pd.concat([true_data_s, translate_result(true_data)], ignore_index=True)
        predict_data_s = pd.concat([predict_data_s, translate_result(pred_data)], ignore_index=True)


🚫 Skipped Targets: 36
Target 2: skipped because few_total_samples (total=4, pos=2, neg=2)
Target 21: skipped because few_total_samples (total=3, pos=1, neg=2)
Target 40: skipped because few_total_samples (total=5, pos=3, neg=2)
Target 41: skipped because few_total_samples (total=8, pos=5, neg=3)
Target 45: skipped because few_total_samples (total=6, pos=1, neg=5)
Target 48: skipped because few_total_samples (total=9, pos=4, neg=5)
Target 52: skipped because few_total_samples (total=1, pos=1, neg=0)
Target 54: skipped because few_total_samples (total=2, pos=1, neg=1)
Target 56: skipped because few_total_samples (total=9, pos=4, neg=5)
Target 58: skipped because few_total_samples (total=9, pos=4, neg=5)
Target 60: skipped because few_total_samples (total=2, pos=1, neg=1)
Target 63: skipped because few_total_samples (total=3, pos=1, neg=2)
Target 68: skipped because few_total_samples (total=3, pos=1, neg=2)
Target 83: skipped because few_total_samples (total=2, pos=1, neg=1)
Target 84: s

  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
MOFGCN (ctrp - drug):  25%|██▍       | 108/434 [00:48<02:19,  2.33it/s]

epoch:   0 loss:0.757741 auc:0.5375
epoch:  20 loss:0.422271 auc:0.9500
epoch:  40 loss:0.377267 auc:0.9625
epoch:  60 loss:0.357410 auc:0.9500
epoch:  80 loss:0.337258 auc:0.9375
epoch: 100 loss:0.313532 auc:0.9375
epoch: 120 loss:0.283645 auc:0.9375
epoch: 140 loss:0.244958 auc:0.9375
epoch: 160 loss:0.216882 auc:0.9375
epoch: 180 loss:0.198974 auc:0.9375
Fit finished.
epoch:   0 loss:0.798014 auc:0.5000
epoch:  20 loss:0.441246 auc:0.5204
epoch:  40 loss:0.379826 auc:0.5204
epoch:  60 loss:0.362243 auc:0.5357
epoch:  80 loss:0.347705 auc:0.5459
epoch: 100 loss:0.327051 auc:0.5459
epoch: 120 loss:0.304272 auc:0.5510
epoch: 140 loss:0.276990 auc:0.5969
epoch: 160 loss:0.243280 auc:0.5714
epoch: 180 loss:0.216516 auc:0.5765
epoch: 200 loss:0.199469 auc:0.5612
epoch: 220 loss:0.187908 auc:0.5612
epoch: 240 loss:0.179105 auc:0.5306
Fit finished.
epoch:   0 loss:0.773756 auc:0.4701
epoch:  20 loss:0.470179 auc:0.8496
epoch:  40 loss:0.384805 auc:0.8700
epoch:  60 loss:0.367078 auc:0.8802


epoch: 140 loss:0.266050 epoch:   0 loss:0.768452 auc:0.4681
epoch:  20 loss:0.403405 auc:0.8570
epoch:  40 loss:0.376355 auc:0.8786
epoch:  60 loss:0.358608 auc:0.8750
epoch:  80 loss:0.341248 auc:0.8688
epoch: 100 loss:0.317393 auc:0.8431
epoch: 120 loss:0.288822 auc:0.8251
epoch: 140 loss:0.252713 auc:0.8302
epoch: 160 loss:0.221772 auc:0.8513
epoch: 180 loss:0.203051 auc:0.8709
epoch: 200 loss:0.188677 auc:0.8729
Fit finished.
epoch:   0 loss:0.813357 auc:0.4279
epoch:  20 loss:0.466004 auc:0.8758
epoch:  40 loss:0.384726 auc:0.9069
epoch:  60 loss:0.365521 auc:0.9091
epoch:  80 loss:0.352748 auc:0.9180
epoch: 100 loss:0.335282 auc:0.9069
epoch: 120 loss:0.314325 auc:0.8714
epoch: 140 loss:0.289809 auc:0.8182
epoch: 160 loss:0.255789 auc:0.8160
epoch: 180 loss:0.223725 auc:0.8271
Fit finished.
epoch:   0 loss:0.756911 auc:0.6575
epoch:  20 loss:0.397238 auc:0.9382
epoch:  40 loss:0.373354 auc:0.9410
epoch:  60 loss:0.355804 auc:0.9403
epoch:  80 loss:0.337466 auc:0.9403
epoch: 100 

epoch: 400 epoch:   0 loss:0.718217 auc:0.4226
epoch:  20 loss:0.407712 auc:0.9107
epoch:  40 loss:0.375095 auc:0.9152
epoch:  60 loss:0.359603 auc:0.9182
epoch:  80 loss:0.342078 auc:0.9256
epoch: 100 loss:0.315146 auc:0.9286
epoch: 120 loss:0.283720 auc:0.9182
epoch: 140 loss:0.244912 auc:0.8899
epoch: 160 loss:0.215096 auc:0.8973
epoch: 180 loss:0.196958 auc:0.8943
Fit finished.
epoch:   0 loss:0.814188 auc:0.5200
epoch:  20 loss:0.462728 auc:0.9467
epoch:  40 loss:0.382432 auc:0.9400
epoch:  60 loss:0.366079 auc:0.9467
epoch:  80 loss:0.355212 auc:0.9467
epoch: 100 loss:0.339795 auc:0.9200
epoch: 120 loss:0.317262 auc:0.9067
epoch: 140 loss:0.289564 auc:0.9267
epoch: 160 loss:0.253899 auc:0.9200
epoch: 180 loss:0.225043 auc:0.9267
Fit finished.
epoch:   0 loss:0.769072 auc:0.3485
epoch:  20 loss:0.449894 auc:0.9324
epoch:  40 loss:0.380423 auc:0.9310
epoch:  60 loss:0.365621 auc:0.9372
epoch:  80 loss:0.352072 auc:0.9420
epoch: 100 loss:0.332366 auc:0.9409
epoch: 120 loss:0.310099 

MOFGCN (ctrp - drug):  48%|████▊     | 210/434 [01:28<01:26,  2.58it/s]

loss:0.378907 auc:0.9867
epoch:  60 loss:0.360065 auc:0.9867
epoch:  80 loss:0.344371 auc:0.9867
epoch: 100 loss:0.322463 auc:1.0000
epoch: 120 loss:0.296764 auc:1.0000
epoch: 140 loss:0.260635 auc:0.9867
epoch: 160 loss:0.227166 auc:0.9600
epoch: 180 loss:0.206912 auc:0.9600
Fit finished.
epoch:   0 loss:0.864470 auc:0.6182
epoch:  20 loss:0.501042 auc:0.7636
epoch:  40 loss:0.388581 auc:0.7818
epoch:  60 loss:0.369574 auc:0.8182
epoch:  80 loss:0.357769 auc:0.8182
epoch: 100 loss:0.345425 auc:0.8091
epoch: 120 loss:0.327060 auc:0.8000
epoch: 140 loss:0.308296 auc:0.7818
epoch: 160 loss:0.284217 auc:0.7818
epoch: 180 loss:0.252496 auc:0.8273
epoch: 200 loss:0.223770 auc:0.8727
epoch: 220 loss:0.203249 auc:0.8818
epoch: 240 loss:0.190229 auc:0.8727
epoch: 260 loss:0.180927 auc:0.8636
epoch: 280 loss:0.174003 auc:0.8636
epoch: 300 loss:0.169809 auc:0.8545
epoch: 320 loss:0.165083 auc:0.8455
epoch: 340 loss:0.161645 auc:0.8364
epoch: 360 loss:0.160253 auc:0.8455
Fit finished.
epoch:   0 

epoch:  20 loss:0.437347 loss:0.187701 auc:0.9904
epoch: 240 loss:0.178570 auc:0.9891
epoch: 260 loss:0.173767 auc:0.9891
Fit finished.
epoch:   0 loss:0.771969 auc:0.6088
epoch:  20 loss:0.415878 auc:0.9363
epoch:  40 loss:0.374682 auc:0.9452
epoch:  60 loss:0.357785 auc:0.9582
epoch:  80 loss:0.341544 auc:0.9592
epoch: 100 loss:0.317097 auc:0.9607
epoch: 120 loss:0.287032 auc:0.9527
epoch: 140 loss:0.247053 auc:0.9577
epoch: 160 loss:0.217905 auc:0.9637
epoch: 180 loss:0.200204 auc:0.9731
epoch: 200 loss:0.186478 auc:0.9726
epoch: 220 loss:0.177740 auc:0.9731
epoch: 240 loss:0.171897 auc:0.9756
epoch: 260 loss:0.166556 auc:0.9751
epoch: 280 loss:0.162648 auc:0.9751
epoch: 300 loss:0.159678 auc:0.9756
epoch: 320 loss:0.157979 auc:0.9761
epoch: 340 loss:0.156101 auc:0.9776
epoch: 360 loss:0.153601 auc:0.9766
epoch: 380 loss:0.153444 auc:0.9751
epoch: 400 loss:0.150586 auc:0.9751
Fit finished.
epoch:   0 loss:0.764039 auc:0.5647
epoch:  20 loss:0.474102 auc:0.8814
epoch:  40 loss:0.3881

MOFGCN (ctrp - drug):  71%|███████   | 306/434 [02:11<00:52,  2.46it/s]

loss:0.349983 auc:0.9042
epoch:  80 loss:0.326787 auc:0.8979
epoch: 100 loss:0.295266 auc:0.8565
epoch: 120 loss:0.250062 auc:0.8346
epoch: 140 loss:0.216801 auc:0.7698
epoch: 160 loss:0.196555 auc:0.7695
epoch: 180 loss:0.183779 auc:0.7818
Fit finished.
epoch:   0 loss:0.760805 auc:0.4434
epoch:  20 loss:0.399587 auc:0.9386
epoch:  40 loss:0.369255 auc:0.9556
epoch:  60 loss:0.351403 auc:0.9581
epoch:  80 loss:0.328162 auc:0.9478
epoch: 100 loss:0.300128 auc:0.9256
epoch: 120 loss:0.261453 auc:0.8806
epoch: 140 loss:0.226600 auc:0.8648
epoch: 160 loss:0.204425 auc:0.8732
epoch: 180 loss:0.188637 auc:0.8637
Fit finished.
epoch:   0 loss:0.760140 auc:0.4480
epoch:  20 loss:0.462355 auc:0.6199
epoch:  40 loss:0.384790 auc:0.6199
epoch:  60 loss:0.365067 auc:0.6425
epoch:  80 loss:0.352559 auc:0.6425
epoch: 100 loss:0.335894 auc:0.6561
epoch: 120 loss:0.314993 auc:0.6652
epoch: 140 loss:0.287222 auc:0.6244
epoch: 160 loss:0.254899 auc:0.5973
epoch: 180 loss:0.222036 auc:0.6425
epoch: 200 

epoch: 140 Fit finished.
epoch:   0 loss:0.795186 auc:0.6372
epoch:  20 loss:0.432956 auc:0.9271
epoch:  40 loss:0.379133 auc:0.9210
epoch:  60 loss:0.359420 auc:0.9193
epoch:  80 loss:0.341004 auc:0.9195
epoch: 100 loss:0.317838 auc:0.9208
epoch: 120 loss:0.292146 auc:0.9190
epoch: 140 loss:0.260326 auc:0.9097
epoch: 160 loss:0.227870 auc:0.9093
epoch: 180 loss:0.205924 auc:0.9039
Fit finished.
epoch:   0 loss:0.768679 auc:0.5848
epoch:  20 loss:0.394574 auc:0.7394
epoch:  40 loss:0.369715 auc:0.7182
epoch:  60 loss:0.355189 auc:0.6848
epoch:  80 loss:0.337837 auc:0.6455
epoch: 100 loss:0.311174 auc:0.5455
epoch: 120 loss:0.277548 auc:0.4424
epoch: 140 loss:0.235411 auc:0.4242
epoch: 160 loss:0.207383 auc:0.4576
Fit finished.
epoch:   0 loss:0.826697 auc:0.4900
epoch:  20 loss:0.461340 auc:0.9792
epoch:  40 loss:0.382859 auc:0.9843
epoch:  60 loss:0.365978 auc:0.9836
epoch:  80 loss:0.353814 auc:0.9838
epoch: 100 loss:0.336822 auc:0.9819
epoch: 120 loss:0.313427 auc:0.9788
epoch: 140 

MOFGCN (ctrp - drug):  72%|███████▏  | 312/434 [02:13<00:49,  2.48it/s]

auc:0.7876
epoch:  40 loss:0.377900 auc:0.8006
epoch:  60 loss:0.361977 auc:0.8067
epoch:  80 loss:0.347192 auc:0.8014
epoch: 100 loss:0.323460 auc:0.8127
epoch: 120 loss:0.290484 auc:0.8118
epoch: 140 loss:0.253285 auc:0.8604
epoch: 160 loss:0.221628 auc:0.8796
epoch: 180 loss:0.201709 auc:0.8761
epoch: 200 loss:0.189043 auc:0.8702
epoch: 220 loss:0.179264 auc:0.8614
epoch: 240 loss:0.173621 auc:0.8563
epoch: 260 loss:0.168027 auc:0.8420
epoch: 280 loss:0.164502 auc:0.8345
epoch: 300 loss:0.160852 auc:0.8243
Fit finished.
epoch:   0 loss:0.815171 auc:0.7130
epoch:  20 loss:0.452656 auc:0.8519
epoch:  40 loss:0.383313 auc:0.8426
epoch:  60 loss:0.366425 auc:0.8704
epoch:  80 loss:0.354800 auc:0.8611
epoch: 100 loss:0.339366 auc:0.8704
epoch: 120 loss:0.311114 auc:0.8519
epoch: 140 loss:0.279370 auc:0.8519
epoch: 160 loss:0.244960 auc:0.8704
epoch: 180 loss:0.216321 auc:0.8611
epoch: 200 loss:0.199535 auc:0.8796
epoch: 220 loss:0.186394 auc:0.9074
epoch: 240 loss:0.177980 auc:0.9074
epo

epoch: 320 loss:0.163875 auc:0.7357
epoch: 340 loss:0.160916 auc:0.7358
epoch: 360 loss:0.158259 auc:0.7394
epoch: 380 loss:0.156338 auc:0.7414
epoch: 400 loss:0.154774 auc:0.7419
epoch: 420 loss:0.152701 auc:0.7400
epoch: 440 loss:0.152224 auc:0.7407
epoch: 460 loss:0.150098 auc:0.7392
epoch: 480 loss:0.149562 auc:0.7411
epoch: 500 loss:0.148658 auc:0.7390
epoch: 520 loss:0.147725 auc:0.7395
epoch: 540 loss:0.146283 auc:0.7402
Fit finished.
epoch:   0 loss:0.810088 auc:0.4391
epoch:  20 loss:0.423199 auc:0.9426
epoch:  40 loss:0.378535 auc:0.9504
epoch:  60 loss:0.363942 auc:0.9530
epoch:  80 loss:0.352297 auc:0.9483
epoch: 100 loss:0.336503 auc:0.9426
epoch: 120 loss:0.315757 auc:0.9317
epoch: 140 loss:0.290924 auc:0.9271
epoch: 160 loss:0.255336 auc:0.9336
epoch: 180 loss:0.223187 auc:0.9284
Fit finished.
epoch:   0 loss:0.821690 auc:0.5000
epoch:  20 loss:0.430146 auc:0.7917
epoch:  40 loss:0.379264 auc:0.7917
epoch:  60 loss:0.361572 auc:0.7917
epoch:  80 loss:0.344838 auc:0.7917


MOFGCN (ctrp - drug):  95%|█████████▌| 414/434 [02:54<00:07,  2.52it/s]

loss:0.846221 auc:0.5770
epoch:  20 loss:0.508268 auc:0.8114
epoch:  40 loss:0.388629 auc:0.8214
epoch:  60 loss:0.369472 auc:0.8399
epoch:  80 loss:0.356060 auc:0.8525
epoch: 100 loss:0.341744 auc:0.8595
epoch: 120 loss:0.322329 auc:0.8536
epoch: 140 loss:0.302250 auc:0.8364
epoch: 160 loss:0.276077 auc:0.8141
epoch: 180 loss:0.243713 auc:0.7806
Fit finished.
epoch:   0 loss:0.858111 auc:0.5168
epoch:  20 loss:0.430189 auc:0.6852
epoch:  40 loss:0.377205 auc:0.6674
epoch:  60 loss:0.361313 auc:0.6714
epoch:  80 loss:0.349290 auc:0.6751
epoch: 100 loss:0.333487 auc:0.6753
epoch: 120 loss:0.311071 auc:0.6704
epoch: 140 loss:0.282804 auc:0.6547
epoch: 160 loss:0.248496 auc:0.6352
epoch: 180 loss:0.219174 auc:0.6744
Fit finished.
epoch:   0 loss:0.784705 auc:0.3835
epoch:  20 loss:0.458006 auc:0.9649
epoch:  40 loss:0.380138 auc:0.9649
epoch:  60 loss:0.364490 auc:0.9654
epoch:  80 loss:0.350652 auc:0.9677
epoch: 100 loss:0.328691 auc:0.9666
epoch: 120 loss:0.300420 auc:0.9638
epoch: 140 

MOFGCN (ctrp - drug):  97%|█████████▋| 420/434 [02:56<00:05,  2.59it/s]

loss:0.811057 auc:0.4375
epoch:  20 loss:0.481498 auc:0.6042
epoch:  40 loss:0.381588 auc:0.6042
epoch:  60 loss:0.364926 auc:0.6042
epoch:  80 loss:0.352271 auc:0.6042
epoch: 100 loss:0.335183 auc:0.6458
epoch: 120 loss:0.312531 auc:0.6458
epoch: 140 loss:0.289000 auc:0.6042
epoch: 160 loss:0.259628 auc:0.5625
epoch: 180 loss:0.229083 auc:0.5000
Fit finished.
epoch:   0 loss:0.809204 auc:0.5553
epoch:  20 loss:0.441672 auc:0.9449
epoch:  40 loss:0.380709 auc:0.9539
epoch:  60 loss:0.363359 auc:0.9658
epoch:  80 loss:0.348567 auc:0.9718
epoch: 100 loss:0.328558 auc:0.9718
epoch: 120 loss:0.308504 auc:0.9720
epoch: 140 loss:0.281214 auc:0.9550
epoch: 160 loss:0.246110 auc:0.8971
epoch: 180 loss:0.217782 auc:0.8948
Fit finished.
epoch:   0 loss:0.819860 auc:0.3900
epoch:  20 loss:0.413141 auc:0.8475
epoch:  40 loss:0.373859 auc:0.8638
epoch:  60 loss:0.355261 auc:0.8745
epoch:  80 loss:0.335757 auc:0.8815
epoch: 100 loss:0.309122 auc:0.8764
epoch: 120 loss:0.271219 auc:0.8588
epoch: 140 

MOFGCN (ctrp - drug): 100%|██████████| 434/434 [03:02<00:00,  2.38it/s]


In [10]:
# Save
true_path = f"mofgcn_true_{args.data}_{target_option}.csv"
pred_path = f"mofgcn_pred_{args.data}_{target_option}.csv"
true_data_s.to_csv(true_path, index=False)
predict_data_s.to_csv(pred_path, index=False)

print(f"\n✅ Done. Results saved to:\n  - {true_path}\n  - {pred_path}")


✅ Done. Results saved to:
  - mofgcn_true_ctrp_drug.csv
  - mofgcn_pred_ctrp_drug.csv
