### Import Modules

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # %env CUDA_VISIBLE_DEVICES=0
import numpy as np
import pandas as pd
import nibabel as nib
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import logging
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import train_test_split
from monai.data import Dataset, DataLoader, decollate_batch
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    DivisiblePadd,
    Resize, Resized,
    ScaleIntensityd,
    RandFlipd,
    Activations,
    AsDiscrete
)
from monai.networks.nets import UNet, VNet, SegResNet, AttentionUnet, UNETR, SwinUNETR
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.utils import first, set_determinism
from monai.visualize.utils import blend_images

### Define Functions and Classes

In [None]:
def calculate_required_divisibility(model_name, **divisibility_kwargs):
    if model_name in ["UNet", "AttentionUnet"]:
        # UNet and AttentionUNet: divisibility = product of all strides
        # Example: strides=(2,2,2) → 2×2×2 = 8
        strides = divisibility_kwargs.get('strides', (2, 2, 2))
        total_downsampling = 1
        for stride in strides:
            total_downsampling *= stride
        return total_downsampling
    elif model_name == "VNet":
        # VNet: fixed 5-level architecture
        # Architecture: in → down_tr32 → down_tr64 → down_tr128 → down_tr256
        # Each transition has stride 2: 2^5 = 32
        return 32
    elif model_name == "SegResNet":
        # SegResNet: uses residual blocks with flexible architecture
        # Can technically handle various sizes but 16 minimizes edge effects
        return 16 # Recommended divisibility for optimal performance
    elif model_name == "UNETR":
        # UNETR: ViT encoder uses fixed 16×16×16 patches (not configurable)
        # Input dimensions must be divisible by 16
        return 16
    elif model_name == "SwinUNETR":
        # SwinUNETR: depends on spatial_dims and depths
        # Downsampling: patch_size × 2^(num_stages-1)
        # With defaults: patch_size=2, depths=(2,2,2,2) with 4 stages, so 2 × 2^3 = 16
        depths = divisibility_kwargs.get('depths', (2, 2, 2, 2))
        num_stages = len(depths)
        return 2 ** num_stages  # 2^4 = 16 with default depths
    else:
        print(f"Unsupported model name: {model_name}")
        return None

