In [7]:
import os, json, time, csv, numpy as np, pandas as pd
from collections import Counter, defaultdict
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter

from datasets.ADNI import ADNI, ADNI_transform
from monai.data import Dataset

from sklearn.model_selection import StratifiedKFold, train_test_split
from utils.metrics import calculate_metrics


In [8]:
# -------------------- 配置 --------------------
def load_cfg(path):
    with open(path) as f: 
        return json.load(f)

class Cfg:
    def __init__(self, d):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for k, v in d.items(): 
            setattr(self, k, v)
# ----------------- 加载配置 -------------------
config_path = "config/config.json"
cfg = Cfg(load_cfg(config_path))
for name, val in vars(cfg).items():
    print(f"{name:15s}: {val}")
writer = SummaryWriter(cfg.checkpoint_dir)

device         : cuda:1
label_file     : adni_dataset/ADNI_902.csv
mri_dir        : adni_dataset/MRI
pet_dir        : adni_dataset/PET
table_dir      : adni_dataset/ADNI_Tabel.csv
tabular_emb    : models/tabular_emb.csv
table_startcol : 4
task           : SMCIPMCI
augment        : False
split_ratio_test: 0.2
seed           : 42
num_epochs     : 100
batch_size     : 6
lr             : 1e-06
weight_decay   : 1e-05
fp16           : True
checkpoint_dir : checkpoints_mmad-mci
nb_class       : 2
n_splits       : 5
dropout_rate   : 0.5
in_channels    : 2
seg_task       : False
img_dim        : 512
tab_dim        : 192


In [9]:
full_dataset = ADNI(cfg.label_file, cfg.mri_dir, cfg.pet_dir, cfg.task, cfg.augment)
full_ds      = full_dataset.data_dict         # list[dict]
labels       = [d["label"] for d in full_ds]

# -------------------- 划分 --------------------

fold_indices = defaultdict(dict)
outer_cv = StratifiedKFold(
    n_splits=cfg.n_splits, shuffle=True, random_state=cfg.seed
)

for fold, (train_val_idx, test_idx) in enumerate(outer_cv.split(full_ds, labels), start=1):
    # ——— 内层 90/10 再分验证集 ———
    train_val_labels = [labels[i] for i in train_val_idx]
    idxs_inner       = np.arange(len(train_val_idx))
    train_idx_in, val_idx_in = train_test_split(
        idxs_inner, test_size=0.125, stratify=train_val_labels, random_state=cfg.seed
    )

    # ——— 映射回 full_ds 的绝对索引 ———
    train_idx = np.array(train_val_idx)[train_idx_in]
    val_idx   = np.array(train_val_idx)[val_idx_in]

    fold_indices[fold]["train_idx"] = train_idx.tolist()
    fold_indices[fold]["val_idx"]   = val_idx.tolist()
    fold_indices[fold]["test_idx"]  = test_idx.tolist()

    print(f"Fold {fold}: train {len(train_idx)}, val {len(val_idx)}, test {len(test_idx)}")

# -------------------- 保存 JSON --------------------

os.makedirs(cfg.checkpoint_dir, exist_ok=True)
json_path = os.path.join(cfg.checkpoint_dir, "fold_indices.json")
with open(json_path, "w") as f:
    json.dump({str(k): v for k, v in fold_indices.items()}, f, indent=2)
print(f"Fold indices saved to {json_path}")


[ADNI Dataset: SMCIPMCI] 样本分布：
  SMCI (0): 321
  PMCI (1): 158

Fold 1: train 335, val 48, test 96
Fold 2: train 335, val 48, test 96
Fold 3: train 335, val 48, test 96
Fold 4: train 335, val 48, test 96
Fold 5: train 336, val 48, test 95
Fold indices saved to checkpoints_mmad-mci\fold_indices.json


In [10]:
import os
import json
import pandas as pd

