In [1]:
import os, json, time, csv, numpy as np, pandas as pd
from collections import Counter, defaultdict
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter

from datasets.ADNI import ADNI, ADNI_transform
from monai.data import Dataset

from sklearn.model_selection import StratifiedKFold, train_test_split
from utils.metrics import calculate_metrics

# -------------------- 配置 --------------------
def load_cfg(path):
    with open(path) as f: 
        return json.load(f)

class Cfg:
    def __init__(self, d):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for k, v in d.items(): 
            setattr(self, k, v)
# ----------------- 加载配置 -------------------
config_path = "config/config.json"
cfg = Cfg(load_cfg(config_path))
for name, val in vars(cfg).items():
    print(f"{name:15s}: {val}")
writer = SummaryWriter(cfg.checkpoint_dir)

  from .autonotebook import tqdm as notebook_tqdm


device         : cuda:1
label_file     : adni_dataset/ADNI_902.csv
mri_dir        : adni_dataset/MRI
pet_dir        : adni_dataset/PET
table_dir      : adni_dataset/ADNI_Tabel.csv
tabular_emb    : models/tabular_emb.csv
AAL_dir        : adni_dataset/AAL_space-MNI152NLin6_res-2x2x2/AAL.nii
table_startcol : 4
task           : SMCIPMCI
augment        : False
split_ratio_test: 0.2
seed           : 42
num_epochs     : 100
batch_size     : 4
lr             : 0.0001
weight_decay   : 1e-05
fp16           : True
checkpoint_dir : checkpoints
nb_class       : 2
n_splits       : 5
dropout_rate   : 0.5
in_channels    : 2
seg_task       : False
img_dim        : 512
tab_dim        : 192


In [2]:
full_dataset = ADNI(cfg.label_file, cfg.mri_dir, cfg.pet_dir, cfg.task, cfg.augment)
full_ds      = full_dataset.data_dict         # list[dict]
labels       = [d["label"] for d in full_ds]

# -------------------- 划分 --------------------

fold_indices = defaultdict(dict)
outer_cv = StratifiedKFold(
    n_splits=cfg.n_splits, shuffle=True, random_state=cfg.seed
)

for fold, (train_val_idx, test_idx) in enumerate(outer_cv.split(full_ds, labels), start=1):
    # ——— 内层 90/10 再分验证集 ———
    train_val_labels = [labels[i] for i in train_val_idx]
    idxs_inner       = np.arange(len(train_val_idx))
    train_idx_in, val_idx_in = train_test_split(
        idxs_inner, test_size=0.125, stratify=train_val_labels, random_state=cfg.seed
    )

    # ——— 映射回 full_ds 的绝对索引 ———
    train_idx = np.array(train_val_idx)[train_idx_in]
    val_idx   = np.array(train_val_idx)[val_idx_in]

    fold_indices[fold]["train_idx"] = train_idx.tolist()
    fold_indices[fold]["val_idx"]   = val_idx.tolist()
    fold_indices[fold]["test_idx"]  = test_idx.tolist()

    print(f"Fold {fold}: train {len(train_idx)}, val {len(val_idx)}, test {len(test_idx)}")

# -------------------- 保存 JSON --------------------

os.makedirs(cfg.checkpoint_dir, exist_ok=True)
json_path = os.path.join(cfg.checkpoint_dir, "fold_indices.json")
with open(json_path, "w") as f:
    json.dump({str(k): v for k, v in fold_indices.items()}, f, indent=2)
print(f"Fold indices saved to {json_path}")


[ADNI Dataset: SMCIPMCI] 样本分布：
  SMCI (0): 321
  PMCI (1): 158

Fold 1: train 335, val 48, test 96
Fold 2: train 335, val 48, test 96
Fold 3: train 335, val 48, test 96
Fold 4: train 335, val 48, test 96
Fold 5: train 336, val 48, test 95
Fold indices saved to checkpoints\fold_indices.json


In [3]:
# ------------------ 加载数据 ------------------
import json, torch, os
from torch.utils.data import DataLoader
from monai.data import Dataset

def load_fold_indices(json_path, fold):
    with open(json_path) as f:
        fold_dict = json.load(f)
    return (fold_dict[str(fold)]["train_idx"],
            fold_dict[str(fold)]["val_idx"],
            fold_dict[str(fold)]["test_idx"])

