In [1]:
import argparse
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from dataclasses import dataclass
from joblib import Parallel, delayed
from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from model import GModel
from myutils import roc_auc, translate_result, filter_target
from load_data import load_data
from optimizer import Optimizer
from sampler import NewSampler

In [3]:
@dataclass
class Args:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    data: str = "gdsc1"
    n_jobs: int = 15 # ← ここをコア数に応じて調整
    lr: float = 5e-4
    epochs: int = 1000

In [4]:
args = Args()
target_option = "drug"  # "cell" か "drug"

In [5]:
# データ読み込み
res, drug_feature, exprs, mut, cna, null_mask = load_data(args)

load gdsc1
response matrix (res) shape: (300, 925)
exprs shape: (1084, 19562)
mut shape: (1084, 18099)
cna shape: (1084, 24502)
925
exprs shape: (925, 19562)
mut shape: (925, 18099)
cna shape: (925, 24502)
drug_feature shape: (300, 920)
response matrix (res) shape: (925, 300)
null_mask shape: (925, 300)


In [6]:
def run_single_model(exprs, cna, mut, drug_feature, res_mat, null_mask, target_dim, target_index, args, seed):
    sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
    val_labels = sampler.test_data[sampler.test_mask]

    model = GModel(
        adj_mat=sampler.train_data.float(),
        gene=exprs,
        cna=cna,
        mutation=mut,
        sigma=2,
        k=11,
        iterates=3,
        feature_drug=drug_feature,
        n_hid1=192,
        n_hid2=36,
        alpha=5.74,
        device=args.device,
    )
    opt = Optimizer(
        model=model,
        train_data=sampler.train_data,
        test_data=sampler.test_data,
        test_mask=sampler.test_mask,
        train_mask=sampler.train_mask,
        evaluate_fun=roc_auc,
        lr=args.lr,
        epochs=args.epochs,
        device=args.device,
    ).to(args.device)
    _, true_data, predict_data = opt()
    return true_data, predict_data

In [7]:
# ターゲット次元と統計
target_dim = 0 if target_option == "cell" else 1
samples = res.shape[target_dim]
cell_sum = np.sum(res.values, axis=1)
drug_sum = np.sum(res.values, axis=0)

In [8]:
# 並列実行用ラッパー
def process_target(seed, target_index):
    try:
        return run_single_model(
            exprs=exprs,
            cna=cna,
            mut=mut,
            drug_feature=drug_feature,
            res_mat=res.values,
            null_mask=null_mask,
            target_dim=target_dim,
            target_index=target_index,
            args=args,
            seed=seed,
        )
    except Exception as e:
        print(f"❌ Failed at target {target_index}: {e}")
        return None

In [9]:
# 結果格納用
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()
skipped_targets = []
passed_targets = []

# スキップチェック
for target_index in range(samples):
    label_vec = res.iloc[target_index] if target_dim == 0 else res.iloc[:, target_index]
    passed, reason, pos, neg, total = filter_target(label_vec)

    if passed:
        passed_targets.append(target_index)
    else:
        skipped_targets.append((target_index, reason, pos, neg, total))

# スキップ情報表示
print(f"\n🚫 Skipped Targets: {len(skipped_targets)}")
for idx, reason, pos, neg, total in skipped_targets:
    print(f"Target {idx}: skipped because {reason} (total={total}, pos={pos}, neg={neg})")

# 並列実行（max_njobs を args.n_jobs に変更可能）
results = Parallel(n_jobs=args.n_jobs)(
    delayed(process_target)(seed, target_index)
    for seed, target_index in enumerate(tqdm(passed_targets, desc=f"MOFGCN ({args.data} - {target_option})"))
)

# 結果の統合（None を除外）
for r in results:
    if r is not None:
        true_data, pred_data = r
        true_data_s = pd.concat([true_data_s, translate_result(true_data)], ignore_index=True)
        predict_data_s = pd.concat([predict_data_s, translate_result(pred_data)], ignore_index=True)