def load_data(data_dir, batch_size, resize_dim=None, test_size=0.2, inference=False, model_name=None, **model_kwargs):
    if not inference: # Training/Validation
        # Define training transforms with data augmentation
        train_transforms = [
            LoadImaged(keys=["brain", "lesion"], image_only=True),
            EnsureChannelFirstd(keys=["brain", "lesion"]),
            ScaleIntensityd(keys="brain"),
            RandFlipd(keys=["brain", "lesion"], spatial_axis=0, prob=0.5),
            RandFlipd(keys=["brain", "lesion"], spatial_axis=1, prob=0.5), 
            RandFlipd(keys=["brain", "lesion"], spatial_axis=2, prob=0.5)
        ]
        # Define validation transforms without augmentation
        val_transforms = [
            LoadImaged(keys=["brain", "lesion"], image_only=True),
            EnsureChannelFirstd(keys=["brain", "lesion"]),
            ScaleIntensityd(keys="brain")
        ]
        # Add padding or resizing before normalization (ScaleIntensityd)
        if resize_dim is None:
            required_k = calculate_required_divisibility(model_name, **model_kwargs)
            # Padding approach: ensure dimensions are divisible by model requirements
            if required_k is not None and required_k > 1:
                trainval_pad_transform = DivisiblePadd(keys=["brain", "lesion"], k=required_k, method="end")
                train_transforms.insert(2, trainval_pad_transform)
                val_transforms.insert(2, trainval_pad_transform)
        else:
            # Resizing approach: resize to fixed dimensions
            trainval_resize_transform = Resized(keys=["brain", "lesion"], spatial_size=resize_dim, mode=["trilinear", "nearest"])
            train_transforms.insert(2, trainval_resize_transform)
            val_transforms.insert(2, trainval_resize_transform)
        train_transforms = Compose(train_transforms)
        val_transforms = Compose(val_transforms)
        # Load file paths
        brains = sorted(glob.glob(os.path.join(data_dir, "Brain", "*.nii.gz")))
        lesions = sorted(glob.glob(os.path.join(data_dir, "Lesion", "*.nii.gz")))
        data_dicts = [
            {"brain": brain_name, "lesion": lesion_name}
            for brain_name, lesion_name in zip(brains, lesions)
        ]
        # Split data into train and validation sets
        train_files, val_files = train_test_split(data_dicts, test_size=test_size, random_state=42)
        train_ds = Dataset(data=train_files, transform=train_transforms)
        val_ds = Dataset(data=val_files, transform=val_transforms)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
        val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        return train_loader, val_loader
    else: # Test
        # Custom transform to save original metadata before any spatial transformations
        class SaveOriginalMetadata:
            def __call__(self, data):
                brain = data["brain"]
                # Save metadata for later restoration
                data["original_dim"] = tuple(brain.shape[1:])
                data["affine"] = brain.meta["affine"]
                data["filename"] = os.path.basename(brain.meta["filename_or_obj"])
                return data
        # Build test transforms: save original size, then apply same preprocessing as training
        test_transforms = [
            LoadImaged(keys="brain", image_only=False), # Load with metadata
            EnsureChannelFirstd(keys="brain"),
            SaveOriginalMetadata(), # Save original spatial dimensions
            ScaleIntensityd(keys="brain")
        ]
        # Add same padding/resizing as training
        if resize_dim is None:
            required_k = calculate_required_divisibility(model_name, **model_kwargs)
            if required_k is not None and required_k > 1:
                test_pad_transform = DivisiblePadd(keys="brain", k=required_k, method="end")
                test_transforms.insert(3, test_pad_transform)
        else:
            test_transforms.insert(3, Resized(keys="brain", spatial_size=resize_dim, mode="trilinear"))
        test_transforms = Compose(test_transforms)
        brains = sorted(glob.glob(os.path.join(data_dir, "Brain", "*.nii.gz")))
        data_dicts = [{"brain": brain_name} for brain_name in brains]
        test_ds = Dataset(data=data_dicts, transform=test_transforms)
        # Custom collate to preserve metadata as lists instead of tensors
        def custom_collate(batch):
            return {
                'brain': torch.stack([item['brain'] for item in batch]),
                'affine': [item['affine'] for item in batch], # Keep as list
                'original_dim': [item['original_dim'] for item in batch], # Keep as list
                'filename': [item['filename'] for item in batch] # Keep as list
            }
        test_loader = DataLoader(test_ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available(), collate_fn=custom_collate)
        return test_loader

def get_model(model_name, img_size, num_classes=1):
    if model_name == "UNet":
        params = {
            'spatial_dims': 3, # Required: spatial dimensions
            'in_channels': 1, # Required: input channels
            'out_channels': num_classes, # Required: output classes
            'channels': (32, 64, 128, 256), # Required: sequence of feature channels
            'strides': (2, 2, 2) # Required: sequence of convolution strides
        }
        model = UNet(**params)
    elif model_name == "VNet":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'in_channels': 1, # Default: input channels
            'out_channels': num_classes # Non-default: output classes
        }
        model = VNet(**params)
    elif model_name == "SegResNet":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'in_channels': 1, # Default: input channels
            'out_channels': num_classes # Non-default: output classes
        }
        model = SegResNet(**params)
    elif model_name == "AttentionUnet":
        params = {
            'spatial_dims': 3, # Required: spatial dimensions
            'in_channels': 1, # Required: input channels
            'out_channels': num_classes, # Required: output classes
            'channels': (32, 64, 128, 256), # Required: sequence of feature channels
            'strides': (2, 2, 2) # Required: sequence of convolutoin strides
        }
        model = AttentionUnet(**params)
    elif model_name == "UNETR":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'in_channels': 1, # Required: input channels
            'out_channels': num_classes, # Required: output classes
            'img_size': img_size, # Required: input image size
            'feature_size': 16, # Default: CNN decoder feature channels
            'hidden_size': 768, # Default: transformer embedding dimension
            'mlp_dim': 3072, # Default: MLP dimension (typically 4 × hidden_size)
            'num_heads': 12 # Default: number of attention heads (same across all layers)
        }
        model = UNETR(**params)
    elif model_name == "SwinUNETR":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'in_channels': 1, # Required: input channels
            'out_channels': num_classes, # Required: output classes
            'patch_size': 2, # Default: spatial patch size for tokenization
            'feature_size': 24, # Default: initial transformer embedding dimension
            'depths': (2, 2, 2, 2), # Default: sequence of transformer blocks per stage
            'num_heads': (3, 6, 12, 24) # Default: sequence if attention heads per stage
        }
        model = SwinUNETR(**params)
    else:
        raise ValueError(f"Unsupported model name: {model_name}")
    return model