def get_dataloaders(cfg, fold, batch_size=4, num_workers=4):
    # ---- ① 读完整列表 ----
    full_ds = ADNI(cfg.label_file, cfg.mri_dir, cfg.pet_dir,
                   cfg.task, augment=False).data_dict

    # ---- ② 折分索引 ----
    train_idx, val_idx, test_idx = load_fold_indices(
        os.path.join(cfg.checkpoint_dir, "fold_indices.json"), fold)

    # ---- ③ 指定 transform ----
    tf_tr, tf_val = ADNI_transform(augment=False)   # 训增强 / 验无增强
    tf_te = tf_val                                 # 测试同 val

    # ---- ④ MONAI Dataset ----
    ds_train = Dataset([full_ds[i] for i in train_idx], transform=tf_tr)
    ds_val   = Dataset([full_ds[i] for i in val_idx],   transform=tf_val)
    ds_test  = Dataset([full_ds[i] for i in test_idx],  transform=tf_te)

    # ---- ⑤ DataLoader ----
    loader_tr = DataLoader(ds_train, batch_size=batch_size,
                           shuffle=True,  num_workers=num_workers, pin_memory=True)
    loader_val= DataLoader(ds_val,   batch_size=batch_size,
                           shuffle=False, num_workers=num_workers, pin_memory=True)
    loader_te = DataLoader(ds_test,  batch_size=batch_size,
                           shuffle=False, num_workers=num_workers, pin_memory=True)
    return loader_tr, loader_val, loader_te



In [4]:
import torch
from models.unet3d import UNet3D_Feature           # 已改成输出 64-ch 特征图的版本

# ---------- 0. 配置 ----------
fold       = 1                 # 任选一折
BATCH_SIZE = 2                 # 每批多少张影像
device     = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# ---------- 1. DataLoader ----------
loader_tr, loader_val, loader_te = get_dataloaders(
    cfg, fold, batch_size=BATCH_SIZE, num_workers=4)

# 只需要一个 batch，直接取第一批即可
batch = next(iter(loader_tr))                       # type: dict
vol   = batch["MRI"].to(device)                    # (B,1,96,112,96) 等
print("输入体积张量 :", vol.shape)

# ---------- 2. 模型 ----------
model = UNet3D_Feature(in_channels=1).to(device)
model.eval()

with torch.no_grad():
    feat_map = model(vol)                          # (B,64,96,112,96)
    feat_vec = feat_map.flatten(1)                 # (B, 64*96*112*96)

print("输出特征图 :", feat_map.shape)
print("展平向量 :",   feat_vec.shape)

# 若后续想要全局 64 维向量，可用 GAP：
gap_vec = feat_map.mean(dim=(2,3,4))               # (B,64)
print("全局均值池化后 :", gap_vec.shape)

# 可选：查看该 batch 对应的受试者 ID
print("Subject IDs:", batch["Subject"])



[ADNI Dataset: SMCIPMCI] 样本分布：
  SMCI (0): 321
  PMCI (1): 158

输入体积张量 : torch.Size([2, 1, 96, 112, 96])
输出特征图 : torch.Size([2, 64, 96, 112, 96])
展平向量 : torch.Size([2, 66060288])
全局均值池化后 : torch.Size([2, 64])
Subject IDs: ['013_S_4791', '016_S_4601']


In [None]:
import os, csv, time, torch, json, numpy as np
from torch import nn
from sklearn.metrics import roc_auc_score, confusion_matrix

# ---------------- 0. 工具：指标计算 ----------------
def calc_stats(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    acc = (tp + tn) / (tp + tn + fp + fn)
    sen = tp / (tp + fn + 1e-6)                    # Recall / Sensitivity
    spe = tn / (tn + fp + 1e-6)
    auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true))==2 else 0.5
    return {'ACC':acc, 'SEN':sen, 'SPE':spe, 'AUC':auc}

# ---------------- 1. 模型组件 ----------------
from models.unet3d   import UNet3D_Feature
from models.ROI_pol  import ROIPooling3D, ROIClassifier

unet  = UNet3D_Feature(in_channels=1).to(cfg.device).eval()   # 冻结
pool  = ROIPooling3D(cfg.AAL_dir).to(cfg.device).eval()

# ---------------- 2. 五折交叉验证 ----------------
os.makedirs(cfg.checkpoint_dir, exist_ok=True)

for fold in range(1, cfg.n_splits+1):
    print(f"\n=== Fold {fold}/{cfg.n_splits} ===")
    loader_tr, loader_val, _ = get_dataloaders(cfg, fold, batch_size=4)

    clf  = ROIClassifier().to(cfg.device)
    opt  = torch.optim.AdamW(clf.parameters(), lr=1e-4)
    crit = nn.CrossEntropyLoss()

    csv_path = os.path.join(cfg.checkpoint_dir, f"fold{fold}_metrics.csv")
    with open(csv_path, "w", newline="") as f:
        csv.writer(f).writerow(
            ["epoch","set","Loss","ACC","SEN","SPE","AUC"])

    best_auc = -np.inf
    for epoch in range(1, cfg.num_epochs+1):
        t0 = time.time()
        stats = {}
        # ------- Train -------
        clf.train()
        loss_sum, yt, yp, ys = 0, [], [], []
        for batch in loader_tr:
            mri = batch["MRI"].to(cfg.device)
            y   = batch["label"].to(cfg.device)

            with torch.no_grad():
                feat = pool(unet(mri))           # (B,94,64)
            logit = clf(feat)
            loss  = crit(logit, y)
            opt.zero_grad(); loss.backward(); opt.step()

            loss_sum += loss.item()*y.size(0)
            # prob = torch.softmax(logit,1)[:,1].cpu().numpy()
            prob = torch.softmax(logit, 1)[:, 1].detach().cpu().numpy()

            pred = logit.argmax(1).cpu().numpy()
            yt.extend(y.cpu().numpy()); yp.extend(pred); ys.extend(prob)

        stats["train"] = {"Loss":loss_sum/len(loader_tr.dataset),
                          **calc_stats(yt, yp, ys)}

        # ------- Val -------
        clf.eval(); loss_sum, yt, yp, ys = 0, [], [], []
        with torch.no_grad():
            for batch in loader_val:
                mri = batch["MRI"].to(cfg.device)
                y   = batch["label"].to(cfg.device)
                feat= pool(unet(mri))
                logit=clf(feat); loss=crit(logit,y)
                loss_sum+=loss.item()*y.size(0)
                prob = torch.softmax(logit, 1)[:, 1].detach().cpu().numpy()