🚫 Skipped Targets: 3
Target 5: skipped because few_total_samples (total=8, pos=1, neg=7)
Target 229: skipped because few_total_samples (total=9, pos=1, neg=8)
Target 298: skipped because few_total_samples (total=7, pos=1, neg=6)


  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
MOFGCN (gdsc1 - drug):  76%|███████▌  | 225/297 [01:25<00:26,  2.69it/s]

epoch:   0 loss:0.732250 auc:0.6090
epoch:  20 loss:0.659111 auc:0.8518
epoch:  40 loss:0.626945 auc:0.8377
epoch:  60 loss:0.563546 auc:0.8474
epoch:  80 loss:0.519740 auc:0.8574
epoch: 100 loss:0.504845 auc:0.8678
epoch: 120 loss:0.504318 auc:0.8686
epoch: 140 loss:0.493551 auc:0.8718
epoch: 160 loss:0.487156 auc:0.8730
epoch: 180 loss:0.483310 auc:0.8686
epoch: 200 loss:0.472248 auc:0.8586
epoch: 220 loss:0.458642 auc:0.8482
epoch: 240 loss:0.435002 auc:0.8361
Fit finished.
epoch:   0 loss:0.723890 auc:0.4545
epoch:  20 loss:0.661003 auc:0.8539
epoch:  40 loss:0.625986 auc:0.8258
epoch:  60 loss:0.564244 auc:0.7879
epoch:  80 loss:0.520786 auc:0.7370
epoch: 100 loss:0.505427 auc:0.7348
epoch: 120 loss:0.496669 auc:0.7338
epoch: 140 loss:0.493030 auc:0.7305
epoch: 160 loss:0.489320 auc:0.7240
epoch: 180 loss:0.487751 auc:0.6818
Fit finished.
epoch:   0 loss:0.733286 auc:0.5346
epoch:  20 loss:0.637323 auc:0.7278
epoch:  40 loss:0.574777 auc:0.7842
epoch:  60 loss:0.521030 auc:0.8191


MOFGCN (gdsc1 - drug):  81%|████████  | 240/297 [01:29<00:19,  2.88it/s]

epoch:   0 loss:0.704364 auc:0.3750
epoch:  20 loss:0.663565 auc:0.3542
epoch:  40 loss:0.636721 auc:0.3438
epoch:  60 loss:0.571633 auc:0.3438
epoch:  80 loss:0.527589 auc:0.3750
epoch: 100 loss:0.508906 auc:0.3854
epoch: 120 loss:0.501407 auc:0.4062
epoch: 140 loss:0.494765 auc:0.3958
epoch: 160 loss:0.492266 auc:0.4062
epoch: 180 loss:0.484858 auc:0.4375
epoch: 200 loss:0.478903 auc:0.4479
epoch: 220 loss:0.456873 auc:0.4687
epoch: 240 loss:0.443714 auc:0.4688
epoch: 260 loss:0.406776 auc:0.4792
epoch: 280 loss:0.430877 auc:0.4479
epoch: 300 loss:0.361890 auc:0.5000
epoch: 320 loss:0.325279 auc:0.4896
epoch: 340 loss:0.295096 auc:0.4792
epoch: 360 loss:0.303086 auc:0.4479
epoch: 380 loss:0.402549 auc:0.3750
Fit finished.
epoch:   0 loss:0.730385 auc:0.5279
epoch:  20 loss:0.672210 auc:0.5913
epoch:  40 loss:0.634490 auc:0.6083
epoch:  60 loss:0.583061 auc:0.6423
epoch:  80 loss:0.520652 auc:0.7268
epoch: 100 loss:0.506126 auc:0.7409
epoch: 120 loss:0.500112 auc:0.7459
epoch: 140 los

MOFGCN (gdsc1 - drug): 100%|██████████| 297/297 [01:47<00:00,  2.77it/s]


