# UNet++ for Brain Tumor Segmentation

## 1. Set up

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchio as tio
import torchvision.datasets as dset
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, sampler
from models.unetplusplus import Generic_UNetPlusPlus

In [None]:
softmax_helper = lambda x: F.softmax(x, 1)

class InitWeights_He(object):
    def __init__(self, neg_slope=1e-2):
        self.neg_slope = neg_slope

    def __call__(self, module):
        if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
            module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
            if module.bias is not None:
                module.bias = nn.init.constant_(module.bias, 0)

## 2. Load Data

In [None]:
X_train, y_train = torch.load("data/brats_train.pt")
X_val, y_val = torch.load("data/brats_val.pt")
X_test, y_test = torch.load("data/brats_test.pt")

## 3. Preprocess

In [None]:
class BraTSDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img = self.X[idx].float()
        mask = self.y[idx].long()

        # Clamp mask values just in case
        mask = torch.clamp(mask, 0, 3)

        if self.transform:
            # Create a TorchIO Subject
            subject = tio.Subject(
                image=tio.ScalarImage(tensor=img.unsqueeze(0)),
                mask=tio.LabelMap(tensor=mask.unsqueeze(0))
            )
            transformed_subject = self.transform(subject)
            img = transformed_subject.image.data.squeeze(0)
            mask = transformed_subject.mask.data.squeeze(0)

        if mask.ndim == 2:
             mask = mask.unsqueeze(0)

        return img, mask

In [None]:
transform = tio.Compose([
    tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0.5, 99.5)),
])

train_dataset = BraTSDataset(X_train, y_train, transform=transform)
val_dataset = BraTSDataset(X_val, y_val)
test_dataset = BraTSDataset(X_test, y_test)

try:
    img, mask = train_dataset[0]
    print(f'Sample 0 - Image shape: {img.shape}, Mask shape: {mask.shape}')
    print(f'Image dtype: {img.dtype}, Mask dtype: {mask.dtype}')
    print(f'Image value range: [{img.min()}, {img.max()}]')
    print(f'Unique mask values: {torch.unique(mask)}')
except IndexError:
     print("Could not get sample from dataset. Check data loading.")

In [None]:
def plot_samples(dataset, num_samples=4):
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))  # 3 inputs + 1 mask
    if num_samples > len(dataset):
        print(f"Warning: Requested {num_samples} samples, but dataset only has {len(dataset)}.")
        num_samples = len(dataset)
    if num_samples == 0:
        print("No samples to plot.")
        return

    for i in range(num_samples):
        img, mask = dataset[i]
        # img: (3, H, W), mask: (1, H, W)
        img = img.cpu().numpy()
        mask = mask.squeeze().cpu().numpy() # Squeeze channel dim for plotting

        # Assuming channel order: FLAIR, T1CE, T2
        flair = img[0]
        t1ce = img[1]
        t2 = img[2]

        axes[i, 0].imshow(flair, cmap='gray')
        axes[i, 0].set_title(f'Sample {i} - FLAIR')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(t1ce, cmap='gray')
        axes[i, 1].set_title(f'Sample {i} - T1CE')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(t2, cmap='gray')
        axes[i, 2].set_title(f'Sample {i} - T2')
        axes[i, 2].axis('off')

        axes[i, 3].imshow(mask, cmap='tab10', vmin=0, vmax=3) # Use tab10 colormap, set vmin/vmax
        axes[i, 3].set_title(f'Sample {i} - Segmentation')
        axes[i, 3].axis('off')

    plt.tight_layout()
    plt.show()

# Plot samples from the training dataset
plot_samples(train_dataset, num_samples=4)

## 4. Model Definition (Generic UNet++)

In [None]:
.
print("Generic_UNetPlusPlus model imported from utils.unetplusplus")

input_channels = 3
base_num_features = 32 # Initial number of filters (nnU-Net default is 30 or 32)
num_classes = 4
num_pool = 5 # Number of pooling layers (corresponds to U-Net depth)
pool_op_kernel_sizes = [(2, 2)] * num_pool
conv_kernel_sizes = [(3, 3)] * (num_pool + 1)