#                             ^^^^^^^  先分离计算图

                pred=logit.argmax(1).cpu().numpy()
                yt.extend(y.cpu().numpy()); yp.extend(pred); ys.extend(prob)
        stats["val"]={"Loss":loss_sum/len(loader_val.dataset),
                      **calc_stats(yt,yp,ys)}

        # ------- 打印 / 记录 -------
        msg=(f"Fold{fold} Ep{epoch:03d} | "
             f"TrL={stats['train']['Loss']:.4f} "
             f"TrACC={stats['train']['ACC']:.3f} "
             f"TrSEN={stats['train']['SEN']:.3f} "
             f"TrSPE={stats['train']['SPE']:.3f} "
             f"TrAUC={stats['train']['AUC']:.3f} || "
             f"VaL={stats['val']['Loss']:.4f} "
             f"VaACC={stats['val']['ACC']:.3f} "
             f"VaSEN={stats['val']['SEN']:.3f} "
             f"VaSPE={stats['val']['SPE']:.3f} "
             f"VaAUC={stats['val']['AUC']:.3f} | "
             f"time {time.time()-t0:.1f}s")
        print(msg)

        with open(csv_path,"a",newline="") as f:
            w=csv.writer(f)
            for split in ("train","val"):
                w.writerow([epoch,split,
                            f"{stats[split]['Loss']:.4f}",
                            f"{stats[split]['ACC']:.4f}",
                            f"{stats[split]['SEN']:.4f}",
                            f"{stats[split]['SPE']:.4f}",
                            f"{stats[split]['AUC']:.4f}"])

        # ------- 保存最好 -------
        if stats["val"]["AUC"]>best_auc:
            best_auc=stats["val"]["AUC"]
            torch.save(clf.state_dict(),
                os.path.join(cfg.checkpoint_dir,f"roi_clf_fold{fold}.pth"))
            print(f"✅ new best AUC={best_auc:.3f}")

    print(f"=== Fold {fold} done, best AUC={best_auc:.3f} ===")



=== Fold 1/5 ===

[ADNI Dataset: SMCIPMCI] 样本分布：
  SMCI (0): 321
  PMCI (1): 158

Fold1 Ep001 | TrL=0.6540 TrACC=0.654 TrSEN=0.009 TrSPE=0.973 TrAUC=0.498 || VaL=0.6370 VaACC=0.667 VaSEN=0.000 VaSPE=1.000 VaAUC=0.535 | time 173.1s
✅ new best AUC=0.535
Fold1 Ep002 | TrL=0.6372 TrACC=0.669 TrSEN=0.000 TrSPE=1.000 TrAUC=0.470 || VaL=0.6363 VaACC=0.667 VaSEN=0.000 VaSPE=1.000 VaAUC=0.537 | time 174.6s
✅ new best AUC=0.537
Fold1 Ep003 | TrL=0.6383 TrACC=0.669 TrSEN=0.000 TrSPE=1.000 TrAUC=0.462 || VaL=0.6373 VaACC=0.667 VaSEN=0.000 VaSPE=1.000 VaAUC=0.537 | time 180.9s
Fold1 Ep004 | TrL=0.6378 TrACC=0.669 TrSEN=0.000 TrSPE=1.000 TrAUC=0.458 || VaL=0.6365 VaACC=0.667 VaSEN=0.000 VaSPE=1.000 VaAUC=0.539 | time 174.6s
✅ new best AUC=0.539
Fold1 Ep005 | TrL=0.6370 TrACC=0.669 TrSEN=0.000 TrSPE=1.000 TrAUC=0.474 || VaL=0.6371 VaACC=0.667 VaSEN=0.000 VaSPE=1.000 VaAUC=0.541 | time 174.3s
✅ new best AUC=0.541
Fold1 Ep006 | TrL=0.6381 TrACC=0.669 TrSEN=0.000 TrSPE=1.000 TrAUC=0.480 || VaL=0.6363 V

In [6]:
# ----------------- 超图 ---------------------
# main_train.py  --------------------------------------------------------------
import os, csv, time, json, numpy as np
import torch, torch.nn as nn
from sklearn.metrics import roc_auc_score, confusion_matrix