epoch:   0 loss:0.764417 auc:0.5462
epoch:  20 loss:0.676453 auc:0.7095
epoch:  40 loss:0.638765 auc:0.7221
epoch:  60 loss:0.592762 auc:0.7319
epoch:  80 loss:0.527687 auc:0.7221
epoch: 100 loss:0.511315 auc:0.7411
epoch: 120 loss:0.502559 auc:0.7496
epoch: 140 loss:0.497675 auc:0.7595
epoch: 160 loss:0.491067 auc:0.7714
epoch: 180 loss:0.487180 auc:0.7743
epoch: 200 loss:0.477853 auc:0.7638
epoch: 220 loss:0.466008 auc:0.6873
Fit finished.
epoch:   0 loss:0.737087 auc:0.4996
epoch:  20 loss:0.684977 auc:0.6869
epoch:  40 loss:0.647148 auc:0.6813
epoch:  60 loss:0.610543 auc:0.6795
epoch:  80 loss:0.550598 auc:0.7072
epoch: 100 loss:0.516584 auc:0.7364
epoch: 120 loss:0.505075 auc:0.7451
epoch: 140 loss:0.504692 auc:0.7513
epoch: 160 loss:0.491146 auc:0.7507
epoch: 180 loss:0.484688 auc:0.7551
epoch: 200 loss:0.471575 auc:0.7609
epoch: 220 loss:0.452665 auc:0.7496
epoch: 240 loss:0.430703 auc:0.7186
epoch: 260 loss:0.393032 auc:0.6479
Fit finished.
epoch:   0 loss:0.925404 auc:0.5184


epoch: 220 loss:0.474970 epoch:   0 loss:0.813333 auc:0.5549
epoch:  20 loss:0.687714 auc:0.6297
epoch:  40 loss:0.644049 auc:0.6124
epoch:  60 loss:0.604881 auc:0.6041
epoch:  80 loss:0.540350 auc:0.7243
epoch: 100 loss:0.517105 auc:0.7763
epoch: 120 loss:0.506379 auc:0.7845
epoch: 140 loss:0.500786 auc:0.7854
epoch: 160 loss:0.494941 auc:0.7859
epoch: 180 loss:0.492239 auc:0.7096
epoch: 200 loss:0.512031 auc:0.5150
Fit finished.
epoch:   0 loss:0.798004 auc:0.4415
epoch:  20 loss:0.664980 auc:0.6498
epoch:  40 loss:0.637999 auc:0.6675
epoch:  60 loss:0.585394 auc:0.6806
epoch:  80 loss:0.525014 auc:0.6859
epoch: 100 loss:0.507423 auc:0.6896
epoch: 120 loss:0.501757 auc:0.6920
epoch: 140 loss:0.495105 auc:0.6234
epoch: 160 loss:0.488027 auc:0.5136
epoch: 180 loss:0.482713 auc:0.5112
Fit finished.
epoch:   0 loss:0.732987 auc:0.5514
epoch:  20 loss:0.660060 auc:0.5657
epoch:  40 loss:0.607701 auc:0.5286
epoch:  60 loss:0.527365 auc:0.4514
epoch:  80 loss:0.514864 auc:0.4314
epoch: 100 