unet_plusplus_model = Generic_UNetPlusPlus(
    input_channels=input_channels,
    base_num_features=base_num_features,
    num_classes=num_classes,
    num_pool=num_pool,
    num_conv_per_stage=2,
    feat_map_mul_on_downscale=2,
    conv_op=nn.Conv2d,
    norm_op=nn.BatchNorm2d,
    norm_op_kwargs={'eps': 1e-5, 'affine': True},
    dropout_op=nn.Dropout2d,
    dropout_op_kwargs={'p': 0, 'inplace': True}, # No dropout in standard U-Net
    nonlin=nn.LeakyReLU,
    nonlin_kwargs={'negative_slope': 1e-2, 'inplace': True},
    deep_supervision=True, # UNet++ uses deep supervision
    dropout_in_localization=False,
    final_nonlin=softmax_helper,
    weightInitializer=InitWeights_He(1e-2),
    pool_op_kernel_sizes=pool_op_kernel_sizes,
    conv_kernel_sizes=conv_kernel_sizes,
    upscale_logits=False,
    convolutional_pooling=True,
    convolutional_upsampling=True,
    max_num_features=None 
)

print("Generic_UNetPlusPlus model instantiated.")

## 5. Loss Function and Metrics

In [None]:
def dice_coeff_multiclass(pred, target, smooth=1e-6):
    """Calculates Dice Coefficient for multi-class segmentation."""
    num_classes = pred.shape[1]
    pred_probs = F.softmax(pred, dim=1)
    pred_masks = F.one_hot(torch.argmax(pred_probs, dim=1), num_classes).permute(0, 3, 1, 2).float()
    target_masks = F.one_hot(target.squeeze(1), num_classes).permute(0, 3, 1, 2).float()

    intersection = torch.sum(pred_masks * target_masks, dim=(2, 3))
    union = torch.sum(pred_masks, dim=(2, 3)) + torch.sum(target_masks, dim=(2, 3))

    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean(dim=1)

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, pred, target):
        return (1 - dice_coeff_multiclass(pred, target)).mean()

class DiceCELoss(nn.Module):
    """Combines Dice Loss and Cross Entropy Loss."""
    def __init__(self, dice_weight=0.5, ce_weight=0.5, class_weights=None, ignore_index=-100):
        super(DiceCELoss, self).__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_index)

    def forward(self, pred, target):
        target_long_squeezed = target.squeeze(1).long()
        dice = self.dice_loss(pred, target)
        ce = self.ce_loss(pred, target_long_squeezed)
        return self.dice_weight * dice + self.ce_weight * ce

def iou_score_multiclass(pred, target, smooth=1e-6):
    """Calculates Intersection over Union (IoU) for multi-class segmentation."""
    num_classes = pred.shape[1]
    pred_probs = F.softmax(pred, dim=1)
    pred_masks = F.one_hot(torch.argmax(pred_probs, dim=1), num_classes).permute(0, 3, 1, 2).float()
    target_masks = F.one_hot(target.squeeze(1), num_classes).permute(0, 3, 1, 2).float()

    intersection = torch.sum(pred_masks * target_masks, dim=(2, 3))
    union = torch.sum(pred_masks, dim=(2, 3)) + torch.sum(target_masks, dim=(2, 3)) - intersection

    iou = (intersection + smooth) / (union + smooth)
    return iou.mean(dim=1) # Mean IoU per sample in batch

def compute_deep_supervision_loss(criterion, outputs, target):
    """Computes loss for deep supervision outputs from Generic_UNetPlusPlus."""
    # Generic_UNetPlusPlus returns a tuple of segmentations if deep_supervision=True
    # The first element is the final output, the rest are from intermediate layers.
    # We weight the losses, giving more weight to the final output.
    if isinstance(outputs, tuple):
        # Example weights: [0.5, 0.2, 0.1, 0.1, 0.1] for 5 outputs
        # Adjust weights based on the number of outputs
        num_outputs = len(outputs)
        weights = np.array([1 / (2 ** i) for i in range(num_outputs)])
        weights = weights / weights.sum() # Normalize

        total_loss = 0
        for i, output in enumerate(outputs):
            # Resize intermediate outputs to match target size if needed
            if output.shape[2:] != target.shape[2:]:
                output = F.interpolate(output, size=target.shape[2:], mode='bilinear', align_corners=False)
            total_loss += weights[i] * criterion(output, target)
        return total_loss
    else:
        # If not deep supervision, just compute the standard loss
        return criterion(outputs, target)