# ---------------- 1. 指标函数 ----------------
def calc_stats(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    acc = (tp + tn) / (tp + tn + fp + fn)
    sen = tp / (tp + fn + 1e-6)
    spe = tn / (tn + fp + 1e-6)
    auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true))==2 else 0.5
    return {'ACC':acc, 'SEN':sen, 'SPE':spe, 'AUC':auc}

# ---------------- 2. 导入模型组件 ----------------
from models.unet3d        import UNet3D_Feature
from models.ROI_pol       import ROIPooling3D
from models.HGNN    import DualModalHyperGraphWithAttn
from models.fc_classifier import FCClassifier

device = cfg.device
# a) 影像编码器
unet_mri = UNet3D_Feature(in_channels=1).to(device).eval()
unet_pet = UNet3D_Feature(in_channels=1).to(device).eval()  # 可共享权重

# b) ROI Pool
pool = ROIPooling3D(cfg.AAL_dir).to(device).eval()

# c) 超图融合模块
hg = DualModalHyperGraphWithAttn(in_dim=64, hidden_dim=128,
                                 num_layers=2, dk=64).to(device).eval()

# d) 分类头（仅该模块参与训练）
clf  = FCClassifier(in_dim=128, num_classes=2,
                    hidden_dims=[256,128,64], p_drop=0.3).to(device)

opt  = torch.optim.AdamW(clf.parameters(), lr=1e-4, weight_decay=1e-4)
crit = nn.CrossEntropyLoss()

# ---------------- 3. 数据加载器 ----------------

# 创建保存目录
os.makedirs(cfg.checkpoint_dir, exist_ok=True)

# ---------------- 4. 五折交叉验证 ----------------
for fold in range(1, cfg.n_splits+1):
    print(f"\n=== Fold {fold}/{cfg.n_splits} ===")
    loader_tr, loader_val, _ = get_dataloaders(cfg, fold, batch_size=cfg.batch_size)

    best_auc = -np.inf
    csv_path = os.path.join(cfg.checkpoint_dir, f"fold{fold}_metrics.csv")
    with open(csv_path, "w", newline="") as f:
        csv.writer(f).writerow(
            ["epoch","set","Loss","ACC","SEN","SPE","AUC"])

    for epoch in range(1, cfg.num_epochs+1):
        t0 = time.time()
        stats = {}

        # ======== Train ========
        clf.train()
        loss_sum, yt, yp, ys = 0, [], [], []
        for batch in loader_tr:
            mri = batch["MRI"].to(device)    # [B,1,D,H,W]
            pet = batch["PET"].to(device)
            y   = batch["label"].to(device)

            with torch.no_grad():
                f_mri = pool(unet_mri(mri))  # [B,94,64]
                f_pet = pool(unet_pet(pet))
                g = hg(f_mri, f_pet)         # dict
                fused = torch.cat([g["global_mod1"], g["global_mod2"]], 1)  # [B,128]

            logit = clf(fused)
            loss  = crit(logit, y)

            opt.zero_grad(); loss.backward(); opt.step()

            loss_sum += loss.item()*y.size(0)
            prob = torch.softmax(logit,1)[:,1].detach().cpu().numpy()
            pred = logit.argmax(1).cpu().numpy()
            yt.extend(y.cpu().numpy()); yp.extend(pred); ys.extend(prob)

        stats["train"] = {"Loss":loss_sum/len(loader_tr.dataset),
                          **calc_stats(yt,yp,ys)}

        # ======== Valid ========
        clf.eval(); loss_sum, yt, yp, ys = 0, [], [], []
        with torch.no_grad():
            for batch in loader_val:
                mri = batch["MRI"].to(device)
                pet = batch["PET"].to(device)
                y   = batch["label"].to(device)

                f_mri = pool(unet_mri(mri))
                f_pet = pool(unet_pet(pet))
                g = hg(f_mri, f_pet)
                fused = torch.cat([g["global_mod1"], g["global_mod2"]], 1)

                logit = clf(fused); loss = crit(logit, y)
                loss_sum += loss.item()*y.size(0)
                prob = torch.softmax(logit,1)[:,1].cpu().numpy()
                pred = logit.argmax(1).cpu().numpy()
                yt.extend(y.cpu().numpy()); yp.extend(pred); ys.extend(prob)

        stats["val"] = {"Loss":loss_sum/len(loader_val.dataset),
                        **calc_stats(yt,yp,ys)}

        # ======== 日志输出 ========
        msg = (f"Fold{fold} Ep{epoch:03d} | "
               f"TrL={stats['train']['Loss']:.4f} "
               f"TrAUC={stats['train']['AUC']:.3f} || "
               f"VaL={stats['val']['Loss']:.4f} "
               f"VaAUC={stats['val']['AUC']:.3f} "
               f"time={time.time()-t0:.1f}s")
        print(msg)

        with open(csv_path, "a", newline="") as f:
            w = csv.writer(f)
            for split in ("train","val"):
                w.writerow([epoch, split,
                            f"{stats[split]['Loss']:.4f}",
                            f"{stats[split]['ACC']:.4f}",
                            f"{stats[split]['SEN']:.4f}",
                            f"{stats[split]['SPE']:.4f}",
                            f"{stats[split]['AUC']:.4f}"])

        # ======== 保存最优 ========
        if stats["val"]["AUC"] > best_auc:
            best_auc = stats["val"]["AUC"]
            torch.save(clf.state_dict(),
                       os.path.join(cfg.checkpoint_dir,
                                    f"best_fold{fold}.pth"))
            print(f"✅ new best AUC = {best_auc:.3f}")

    print(f"=== Fold {fold} finished, best AUC = {best_auc:.3f} ===")



