In [None]:
import os, json, time, csv, numpy as np
from collections import Counter
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter

from datasets.ADNI import ADNI, ADNI_transform
from monai.data import Dataset

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score, matthews_corrcoef,
                             confusion_matrix, roc_curve, auc)
from models.unet3d import UNet3DClassifier,UNet3D
from utils.metrics import calculate_metrics

In [2]:
# -------------------- 配置 --------------------
def load_cfg(path):
    with open(path) as f: return json.load(f)

class Cfg:
    def __init__(self, d):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for k, v in d.items(): 
            setattr(self, k, v)

# # ----------------- 指标函数 -------------------
# def calculate_metrics(y_true, y_pred, y_score):
#     if len(y_true) == 0:
#         raise ValueError("No samples to evaluate. Please check your test_loader / data split.")
#     tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
#     spe = tn / (tn + fp + 1e-8)
#     return {
#         'ACC': accuracy_score(y_true, y_pred),
#         'PRE': precision_score(y_true, y_pred, zero_division=0),
#         'SEN': recall_score(y_true, y_pred, zero_division=0),
#         'SPE': spe,
#         'F1' : f1_score(y_true, y_pred, zero_division=0),
#         'AUC': roc_auc_score(y_true, y_score),
#         'MCC': matthews_corrcoef(y_true, y_pred),
#         'cm' : np.array([[tn, fp], [fn, tp]])
#     }

config_path = "config/config.json"
cfg = Cfg(load_cfg(config_path))
for name, val in vars(cfg).items():
    print(f"{name:15s}: {val}")
writer = SummaryWriter(cfg.checkpoint_dir)

device         : cuda:0
label_file     : adni_dataset/ADNI_902.csv
mri_dir        : adni_dataset/MRI
pet_dir        : adni_dataset/PET
task           : ADCN
augment        : False
split_ratio_test: 0.2
seed           : 42
num_epochs     : 100
batch_size     : 8
lr             : 1e-06
weight_decay   : 1e-05
fp16           : True
checkpoint_dir : checkpoints
nb_class       : 2
n_splits       : 1
dropout_rate   : 0.5
in_channels    : 2
seg_task       : False


In [3]:
# 数据划分
full_ds = ADNI(cfg.label_file, cfg.mri_dir, cfg.pet_dir, cfg.task, cfg.augment).data_dict
train_val, test_ds = train_test_split(
    full_ds, test_size=0.2, random_state=42,
    stratify=[d['label'] for d in full_ds])

# 再划分验证集
train_ds, val_ds = train_test_split(
    train_val, test_size=0.2, random_state=42,
    stratify=[d['label'] for d in train_val])

# 数据加载
tr_tf, vl_tf = ADNI_transform(augment=cfg.augment)
tr_loader = DataLoader(Dataset(train_ds, tr_tf),
                       batch_size=cfg.batch_size, shuffle=True,
                       num_workers=4, pin_memory=True)
vl_loader = DataLoader(Dataset(val_ds, vl_tf),
                       batch_size=cfg.batch_size, shuffle=False,
                       num_workers=2, pin_memory=True)

from pprint import pprint
print("=== raw data_dict 前 5 条样本 ===")
pprint(full_ds[:5])


[ADNI Dataset: ADCN] 样本分布：
  CN (0): 204
  AD (1): 219

=== raw data_dict 前 5 条样本 ===
[{'MRI': 'adni_dataset/MRI\\002_S_4213.nii',
  'PET': 'adni_dataset/PET\\002_S_4213.nii',
  'Subject': '002_S_4213',
  'label': 0},
 {'MRI': 'adni_dataset/MRI\\002_S_5018.nii',
  'PET': 'adni_dataset/PET\\002_S_5018.nii',
  'Subject': '002_S_5018',
  'label': 1},
 {'MRI': 'adni_dataset/MRI\\003_S_4081.nii',
  'PET': 'adni_dataset/PET\\003_S_4081.nii',
  'Subject': '003_S_4081',
  'label': 0},
 {'MRI': 'adni_dataset/MRI\\003_S_4119.nii',
  'PET': 'adni_dataset/PET\\003_S_4119.nii',
  'Subject': '003_S_4119',
  'label': 0},
 {'MRI': 'adni_dataset/MRI\\003_S_4136.nii',
  'PET': 'adni_dataset/PET\\003_S_4136.nii',
  'Subject': '003_S_4136',
  'label': 1}]