def get_grad_scaler(device):
    if device.type != "cuda":
        return None
    try: # Try newest API first (PyTorch 2.0+)
        return torch.GradScaler("cuda")
    except (AttributeError, TypeError):
        try: # Try torch.amp (PyTorch 1.10+)
            return torch.amp.GradScaler("cuda")
        except (AttributeError, TypeError): # Fall back to old API
            return torch.cuda.amp.GradScaler()

def get_autocast_context(device, enabled=True):
    if not enabled:
        from contextlib import nullcontext
        return nullcontext()
    try:
        # Try newest API first (PyTorch 2.0+)
        return torch.autocast(device_type=device.type, dtype=torch.float16)
    except (AttributeError, TypeError):
        try:
            # Try torch.amp (PyTorch 1.10+)
            return torch.amp.autocast(device.type)
        except (AttributeError, TypeError):
            # Fall back to CUDA-specific (old)
            if device.type == "cuda":
                return torch.cuda.amp.autocast()
            else:
                from contextlib import nullcontext
                return nullcontext()

def train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric, post_pred):
    model.train() # Set model to training mode
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        # Prepare data
        images, labels = (
            batch_data["brain"].to(device),
            batch_data["lesion"].to(device),
        )
        # Forward pass with mixed precision (if available)
        optimizer.zero_grad()
        with get_autocast_context(device, enabled=(scaler is not None)):
            outputs = model(images)
            loss = criterion(outputs, labels)
        # Backward pass with gradient scaling (if available)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        # Accumulate metrics
        epoch_loss += loss.item()
        outputs = [post_pred(i) for i in decollate_batch(outputs)] # Apply post-processing
        metric(y_pred=outputs, y=labels)
    epoch_metric = metric.aggregate().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, metric, post_pred):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = (
                batch_data["brain"].to(device),
                batch_data["lesion"].to(device),
            )
            outputs = model(images)
            outputs = [post_pred(i) for i in decollate_batch(outputs)]
            metric(y_pred=outputs, y=labels)
    return metric.aggregate().item()

class EarlyStopping:
    def __init__(self, patience=30, delta=0):
        self.patience = patience # Number of epochs to wait before stopping
        self.delta = delta # Minimum improvement threshold
        self.best_score = None
        self.early_stop = False
        self.counter = 0
    def __call__(self, metric):
        score = metric
        if self.best_score is None: # First epoch
            self.best_score = score
        elif score < self.best_score - self.delta: # Metric decreased
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else: # Metric improved
            self.best_score = score
            self.counter = 0

def train_model(model_dir, model, device, train_loader, val_loader, logger,
        criterion, metric, post_pred, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1, es_patience=30):
    # Setup optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    scaler = get_grad_scaler(device)
    start_time = time.time()
    best_metric = -1
    best_metric_epoch = -1
    best_model_state = None
    early_stopping = EarlyStopping(patience=es_patience, delta=0)
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        # Training phase
        epoch_loss, epoch_metric = train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric, post_pred)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        # Validation phase
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, metric, post_pred)
            metric_values.append(val_metric)
            # Save best model
            if val_metric > best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best DSC: {best_metric:.4f} at epoch {best_metric_epoch}")
            # Check early stopping
            early_stopping(val_metric)
            if early_stopping.early_stop:
                logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                print(f"; Early stopping triggered at epoch {epoch + 1}", end="")
                break
        epoch_end_time = time.time()
        logger.info(
            f"Epoch {epoch + 1} completed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - "
            f"Training loss: {epoch_loss:.4f}, Training DSC: {epoch_metric:.4f}, Validation DSC: {val_metric:.4f}"
        )
        # Update learning rate
        lr_scheduler.step()
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(
        f"Best DSC: {best_metric:.3f} at epoch {best_metric_epoch}; "
        f"Total time consumed: {total_time/60:.2f} mins"
    )
    print(
        f"\nBest DSC: {best_metric:.3f} at epoch {best_metric_epoch}; "
        f"Total time consumed: {total_time/60:.2f} mins"
    )
    # Load best model weights
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(8, 5))
    axs[0].plot( [i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training DSC', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation DSC', color='blue')
    axs[1].set_title('Training DSC vs. Validation DSC')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('DSC')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)

