# Environment Check

In [1]:
import torch, sys, platform, subprocess, os

print("torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("cuda.is_available:", torch.cuda.is_available())
print("device_count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("device 0:", torch.cuda.get_device_name(0),
          "capability:", torch.cuda.get_device_capability(0))
try:
    out = subprocess.check_output("nvidia-smi", shell=True).decode().splitlines()[0]
    print("\n=== nvidia-smi ===\n", out)
except Exception as e:
    print("nvidia-smi not found / driver not working:", e)
print("python:", sys.executable)
print("platform:", platform.platform())

torch: 2.5.1
torch.version.cuda: 12.1
cuda.is_available: True
device_count: 1
device 0: NVIDIA GeForce RTX 3060 Ti capability: (8, 6)

=== nvidia-smi ===
 Fri Oct 17 20:57:39 2025       
python: c:\Users\nhwen\anaconda3\envs\comp3710\python.exe
platform: Windows-10-10.0.26200-SP0


# init Dataset

In [2]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torch.cuda.amp import autocast, GradScaler

In [3]:
import os
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF

NUM_CLASSES = 4  # 背景 + 3 种脑组织
IMG_SIZE = 128

class OASIS2DPNGDataset(Dataset):
    """OASIS PNG 2D Brain Slice Dataset for segmentation"""
    def __init__(self, image_dir, mask_dir, img_size=IMG_SIZE, augment=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.augment = augment

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        # Load images
        img = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        # Resize
        img = TF.resize(img, [self.img_size, self.img_size])
        mask = TF.resize(mask, [self.img_size, self.img_size], interpolation=TF.InterpolationMode.NEAREST)

        # To tensor
        img = TF.to_tensor(img)  # [1, H, W], float32
        mask = np.array(mask, dtype=np.uint8)

        # Map mask pixel values to [0, NUM_CLASSES-1]
        mask = mask // (256 // NUM_CLASSES)
        mask = torch.tensor(mask, dtype=torch.long)  # [H, W]

        # Data augmentation
        if self.augment:
            if random.random() < 0.5:
                img = TF.hflip(img); mask = TF.hflip(mask)
            if random.random() < 0.5:
                img = TF.vflip(img); mask = TF.vflip(mask)

        return img, mask

# Create Dataset (DataLoader)

In [4]:
# from google.colab import drive
# drive.mount('/content/drive')

# ROOT = "/content/drive/MyDrive/COMP3710/OASIS"

In [5]:
# base_path = r"C:\COMP3710\OASIS"
# base_path = ROOT
base_path = "D:\COMP3710\OASIS"
IMG_SIZE = 128
NUM_CLASSES = 4

In [6]:
train_dataset = OASIS2DPNGDataset(
    os.path.join(base_path, "keras_png_slices_train"),
    os.path.join(base_path, "keras_png_slices_seg_train"),
    img_size=IMG_SIZE,
    augment=True
)
val_dataset = OASIS2DPNGDataset(
    os.path.join(base_path, "keras_png_slices_validate"),
    os.path.join(base_path, "keras_png_slices_seg_validate"),
    img_size=IMG_SIZE,
    augment=False
)
test_dataset = OASIS2DPNGDataset(
    os.path.join(base_path, "keras_png_slices_test"),
    os.path.join(base_path, "keras_png_slices_seg_test"),
    img_size=IMG_SIZE,
    augment=False
)
if __name__ == "__main__":
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    # train_loader = DataLoader(
    #   train_dataset,
    #   batch_size=8,                # 保持合适 batch size
    #   shuffle=True,
    #   num_workers=4,               # 提高加载速度
    #   pin_memory=True,             # 加速 CPU→GPU 传输
    # )
    val_loader   = DataLoader(val_dataset, batch_size=8, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=8, shuffle=False)

    print(f"Train/Val/Test samples: {len(train_dataset)}/{len(val_dataset)}/{len(test_dataset)}")

Train/Val/Test samples: 9664/1120/544


# Model (UNet2D)

In [7]:
import torch.nn as nn
class UNet2D(nn.Module):
    def __init__(self, in_c=1, n_classes=4, base=32, act_layer=nn.SiLU):
        super().__init__()
        def double_conv(ic, oc):
            return nn.Sequential(
                nn.Conv2d(ic, oc, 3, padding=1, bias=False),
                nn.BatchNorm2d(oc),
                act_layer(),
                nn.Conv2d(oc, oc, 3, padding=1, bias=False),
                nn.BatchNorm2d(oc),
                act_layer(),
            )
        self.inc   = double_conv(in_c, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), double_conv(base, base*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), double_conv(base*2, base*4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), double_conv(base*4, base*8))
        self.bot   = nn.Sequential(nn.MaxPool2d(2), double_conv(base*8, base*16))
        self.up1   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c1    = double_conv(base*16 + base*8, base*8)
        self.up2   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c2    = double_conv(base*8 + base*4, base*4)
        self.up3   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c3    = double_conv(base*4 + base*2, base*2)
        self.up4   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c4    = double_conv(base*2 + base, base)
        self.outc  = nn.Conv2d(base, n_classes, 1)

    def _cat(self, up, skip):
        diffY = skip.size(2) - up.size(2)
        diffX = skip.size(3) - up.size(3)
        up = F.pad(up, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2])
        return torch.cat([skip, up], dim=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.bot(x4)
        x = self.c1(self._cat(self.up1(x5), x4))
        x = self.c2(self._cat(self.up2(x), x3))
        x = self.c3(self._cat(self.up3(x), x2))
        x = self.c4(self._cat(self.up4(x), x1))
        return self.outc(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet2D(in_c=1, n_classes=NUM_CLASSES).to(device)
print("Model params (M):", sum(p.numel() for p in model.parameters())/1e6)

Model params (M): 7.849124


# Loss, Optimizer, Scaler

In [11]:
def dice_loss(logits: torch.Tensor, target: torch.Tensor, eps: float = 1e-6):
    """
    logits: (N, C, H, W)
    target: 
      - 二分类: (N, H, W) 或 (N,1,H,W) 的{0,1}
      - 多分类: (N, H, W) 的类别索引(0..C-1)
    返回: mean dice loss
    """
    assert logits.dim() == 4, "logits 应为 (N,C,H,W)"
    N, C, H, W = logits.shape

    if target.dim() == 4 and target.shape[1] == 1:
        target = target.squeeze(1)         # (N,H,W)
    target = target.long()

    if C == 1:
        # 二分类
        probs = torch.sigmoid(logits)      # (N,1,H,W)
        tgt = target.float().unsqueeze(1)  # (N,1,H,W), 0/1
        inter = (probs * tgt).sum(dim=(0,2,3))
        union = probs.sum(dim=(0,2,3)) + tgt.sum(dim=(0,2,3))
        dice = (2*inter + eps) / (union + eps)
        loss = 1 - dice.mean()
        return loss
    else:
        # 多分类
        probs = F.softmax(logits, dim=1)   # (N,C,H,W)
        tgt_1h = F.one_hot(target, num_classes=C).permute(0,3,1,2).float()  # (N,C,H,W)
        dims = (0,2,3)
        inter = (probs * tgt_1h).sum(dim=dims)   # (C,)
        union = probs.sum(dim=dims) + tgt_1h.sum(dim=dims)  # (C,)
        dice = (2*inter + eps) / (union + eps)   # (C,)
        loss = 1 - dice.mean()
        return loss

@torch.no_grad()
def dice_score(logits: torch.Tensor, targets: torch.Tensor, exclude_bg: bool = True, eps: float = 1e-6) -> float:
    """
    返回批次平均 Dice（float）
    """
    N, C, H, W = logits.shape
    if targets.dim() == 4 and targets.shape[1] == 1:
        targets = targets.squeeze(1)
    targets = targets.long()

    if C == 1:
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        tgt = targets.float().unsqueeze(1)
        inter = (preds * tgt).sum(dim=(1,2,3))
        denom = preds.sum(dim=(1,2,3)) + tgt.sum(dim=(1,2,3))
        dice = ((2*inter + eps) / (denom + eps)).mean()
        return dice.item()
    else:
        probs = F.softmax(logits, dim=1)
        preds = probs.argmax(dim=1)  # (N,H,W)
        preds_oh   = F.one_hot(preds, num_classes=C).permute(0,3,1,2).float()
        targets_oh = F.one_hot(targets, num_classes=C).permute(0,3,1,2).float()

        inter = (preds_oh * targets_oh).sum(dim=(0,2,3))          # (C,)
        denom = preds_oh.sum(dim=(0,2,3)) + targets_oh.sum(dim=(0,2,3))
        dice_c = (2*inter + eps) / (denom + eps)                   # (C,)
        if exclude_bg and C > 1:
            dice_c = dice_c[1:]
        return dice_c.mean().item()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scaler = GradScaler(enabled=torch.cuda.is_available())

  scaler = GradScaler(enabled=torch.cuda.is_available())


# Training & Validation Loop

In [12]:
from tqdm import tqdm  # 用于显示进度条

NUM_EPOCHS = 20

for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    train_loss = 0
    # 使用 tqdm 显示 batch 进度条
    loop = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Train]", leave=False)
    for imgs, masks in loop:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        with autocast(enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = dice_loss(logits, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)

        # 实时更新 batch loss
        loop.set_postfix({"batch_loss": loss.item()})

    train_loss /= len(train_dataset)

    # Validation
    model.eval()
    val_loss, val_dice = 0, 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            batch_loss = dice_loss(logits, masks).item()
            val_loss += batch_loss * imgs.size(0)
            val_dice += dice_score(logits, masks) * imgs.size(0)
    val_loss /= len(val_dataset)
    val_dice /= len(val_dataset)

    print(f"Epoch {epoch}/{NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val Dice: {val_dice:.4f}")


  with autocast(enabled=torch.cuda.is_available()):
                                                                                          

Epoch 1/20 | Train Loss: 0.0560 | Val Loss: 0.0497 | Val Dice: 0.9392


                                                                                          

Epoch 2/20 | Train Loss: 0.0468 | Val Loss: 0.0509 | Val Dice: 0.9351


                                                                                          

Epoch 3/20 | Train Loss: 0.0436 | Val Loss: 0.0479 | Val Dice: 0.9383


                                                                                          

Epoch 4/20 | Train Loss: 0.0409 | Val Loss: 0.0522 | Val Dice: 0.9320


                                                                                          

Epoch 5/20 | Train Loss: 0.0392 | Val Loss: 0.0438 | Val Dice: 0.9433


                                                                                          

Epoch 6/20 | Train Loss: 0.0378 | Val Loss: 0.0427 | Val Dice: 0.9446


                                                                                          

Epoch 7/20 | Train Loss: 0.0363 | Val Loss: 0.0400 | Val Dice: 0.9481


                                                                                          

Epoch 8/20 | Train Loss: 0.0357 | Val Loss: 0.0387 | Val Dice: 0.9497


                                                                                          

Epoch 9/20 | Train Loss: 0.0350 | Val Loss: 0.0401 | Val Dice: 0.9476


                                                                                           

Epoch 10/20 | Train Loss: 0.0342 | Val Loss: 0.0448 | Val Dice: 0.9413


                                                                                           

Epoch 11/20 | Train Loss: 0.0337 | Val Loss: 0.0408 | Val Dice: 0.9466


                                                                                           

Epoch 12/20 | Train Loss: 0.0332 | Val Loss: 0.0382 | Val Dice: 0.9501


                                                                                           

Epoch 13/20 | Train Loss: 0.0328 | Val Loss: 0.0378 | Val Dice: 0.9505


                                                                                           

Epoch 14/20 | Train Loss: 0.0324 | Val Loss: 0.0375 | Val Dice: 0.9508


                                                                                           

Epoch 15/20 | Train Loss: 0.0318 | Val Loss: 0.0363 | Val Dice: 0.9524


                                                                                           

Epoch 16/20 | Train Loss: 0.0317 | Val Loss: 0.0374 | Val Dice: 0.9509


                                                                                           

Epoch 17/20 | Train Loss: 0.0314 | Val Loss: 0.0377 | Val Dice: 0.9505


                                                                                           

Epoch 18/20 | Train Loss: 0.0311 | Val Loss: 0.0367 | Val Dice: 0.9519


                                                                                           

Epoch 19/20 | Train Loss: 0.0310 | Val Loss: 0.0368 | Val Dice: 0.9517


                                                                                           

Epoch 20/20 | Train Loss: 0.0306 | Val Loss: 0.0375 | Val Dice: 0.9508


# Test & Visualization

In [17]:
import torch
from torch import amp as torch_amp

CLASS_NAMES = ["bg", "class1", "class2", "class3"]  # 确保长度 >= NUM_CLASSES
AMP_DTYPE = torch.float16

@torch.no_grad()
def evaluate_per_class_dice(model, loader, num_classes: int, 
                            exclude_bg: bool = True, 
                            device: torch.device = None,
                            amp_enabled: bool = True):
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    cm = torch.zeros((num_classes, num_classes), dtype=torch.long, device=device)

    for imgs, masks in loader:
        imgs  = imgs.to(device, non_blocking=True)
        if masks.ndim == 4 and masks.shape[1] == 1:
            masks = masks.squeeze(1)
        masks = masks.to(device, non_blocking=True).long()  # (N,H,W)

        with torch_amp.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=amp_enabled and device.type=="cuda"):
            logits = model(imgs)
            preds  = logits.argmax(dim=1)  # (N,H,W)

        k = (masks * num_classes + preds).view(-1)
        cm += torch.bincount(k, minlength=num_classes*num_classes).view(num_classes, num_classes)

    TP = cm.diag().to(torch.float32)
    FP = cm.sum(0).to(torch.float32) - TP
    FN = cm.sum(1).to(torch.float32) - TP
    dice_c = (2*TP) / (2*TP + FP + FN + 1e-6)  # (C,)

    def print_table(dice_vals, names, title):
        print("\n" + title)
        print("-"*len(title))
        for name, v in zip(names, dice_vals):
            print(f"{name:>8}: Dice = {float(v):.4f}")
        print(f"{'MEAN':>8}: Dice = {float(torch.tensor(dice_vals).mean()):.4f}")

    # 含背景
    names_all = [CLASS_NAMES[i] if i < len(CLASS_NAMES) else f"class{i}" for i in range(num_classes)]
    print_table(dice_c.tolist(), names_all, "Per-class Dice (including background)")

    # 排除背景
    if num_classes > 1:
        fg = dice_c[1:]
        names_fg = names_all[1:]
        print_table(fg.tolist(), names_fg, "Per-class Dice (excluding background)")

        passed = bool((fg >= 0.90).all().item())
        print("\n==> " + ("PASS ✅ (exclude bg)" if passed else "FAIL ❌ (exclude bg)"),
              f"| min Dice = {float(fg.min()):.4f}")
    else:
        passed = bool((dice_c >= 0.90).all().item())
        print("\n==> " + ("PASS ✅" if passed else "FAIL ❌"),
              f"| min Dice = {float(dice_c.min()):.4f}")

    return dice_c.cpu().tolist(), float(dice_c.mean().cpu())


In [18]:
from torch import amp as torch_amp  # 确保已导入
import torch

# 如果你用了我上一条消息里的函数，请把它所在的 cell 先执行一遍
# 然后运行这个 cell 进行调用与打印：

try:
    per_class, mean_dice = evaluate_per_class_dice(
        model,
        test_loader,
        num_classes=NUM_CLASSES,
        exclude_bg=True,        # 判定是否达标按“排除背景”
        device=device,
        amp_enabled=True
    )
    print("\n[CALL RESULT]", flush=True)
    print("per_class (including bg):", per_class, flush=True)
    print("mean_dice (including bg):", mean_dice, flush=True)

    # 额外做一次达标判断（排除背景）
    if len(per_class) > 1:
        fg = torch.tensor(per_class[1:], dtype=torch.float32)
        passed = bool((fg >= 0.90).all().item())
        print("PASS(exclude bg):", passed, "| min_fg_dice =", float(fg.min()), flush=True)

except Exception as e:
    import traceback, sys
    print("ERROR while evaluating:", e, file=sys.stderr, flush=True)
    traceback.print_exc()



Per-class Dice (including background)
-------------------------------------
      bg: Dice = 0.9988
  class1: Dice = 0.9444
  class2: Dice = 0.9514
  class3: Dice = 0.9727
    MEAN: Dice = 0.9668

Per-class Dice (excluding background)
-------------------------------------
  class1: Dice = 0.9444
  class2: Dice = 0.9514
  class3: Dice = 0.9727
    MEAN: Dice = 0.9562

==> PASS ✅ (exclude bg) | min Dice = 0.9444

[CALL RESULT]
per_class (including bg): [0.9988069534301758, 0.9444249868392944, 0.951388418674469, 0.9727374315261841]
mean_dice (including bg): 0.9668394327163696
PASS(exclude bg): True | min_fg_dice = 0.9444249868392944