# 假设 full_ds = ADNI(...).data_dict，cfg 已定义
json_path = os.path.join(cfg.checkpoint_dir, "fold_indices.json")
csv_path  = os.path.join(cfg.table_dir)  # 请替换为你的实际文件名

# 读取完整表格
df = pd.read_csv(csv_path)

# 读取 fold 索引
with open(json_path, 'r') as f:
    fold_indices = json.load(f)

# 提取 full_ds 中的 Subject_ID 列表
all_subjects = [entry['Subject'] for entry in full_ds]

for fold_str, idxs in fold_indices.items():
    fold = int(fold_str)
    train_idx = idxs['train_idx']
    val_idx   = idxs['val_idx']
    test_idx  = idxs['test_idx']

    # 根据索引得到本折的 Subject_ID 列表
    train_subs = [ all_subjects[i] for i in train_idx ]
    val_subs   = [ all_subjects[i] for i in val_idx   ]
    test_subs  = [ all_subjects[i] for i in test_idx  ]

    # 在原始 df 中筛出对应行
    df_train = df[df['Subject_ID'].isin(train_subs)].reset_index(drop=True)
    df_val   = df[df['Subject_ID'].isin(val_subs)]  .reset_index(drop=True)
    df_test  = df[df['Subject_ID'].isin(test_subs)] .reset_index(drop=True)

    # 重排列顺序：第一列 Subject_ID，第二列 Group，后面是所有其他列
    def reorder(df_split):
        cols = df_split.columns.tolist()
        cols.remove('Subject_ID')
        if 'Group' not in cols:
            raise KeyError("'Group' 列未找到，请确认表格中有此列")
        cols.remove('Group')
        return df_split[['Subject_ID', 'Group'] + cols]

    df_train = reorder(df_train)
    df_val   = reorder(df_val)
    df_test  = reorder(df_test)

    # 保存到各自目录
    out_dir = os.path.join(cfg.checkpoint_dir, f"fold{fold}")
    os.makedirs(out_dir, exist_ok=True)

    df_train.to_csv(os.path.join(out_dir, "train.csv"), index=False)
    df_val  .to_csv(os.path.join(out_dir, "val.csv"),   index=False)
    df_test .to_csv(os.path.join(out_dir, "test.csv"),  index=False)

    print(f"Fold {fold} saved:")
    print(f"  train → {os.path.join(out_dir, 'train.csv')}")
    print(f"  val   → {os.path.join(out_dir, 'val.csv')}")
    print(f"  test  → {os.path.join(out_dir, 'test.csv')}")


Fold 1 saved:
  train → checkpoints_mmad-mci\fold1\train.csv
  val   → checkpoints_mmad-mci\fold1\val.csv
  test  → checkpoints_mmad-mci\fold1\test.csv
Fold 2 saved:
  train → checkpoints_mmad-mci\fold2\train.csv
  val   → checkpoints_mmad-mci\fold2\val.csv
  test  → checkpoints_mmad-mci\fold2\test.csv
Fold 3 saved:
  train → checkpoints_mmad-mci\fold3\train.csv
  val   → checkpoints_mmad-mci\fold3\val.csv
  test  → checkpoints_mmad-mci\fold3\test.csv
Fold 4 saved:
  train → checkpoints_mmad-mci\fold4\train.csv
  val   → checkpoints_mmad-mci\fold4\val.csv
  test  → checkpoints_mmad-mci\fold4\test.csv
Fold 5 saved:
  train → checkpoints_mmad-mci\fold5\train.csv
  val   → checkpoints_mmad-mci\fold5\val.csv
  test  → checkpoints_mmad-mci\fold5\test.csv


In [11]:
from models.tabular_encoder import tabular_encoder_fold