class GradCAM3D:
    def __init__(self, model, target_layer, criterion):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.model.eval()
        # Register hooks to capture intermediate features and gradients
        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_full_backward_hook(self.save_gradient)
        self.loss = criterion
    def save_activation(self, _module, _input, output):
        # Capture forward pass activations
        self.activations = output.detach()
    def save_gradient(self, _module, _grad_input, grad_output):
        # Capture backward pass gradients
        self.gradients = grad_output[0].detach()
    def __call__(self, brain, lesion):
        # Forward pass to compute loss
        self.model.zero_grad()
        output = self.model(brain)
        loss = self.loss(output, lesion)
        loss.backward()
        # Compute importance weights by global average pooling of gradients
        gradients = self.gradients
        activations = self.activations
        weights = torch.mean(gradients, dim=(2, 3, 4), keepdim=True) # Global average pooling
        # Weighted combination of activation maps
        cam = torch.sum(weights * activations, dim=1, keepdim=True) # Weighted sum
        cam = F.relu(cam) # Apply ReLU to focus on positive contributions
        # Resize CAM to match input dimensions
        cam = F.interpolate(cam, size=brain.shape[2:], mode='trilinear', align_corners=False)
        # Normalize to [0, 1] range
        cam = cam - torch.min(cam)
        cam = cam / torch.max(cam)
        return cam.squeeze(0)

def apply_gradcam_to_sample(model, val_loader, target_layers, device, criterion, sample_index, slice_index):
    # Load sample data
    brain = val_loader.dataset[sample_index]["brain"].to(device) 
    lesion = val_loader.dataset[sample_index]["lesion"].to(device)
    # Setup visualization grid
    _, axes = plt.subplots(len(target_layers), 3, figsize=(12, 4 * len(target_layers)))
    # Extract 2D slices for visualization
    brain_slice = torch.rot90(brain[0, :, :, slice_index], k=1, dims=(0, 1))
    lesion_slice = torch.rot90(lesion[0, :, :, slice_index], k=1, dims=(0,1))
    # Iterate through each target layer
    for i, (layer_name, target_layer) in enumerate(target_layers.items()):
        # Initialize GradCAM for current layer
        grad_cam = GradCAM3D(model, target_layer, criterion)
        # Compute GradCAM heatmap
        with torch.no_grad():
            _ = model(brain.unsqueeze(0)) # Warm-up forward pass
        cam = grad_cam(brain.unsqueeze(0), lesion.unsqueeze(0))
        # Extract 2D slice from GradCAM heatmap
        cam_slice = torch.rot90(cam[0, :, :, slice_index], k=1, dims=(0,1))
        # Visualize: Original image
        axes[i, 0].imshow(brain_slice.detach().cpu(), cmap='gray')
        axes[i, 0].set_title("Brain")
        axes[i, 0].axis('off')
        # Visualize: Ground truth lesion
        axes[i, 1].imshow(lesion_slice.detach().cpu(), cmap='gray')
        axes[i, 1].set_title("Lesion")
        axes[i, 1].axis('off')
        # Visualize: GradCAM overlay
        axes[i, 2].imshow(brain_slice.detach().cpu(), cmap='gray')
        axes[i, 2].imshow(cam_slice.detach().cpu(), cmap='jet', alpha=0.5)
        axes[i, 2].set_title(f"GradCAM: {layer_name}")
        axes[i, 2].axis('off')
    plt.tight_layout()
    plt.show()