In [4]:
# ----------------- 创建模型 -------------------
model = UNet3D(in_channels=cfg.in_channels, num_classes=cfg.nb_class)

# 参数统计
total_params     = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
bytes_per_param  = 2 if cfg.fp16 else 4
print("--------------------model------------------")
print(f"Total params(M)    : {total_params:,}")
print(f"Trainable params(M): {trainable_params:,}")
print(f"Approx. size    : {total_params*bytes_per_param/1024**2:.2f} MB")
print("model type:", type(model).__name__)


# ----------------- 测试创建模型 -------------------
from torchsummary import summary
start_time = time.time()
summary(
    model=model,
    input_size=(2, 96, 112, 96),    # (C, D, H, W)
    batch_size=-1,
    device="cpu"                                  # 若想在 GPU 上汇总，可填 "cuda"
)
print(f"--- {time.time() - start_time:.2f} seconds ---")

--------------------model------------------
Total params(M)    : 19,073,698
Trainable params(M): 19,073,698
Approx. size    : 36.38 MB
model type: UNet3D
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 32, 96, 112, 96]           1,760
       BatchNorm3d-2      [-1, 32, 96, 112, 96]              64
              ReLU-3      [-1, 32, 96, 112, 96]               0
            Conv3d-4      [-1, 64, 96, 112, 96]          55,360
       BatchNorm3d-5      [-1, 64, 96, 112, 96]             128
              ReLU-6      [-1, 64, 96, 112, 96]               0
         MaxPool3d-7       [-1, 64, 48, 56, 48]               0
       Conv3DBlock-8  [[-1, 64, 48, 56, 48], [-1, 64, 96, 112, 96]]               0
            Conv3d-9       [-1, 64, 48, 56, 48]         110,656
      BatchNorm3d-10       [-1, 64, 48, 56, 48]             128
             ReLU-11       [-1, 64, 48, 56, 48]          

In [6]:
# ------------------- 训练 ---------------------
model = UNet3DClassifier(in_channels=cfg.in_channels, num_classes=cfg.nb_class).to(cfg.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.num_epochs)
scaler = GradScaler(enabled=cfg.fp16)