for fold in range(1, cfg.n_splits+1):
    fold_dir = os.path.join(cfg.checkpoint_dir, f"fold{fold}")
    if cfg.task == "ADCN":
        classes = ["CN", "AD"]
    elif cfg.task == "SMCIPMCI":
        classes = ["SMCI", "PMCI"]
    else:
        raise ValueError(f"Unsupported task: {cfg.task}")
    tabular_encoder_fold(
        fold_dir    = fold_dir,
        label_col   = "Group",
        classes     = classes,
        start_col   = 3,
        device      = cfg.device,
        n_fold      = 5,
        dropna      = False
    )
    



▶ Using device: cuda:0
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((335, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((96, 194))
✓ Saved train_emb.csv ((336, 194))
✓ Saved val_emb.csv ((48, 194))
✓ Saved test_emb.csv ((95, 194))


In [12]:
# -------------------------------------------------
# 1. 建立 Subject ➜ 影像样本字典，便于快速对齐
# -------------------------------------------------
subject_map = {d["Subject"]: d for d in full_ds}

# -------------------------------------------------
# 2. 为五折构造 train/val/test DataLoader
# -------------------------------------------------
import pandas as pd
from monai.data import Dataset
from torch.utils.data import DataLoader

def load_emb_csv(path):
    """CSV -> {sid: (label:int, emb:numpy[192])}"""
    df = pd.read_csv(path)
    sid   = df.iloc[:, 0].astype(str).tolist()
    label = df.iloc[:, 1].astype(int).tolist()
    emb   = df.iloc[:, 2:].astype("float32").values
    return {s: (l, e) for s, l, e in zip(sid, label, emb)}

tr_tf, vl_tf = ADNI_transform(augment=cfg.augment)
te_tf        = vl_tf

fold_loaders = []

for fold in range(1, cfg.n_splits + 1):
    fold_dir  = os.path.join(cfg.checkpoint_dir, f"fold{fold}")
    paths     = {sp: os.path.join(fold_dir, f"{sp}_emb.csv")
                 for sp in ["train", "val", "test"]}

    # 解析 CSV
    emb_maps  = {sp: load_emb_csv(p) for sp, p in paths.items()}

    split_ds  = {}
    for sp, emb_map in emb_maps.items():
        samples = []
        for sid, (lbl, emb) in emb_map.items():
            if sid not in subject_map:
                raise KeyError(f"{sid} not found in ADNI dataset")
            s = subject_map[sid].copy()     # MRI / PET / label / Subject
            s["label"] = lbl                # 以 CSV 为准
            s["table"] = emb                # 192‑d numpy
            samples.append(s)
        split_ds[sp] = samples

    # DataLoader
    dl_kw = dict(batch_size=cfg.batch_size, pin_memory=True)
    fold_loaders.append({
        "fold": fold,
        "train_loader": DataLoader(Dataset(split_ds["train"], tr_tf),
                                   shuffle=True,  num_workers=4, **dl_kw),
        "val_loader":   DataLoader(Dataset(split_ds["val"],   vl_tf),
                                   shuffle=False, num_workers=2, **dl_kw),
        "test_loader":  DataLoader(Dataset(split_ds["test"],  te_tf),
                                   shuffle=False, num_workers=2, **dl_kw),
    })

    print(f"Fold {fold}  ➜  train {len(split_ds['train'])} | "
          f"val {len(split_ds['val'])} | test {len(split_ds['test'])}")

# 现在 fold_loaders 就和之前影像-only 版本一模一样可直接用于训练：
# for fold_dict in fold_loaders:
#     tr_loader = fold_dict["train_loader"]
#     ...


Fold 1  ➜  train 335 | val 48 | test 96
Fold 2  ➜  train 335 | val 48 | test 96
Fold 3  ➜  train 335 | val 48 | test 96
Fold 4  ➜  train 335 | val 48 | test 96
Fold 5  ➜  train 336 | val 48 | test 95


In [13]:
def inspect_fold_loaders(fold_loaders):
    for fd in fold_loaders:
        fold = fd["fold"]
        tl, vl, te = fd["train_loader"], fd["val_loader"], fd["test_loader"]

        print(f"\n=== Fold {fold} ===")
        print(f"  ▸ train_loader ─ {len(tl.dataset):4d} samples "
              f"· {len(tl):3d} batches (batch={tl.batch_size})")
        print(f"  ▸ val_loader   ─ {len(vl.dataset):4d} samples "
              f"· {len(vl):3d} batches")
        print(f"  ▸ test_loader  ─ {len(te.dataset):4d} samples "
              f"· {len(te):3d} batches")

        # -------- 取一个 batch 测 shape --------
        batch = next(iter(tl))
        keys  = list(batch.keys())
        print("  --> keys:", keys)

        mri, pet, tab, y = (batch["MRI"], batch["PET"],
                            batch["table"], batch["label"])
        print(f"      MRI   : {tuple(mri.shape)}  {mri.dtype}")
        print(f"      PET   : {tuple(pet.shape)}  {pet.dtype}")
        print(f"      table : {tuple(tab.shape)} {tab.dtype}  "
              f"(should be [B, 192])")
        print(f"      label : {tuple(y.shape)}   {y.dtype}")

        # -------- 快速看标签分布 --------
        uniq, cnt = torch.unique(y, return_counts=True)
        dist = {int(k): int(v) for k, v in zip(uniq, cnt)}
        print("      label counts:", dist)

# 调用
inspect_fold_loaders(fold_loaders)


=== Fold 1 ===
  ▸ train_loader ─  335 samples ·  56 batches (batch=6)
  ▸ val_loader   ─   48 samples ·   8 batches
  ▸ test_loader  ─   96 samples ·  16 batches
  --> keys: ['MRI', 'PET', 'label', 'Subject', 'table']
      MRI   : (6, 1, 91, 109, 91)  torch.float32
      PET   : (6, 1, 91, 109, 91)  torch.float32
      table : (6, 192) torch.float32  (should be [B, 192])
      label : (6,)   torch.int64
      label counts: {0: 3, 1: 3}

=== Fold 2 ===
  ▸ train_loader ─  335 samples ·  56 batches (batch=6)
  ▸ val_loader   ─   48 samples ·   8 batches
  ▸ test_loader  ─   96 samples ·  16 batches
  --> keys: ['MRI', 'PET', 'label', 'Subject', 'table']
      MRI   : (6, 1, 91, 109, 91)  torch.float32
      PET   : (6, 1, 91, 109, 91)  torch.float32
      table : (6, 192) torch.float32  (should be [B, 192])
      label : (6,)   torch.int64
      label counts: {0: 5, 1: 1}

=== Fold 3 ===
  ▸ train_loader ─  335 samples ·  56 batches (batch=6)
  ▸ val_loader   ─   48 samples ·   8 batc

In [14]:
# ----------------- 创建模型 -------------------
from models.mmad_encoder import ImageEncoder, ImageEncoder_CEN  # 根据你的命名调整
from models.mmad_encoder import MultiModalClassifier
def generate_image_model(cfg):
    # 使用 CEN 版本的编码器 + 分类头
    model = ImageEncoder_CEN(
        in_ch_modality   = 1,
        level_channels   = [64, 128, 256],
        bottleneck_ch    = 512,
        share_layers     = 2,
        cen_ratios       = (0.20, 0.10),
    ).to(cfg.device)

    # 参数统计
    total_params     = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    bytes_per_param  = 2 if getattr(cfg, 'fp16', False) else 4

    print("-------------------- model --------------------")
    print(f"Total params(M)    : {total_params:,}")
    print(f"Trainable params(M): {trainable_params:,}")
    print(f"Approx. size       : {total_params * bytes_per_param / 1024**2:.2f} MB")
    print("Model type:", type(model).__name__)

    return model

def generate_mm_classifier(cfg):
    model = MultiModalClassifier(
        img_dim=1024,
        tab_dim=192,
        num_classes=2
    ).to(cfg.device)
    # 参数统计
    total_params     = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    bytes_per_param  = 2 if getattr(cfg, 'fp16', False) else 4

    print("-------------------- model --------------------")
    print(f"Total params(M)    : {total_params:,}")
    print(f"Trainable params(M): {trainable_params:,}")
    print(f"Approx. size       : {total_params * bytes_per_param / 1024**2:.2f} MB")
    print("Model type:", type(model).__name__)

    return model

model = generate_image_model(cfg)
model = generate_mm_classifier(cfg)
print(model)

-------------------- model --------------------
Total params(M)    : 13,667,328
Trainable params(M): 13,667,328
Approx. size       : 26.07 MB
Model type: ImageEncoder_CEN
-------------------- model --------------------
Total params(M)    : 312,066
Trainable params(M): 312,066
Approx. size       : 0.60 MB
Model type: MultiModalClassifier
MultiModalClassifier(
  (fc): Sequential(
    (0): Linear(in_features=1216, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=2, bias=True)
  )
)


In [None]:
# ----------------- 五折训练主循环 -----------------
os.makedirs(cfg.checkpoint_dir, exist_ok=True)

for fold_idx in range(cfg.n_splits):
    fold = fold_idx + 1
    print(f"\n=== Fold {fold}/{cfg.n_splits} ===")

    # ---- 模型：影像编码器 + 多模态分类器 ---- #
    img_encoder = generate_image_model(cfg).to(cfg.device)
    clf_model   = generate_mm_classifier(cfg).to(cfg.device)

    params = list(img_encoder.parameters()) + list(clf_model.parameters())
    optimizer = torch.optim.AdamW(
        params, lr=cfg.lr, weight_decay=getattr(cfg, 'weight_decay', 0)
    )
    scheduler = CosineAnnealingLR(optimizer, T_max=cfg.num_epochs)
    scaler    = GradScaler(enabled=getattr(cfg, 'fp16', False))

    # ---- DataLoader ---- #
    tr_loader = fold_loaders[fold_idx]['train_loader']
    vl_loader = fold_loaders[fold_idx]['val_loader']

    criterion = nn.CrossEntropyLoss()

    # ---- CSV ---- #
    csv_path = os.path.join(cfg.checkpoint_dir, f"metrics_fold{fold}.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_Loss","train_ACC","train_PRE","train_SEN","train_SPE","train_F1","train_AUC","train_MCC",
            "val_Loss","val_ACC","val_PRE","val_SEN","val_SPE","val_F1","val_AUC","val_MCC",
        ])

    best_auc = -np.inf

    # -------------- Epoch 循环 --------------
    for epoch in range(1, cfg.num_epochs + 1):
        t0 = time.time()

        # -------- Train --------
        img_encoder.train(); clf_model.train()
        tr_loss_sum, tr_batches = 0.0, 0
        yt, yp, ys = [], [], []

        for batch in tr_loader:
            mri   = batch['MRI'].to(cfg.device)
            pet   = batch['PET'].to(cfg.device)
            table = batch['table'].to(cfg.device).float()
            y     = batch['label'].to(cfg.device).long()

            optimizer.zero_grad()
            with autocast(device_type='cuda', enabled=getattr(cfg, 'fp16', False)):
                img_feat = img_encoder(mri, pet)          # [B, img_dim]
                out      = clf_model(img_feat, table)     # [B, num_cls]
                loss     = criterion(out, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            tr_loss_sum += loss.item(); tr_batches += 1
            prob = torch.softmax(out, dim=1)[:, 1].detach().cpu().numpy()
            pred = out.argmax(1).detach().cpu().numpy()
            yt.extend(y.cpu().numpy()); yp.extend(pred); ys.extend(prob)

        tr_met  = calculate_metrics(yt, yp, ys)
        tr_loss = tr_loss_sum / tr_batches

        # -------- Validation --------
        img_encoder.eval(); clf_model.eval()
        vl_loss_sum, vl_batches = 0.0, 0
        yt, yp, ys = [], [], []
        with torch.no_grad():
            for batch in vl_loader:
                mri   = batch['MRI'].to(cfg.device)
                pet   = batch['PET'].to(cfg.device)
                table = batch['table'].to(cfg.device).float()
                y     = batch['label'].to(cfg.device).long()

                with autocast(device_type='cuda', enabled=getattr(cfg, 'fp16', False)):
                    img_feat = img_encoder(mri, pet)
                    out      = clf_model(img_feat, table)
                    loss     = criterion(out, y)

                vl_loss_sum += loss.item(); vl_batches += 1
                prob = torch.softmax(out, dim=1)[:, 1].cpu().numpy()
                pred = out.argmax(1).cpu().numpy()
                yt.extend(y.cpu().numpy()); yp.extend(pred); ys.extend(prob)

        vl_met  = calculate_metrics(yt, yp, ys)
        vl_loss = vl_loss_sum / vl_batches
        scheduler.step()

        print(f"Fold {fold} | Epoch {epoch:03d} | Train Loss={tr_loss:.4f} | Val Loss={vl_loss:.4f} | "
              f"Train ACC={tr_met['ACC']:.4f} | Val ACC={vl_met['ACC']:.4f} | Train AUC={tr_met['AUC']:.4f} | Val AUC={vl_met['AUC']:.4f} | "
              f"time={time.time()-t0:.1f}s")

        # -------- Save best model --------
        if vl_met['AUC'] > best_auc:
            best_auc = vl_met['AUC']
            torch.save({
                'img_encoder': img_encoder.state_dict(),
                'clf_model'  : clf_model.state_dict()
            }, os.path.join(cfg.checkpoint_dir, f"best_model_fold{fold}.pth"))
            print(f"✅ Fold {fold} saved best model (AUC={best_auc:.4f})")

        # -------- CSV log --------
        with open(csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                epoch,
                f"{tr_loss:.4f}", f"{tr_met['ACC']:.4f}", f"{tr_met['PRE']:.4f}",
                f"{tr_met['SEN']:.4f}", f"{tr_met['SPE']:.4f}", f"{tr_met['F1']:.4f}", f"{tr_met['AUC']:.4f}", f"{tr_met['MCC']:.4f}",
                f"{vl_loss:.4f}", f"{vl_met['ACC']:.4f}", f"{vl_met['PRE']:.4f}",
                f"{vl_met['SEN']:.4f}", f"{vl_met['SPE']:.4f}", f"{vl_met['F1']:.4f}", f"{vl_met['AUC']:.4f}", f"{vl_met['MCC']:.4f}",
            ])

    print(f"=== Fold {fold} 完成，Best AUC={best_auc:.4f} ===")



=== Fold 1/5 ===
-------------------- model --------------------
Total params(M)    : 13,667,328
Trainable params(M): 13,667,328
Approx. size       : 26.07 MB
Model type: ImageEncoder_CEN
-------------------- model --------------------
Total params(M)    : 312,066
Trainable params(M): 312,066
Approx. size       : 0.60 MB
Model type: MultiModalClassifier


  scaler    = GradScaler(enabled=getattr(cfg, 'fp16', False))


Fold 1 | Epoch 001 | Train Loss=0.6993 | Val Loss=0.7077 | Train ACC=0.4746 | Val ACC=0.3333 | Train AUC=0.4321 | Val AUC=0.2383 | time=221.2s
✅ Fold 1 saved best model (AUC=0.2383)
Fold 1 | Epoch 002 | Train Loss=0.6792 | Val Loss=0.6771 | Train ACC=0.5970 | Val ACC=0.6667 | Train AUC=0.5154 | Val AUC=0.4023 | time=223.5s
✅ Fold 1 saved best model (AUC=0.4023)


In [None]:
# -------------------- 测试 -----------------------------