=== Fold 1/5 ===

[ADNI Dataset: SMCIPMCI] 样本分布：
  SMCI (0): 321
  PMCI (1): 158

Fold1 Ep001 | TrL=0.6713 TrAUC=0.527 || VaL=0.6456 VaAUC=0.516 time=324.1s
✅ new best AUC = 0.516
Fold1 Ep002 | TrL=0.6513 TrAUC=0.464 || VaL=0.6375 VaAUC=0.509 time=325.2s
Fold1 Ep003 | TrL=0.6392 TrAUC=0.509 || VaL=0.6369 VaAUC=0.606 time=325.7s
✅ new best AUC = 0.606
Fold1 Ep004 | TrL=0.6498 TrAUC=0.451 || VaL=0.6366 VaAUC=0.514 time=325.9s
Fold1 Ep005 | TrL=0.6424 TrAUC=0.484 || VaL=0.6365 VaAUC=0.516 time=326.0s
Fold1 Ep006 | TrL=0.6328 TrAUC=0.540 || VaL=0.6365 VaAUC=0.609 time=325.9s
✅ new best AUC = 0.609
Fold1 Ep007 | TrL=0.6468 TrAUC=0.470 || VaL=0.6365 VaAUC=0.484 time=326.5s
Fold1 Ep008 | TrL=0.6368 TrAUC=0.528 || VaL=0.6365 VaAUC=0.531 time=331.0s
Fold1 Ep009 | TrL=0.6317 TrAUC=0.545 || VaL=0.6368 VaAUC=0.484 time=330.5s
Fold1 Ep010 | TrL=0.6413 TrAUC=0.504 || VaL=0.6365 VaAUC=0.498 time=326.9s


KeyboardInterrupt: 

In [7]:
import os
import json
import pandas as pd

# 假设 full_ds = ADNI(...).data_dict，cfg 已定义
json_path = os.path.join(cfg.checkpoint_dir, "fold_indices.json")
csv_path  = os.path.join(cfg.table_dir)  # 请替换为你的实际文件名

# 读取完整表格
df = pd.read_csv(csv_path)

# 读取 fold 索引
with open(json_path, 'r') as f:
    fold_indices = json.load(f)

# 提取 full_ds 中的 Subject_ID 列表
all_subjects = [entry['Subject'] for entry in full_ds]

for fold_str, idxs in fold_indices.items():
    fold = int(fold_str)
    train_idx = idxs['train_idx']
    val_idx   = idxs['val_idx']
    test_idx  = idxs['test_idx']

    # 根据索引得到本折的 Subject_ID 列表
    train_subs = [ all_subjects[i] for i in train_idx ]
    val_subs   = [ all_subjects[i] for i in val_idx   ]
    test_subs  = [ all_subjects[i] for i in test_idx  ]

    # 在原始 df 中筛出对应行
    df_train = df[df['Subject_ID'].isin(train_subs)].reset_index(drop=True)
    df_val   = df[df['Subject_ID'].isin(val_subs)]  .reset_index(drop=True)
    df_test  = df[df['Subject_ID'].isin(test_subs)] .reset_index(drop=True)

    # 重排列顺序：第一列 Subject_ID，第二列 Group，后面是所有其他列
    def reorder(df_split):
        cols = df_split.columns.tolist()
        cols.remove('Subject_ID')
        if 'Group' not in cols:
            raise KeyError("'Group' 列未找到，请确认表格中有此列")
        cols.remove('Group')
        return df_split[['Subject_ID', 'Group'] + cols]

    df_train = reorder(df_train)
    df_val   = reorder(df_val)
    df_test  = reorder(df_test)

    # 保存到各自目录
    out_dir = os.path.join(cfg.checkpoint_dir, f"fold{fold}")
    os.makedirs(out_dir, exist_ok=True)

    df_train.to_csv(os.path.join(out_dir, "train.csv"), index=False)
    df_val  .to_csv(os.path.join(out_dir, "val.csv"),   index=False)
    df_test .to_csv(os.path.join(out_dir, "test.csv"),  index=False)

    print(f"Fold {fold} saved:")
    print(f"  train → {os.path.join(out_dir, 'train.csv')}")
    print(f"  val   → {os.path.join(out_dir, 'val.csv')}")
    print(f"  test  → {os.path.join(out_dir, 'test.csv')}")


Fold 1 saved:
  train → checkpoints\fold1\train.csv
  val   → checkpoints\fold1\val.csv
  test  → checkpoints\fold1\test.csv