best_auc = -np.inf
for epoch in range(1, cfg.num_epochs + 1):
    t0 = time.time()

    # -------- Train --------
    model.train(); yt, yp, ys = [], [], []
    for batch in tr_loader:
        # ❶ 从 batch 里取出两模态，并拼通道
        mri = batch['MRI'].to(cfg.device)    # [B,1,D,H,W]
        pet = batch['PET'].to(cfg.device)    # [B,1,D,H,W]
        x   = torch.cat([mri, pet], dim=1)   # [B,2,D,H,W]
        y   = batch['label'].to(cfg.device).long().view(-1)

        optimizer.zero_grad()
        with autocast(device_type='cuda', enabled=cfg.fp16):
            out  = model(x)
            loss = criterion(out, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        prob = torch.softmax(out, dim=1)[:, 1].detach().cpu().numpy()
        pred = out.argmax(1).detach().cpu().numpy()
        yt.extend(y.cpu().numpy())
        yp.extend(pred)
        ys.extend(prob)

    tr_met = calculate_metrics(yt, yp, ys)

    # -------- Validation --------
    model.eval(); yt, yp, ys = [], [], []
    with torch.no_grad():
        for batch in vl_loader:
            mri = batch['MRI'].to(cfg.device)
            pet = batch['PET'].to(cfg.device)
            x   = torch.cat([mri, pet], dim=1)
            y   = batch['label'].to(cfg.device).long().view(-1)

            with autocast(device_type='cuda', enabled=cfg.fp16):
                out  = model(x)
                loss = criterion(out, y)

            prob = torch.softmax(out, dim=1)[:, 1].cpu().numpy()
            pred = out.argmax(1).cpu().numpy()
            yt.extend(y.cpu().numpy())
            yp.extend(pred)
            ys.extend(prob)

    vl_met = calculate_metrics(yt, yp, ys)
    scheduler.step()

    print(f"Epoch {epoch:03d} | "
          f"Train ACC={tr_met['ACC']:.4f} F1={tr_met['F1']:.4f} AUC={tr_met['AUC']:.4f} | "
          f"Val   ACC={vl_met['ACC']:.4f} F1={vl_met['F1']:.4f} AUC={vl_met['AUC']:.4f} | "
          f"time={time.time()-t0:.1f}s")

    if vl_met['AUC'] > best_auc:
        best_auc = vl_met['AUC']
        torch.save(model.state_dict(), os.path.join(cfg.checkpoint_dir, "best_model.pth"))
        print("✅ Saved best model.")


  scaler = GradScaler(enabled=cfg.fp16)


Epoch 001 | Train ACC=0.5185 F1=0.6829 AUC=0.5449 | Val   ACC=0.4853 F1=0.0000 AUC=0.6048 | time=136.3s
✅ Saved best model.
Epoch 002 | Train ACC=0.5185 F1=0.6829 AUC=0.6681 | Val   ACC=0.5441 F1=0.6804 AUC=0.6316 | time=131.3s
✅ Saved best model.
Epoch 003 | Train ACC=0.5370 F1=0.6898 AUC=0.6954 | Val   ACC=0.5441 F1=0.6804 AUC=0.6952 | time=132.3s
✅ Saved best model.
Epoch 004 | Train ACC=0.5741 F1=0.7028 AUC=0.7468 | Val   ACC=0.5735 F1=0.7010 AUC=0.7485 | time=131.2s
✅ Saved best model.
Epoch 005 | Train ACC=0.6556 F1=0.7240 AUC=0.7492 | Val   ACC=0.6471 F1=0.7143 AUC=0.7117 | time=130.2s
Epoch 006 | Train ACC=0.6889 F1=0.7358 AUC=0.7572 | Val   ACC=0.6618 F1=0.7294 AUC=0.7675 | time=130.1s
✅ Saved best model.
Epoch 007 | Train ACC=0.6741 F1=0.7086 AUC=0.7448 | Val   ACC=0.6765 F1=0.6071 AUC=0.7848 | time=140.7s
✅ Saved best model.
Epoch 008 | Train ACC=0.7000 F1=0.7197 AUC=0.7546 | Val   ACC=0.7647 F1=0.7838 AUC=0.8065 | time=132.5s
✅ Saved best model.
Epoch 009 | Train ACC=0.7185

In [5]:
# -------------------- 测试 --------------------
def calculate_metrics(y_true, y_pred, y_score):
    if len(y_true) == 0:
        raise ValueError("No samples to evaluate. Please check your test_loader / data split.")
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sen = recall_score(y_true, y_pred, zero_division=0)
    spe = tn / (tn + fp + 1e-8)

    # 近似 AUC = (Sensitivity + Specificity) / 2
    auc_val = (sen + spe) / 2

    return {
        'ACC': accuracy_score(y_true, y_pred),
        'PRE': precision_score(y_true, y_pred, zero_division=0),
        'SEN': sen,
        'SPE': spe,
        'F1' : f1_score(y_true, y_pred, zero_division=0),
        'AUC': auc_val,
        'MCC': matthews_corrcoef(y_true, y_pred),
        'cm' : np.array([[tn, fp], [fn, tp]])
    }

# ❶ 用同样的 transform（无增强）处理测试集
_, test_tf = ADNI_transform(augment=False)

# ❷ 构造 MONAI Dataset → DataLoader
test_loader = DataLoader(
    Dataset(test_ds, test_tf),
    batch_size=cfg.batch_size,
    shuffle=False,    # 测试一般不打乱
    num_workers=2,
    pin_memory=True
)

# ❸ （可选）打印 test_ds 的前几条，确认路径/标签无误
from pprint import pprint
print("=== raw test_ds 前 5 条样本 ===")
pprint(test_ds[:5])

# ❹ （可选）迭代一个 batch，确认形状
batch = next(iter(test_loader))
print("MRI batch shape:", batch['MRI'].shape)   # → [B,1,D,H,W]
print("PET batch shape:", batch['PET'].shape)   # → [B,1,D,H,W]
print("Label batch shape:", batch['label'].shape)  # → [B]

import torch
from torch.cuda.amp import autocast
from utils.metrics import calculate_metrics
from models.unet3d import UNet3DClassifier

# 1) 构建模型并加载权重
model = UNet3DClassifier(
    in_channels=cfg.in_channels, 
    num_classes=cfg.nb_class
).to(cfg.device)
model.load_state_dict(torch.load(
    os.path.join(cfg.checkpoint_dir, "best_model.pth"),
    map_location=cfg.device
))
model.eval()

# 2) 遍历 test_loader
yt, yp, ys = [], [], []
with torch.no_grad():
    for batch in test_loader:
        mri = batch['MRI'].to(cfg.device)    # [B,1,D,H,W]
        pet = batch['PET'].to(cfg.device)    # [B,1,D,H,W]
        x   = torch.cat([mri, pet], dim=1)   # [B,2,D,H,W]
        y   = batch['label'].to(cfg.device).long().view(-1)

        with autocast(enabled=cfg.fp16):
            out = model(x)                   # [B, num_classes]

        prob = torch.softmax(out, dim=1)[:, 1].cpu().numpy()  # 正例概率
        pred = out.argmax(1).cpu().numpy()

        yt.extend(y.cpu().numpy())
        yp.extend(pred)
        ys.extend(prob)

# 3) 计算并打印指标
metrics = calculate_metrics(yt, yp, ys)
print("\n=== Test Results ===")
print(f"ACC: {metrics['ACC']:.4f}")
print(f"PRE: {metrics['PRE']:.4f}")
print(f"SEN: {metrics['SEN']:.4f}")
print(f"SPE: {metrics['SPE']:.4f}")
print(f"F1 : {metrics['F1']:.4f}")
print(f"AUC: {metrics['AUC']:.4f}")
print(f"MCC: {metrics['MCC']:.4f}")
print("Confusion Matrix:")
print(metrics['cm'])  # [[tn, fp], [fn, tp]]


=== raw test_ds 前 5 条样本 ===
[{'MRI': 'adni_dataset/MRI\\941_S_4376.nii',
  'PET': 'adni_dataset/PET\\941_S_4376.nii',
  'Subject': '941_S_4376',
  'label': 0},
 {'MRI': 'adni_dataset/MRI\\031_S_4024.nii',
  'PET': 'adni_dataset/PET\\031_S_4024.nii',
  'Subject': '031_S_4024',
  'label': 1},
 {'MRI': 'adni_dataset/MRI\\032_S_4348.nii',
  'PET': 'adni_dataset/PET\\032_S_4348.nii',
  'Subject': '032_S_4348',
  'label': 0},
 {'MRI': 'adni_dataset/MRI\\094_S_1090.nii',
  'PET': 'adni_dataset/PET\\094_S_1090.nii',
  'Subject': '094_S_1090',
  'label': 1},
 {'MRI': 'adni_dataset/MRI\\016_S_4887.nii',
  'PET': 'adni_dataset/PET\\016_S_4887.nii',
  'Subject': '016_S_4887',
  'label': 1}]
MRI batch shape: torch.Size([8, 1, 91, 109, 91])
PET batch shape: torch.Size([8, 1, 91, 109, 91])
Label batch shape: torch.Size([8])


  model.load_state_dict(torch.load(
  with autocast(enabled=cfg.fp16):



=== Test Results ===
ACC: 0.8824
PRE: 0.9250
SEN: 0.8409
SPE: 0.9268
F1 : 0.8810
AUC: 0.9454
MCC: 0.7686
Confusion Matrix:
[[38  3]
 [ 7 37]]