epoch:   0 loss:0.833727 auc:0.6059
epoch:  20 loss:0.679355 auc:0.5600
epoch:  40 loss:0.647503 auc:0.5600
epoch:  60 loss:0.610770 auc:0.5496
epoch:  80 loss:0.544877 auc:0.6193
epoch: 100 loss:0.514260 auc:0.6474
epoch: 120 loss:0.503967 auc:0.6681
epoch: 140 loss:0.497765 auc:0.6726
epoch: 160 loss:0.489459 auc:0.6830
epoch: 180 loss:0.482068 auc:0.6844
epoch: 200 loss:0.480872 auc:0.6770
epoch: 220 loss:0.464842 auc:0.6711
epoch: 240 loss:0.449902 auc:0.6874
epoch: 260 loss:0.439722 auc:0.6607
epoch: 280 loss:0.414231 auc:0.6548
Fit finished.
epoch:   0 loss:0.813548 auc:0.4125
epoch:  20 loss:0.679670 auc:0.6745
epoch:  40 loss:0.644480 auc:0.6712
epoch:  60 loss:0.602818 auc:0.6740
epoch:  80 loss:0.533432 auc:0.7205
epoch: 100 loss:0.513259 auc:0.7234
epoch: 120 loss:0.503246 auc:0.7234
epoch: 140 loss:0.495554 auc:0.7202
epoch: 160 loss:0.493856 auc:0.7130
epoch: 180 loss:0.482843 auc:0.7062
epoch: 200 loss:0.473460 auc:0.6917
epoch: 220 loss:0.492275 auc:0.6665
Fit finished.


epoch: 220 epoch:   0 loss:0.829392 auc:0.8462
epoch:  20 loss:0.682471 auc:1.0000
epoch:  40 loss:0.645224 auc:1.0000
epoch:  60 loss:0.597590 auc:1.0000
epoch:  80 loss:0.528631 auc:0.9231
epoch: 100 loss:0.512813 auc:0.8462
epoch: 120 loss:0.503516 auc:0.8462
epoch: 140 loss:0.498800 auc:0.8462
epoch: 160 loss:0.492549 auc:0.8462
epoch: 180 loss:0.486452 auc:0.8462
Fit finished.
epoch:   0 loss:0.811926 auc:0.5364
epoch:  20 loss:0.681204 auc:0.7637
epoch:  40 loss:0.651838 auc:0.7637
epoch:  60 loss:0.615918 auc:0.7568
epoch:  80 loss:0.559538 auc:0.7942
epoch: 100 loss:0.521317 auc:0.8247
epoch: 120 loss:0.506561 auc:0.8369
epoch: 140 loss:0.498778 auc:0.8451
epoch: 160 loss:0.495148 auc:0.8500
epoch: 180 loss:0.491689 auc:0.8519
epoch: 200 loss:0.488950 auc:0.8520
epoch: 220 loss:0.484940 auc:0.8521
epoch: 240 loss:0.482430 auc:0.8379
epoch: 260 loss:0.484976 auc:0.8240
Fit finished.
epoch:   0 loss:0.957609 auc:0.5318
epoch:  20 loss:0.668161 auc:0.8693
epoch:  40 loss:0.636758 

epoch:  60 epoch:   0 loss:0.840419 auc:0.7524
epoch:  20 loss:0.684159 auc:0.6444
epoch:  40 loss:0.645521 auc:0.6619
epoch:  60 loss:0.598648 auc:0.6698
epoch:  80 loss:0.527503 auc:0.7254
epoch: 100 loss:0.510398 auc:0.7270
epoch: 120 loss:0.499949 auc:0.7397
epoch: 140 loss:0.498756 auc:0.7365
epoch: 160 loss:0.492299 auc:0.7365
Fit finished.
epoch:   0 loss:0.797333 auc:0.4540
epoch:  20 loss:0.672981 auc:0.6789
epoch:  40 loss:0.638690 auc:0.7583
epoch:  60 loss:0.581577 auc:0.7748
epoch:  80 loss:0.519065 auc:0.8175
epoch: 100 loss:0.505806 auc:0.8269
epoch: 120 loss:0.497795 auc:0.8336
epoch: 140 loss:0.496421 auc:0.8375
epoch: 160 loss:0.489310 auc:0.8284
epoch: 180 loss:0.482875 auc:0.7627
epoch: 200 loss:0.482801 auc:0.5829
Fit finished.
epoch:   0 loss:0.768474 auc:0.4922
epoch:  20 loss:0.681215 auc:0.8906
epoch:  40 loss:0.649022 auc:0.9219
epoch:  60 loss:0.588612 auc:0.9609
epoch:  80 loss:0.529297 auc:0.9766
epoch: 100 loss:0.513472 auc:0.9844
epoch: 120 loss:0.505024 