Fold 2 saved:
  train → checkpoints\fold2\train.csv
  val   → checkpoints\fold2\val.csv
  test  → checkpoints\fold2\test.csv
Fold 3 saved:
  train → checkpoints\fold3\train.csv
  val   → checkpoints\fold3\val.csv
  test  → checkpoints\fold3\test.csv
Fold 4 saved:
  train → checkpoints\fold4\train.csv
  val   → checkpoints\fold4\val.csv
  test  → checkpoints\fold4\test.csv
Fold 5 saved:
  train → checkpoints\fold5\train.csv
  val   → checkpoints\fold5\val.csv
  test  → checkpoints\fold5\test.csv


In [8]:
from models.tabular_encoder import tabular_encoder_fold

for fold in range(1, cfg.n_splits+1):
    fold_dir = os.path.join(cfg.checkpoint_dir, f"fold{fold}")
    if cfg.task == "ADCN":
        classes = ["CN", "AD"]
    elif cfg.task == "SMCIPMCI":
        classes = ["SMCI", "PMCI"]
    else:
        raise ValueError(f"Unsupported task: {cfg.task}")
    tabular_encoder_fold(
        fold_dir    = fold_dir,
        label_col   = "Group",
        classes     = classes,
        start_col   = 3,
        device      = cfg.device,
        n_fold      = 5,
        dropna      = False
    )
    