def get_gradcam_target_layers(model, model_name):
    target_layers = {}
    if model_name == "UNet":
        # UNet structure: model.model contains [encoder, bottleneck, decoder, final_conv]
        target_layers = {
            "Encoder Last": model.model[0][-1],  # Last encoder block - captures high-level semantic features
            "Bottleneck": model.model[1],  # Bottleneck layer - most compressed feature representation
            "Decoder First": model.model[2][0],  # First decoder block - begins spatial detail restoration
            "Final Conv": model.model[-1]  # Final convolution layer - produces segmentation output
        }
    elif model_name == "VNet":
        # VNet structure: down_tr (encoder transitions), up_tr (decoder transitions), out_tr (output)
        # Note: VNet uses "transition" modules (downsampling + residual connection + feature processing) instead of traditional "blocks"
        target_layers = {
            "Encoder Last": model.down_tr256,  # Last encoder transition - captures highest level abstraction
            "Decoder First": model.up_tr256,  # First decoder transition - begins spatial reconstruction
            "Final Conv": model.out_tr  # Output transition - produces final segmentation output
        }
    elif model_name == "SegResNet":
        # SegResNet structure: down_layers (encoder, list of lists), up_layers (decoder, list of lists), conv_final (output, list)
        target_layers = {
            "Encoder Last": model.down_layers[-1][-1],  # Last encoder layer - captures highest level abstraction
            "Decoder First": model.up_layers[0][0],  # First decoder layer - begins spatial reconstruction
            "Final Conv": model.conv_final[-1]  # Final convolution layer - produces segmentation output
        }
    elif model_name == "AttentionUnet":
        # AttentionUNet structure: encoder, attention gates, decoder, output
        try:
            target_layers = {
                "Encoder Last": model.model[0][-1],  # Last encoder block - captures high-level semantic features
                "Decoder First": model.model[2][0],  # First decoder block - begins spatial detail restoration
                "Final Conv": model.model[-1]  # Final convolution layer - produces segmentation output
            }
        except (AttributeError, IndexError):
            print(f"Warning: Could not auto-detect AttentionUnet layers")
    elif model_name == "UNETR":
        # UNETR structure: vit (transformer encoder), encoder (CNN), decoder (CNN), out (output)
        target_layers = {
            "Transformer Last": model.vit.blocks[-1],  # Last transformer block - captures global context features
            "Encoder Last": model.encoder4,  # Last CNN encoder - fuses transformer and CNN features
            "Decoder First": model.decoder5,  # First decoder block - begins spatial reconstruction
            "Final Conv": model.out  # Final convolution layer - produces segmentation output
        }
    elif model_name == "SwinUNETR":
        # SwinUNETR structure: swinViT (Swin transformer), encoder (CNN), decoder (CNN), out (output)
        target_layers = {
            "Swin Last": model.swinViT.layers3[-1].blocks[-1],  # Last Swin block - captures hierarchical attention features
            "Encoder Last": model.encoder4,  # Last CNN encoder - fuses transformer and CNN features
            "Decoder First": model.decoder5,  # First decoder block - begins spatial reconstruction
            "Final Conv": model.out  # Final convolution layer - produces segmentation output
        }
    else:
        raise ValueError(f"Unsupported model name: {model_name}")
    return target_layers

def apply_best_model(model_dir, model, device, test_loader, post_pred, pred_dir, resize_dim):
    # Load best model weights
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    with torch.no_grad():
        for batch_data in test_loader:
            brain = batch_data["brain"].to(device)
            affine = batch_data["affine"][0] # Affine matrix for NIfTI
            original_dim = tuple(batch_data["original_dim"][0]) # Original spatial dimensions
            filename = batch_data["filename"][0]
            # Model inference
            output = model(brain)
            output = post_pred(output[0]) # Apply post-processing
            current_dim = tuple(output.shape[1:])
            # Restore to original dimensions if needed
            if current_dim != original_dim:
                if resize_dim is not None:
                    # If resizing was used: apply inverse resizing
                    output = Resize(spatial_size=original_dim, mode="nearest")(output)
                    print(f"{filename}: Resized {current_dim} → {original_dim}")
                else:
                    # If padding was used: crop back to original size
                    output = output[:, :original_dim[0], :original_dim[1], :original_dim[2]]
                    print(f"{filename}: Cropped {current_dim} → {original_dim}")
            # Save as NIfTI file
            nifti_image = nib.Nifti1Image(output.squeeze().detach().cpu().numpy(), affine)
            pred_file = os.path.join(pred_dir, filename)
            nib.save(nifti_image, pred_file)

