In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

class down(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.BatchNorm2d(dout),
            nn.ReLU(),
            nn.Conv2d(dout, dout, kernel_size=3, padding=1),
            nn.BatchNorm2d(dout),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.down(x)
    
class up(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.BatchNorm2d(dout),
            nn.ReLU(),
            nn.Conv2d(dout, dout, kernel_size=3, padding=1),
            nn.BatchNorm2d(dout),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.up(x)

class unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.scale = 1 # maybe using the original channels might be better
        self.flatten = nn.Flatten()
        self.down1 = down(3, self.scale * 64)
        self.down2 = down(self.scale * 64, self.scale * 128)
        self.down3 = down(self.scale * 128, self.scale * 256)
        self.down4 = down(self.scale * 256, self.scale * 512)
        self.down5 = down(self.scale * 512, self.scale * 1024)
        self.up1 = up(self.scale * 1024, self.scale * 512)
        self.up2 = up(self.scale * 512, self.scale * 256)
        self.up3 = up(self.scale * 256, self.scale * 128)
        self.up4 = up(self.scale * 128, self.scale * 64)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2)
        self.upconv1 = nn.Conv2d(self.scale * 1024, self.scale * 512, kernel_size=3, padding=1)
        self.upconv2 = nn.Conv2d(self.scale * 512, self.scale * 256, kernel_size=3, padding=1)
        self.upconv3 = nn.Conv2d(self.scale * 256, self.scale * 128, kernel_size=3, padding=1)
        self.upconv4 = nn.Conv2d(self.scale * 128, self.scale * 64, kernel_size=3, padding=1)
        self.output = nn.Sequential(
            nn.Conv2d(self.scale * 64, 4, kernel_size=1),
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(self.maxpool(x1))
        x3 = self.down3(self.maxpool(x2))
        x4 = self.down4(self.maxpool(x3))
        x5 = self.down5(self.maxpool(x4))
        x = self.up1(torch.cat([x4, self.upconv1(self.upsample(x5))], dim=1))
        x = self.up2(torch.cat([x3, self.upconv2(self.upsample(x))], dim=1))
        x = self.up3(torch.cat([x2, self.upconv3(self.upsample(x))], dim=1))
        x = self.up4(torch.cat([x1, self.upconv4(self.upsample(x))], dim=1))
        pre_output = self.output(x)
        return pre_output

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, PILToTensor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


import os
from torchvision.io import decode_image

target_batch_size = 64
batch_size = 4

class dataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.img_names = sorted([os.path.splitext(filename)[0] for filename in os.listdir(img_dir)])
        self.len = len(self.img_names)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        img = decode_image(os.path.join(self.img_dir, self.img_names[idx] + ".jpg")).float()/255
        label = decode_image(os.path.join(self.label_dir, self.img_names[idx] + ".png"))

        if self.transform:
            img = self.transform(img)

        if self.target_transform:
            label = self.target_transform(label)

        return img, label
    
def display_img_label(data, idx):
    img, label = data[idx]
    figure = plt.figure(figsize=(10,20))
    figure.add_subplot(1, 2, 1)
    plt.imshow(img.permute(1, 2, 0))

    figure.add_subplot(1, 2, 2)
    plt.imshow(label.permute(1, 2, 0), cmap='grey')

    plt.show()

class target_remap(object):
    def __call__(self, img):
        img[img == 255] = 3
        return img

def diff_size_collate(batch):
    imgs = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return imgs, labels



training_data = dataset("atrain/color", "atrain/label", target_transform=target_remap())
val_data = dataset("Val/color", "Val/label", target_transform=target_remap())
test_data = dataset("Test/color", "Test/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)

## Resizers for Eval loop

In [None]:
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from tqdm import tqdm

# Assume these are defined somewhere in your code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_amp = True  # or False depending on your setup

def resize_with_padding(image, target_size=512):
    """
    Resize a single image (Tensor of shape (C, H, W)) so that the longer side
    equals target_size, preserving aspect ratio; add black padding as needed.
    Returns the resized and padded image, plus a metadata dictionary.
    """
    _, orig_h, orig_w = image.shape
    scale = min(target_size / orig_w, target_size / orig_h)
    new_w = int(round(orig_w * scale))
    new_h = int(round(orig_h * scale))
    
    # Resize the image
    image_resized = TF.resize(image, size=(new_h, new_w))
    
    # Compute padding on each side
    pad_w = target_size - new_w
    pad_h = target_size - new_h
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left
    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top

    # Pad the image (padding order: left, top, right, bottom)
    image_padded = TF.pad(image_resized, padding=(pad_left, pad_top, pad_right, pad_bottom), fill=0)

    meta = {
        "original_size": (orig_h, orig_w),
        "new_size": (new_h, new_w),
        "pad": (pad_left, pad_top, pad_right, pad_bottom),
        "scale": scale
    }
    return image_padded, meta

def reverse_resize_and_padding(image, meta, interpolation="bilinear"):
    """
    Remove the padding from image (Tensor of shape (C, target_size, target_size))
    using metadata and then resize the cropped image back to the original size.
    interpolation: "bilinear" for continuous outputs; use "nearest" for label maps.
    """
    pad_left, pad_top, pad_right, pad_bottom = meta["pad"]
    new_h, new_w = meta["new_size"]
    
    # Crop out the padding: from pad_top to pad_top+new_h and pad_left to pad_left+new_w.
    image_cropped = image[..., pad_top: pad_top + new_h, pad_left: pad_left + new_w]
    
    # Resize the cropped image back to the original size.
    orig_h, orig_w = meta["original_size"]
    # F.interpolate expects a 4D tensor.
    image_original = F.interpolate(image_cropped.unsqueeze(0),
                                   size=(orig_h, orig_w),
                                   mode=interpolation,
                                   align_corners=False if interpolation != "nearest" else None)
    return image_original.squeeze(0)

def process_batch_forward(batch_images, target_size=512):
    """
    Process a batch (Tensor of shape (N, C, H, W)) by resizing each image to target_size
    with aspect ratio preserved (adding black padding).
    Returns the processed batch and a list of meta dictionaries.
    """
    resized_batch = []
    meta_list = []
    for image in batch_images:
        if image.ndim == 3 and image.shape[0] == 4:
            image = image[:3, ...] # Slice to keep only the first 3 channels (R, G, B)
        image_resized, meta = resize_with_padding(image, target_size)
        resized_batch.append(image_resized)
        meta_list.append(meta)
    return torch.stack(resized_batch), meta_list

def process_batch_reverse(batch_outputs, meta_list, interpolation="bilinear"):
    """
    Given a batch of network outputs of shape (N, C, target_size, target_size) and the
    corresponding meta info, reverse the transform for each one to obtain predictions at their
    original sizes.
    """
    original_outputs = []
    for output, meta in zip(batch_outputs, meta_list):
        restored = reverse_resize_and_padding(output, meta, interpolation=interpolation)
        original_outputs.append(restored)
    return original_outputs




## Losses

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import PolynomialLR # Import PolynomialLR
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F # Needed for softmax in Dice Loss

# --- Configuration ---
EPOCHS = 100
MODEL_SAVE_DIR = "unet" # Changed path
INITIAL_LR = 0.01 # Standard nnU-Net initial LR for SGD
WEIGHT_DECAY = 3e-5 # A common weight decay value, adjust if needed
SGD_MOMENTUM = 0.99 # nnU-Net standard momentum

os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

class MemoryEfficientDiceLoss(nn.Module):
    """ Version using ignore_index """
    def __init__(self, apply_softmax: bool = True, ignore_index: int = None, smooth: float = 1e-5):
        super(MemoryEfficientDiceLoss, self).__init__()
        self.apply_softmax = apply_softmax
        self.ignore_index = ignore_index
        self.smooth = smooth
        self.apply_softmax = apply_softmax

    def forward(self, x, y):
        num_classes = x.shape[1]
        shp_y = y.shape
        if self.apply_softmax:
            probs = F.softmax(x, dim=1)
        else:
            probs = x

        with torch.no_grad():
            if len(shp_y) != len(probs.shape):
                 if len(shp_y) == len(probs.shape) - 1 and len(shp_y) >= 2 and shp_y == probs.shape[2:]:
                      y = y.unsqueeze(1)
                 elif len(shp_y) == len(probs.shape) and shp_y[1] == 1: pass
                 else: raise ValueError(f"Shape mismatch...")
            y_long = y.long()
            mask = None
            if self.ignore_index is not None:
                mask = (y_long != self.ignore_index)

            if probs.shape == y.shape:
                 y_onehot = y.float()
                 if mask is not None:
                      warnings.warn("Input y has same shape...") # Shortened warning
                      y_indices_for_mask = torch.argmax(y_onehot, dim=1, keepdim=True)
                      mask = (y_indices_for_mask != self.ignore_index)
                      y_onehot = y_onehot * mask
            else:
                y_onehot = torch.zeros_like(probs, device=probs.device)
                y_onehot.scatter_(1, y_long, 1)
                if mask is not None: y_onehot = y_onehot * mask

            sum_gt = y_onehot.sum(dim=(2, 3))

        if mask is not None: probs = probs * mask

        intersect = (probs * y_onehot).sum(dim=(2, 3))
        sum_pred = probs.sum(dim=(2, 3))

        intersect = intersect.sum(0)
        sum_pred = sum_pred.sum(0)
        sum_gt = sum_gt.sum(0)

        denominator = sum_pred + sum_gt
        dc = (2. * intersect + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8))

        # --- Average Dice Logic ---
        # Decide how to average if ignore_index is used.
        # If a class index IS the ignore_index, should its Dice score contribute to mean?
        # Standard practice often excludes the ignored class from the final average.
        valid_classes_mask = torch.ones_like(dc, dtype=torch.bool)
        if self.ignore_index is not None and 0 <= self.ignore_index < num_classes:
            valid_classes_mask[self.ignore_index] = False

        # Only average over valid classes
        if valid_classes_mask.sum() > 0:
            dc_mean = dc[valid_classes_mask].mean()
        else: # Avoid NaN if all classes are ignored (edge case)
            dc_mean = torch.tensor(0.0, device=dc.device) # Or handle as error

        return -dc_mean

class DiceCELoss(nn.Module):
    """Combines MemoryEfficientDiceLoss and Cross Entropy Loss"""
    def __init__(self, dice_weight=1.0, ce_weight=1.0, ignore_index: int = None, smooth_dice=1e-5, ce_kwargs={}):
        super(DiceCELoss, self).__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        # Use the memory efficient Dice loss
        self.dice = MemoryEfficientDiceLoss(apply_softmax=True, ignore_index=ignore_index, smooth=smooth_dice)
        ce_ignore_kwargs = ce_kwargs.copy()

        if ignore_index is not None:
            ce_ignore_kwargs['ignore_index'] = ignore_index
        self.cross_entropy = nn.CrossEntropyLoss(**ce_ignore_kwargs)

    def forward(self, outputs, targets):
        # --- Dice Loss ---
        # Dice loss expects logits (apply_softmax=True handles conversion inside)
        dice_loss = self.dice(outputs, targets)

        # --- Cross Entropy Loss ---
        # Ensure target is long and shape [N, H, W] for CE
        if targets.ndim == 4 and targets.shape[1] == 1:
             targets_ce = targets.squeeze(1).long()
        else:
             # Assume targets are already [N, H, W] or convert if needed
             targets_ce = targets.long()

        ce_loss = self.cross_entropy(outputs, targets_ce)

        # --- Combine ---
        combined_loss = (self.dice_weight * dice_loss) + (self.ce_weight * ce_loss)
        return combined_loss
    


## Train/Eval loops

In [None]:

# --- Training Loop (Adapted for nnU-Net style) ---
def train_loop(dataloader, model, loss_fn, optimizer, scheduler, accumulation_steps, device):
    """Performs one epoch of training resembling nnU-Net practices."""
    model.train()
    total_loss = 0.0
    processed_batches = 0 # Tracks effective batches (after accumulation)

    optimizer.zero_grad()

    # Determine total iterations for this epoch for tqdm progress bar
    total_iters_in_epoch = len(dataloader)

    pbar = tqdm(enumerate(dataloader), total=total_iters_in_epoch, desc="Training")
    for batch_idx, (X, y) in pbar:
        #### !!! CRITICAL FOR NNUNET STYLE !!! ####
        # Apply extensive Data Augmentation HERE
        # This is ideally done inside your Dataset __getitem__ or using
        # a Pytorch augmentation library (Albumentations, batchgenerators)
        # Examples: Random rotations, scaling, elastic deform, gamma, contrast...
        # X, y = your_augmentation_function(X, y)
        #### ------------------------------------ ####

        X, y = X.to(device), y.to(device)

        # Forward pass (No AMP used by default in nnU-Net)
        pred = model(X)
        loss = loss_fn(pred, y) # Combined Dice+CE loss

        # Scale loss for gradient accumulation
        scaled_loss = loss / accumulation_steps

        # Backward pass
        scaled_loss.backward()

        # Optimizer step after accumulation_steps batches
        if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == total_iters_in_epoch:
            optimizer.step()
            scheduler.step() # Step the scheduler after optimizer step
            optimizer.zero_grad()

            # Log loss and update progress bar
            # Note: Logging unscaled loss from the *last* micro-batch in accumulation cycle
            total_loss += loss.item()
            processed_batches += 1
            pbar.set_postfix({'loss': loss.item(), 'lr': optimizer.param_groups[0]['lr']})

            # Optional memory check (keep if useful)
            if processed_batches == 1:
                try:
                    # print(f"Effective batch done. Memory allocated: {torch.cuda.memory_allocated(device)} bytes", flush=True)
                    pass
                except Exception:
                    pass

    avg_loss = total_loss / processed_batches if processed_batches > 0 else 0
    print(f"Training Avg loss (per effective batch): {avg_loss:>8f}")
    print(f"End of Epoch LR: {optimizer.param_groups[0]['lr']:>8f}")
    return avg_loss

# --- Evaluation Loop (Modified for aggregated IoU) ---
def eval_loop(dataloader, model, loss_fn, device, target_size=512):
    """
    Evaluation loop calculating loss, aggregated Dice, and aggregated IoU.

    Args:
        dataloader: yields batches of (list[Tensor(C,H,W)], list[Tensor(H,W)])
        model: the neural network model (on device)
        loss_fn: the combined loss function (e.g., DiceCELoss) used for training
        device: the torch device (cuda or cpu)
        target_size: the size the model expects for input
    """
    model.eval()
    num_images_processed = 0
    total_loss = 0.0
    num_classes = -1 # Will be determined from first prediction

    # --- Aggregation Containers (CPU tensors recommended) ---
    # For Dice (using the memory-efficient method's components)
    total_dice_intersect = None # Shape [C]
    total_dice_sum_pred = None  # Shape [C]
    total_dice_sum_gt = None    # Shape [C]

    # For IoU
    total_iou_intersection = None # Shape [C]
    total_iou_union = None      # Shape [C]
    # --------------------------------------------------------

    # Use ignore_index from the main loss function for consistency
    ignore_index = getattr(loss_fn, 'ignore_index', None)
    smooth_eval = getattr(loss_fn.dice, 'smooth', 1e-6) if hasattr(loss_fn, 'dice') else 1e-6 # Match smooth

    with torch.no_grad():
        for X_batch_list, y_batch_list in tqdm(dataloader, desc="Eval"):
            # 1. Forward Transform Inputs
            try:
                 X_processed, meta_list = process_batch_forward(X_batch_list, target_size=target_size)
            except NameError: 
                raise ValueError("`process_batch_forward` not found.")

            X_processed = X_processed.to(device)

            # 2. Model Inference
            pred_processed = model(X_processed) # Logits [N, C, target, target]

            # --- Determine num_classes once ---
            if num_classes == -1:
                num_classes = pred_processed.shape[1]
                # Initialize aggregation tensors now that we know num_classes
                total_dice_intersect = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
                total_dice_sum_pred = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
                total_dice_sum_gt = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
                total_iou_intersection = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
                total_iou_union = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
            # -----------------------------------

            # 3. Reverse Transform Outputs
            try:
                 pred_original_list = process_batch_reverse(pred_processed, meta_list, interpolation='bilinear')
            except NameError: raise ValueError("`process_batch_reverse` not found.")

            # 4. Compute Loss & Accumulate Metrics per Image
            current_batch_size = len(y_batch_list)
            for i in range(current_batch_size):
                # --- Prepare single image prediction and label ---
                pred_single_logits = pred_original_list[i].to(device) # [C, H_orig, W_orig]
                label_single_orig = y_batch_list[i].to(device) # [H_orig, W_orig] or [1, H_orig, W_orig]

                pred_single_batched = pred_single_logits.unsqueeze(0) # [1, C, H, W]
                label_single_batched = label_single_orig.unsqueeze(0) # [1, H, W] or [1, 1, H, W]

                # Convert label to index map if needed [1, H, W]
                if label_single_batched.ndim == 4 and label_single_batched.shape[1] == 1:
                    label_single_idxmap = label_single_batched.squeeze(1) # [1, H, W]
                elif label_single_batched.ndim == 3:
                    label_single_idxmap = label_single_batched # Already [1, H, W]
                else:
                    raise ValueError(f"Unsupported label shape: {label_single_batched.shape}")

                # --- Calculate Loss ---
                loss = loss_fn(pred_single_batched, label_single_batched) # Use original batch dim label for loss
                total_loss += loss.item()

                # --- Get Hard Predictions ---
                pred_single_hard = torch.argmax(pred_single_logits, dim=0) # [H_orig, W_orig]

                # --- Calculate & Accumulate Dice Components ---
                # Re-calculate necessary components for Dice aggregation
                probs_single = F.softmax(pred_single_logits, dim=0) # [C, H, W]
                gt_single_long = label_single_idxmap.squeeze(0).long() # [H, W] Long type needed

                mask = None
                if ignore_index is not None:
                    mask = (gt_single_long != ignore_index) # [H,W]

                gt_onehot = F.one_hot(gt_single_long, num_classes=num_classes).permute(2, 0, 1).float() # [C, H, W]

                if mask is not None:
                    gt_onehot = gt_onehot * mask.unsqueeze(0) # Apply mask [C, H, W]
                    probs_single_masked = probs_single * mask.unsqueeze(0)
                else:
                    probs_single_masked = probs_single

                # Sum over spatial H, W -> Shape [C]
                intersect_dice = (probs_single_masked * gt_onehot).sum(dim=(1, 2))
                sum_pred_dice = probs_single_masked.sum(dim=(1, 2))
                sum_gt_dice = gt_onehot.sum(dim=(1, 2)) # Use masked gt

                # Accumulate on CPU
                total_dice_intersect += intersect_dice.cpu().to(torch.float64)
                total_dice_sum_pred += sum_pred_dice.cpu().to(torch.float64)
                total_dice_sum_gt += sum_gt_dice.cpu().to(torch.float64)

                # --- Calculate & Accumulate IoU Components ---
                pred_hard_onehot = F.one_hot(pred_single_hard, num_classes=num_classes).permute(2, 0, 1).bool() # [C, H, W]
                gt_onehot_bool = gt_onehot.bool() # Use the already created (and potentially masked) one-hot GT

                if mask is not None:
                    pred_hard_onehot_masked = pred_hard_onehot & mask.unsqueeze(0) # Apply ignore mask
                else:
                    pred_hard_onehot_masked = pred_hard_onehot

                # Calculate intersection and union per class using boolean logic
                intersection_iou = (pred_hard_onehot_masked & gt_onehot_bool).sum(dim=(1, 2)) # [C]
                union_iou = (pred_hard_onehot_masked | gt_onehot_bool).sum(dim=(1, 2)) # [C]

                # Accumulate on CPU
                total_iou_intersection += intersection_iou.cpu().to(torch.float64)
                total_iou_union += union_iou.cpu().to(torch.float64)
                # -----------------------------------------------

                num_images_processed += 1

    # --- Calculate Final Average Metrics ---
    if num_images_processed == 0: # Handle empty dataloader case
        print("Evaluation dataloader was empty.")
        return 0.0, 0.0, 0.0 # Loss, Dice, IoU

    avg_loss = total_loss / num_images_processed

    # --- Final Aggregated Dice Calculation ---
    # Using Micro-average: (Sum of numerators) / (Sum of denominators)
    dice_numerator = 2. * total_dice_intersect + smooth_eval
    dice_denominator = total_dice_sum_pred + total_dice_sum_gt + smooth_eval
    # Create mask for valid classes (excluding ignore_index)
    valid_class_mask_dice = torch.ones(num_classes, dtype=torch.bool)
    if ignore_index is not None and 0 <= ignore_index < num_classes:
        valid_class_mask_dice[ignore_index] = False

    # Calculate micro average score over valid classes
    avg_dice_micro = 0.0
    if valid_class_mask_dice.sum() > 0:
        avg_dice_micro = (dice_numerator[valid_class_mask_dice].sum() /
                         torch.clip(dice_denominator[valid_class_mask_dice].sum(), 1e-8)).item()

    # --- Optional: Macro Average Dice ---
    per_class_dice = dice_numerator / torch.clip(dice_denominator, 1e-8)
    avg_dice_macro = 0.0
    if valid_class_mask_dice.sum() > 0:
        avg_dice_macro = per_class_dice[valid_class_mask_dice].mean().item()


    # --- Final Aggregated IoU Calculation ---
    # Create mask for valid classes (excluding ignore_index) for IoU
    valid_class_mask_iou = torch.ones(num_classes, dtype=torch.bool)
    if ignore_index is not None and 0 <= ignore_index < num_classes:
        valid_class_mask_iou[ignore_index] = False

    # Calculate per-class IoU using aggregated counts
    # Add epsilon to denominator for stability
    epsilon = 1e-8
    per_class_iou = total_iou_intersection / (total_iou_union + epsilon)

    # Calculate Mean IoU (mIoU) over valid classes
    mean_iou = 0.0
    if valid_class_mask_iou.sum() > 0:
        mean_iou = per_class_iou[valid_class_mask_iou].mean().item()


    print(f"\n--- Evaluation Complete ---")
    print(f"  Images Processed: {num_images_processed}")
    print(f"  Average Loss (Original Size): {avg_loss:>8f}")
    print(f"  Micro Avg Dice Score ({valid_class_mask_dice.sum().item()} classes): {avg_dice_micro:>8f}")
    print(f"  Macro Avg Dice Score ({valid_class_mask_dice.sum().item()} classes): {avg_dice_macro:>8f}")
    print(f"  Mean IoU (mIoU) ({valid_class_mask_iou.sum().item()} classes): {mean_iou:>8f}")
    print(f"  --- Per-Class IoU ---")
    for c in range(num_classes):
        if valid_class_mask_iou[c]: # Only print for valid classes
            print(f"    Class {c}: {per_class_iou[c].item():>8f}")
        else:
            print(f"    Class {c}: Ignored")
    print("-" * 25)

    # Return relevant metrics for model saving (e.g., micro dice and mIoU)
    return avg_loss, avg_dice_micro, mean_iou

In [None]:
# --- Setup ---
if 'accumulation_steps' not in locals() :
    # ... (keep your accumulation steps calculation) ...
    assert 'target_batch_size' in globals() and 'batch_size' in globals(), \
           "Please define target_batch_size and batch_size"
    assert target_batch_size >= batch_size, "target_batch_size must be >= batch_size"
    assert target_batch_size % batch_size == 0, "target_batch_size must be divisible by batch_size for simple accumulation"
    accumulation_steps = target_batch_size // batch_size
    print(f"Using Gradient Accumulation: effective batch size {target_batch_size} ({accumulation_steps} steps)")


# --- Define Model, Loss, Optimizer, Scheduler ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = unet().to(device) # Assuming unet() exists and returns your U-Net model

# --- Define Loss, Optimizer, Scheduler ---
# Configure ignore_index (e.g., 3 to ignore class 3, or 255 if used in labels, None otherwise)
EVAL_IGNORE_INDEX = 3 # Example: ignore class 3 during evaluation metric calculation
TRAIN_IGNORE_INDEX = None  # Example: train on all classes (0,1,2,3)

loss_fn = DiceCELoss(ignore_index=TRAIN_IGNORE_INDEX, smooth_dice=1) # Training loss
# Evaluation loss object used inside eval loop only to get settings like ignore_index
# It is NOT used to calculate the loss score reported for eval (that uses training loss object)
# But we pass it to eval_loop so it knows which index to ignore for metric calc if needed
eval_settings_provider = DiceCELoss(ignore_index=EVAL_IGNORE_INDEX)


optimizer = optim.SGD(model.parameters(), lr=INITIAL_LR, momentum=SGD_MOMENTUM,
                      weight_decay=WEIGHT_DECAY, nesterov=True)

steps_per_epoch = (len(train_dataloader) // accumulation_steps) + (1 if len(train_dataloader) % accumulation_steps != 0 else 0)
total_iters = steps_per_epoch * EPOCHS
scheduler = PolynomialLR(optimizer, total_iters=total_iters, power=0.9)

best_dev_dice = -np.inf # Track best Dice score
best_dev_miou = -np.inf # Track best mIoU
best_dev_loss = np.inf # Track loss corresponding to best metric

# --- Training and Evaluation Loop ---
print("\nStarting Training (nnU-Net style)...")
for t in range(EPOCHS):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train_loop(train_dataloader, model, loss_fn, optimizer, scheduler, accumulation_steps, device)

    # Pass the eval_settings_provider to eval_loop
    val_loss, val_dice_micro, val_miou = eval_loop(val_dataloader, model, eval_settings_provider, device)

    # Save model based on validation MICRO DICE score improvement
    # Could also choose mIoU validation by changing 'val_dice_micro > best_dev_dice'
    if val_dice_micro > best_dev_dice:
        best_dev_dice = val_dice_micro
        best_dev_miou = val_miou # Save corresponding mIoU
        best_dev_loss = val_loss # Save corresponding loss
        print(f"Validation Micro Dice score improved ({best_dev_dice:.6f}). Saving model...")
        checkpoint_path = os.path.join(MODEL_SAVE_DIR, "unet_best_dice.pytorch") # Changed name
        checkpoint = {
            "epoch": t + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "best_dev_dice": best_dev_dice,
            "best_dev_miou": best_dev_miou,
            "best_dev_loss": best_dev_loss,
            "notes": f"Model saved based on best Micro Dice. Ignored index for metric: {EVAL_IGNORE_INDEX}"
        }
        torch.save(checkpoint, checkpoint_path)
    else:
        print(f"Validation Micro Dice score did not improve from {best_dev_dice:.6f}")

print("\n--- Training Finished! ---")
print(f"Best validation Micro Dice score achieved: {best_dev_dice:.6f}")
print(f"Corresponding validation mIoU: {best_dev_miou:.6f}")
print(f"Corresponding validation loss: {best_dev_loss:.6f}")
print(f"Best model saved to: {os.path.join(MODEL_SAVE_DIR, 'unet_best_dice.pytorch')}")

## EVAL Tester

In [None]:
model = unet().to(device)
loss_fn = DiceCELoss(ignore_index=3)


optimizer = optim.SGD(model.parameters(), lr=INITIAL_LR, momentum=SGD_MOMENTUM,
                      weight_decay=WEIGHT_DECAY, nesterov=True)

total_iters = (len(train_dataloader) // accumulation_steps) * EPOCHS
scheduler = PolynomialLR(optimizer, total_iters=total_iters, power=0.9)

checkpoint = torch.load("unet/checkpoint")

model.load_state_dict(checkpoint["model"])
# optimizer.load_state_dict(checkpoint["optimizer"])
# scaler.load_state_dict(checkpoint["scaler"])

model.to(device)

new_dev_loss = eval_loop(val_dataloader, model, loss_fn, device)

## Dice Tester

In [None]:
X, y = next(iter(train_dataloader))
X = X.to(device)
y = y.to(device)

dice = MemoryEfficientDiceLoss()

model = unet().to(device)
checkpoint = torch.load("unet/checkpoint")
model.load_state_dict(checkpoint["model"])
model.to(device)

pred = model(X[0].unsqueeze(0))

print(dice(pred, y[0].unsqueeze(0)))