## 6. Training and Testing Functions

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, dtype, use_deep_supervision=True):
    model.train()
    total_loss = 0.0
    total_iou = 0.0
    total_dice = 0.0
    num_batches = len(loader)

    for inputs, labels in loader:
        inputs = inputs.to(device, dtype=dtype)
        labels = labels.to(device) 

        optimizer.zero_grad()
        outputs = model(inputs)

        # Compute loss (handle deep supervision if enabled)
        if use_deep_supervision and isinstance(outputs, tuple):
            loss = compute_deep_supervision_loss(criterion, outputs, labels)
            final_output = outputs[0]
        else:
            loss = criterion(outputs, labels)
            final_output = outputs

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            iou = iou_score_multiclass(final_output, labels).mean()
            dice = dice_coeff_multiclass(final_output, labels).mean()

        total_loss += loss.item()
        total_iou += iou.item()
        total_dice += dice.item()

    avg_loss = total_loss / num_batches
    avg_iou = total_iou / num_batches
    avg_dice = total_dice / num_batches
    return avg_loss, avg_iou, avg_dice

def evaluate(model, loader, criterion, device, dtype, use_deep_supervision=True):
    model.eval()
    total_loss = 0.0
    total_iou = 0.0
    total_dice = 0.0
    num_batches = len(loader)

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, dtype=dtype)
            labels = labels.to(device)

            outputs = model(inputs)

            # Compute loss (handle deep supervision if enabled)
            if use_deep_supervision and isinstance(outputs, tuple):
                loss = compute_deep_supervision_loss(criterion, outputs, labels)
                final_output = outputs[0]
            else:
                loss = criterion(outputs, labels)
                final_output = outputs

            iou = iou_score_multiclass(final_output, labels).mean()
            dice = dice_coeff_multiclass(final_output, labels).mean()

            total_loss += loss.item()
            total_iou += iou.item()
            total_dice += dice.item()

    avg_loss = total_loss / num_batches
    avg_iou = total_iou / num_batches
    avg_dice = total_dice / num_batches
    return avg_loss, avg_iou, avg_dice

## 7. Training Loop

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
save_path = 'checkpoints/unetplusplus_brats_dce_adam.pth'
print_every = 1
use_deep_supervision = True

batch_size = 4
num_epochs = 10
learning_rate = 1e-4

model = unet_plusplus_model
model = model.to(device)

# Optional: Define class weights (example, adjust based on dataset analysis)
class_weights = torch.tensor([1.0, 10.0, 5.0, 8.0]).to(device, dtype=dtype) # Example weights

criterion = DiceCELoss(dice_weight=0.5, ce_weight=0.5, class_weights=class_weights)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print(f"Using device: {device}")
print(f"Model: Generic UNet++ (Deep Supervision: {use_deep_supervision})")
print(f"Input channels: {input_channels}, Num classes: {num_classes}")
print(f"Batch size: {batch_size}, Epochs: {num_epochs}, LR: {learning_rate}")
print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

# --- Training Loop ---
best_val_iou = -1.0
train_losses, val_losses = [], []
train_ious, val_ious = [], []

for epoch in range(num_epochs):
    train_loss, train_iou, train_dice = train_epoch(model, train_loader, optimizer, criterion, device, dtype, use_deep_supervision)
    val_loss, val_iou, val_dice = evaluate(model, val_loader, criterion, device, dtype, use_deep_supervision)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_ious.append(train_iou)
    val_ious.append(val_iou)

    if (epoch + 1) % print_every == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Train Dice: {train_dice:.4f}')
        print(f'          | Val Loss:   {val_loss:.4f}, Val IoU:   {val_iou:.4f}, Val Dice:   {val_dice:.4f}')

    # Save model if validation IoU improves
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        # Ensure checkpoints directory exists
        import os
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(model.state_dict(), save_path)
        print(f'          | Model saved to {save_path} (Val IoU improved to {best_val_iou:.4f})')

print("\nTraining finished.")

## 8. Evaluation on Test Set