epoch: 120 loss:0.503069 epoch:   0 loss:0.912405 auc:0.5509
epoch:  20 loss:0.678399 auc:0.6688
epoch:  40 loss:0.643401 auc:0.7459
epoch:  60 loss:0.601700 auc:0.7834
epoch:  80 loss:0.537385 auc:0.8800
epoch: 100 loss:0.514188 auc:0.8997
epoch: 120 loss:0.503994 auc:0.9009
epoch: 140 loss:0.497870 auc:0.9000
epoch: 160 loss:0.493807 auc:0.9020
epoch: 180 loss:0.490599 auc:0.9000
epoch: 200 loss:0.489083 auc:0.8911
epoch: 220 loss:0.481615 auc:0.8624
epoch: 240 loss:0.474408 auc:0.8161
Fit finished.
epoch:   0 loss:0.802145 auc:0.3421
epoch:  20 loss:0.685036 auc:0.9040
epoch:  40 loss:0.651704 auc:0.9156
epoch:  60 loss:0.618870 auc:0.9197
epoch:  80 loss:0.551764 auc:0.9356
epoch: 100 loss:0.516131 auc:0.9363
epoch: 120 loss:0.506395 auc:0.9377
epoch: 140 loss:0.498326 auc:0.9381
epoch: 160 loss:0.494005 auc:0.9397
epoch: 180 loss:0.485673 auc:0.9399
epoch: 200 loss:0.473425 auc:0.9303
epoch: 220 loss:0.461208 auc:0.9367
epoch: 240 loss:0.433159 auc:0.9191
Fit finished.
epoch:   0 

epoch: 260 epoch:   0 loss:0.824680 auc:0.5915
epoch:  20 loss:0.681153 auc:0.7870
epoch:  40 loss:0.649759 auc:0.8224
epoch:  60 loss:0.616902 auc:0.8394
epoch:  80 loss:0.556585 auc:0.8815
epoch: 100 loss:0.518363 auc:0.9069
epoch: 120 loss:0.506049 auc:0.9082
epoch: 140 loss:0.498996 auc:0.9082
epoch: 160 loss:0.496444 auc:0.9071
epoch: 180 loss:0.488037 auc:0.9068
epoch: 200 loss:0.481538 auc:0.8995
epoch: 220 loss:0.469101 auc:0.8966
epoch: 240 loss:0.471618 auc:0.9006
epoch: 260 loss:0.438239 auc:0.8975
Fit finished.
epoch:   0 loss:0.754811 auc:0.5528
epoch:  20 loss:0.680355 auc:0.6504
epoch:  40 loss:0.647146 auc:0.6594
epoch:  60 loss:0.601969 auc:0.6614
epoch:  80 loss:0.527861 auc:0.6598
epoch: 100 loss:0.513709 auc:0.6581
epoch: 120 loss:0.502351 auc:0.6606
epoch: 140 loss:0.495356 auc:0.6665
epoch: 160 loss:0.492842 auc:0.6678
epoch: 180 loss:0.485183 auc:0.6604
epoch: 200 loss:0.480976 auc:0.6333
Fit finished.
epoch:   0 loss:0.710140 auc:0.4596
epoch:  20 loss:0.644763 

In [10]:
# Save
true_path = f"mofgcn_true_{args.data}_{target_option}.csv"
pred_path = f"mofgcn_pred_{args.data}_{target_option}.csv"
true_data_s.to_csv(true_path, index=False)
predict_data_s.to_csv(pred_path, index=False)

print(f"\n✅ Done. Results saved to:\n  - {true_path}\n  - {pred_path}")


✅ Done. Results saved to:
  - mofgcn_true_gdsc1_drug.csv
  - mofgcn_pred_gdsc1_drug.csv