def calculate_test_metric(metric, gt_dir, pred_dir, pred_prefix="", pred_suffix=""):
    metric_values = []
    nos = []
    filenames = [f for f in os.listdir(gt_dir) if f.endswith('.nii.gz')]
    for filename in filenames:
        no = filename.replace('.nii.gz', '')
        gt_file = os.path.join(gt_dir, filename)
        pred_file = os.path.join(pred_dir, f"{pred_prefix}{no}{pred_suffix}.nii.gz")
        if not os.path.exists(pred_file):
            print(f"Warning: {no} not found. Skipping this sample.")
            continue
        # Load ground truth and prediction
        gt_data = nib.load(gt_file).get_fdata()
        pred_data = nib.load(pred_file).get_fdata()
        gt_sum = gt_data.sum()
        pred_sum = pred_data.sum()
        # Compute metric
        output = torch.from_numpy(pred_data).unsqueeze(0).unsqueeze(0)
        label = torch.from_numpy(gt_data).unsqueeze(0).unsqueeze(0)
        metric_value = metric(y_pred=output, y=label)
        # Handle NaN cases
        if not torch.isnan(metric_value):
            metric_values.append(metric_value.item())
            nos.append(no)
        else:
            print(f"Warning: Metric calculation for {no} resulted in NaN.")
            print(f"  GT sum: {gt_sum}, Pred sum: {pred_sum}")
            if gt_sum == 0 and pred_sum == 0:
                # Both empty: perfect prediction
                print(f"  → Both GT and Prediction are empty. Setting Dice to 1.0")
                metric_values.append(1.0)
                nos.append(no)
            else:
                print(f"  → Skipping this sample.")
    if len(metric_values) == 0:
        print("Error: No valid metrics calculated!")
        return
    # Compute statistics
    metric_values = np.array(metric_values)
    print(f"\u2022 mean \u00B1 standard deviation: {np.mean(metric_values):.3f} \u00B1 {np.std(metric_values):.3f}")
    print(f"\u2022 [minimum, maximum]: [{np.min(metric_values):.3f}, {np.max(metric_values):.3f}]")
    # Save results to CSV
    pred_df = pd.DataFrame({"No": nos, "DSC": [f"{value:.3f}" for value in metric_values]})
    pred_df.to_csv(os.path.join(pred_dir, "DSC.csv"), index=False)

### Prepare Inputs

In [None]:
data_dir = os.path.join("LesionSegmentation", "Datasets")
model_dir_prefix = "LesionSegmentation"
model_name = "SegResNet" # any supported model name: UNet, VNet, SegResNet, AttentionUnet, UNETR, SwinUNETR
num_classes = 1 # for binary segmentation
resize_dim = None # Use padding or specify tuple for resizing
test_size = 0.2
batch_size = 5
max_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-5
val_interval = 1
es_patience = 30