▶ Using device: cuda:0
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((336, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((95, 194))


In [None]:
# -------------------------------------------------
# 1. 建立 Subject ➜ 影像样本字典，便于快速对齐
# -------------------------------------------------
subject_map = {d["Subject"]: d for d in full_ds}

# -------------------------------------------------
# 2. 为五折构造 train/val/test DataLoader
# -------------------------------------------------
import pandas as pd
from monai.data import Dataset
from torch.utils.data import DataLoader

def load_emb_csv(path):
    """CSV -> {sid: (label:int, emb:numpy[192])}"""
    df = pd.read_csv(path)
    sid   = df.iloc[:, 0].astype(str).tolist()
    label = df.iloc[:, 1].astype(int).tolist()
    emb   = df.iloc[:, 2:].astype("float32").values
    return {s: (l, e) for s, l, e in zip(sid, label, emb)}

tr_tf, vl_tf = ADNI_transform(augment=cfg.augment)
te_tf        = vl_tf

fold_loaders = []

for fold in range(1, cfg.n_splits + 1):
    fold_dir  = os.path.join(cfg.checkpoint_dir, f"fold{fold}")
    paths     = {sp: os.path.join(fold_dir, f"{sp}_emb.csv")
                 for sp in ["train", "val", "test"]}

    # 解析 CSV
    emb_maps  = {sp: load_emb_csv(p) for sp, p in paths.items()}

    split_ds  = {}
    for sp, emb_map in emb_maps.items():
        samples = []
        for sid, (lbl, emb) in emb_map.items():
            if sid not in subject_map:
                raise KeyError(f"{sid} not found in ADNI dataset")
            s = subject_map[sid].copy()     # MRI / PET / label / Subject
            s["label"] = lbl                # 以 CSV 为准
            s["table"] = emb                # 192‑d numpy
            samples.append(s)
        split_ds[sp] = samples

    # DataLoader
    dl_kw = dict(batch_size=cfg.batch_size, pin_memory=True)
    fold_loaders.append({
        "fold": fold,
        "train_loader": DataLoader(Dataset(split_ds["train"], tr_tf),
                                   shuffle=True,  num_workers=4, **dl_kw),
        "val_loader":   DataLoader(Dataset(split_ds["val"],   vl_tf),
                                   shuffle=False, num_workers=2, **dl_kw),
        "test_loader":  DataLoader(Dataset(split_ds["test"],  te_tf),
                                   shuffle=False, num_workers=2, **dl_kw),
    })

    print(f"Fold {fold}  ➜  train {len(split_ds['train'])} | "
          f"val {len(split_ds['val'])} | test {len(split_ds['test'])}")

# 现在 fold_loaders 就和之前影像-only 版本一模一样可直接用于训练：
# for fold_dict in fold_loaders:
#     tr_loader = fold_dict["train_loader"]
#     ...


In [None]:
# ----------------- 创建模型 -------------------
from sklearn.metrics import roc_curve
from models.mmad_encoder import ImageEncoder, ImageEncoder_CEN  # 根据你的命名调整
from models.mmad_encoder import MultiModalClassifier
def generate_image_model(cfg):
    # 使用 CEN 版本的编码器 + 分类头
    model = ImageEncoder_CEN(
        in_ch_modality   = 1,
        level_channels   = [64, 128, 256],
        bottleneck_ch    = 512,
        share_layers     = 2,
        cen_ratios       = (0.20, 0.10),
    ).to(cfg.device)

    # 参数统计
    total_params     = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    bytes_per_param  = 2 if getattr(cfg, 'fp16', False) else 4

    print("-------------------- model --------------------")
    print(f"Total params(M)    : {total_params:,}")
    print(f"Trainable params(M): {trainable_params:,}")
    print(f"Approx. size       : {total_params * bytes_per_param / 1024**2:.2f} MB")
    print("Model type:", type(model).__name__)

    return model

def generate_mm_classifier(cfg):
    model = MultiModalClassifier(
        img_dim=1024,
        tab_dim=192,
        num_classes=2
    ).to(cfg.device)
    # 参数统计
    total_params     = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    bytes_per_param  = 2 if getattr(cfg, 'fp16', False) else 4

    print("-------------------- model --------------------")
    print(f"Total params(M)    : {total_params:,}")
    print(f"Trainable params(M): {trainable_params:,}")
    print(f"Approx. size       : {total_params * bytes_per_param / 1024**2:.2f} MB")
    print("Model type:", type(model).__name__)

    return model

model = generate_image_model(cfg)
model = generate_mm_classifier(cfg)
print(model)

# -------------------- 测试 -----------------------------
from matplotlib import pyplot as plt

def load_test_data(cfg, fold):
    full_ds = ADNI(
        cfg.label_file,
        cfg.mri_dir,
        cfg.pet_dir,
        cfg.task,
        cfg.augment
    ).data_dict

    idx_path = os.path.join(cfg.checkpoint_dir, "fold_indices.json")
    with open(idx_path, "r") as f:
        all_indices = json.load(f)

    test_idx = all_indices[str(fold)]["test_idx"]
    test_data = [full_ds[i] for i in test_idx]
    return test_data

def test_models(checkpoint_dir: str, test_data: list, fold: int):
    """
    对 Dual-Stream UNet3D + 多模态分类器 进行单折测试
    返回:
        metrics : dict   — calculate_metrics 输出
        y_prob  : list   — 正类概率
        y_true  : list   — 真实标签
        y_pred  : list   — 0/1 预测标签
    """
    device = cfg.device

    # ---------- DataLoader ----------
    _, test_tf = ADNI_transform(augment=False)
    from monai.data import Dataset
    from torch.utils.data import DataLoader
    ds = Dataset(data=test_data, transform=test_tf)
    loader = DataLoader(ds,
                        batch_size=cfg.batch_size,
                        shuffle=False,
                        num_workers=2,
                        pin_memory=True)

    # ---------- 构造模型结构 ----------
    img_encoder = generate_image_model(cfg)      # Dual-Stream UNet3D encoder
    clf_model   = generate_mm_classifier(cfg)    # 影像特征 + 表格分类头

    # ---------- 加载最佳权重 ----------
    ckpt_path = os.path.join(checkpoint_dir, f"best_model_fold{fold}.pth")
    assert os.path.exists(ckpt_path), f"❌ 找不到 {ckpt_path}"
    ckpt = torch.load(ckpt_path, map_location=device)
    img_encoder.load_state_dict(ckpt["img_encoder"])
    clf_model.load_state_dict(ckpt["clf_model"])
    img_encoder.to(device).eval()
    clf_model.to(device).eval()
    print(f"✅ Loaded checkpoint from {ckpt_path}")

    # ---------- 推理 ----------
    y_true, y_prob = [], []
    with torch.no_grad():
        for batch in loader:
            mri  = batch["MRI"].to(device)
            pet  = batch["PET"].to(device)

            # 若测试样本缺少表格特征，则用 0 向量占位
            if "table" in batch:
                table = batch["table"].to(device).float()
            else:
                table = torch.zeros(
                    (mri.size(0), cfg.table_dim), device=device, dtype=torch.float32
                )

            label = batch["label"].to(device).long()

            with autocast(device_type="cuda", enabled=getattr(cfg, "fp16", False)):
                img_feat = img_encoder(mri, pet)        # [B, feat_dim]
                out      = clf_model(img_feat, table)   # [B, num_cls]

            probs = torch.softmax(out, dim=1)[:, 1].cpu().numpy()
            y_prob.extend(probs)
            y_true.extend(label.cpu().numpy())

    # ---------- 评估 ----------
    y_pred  = (np.array(y_prob) > 0.5).astype(int)
    metrics = calculate_metrics(y_true, y_pred, y_prob)

    # ---------- ROC 曲线 ----------
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, lw=2, label=f"AUC={metrics['AUC']:.3f}")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC - Fold {fold}")
    plt.legend()
    roc_path = os.path.join(checkpoint_dir, f"roc_fold{fold}.png")
    plt.savefig(roc_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"✅ ROC curve saved to {roc_path}")

    return metrics, y_prob, y_true, y_pred


In [None]:
# 读取配置
config_path = rf"config\config2.json"
cfg = Cfg(load_cfg(config_path))

# ----------- 统计模型参数（影像编码器 + 多模态分类器）-----------
temp_img_enc = generate_image_model(cfg)     # Dual-Stream UNet3D Encoder
temp_clf     = generate_mm_classifier(cfg)   # 表格 + 影像分类头

total_params     = sum(p.numel() for p in (*temp_img_enc.parameters(),
                                           *temp_clf.parameters()))
trainable_params = sum(p.numel() for p in (*temp_img_enc.parameters(),
                                           *temp_clf.parameters()) if p.requires_grad)
bytes_per_param  = 2 if getattr(cfg, 'fp16', False) else 4
approx_size_mb   = total_params * bytes_per_param / 1024 ** 2
del temp_img_enc, temp_clf

#------------- 文件准备 -------------
all_metrics = []
all_probs   = []
all_labels  = []

ckpt_dir = cfg.checkpoint_dir
os.makedirs(ckpt_dir, exist_ok=True)

results_txt = os.path.join(ckpt_dir, "test_results.txt")
result_csv  = os.path.join(ckpt_dir, "result.csv")

# TXT：模型参数 + 表头
with open(results_txt, "w") as f:
    f.write("===== MODEL PARAMETERS =====\n")
    f.write(f"Total params       : {total_params}\n")
    f.write(f"Trainable params   : {trainable_params}\n")
    f.write(f"Approx. size (MB)  : {approx_size_mb:.2f}\n\n")
    f.write("===== FOLD RESULTS =====\n")
    f.write("Fold\tACC\tPRE\tSEN\tSPE\tF1\tAUC\tMCC\n")

# CSV：表头
with open(result_csv, "w", newline="") as csv_f:
    writer = csv.writer(csv_f)
    writer.writerow([
        "fold", "idx_in_fold", "sample_id",
        "true_label", "pred_label", "correct"
    ])

#------------- 逐折测试 -------------
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score, matthews_corrcoef,
                             confusion_matrix, roc_curve, auc)   # ← 已经 import confusion_matrix

