# Environment Check

In [None]:
import torch, sys, platform, subprocess, os

print("torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("cuda.is_available:", torch.cuda.is_available())
print("device_count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("device 0:", torch.cuda.get_device_name(0),
          "capability:", torch.cuda.get_device_capability(0))
try:
    out = subprocess.check_output("nvidia-smi", shell=True).decode().splitlines()[0]
    print("\n=== nvidia-smi ===\n", out)
except Exception as e:
    print("nvidia-smi not found / driver not working:", e)
print("python:", sys.executable)
print("platform:", platform.platform())

# init Dataset

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torch.cuda.amp import autocast, GradScaler

In [None]:
import os
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF

NUM_CLASSES = 4  # 背景 + 3 种脑组织
IMG_SIZE = 128

class OASIS2DPNGDataset(Dataset):
    """OASIS PNG 2D Brain Slice Dataset for segmentation"""
    def __init__(self, image_dir, mask_dir, img_size=IMG_SIZE, augment=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))
        self.img_size = img_size
        self.augment = augment
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        
        # Load images
        img = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        
        # Resize
        img = TF.resize(img, [self.img_size, self.img_size])
        mask = TF.resize(mask, [self.img_size, self.img_size], interpolation=TF.InterpolationMode.NEAREST)
        
        # To tensor
        img = TF.to_tensor(img)  # [1, H, W], float32
        mask = np.array(mask, dtype=np.uint8)
        
        # Map mask pixel values to [0, NUM_CLASSES-1]
        mask = mask // (256 // NUM_CLASSES)
        mask = torch.tensor(mask, dtype=torch.long)  # [H, W]
        
        # Data augmentation
        if self.augment:
            if random.random() < 0.5:
                img = TF.hflip(img); mask = TF.hflip(mask)
            if random.random() < 0.5:
                img = TF.vflip(img); mask = TF.vflip(mask)
        
        return img, mask

# Create Dataset (DataLoader)

In [None]:
base_path = r"C:\COMP3710\OASIS"
IMG_SIZE = 128
NUM_CLASSES = 4

train_dataset = OASIS2DPNGDataset(
    os.path.join(base_path, "keras_png_slices_train"),
    os.path.join(base_path, "keras_png_slices_seg_train"),
    img_size=IMG_SIZE,
    augment=True
)
val_dataset = OASIS2DPNGDataset(
    os.path.join(base_path, "keras_png_slices_validate"),
    os.path.join(base_path, "keras_png_slices_seg_validate"),
    img_size=IMG_SIZE,
    augment=False
)
test_dataset = OASIS2DPNGDataset(
    os.path.join(base_path, "keras_png_slices_test"),
    os.path.join(base_path, "keras_png_slices_seg_test"),
    img_size=IMG_SIZE,
    augment=False
)
if __name__ == "__main__":
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=8, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=8, shuffle=False)

    print(f"Train/Val/Test samples: {len(train_dataset)}/{len(val_dataset)}/{len(test_dataset)}")

# Model (UNet2D)

In [None]:
import torch.nn as nn
class UNet2D(nn.Module):
    def __init__(self, in_c=1, n_classes=4, base=32, act_layer=nn.SiLU):
        super().__init__()
        def double_conv(ic, oc):
            return nn.Sequential(
                nn.Conv2d(ic, oc, 3, padding=1, bias=False),
                nn.BatchNorm2d(oc),
                act_layer(),
                nn.Conv2d(oc, oc, 3, padding=1, bias=False),
                nn.BatchNorm2d(oc),
                act_layer(),
            )
        self.inc   = double_conv(in_c, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), double_conv(base, base*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), double_conv(base*2, base*4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), double_conv(base*4, base*8))
        self.bot   = nn.Sequential(nn.MaxPool2d(2), double_conv(base*8, base*16))
        self.up1   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c1    = double_conv(base*16 + base*8, base*8)
        self.up2   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c2    = double_conv(base*8 + base*4, base*4)
        self.up3   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c3    = double_conv(base*4 + base*2, base*2)
        self.up4   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.c4    = double_conv(base*2 + base, base)
        self.outc  = nn.Conv2d(base, n_classes, 1)

    def _cat(self, up, skip):
        diffY = skip.size(2) - up.size(2)
        diffX = skip.size(3) - up.size(3)
        up = F.pad(up, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2])
        return torch.cat([skip, up], dim=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.bot(x4)
        x = self.c1(self._cat(self.up1(x5), x4))
        x = self.c2(self._cat(self.up2(x), x3))
        x = self.c3(self._cat(self.up3(x), x2))
        x = self.c4(self._cat(self.up4(x), x1))
        return self.outc(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet2D(in_c=1, n_classes=NUM_CLASSES).to(device)
print("Model params (M):", sum(p.numel() for p in model.parameters())/1e6)

# Loss, Optimizer, Scaler

In [None]:
def dice_loss(logits, target, eps=1e-6):
    probs = F.softmax(logits, dim=1)
    tgt_1h = F.one_hot(target, probs.shape[1]).permute(0,3,1,2).float()
    dims = (0,2,3)
    inter = torch.sum(probs * tgt_1h, dims)
    union = torch.sum(probs, dims) + torch.sum(tgt_1h, dims)
    dice = (2*inter + eps) / (union + eps)
    return 1 - dice.mean()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scaler = GradScaler(enabled=torch.cuda.is_available())

# Training & Validation Loop

In [None]:
from tqdm import tqdm  # 用于显示进度条

NUM_EPOCHS = 20

for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    train_loss = 0
    # 使用 tqdm 显示 batch 进度条
    loop = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Train]", leave=False)
    for imgs, masks in loop:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        with autocast(enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = dice_loss(logits, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item() * imgs.size(0)
        
        # 实时更新 batch loss
        loop.set_postfix({"batch_loss": loss.item()})
        
    train_loss /= len(train_dataset)
    
    # Validation
    model.eval()
    val_loss, val_dice = 0, 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            batch_loss = dice_loss(logits, masks).item()
            val_loss += batch_loss * imgs.size(0)
            val_dice += dice_score(logits, masks) * imgs.size(0)
    val_loss /= len(val_dataset)
    val_dice /= len(val_dataset)
    
    print(f"Epoch {epoch}/{NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val Dice: {val_dice:.4f}")


# Test & Visualization

In [None]:
def visualize_predictions(model, loader, n=3):
    model.eval()
    imgs, masks = next(iter(loader))
    imgs, masks = imgs.to(device), masks.to(device)
    with torch.no_grad():
        logits = model(imgs)
        preds = logits.argmax(dim=1)
    imgs = imgs.cpu().numpy()
    masks = masks.cpu().numpy()
    preds = preds.cpu().numpy()
    
    for i in range(min(n, imgs.shape[0])):
        fig, axes = plt.subplots(1,3, figsize=(12,4))
        axes[0].imshow(imgs[i,0], cmap='gray'); axes[0].set_title("Image")
        axes[1].imshow(masks[i], cmap='jet', vmin=0, vmax=NUM_CLASSES-1); axes[1].set_title("Mask")
        axes[2].imshow(preds[i], cmap='jet', vmin=0, vmax=NUM_CLASSES-1); axes[2].set_title("Prediction")
        plt.show()

# 可视化前 3 个测试样本
visualize_predictions(model, test_loader, n=3)