# Setup output directory and logging
model_dir = f"{model_dir_prefix}_{model_name}"
os.makedirs(model_dir, exist_ok=True)
log_file = os.path.join(model_dir, "Prediction.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

### Read Data

In [None]:
set_determinism(seed=0)
train_loader, val_loader = load_data(
    os.path.join(data_dir, "train"), batch_size, resize_dim=resize_dim, test_size=test_size,
    inference=False, model_name=model_name 
)

# Check data shape
tr = first(train_loader)
img_size = tuple(tr["brain"].shape[-3:])
print('\nData shape for training:')
for key, value in tr.items():
    print(f'\u2022 {key}: {tuple(value.shape)} \u00D7 {len(train_loader)}')
vl = first(val_loader)
print('\nData shape for validation:')
for key, value in vl.items():
    print(f'\u2022 {key}: {tuple(value.shape)} \u00D7 {len(val_loader)}')

# Visualize data
sample_index = 19 
slice_index = 52
_, axs = plt.subplots(1, 3, figsize=(12, 5))
brain = train_loader.dataset[sample_index]["brain"].detach().cpu() 
lesion = train_loader.dataset[sample_index]["lesion"].detach().cpu()
brain_slice = torch.rot90(brain[0, :, :, slice_index], k=1, dims=(0, 1))
lesion_slice = torch.rot90(lesion[0, :, :, slice_index], k=1, dims=(0,1))
blended = blend_images(brain, lesion, alpha=0.5)
blended_slice = torch.rot90(blended[0, :, :, slice_index], k=1, dims=(0,1)).squeeze()
axs[0].imshow(brain_slice, cmap='gray')
axs[0].set_title("Brain")
axs[0].axis('off')
axs[1].imshow(lesion_slice, cmap='gray')
axs[1].set_title("Lesion")
axs[1].axis('off')
axs[2].imshow(blended_slice, cmap='gray')
axs[2].set_title("Lesion-overlaied Brain")
axs[2].axis('off')
plt.tight_layout()
plt.show()

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(model_name, img_size, num_classes=num_classes)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.to(device)
print(f"Selected model: {model_name}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params:,}")
criterion = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, sigmoid=True)
metric = DiceMetric(include_background=True, reduction="mean")
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(
    model_dir, model, device, train_loader, val_loader, logger,
    criterion, metric, post_pred, max_epochs, learning_rate, weight_decay, val_interval
)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
sample_index = 18
slice_index = 45
model.eval()
metric.reset()
with torch.no_grad():
    brain = val_loader.dataset[sample_index]["brain"].to(device)
    lesion = val_loader.dataset[sample_index]["lesion"].to(device)
    output = model(brain.unsqueeze(0))
    output = post_pred(output).squeeze(0)
    metric(y_pred=output, y=lesion)
metric_value = metric.aggregate().item()
_, axs = plt.subplots(1, 3, figsize=(12, 5))
brain_slice = torch.rot90(brain[0, :, :, slice_index], k=1, dims=(0, 1))
lesion_slice = torch.rot90(lesion[0, :, :, slice_index], k=1, dims=(0,1))
output_slice = torch.rot90(output[0, :, :, slice_index], k=1, dims=(0,1))
axs[0].imshow(brain_slice.detach().cpu(), cmap="gray")
axs[0].set_title("Brain")
axs[0].axis('off')
axs[1].imshow(lesion_slice.detach().cpu(), cmap="gray")
axs[1].set_title("Lesion")
axs[1].axis('off')
axs[2].imshow(output_slice.detach().cpu(), cmap="gray")
axs[2].set_title(f"Predicted Lesion: DSC = {metric_value:.3f}")
axs[2].axis('off')
plt.tight_layout()
plt.show()

### GradCAM

In [None]:
sample_index = 18
slice_index = 45
target_layers = get_gradcam_target_layers(model, model_name)
apply_gradcam_to_sample(model, val_loader, target_layers, device, criterion, sample_index, slice_index)

### Inference

In [None]:
test_loader = load_data(
    os.path.join(data_dir, "test"), None, resize_dim=resize_dim, test_size=None,
    inference=True, model_name=model_name
)
pred_dir = os.path.join(model_dir, "Prediction")
apply_best_model(model_dir, model, device, test_loader, post_pred, pred_dir, resize_dim)

### Assess Performance
#### Research level
- Excellent: DSC > 0.60
- Good: DSC 0.50-0.60
- Acceptable: DSC 0.40-0.50
- Poor: DSC < 0.40
#### Practical level
- Excellent: DSC > 0.50
- Good: DSC 0.40-0.50
- Acceptable: DSC 0.30-0.40
- Poor: DSC < 0.30

In [None]:
gt_dir = os.path.join(data_dir, "test", "Lesion")
print('DSC on test set:')
calculate_test_metric(metric, gt_dir, pred_dir, pred_prefix="", pred_suffix="")