In [None]:
print(f"Loading best model from {save_path}")
best_model = Generic_UNetPlusPlus(
    input_channels=input_channels,
    base_num_features=base_num_features,
    num_classes=num_classes,
    num_pool=num_pool,
    num_conv_per_stage=2,
    feat_map_mul_on_downscale=2,
    conv_op=nn.Conv2d,
    norm_op=nn.BatchNorm2d,
    norm_op_kwargs={'eps': 1e-5, 'affine': True},
    dropout_op=nn.Dropout2d,
    dropout_op_kwargs={'p': 0, 'inplace': True},
    nonlin=nn.LeakyReLU,
    nonlin_kwargs={'negative_slope': 1e-2, 'inplace': True},
    deep_supervision=use_deep_supervision,
    dropout_in_localization=False,
    final_nonlin=softmax_helper,
    weightInitializer=InitWeights_He(1e-2),
    pool_op_kernel_sizes=pool_op_kernel_sizes,
    conv_kernel_sizes=conv_kernel_sizes,
    upscale_logits=False,
    convolutional_pooling=True,
    convolutional_upsampling=True,
    max_num_features=None
)

try:
    best_model.load_state_dict(torch.load(save_path, map_location=device))
    best_model.to(device)
    print("Model loaded successfully.")

    test_loss, test_iou, test_dice = evaluate(best_model, test_loader, criterion, device, dtype, use_deep_supervision)
    print("\n--- Test Set Evaluation ---")
    print(f'Test Loss: {test_loss:.4f}')
    print(f'Test IoU:  {test_iou:.4f}')
    print(f'Test Dice: {test_dice:.4f}')
except FileNotFoundError:
    print(f"Error: Saved model file not found at {save_path}. Cannot evaluate.")
except Exception as e:
    print(f"An error occurred during model loading or evaluation: {e}")

## 9. Visualize Segmentation Results

In [None]:
def visualize_segmentation_overlay(model, dataloader, num_samples=4, device=torch.device('cuda'), dtype=torch.float32, use_deep_supervision=True):
    """Create and visualize an overlay of segmentation masks on original images."""
    model.eval()
    if num_samples == 0:
        return
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    if num_samples == 1:
        axes = np.expand_dims(axes, axis=0) # Ensure axes is always 2D

    dataset = dataloader.dataset
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, mask = dataset[idx] # Get raw sample
            image_tensor = image.unsqueeze(0).to(device, dtype=dtype)
            mask_tensor = mask.to(device) # Ground truth mask

            outputs = model(image_tensor)
            if use_deep_supervision and isinstance(outputs, tuple):
                final_output = outputs[0]
            else:
                final_output = outputs

            pred_prob = torch.softmax(final_output, dim=1)
            pred_mask = torch.argmax(pred_prob, dim=1).squeeze(0).cpu().numpy() # (H, W)

            image_np = image[0].cpu().numpy()
            mask_np = mask.squeeze(0).cpu().numpy()

            # Original Image (FLAIR)
            axes[i, 0].imshow(image_np, cmap='gray')
            axes[i, 0].set_title(f"Sample {idx} - Original (FLAIR)")
            axes[i, 0].axis('off')

            # Ground Truth Overlay
            gt_overlay = np.stack([image_np] * 3, axis=-1) # Grayscale background
            colors = plt.cm.tab10(np.linspace(0, 1, 10))
            for c in range(1, 4): # Classes 1, 2, 3
                gt_overlay[mask_np == c] = 0.5 * gt_overlay[mask_np == c] + 0.5 * colors[c][:3]
            axes[i, 1].imshow(np.clip(gt_overlay, 0, 1))
            axes[i, 1].set_title(f"Ground Truth Overlay")
            axes[i, 1].axis('off')

            # Predicted Overlay
            pred_overlay = np.stack([image_np] * 3, axis=-1)
            for c in range(1, 4):
                pred_overlay[pred_mask == c] = 0.5 * pred_overlay[pred_mask == c] + 0.5 * colors[c][:3]
            axes[i, 2].imshow(np.clip(pred_overlay, 0, 1))
            axes[i, 2].set_title(f"Predicted Overlay")
            axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
print("Visualizing results on test set...")
vis_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) # Use batch_size=1 for visualization
try:
    visualize_segmentation_overlay(best_model, vis_loader, num_samples=4, device=device, dtype=dtype, use_deep_supervision=use_deep_supervision)
except NameError:
    print("Could not visualize: 'best_model' not defined. Was training successful and the model loaded?")
except Exception as e:
    print(f"An error occurred during visualization: {e}")