In [1]:
import argparse
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from dataclasses import dataclass
from joblib import Parallel, delayed
from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from model import GModel
from myutils import roc_auc, translate_result, filter_target
from load_data import load_data
from optimizer import Optimizer
from sampler import NewSampler

In [3]:
@dataclass
class Args:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    data: str = "gdsc2"
    n_jobs: int = 25 # ← ここをコア数に応じて調整
    lr: float = 5e-4
    epochs: int = 1000

In [4]:
args = Args()
target_option = "drug"  # "cell" か "drug"

In [5]:
# データ読み込み
res, drug_feature, exprs, mut, cna, null_mask = load_data(args)

load gdsc2
response matrix (res) shape: (153, 654)
exprs shape: (1084, 19562)
mut shape: (1084, 18099)
cna shape: (1084, 24502)
654
exprs shape: (654, 19562)
mut shape: (654, 18099)
cna shape: (654, 24502)
drug_feature shape: (153, 920)
response matrix (res) shape: (654, 153)
null_mask shape: (654, 153)


In [6]:
def run_single_model(exprs, cna, mut, drug_feature, res_mat, null_mask, target_dim, target_index, args, seed):
    sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
    val_labels = sampler.test_data[sampler.test_mask]

    model = GModel(
        adj_mat=sampler.train_data.float(),
        gene=exprs,
        cna=cna,
        mutation=mut,
        sigma=2,
        k=11,
        iterates=3,
        feature_drug=drug_feature,
        n_hid1=192,
        n_hid2=36,
        alpha=5.74,
        device=args.device,
    )
    opt = Optimizer(
        model=model,
        train_data=sampler.train_data,
        test_data=sampler.test_data,
        test_mask=sampler.test_mask,
        train_mask=sampler.train_mask,
        evaluate_fun=roc_auc,
        lr=args.lr,
        epochs=args.epochs,
        device=args.device,
    ).to(args.device)
    _, true_data, predict_data = opt()
    return true_data, predict_data

In [7]:
# ターゲット次元と統計
target_dim = 0 if target_option == "cell" else 1
samples = res.shape[target_dim]
cell_sum = np.sum(res.values, axis=1)
drug_sum = np.sum(res.values, axis=0)

In [8]:
# 並列実行用ラッパー
def process_target(seed, target_index):
    try:
        return run_single_model(
            exprs=exprs,
            cna=cna,
            mut=mut,
            drug_feature=drug_feature,
            res_mat=res.values,
            null_mask=null_mask,
            target_dim=target_dim,
            target_index=target_index,
            args=args,
            seed=seed,
        )
    except Exception as e:
        print(f"❌ Failed at target {target_index}: {e}")
        return None

In [9]:
# 結果格納用
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()
skipped_targets = []
passed_targets = []

# スキップチェック
for target_index in range(samples):
    label_vec = res.iloc[target_index] if target_dim == 0 else res.iloc[:, target_index]
    passed, reason, pos, neg, total = filter_target(label_vec)

    if passed:
        passed_targets.append(target_index)
    else:
        skipped_targets.append((target_index, reason, pos, neg, total))

# スキップ情報表示
print(f"\n🚫 Skipped Targets: {len(skipped_targets)}")
for idx, reason, pos, neg, total in skipped_targets:
    print(f"Target {idx}: skipped because {reason} (total={total}, pos={pos}, neg={neg})")

# 並列実行（max_njobs を args.n_jobs に変更可能）
results = Parallel(n_jobs=args.n_jobs)(
    delayed(process_target)(seed, target_index)
    for seed, target_index in enumerate(tqdm(passed_targets, desc=f"MOFGCN ({args.data} - {target_option})"))
)

# 結果の統合（None を除外）
for r in results:
    if r is not None:
        true_data, pred_data = r
        true_data_s = pd.concat([true_data_s, translate_result(true_data)], ignore_index=True)
        predict_data_s = pd.concat([predict_data_s, translate_result(pred_data)], ignore_index=True)


🚫 Skipped Targets: 11
Target 29: skipped because few_total_samples (total=8, pos=5, neg=3)
Target 30: skipped because few_total_samples (total=7, pos=3, neg=4)
Target 35: skipped because few_total_samples (total=8, pos=6, neg=2)
Target 68: skipped because few_total_samples (total=5, pos=3, neg=2)
Target 70: skipped because few_total_samples (total=6, pos=3, neg=3)
Target 97: skipped because few_total_samples (total=6, pos=2, neg=4)
Target 105: skipped because few_total_samples (total=5, pos=1, neg=4)
Target 115: skipped because few_total_samples (total=5, pos=4, neg=1)
Target 117: skipped because few_total_samples (total=8, pos=3, neg=5)
Target 134: skipped because few_total_samples (total=1, pos=1, neg=0)
Target 152: skipped because few_total_samples (total=6, pos=3, neg=3)


  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)


  gene = torch.from_numpy(gene).to(device)
MOFGCN (gdsc2 - drug):  35%|███▌      | 50/142 [00:16<00:35,  2.59it/s]from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
  gene = torch.from_numpy(gene).to(device)
MOFGCN (gdsc2 - drug): 100%|██████████| 142/142 [00:43<00:00,  3.27it/s]


In [10]:
# Save
true_path = f"mofgcn_true_{args.data}_{target_option}.csv"
pred_path = f"mofgcn_pred_{args.data}_{target_option}.csv"
true_data_s.to_csv(true_path, index=False)
predict_data_s.to_csv(pred_path, index=False)

print(f"\n✅ Done. Results saved to:\n  - {true_path}\n  - {pred_path}")


✅ Done. Results saved to:
  - mofgcn_true_gdsc2_drug.csv
  - mofgcn_pred_gdsc2_drug.csv