for fold in range(1, cfg.n_splits + 1):
    print(f"\n=== Testing Fold {fold}/{cfg.n_splits} ===")

    # 直接复用训练阶段构好的 test_loader
    te_loader  = fold_loaders[fold - 1]["test_loader"]
    test_data  = te_loader.dataset.data

    metrics, probs, labels, preds = test_models(
        ckpt_dir, test_data, fold
    )

    # ---------- ★ 新增：计算混淆矩阵 ----------
    cm = confusion_matrix(labels, preds)                   ### NEW

    # Console 输出
    print(
        f"Fold {fold} - "
        f"ACC={metrics['ACC']:.4f}, PRE={metrics['PRE']:.4f}, "
        f"SEN={metrics['SEN']:.4f}, SPE={metrics['SPE']:.4f}, "
        f"F1={metrics['F1']:.4f}, AUC={metrics['AUC']:.4f}, "
        f"MCC={metrics['MCC']:.4f}"
    )
    print(f"Confusion Matrix:\n{cm}")                       ### NEW

    # ---------- TXT 写入 ----------
    with open(results_txt, "a") as f:
        # ① 先写指标行（保持不变）
        f.write(
            f"{fold}\t"
            f"{metrics['ACC']:.4f}\t{metrics['PRE']:.4f}\t"
            f"{metrics['SEN']:.4f}\t{metrics['SPE']:.4f}\t"
            f"{metrics['F1']:.4f}\t{metrics['AUC']:.4f}\t"
            f"{metrics['MCC']:.4f}\n"
        )

        # ② 再写混淆矩阵行（NEW）
        TN, FP, FN, TP = cm.ravel()               # cm 是 2×2 矩阵
        f.write(f"CM\t{TN}\t{FP}\t{FN}\t{TP}\n")   # 与指标行分开写


    # CSV：样本级结果
    with open(result_csv, "a", newline="") as csv_f:
        writer = csv.writer(csv_f)
        for idx, (sample_dict, y_t, y_p) in enumerate(
                zip(test_data, labels, preds)):
            sample_id = (
                sample_dict.get("subject")
                or os.path.basename(sample_dict.get("MRI", f"s{idx}"))
            )
            writer.writerow([
                fold, idx, sample_id,
                int(y_t), int(y_p), int(y_t == y_p)
            ])

    # 汇总
    all_metrics.append(metrics)
    all_probs.extend(probs)
    all_labels.extend(labels)


#------------- 平均 ROC -------------
mean_fpr = np.linspace(0, 1, 100)
fpr, tpr, _ = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)
interp_tpr = np.interp(mean_fpr, fpr, tpr)
plt.plot(mean_fpr, interp_tpr, 'b-', lw=2,
         label=f'Mean ROC (AUC={roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.legend(loc='lower right')
plt.savefig(os.path.join(ckpt_dir, 'mean_test_roc.png'),
            dpi=300, bbox_inches='tight')
plt.close()

#------------- 汇总指标 -------------
print("\n=== Final Test Results (mean ± std) ===")
summary_lines = []
for k in ['ACC', 'PRE', 'SEN', 'SPE', 'F1', 'AUC', 'MCC']:
    vals = [m[k] for m in all_metrics]
    mean_val = np.mean(vals)
    std_val  = np.std(vals)
    line = f"{k}: {mean_val:.4f} ± {std_val:.4f}"
    print(line)
    summary_lines.append(line)

with open(results_txt, "a") as f:
    f.write("\n===== SUMMARY =====\n")
    for line in summary_lines:
        f.write(line + "\n")
