In [None]:
#!pip install torch
#!pip install torchvision

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import os
import re
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import zipfile
import matplotlib.pyplot as plt
import torchvision.transforms.functional as transForm
import tqdm

In [None]:
with zipfile.ZipFile('/content/pack.zip', 'r') as zip_ref:
  zip_ref.extractall()

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
data_path = '/content/pack/processed_data/ct_256'
#data_path = '/content/drive/My Drive/Imperial - AI MSc/Medical Imaging/pack.zip'
#file_list = os.listdir(data_path)
#print(file_list)

In [None]:
file_list = os.listdir(data_path)
print(file_list)

In [None]:
device = torch.device('cuda')

# **1. Data Loading**

In [None]:
class SliceAugmentation:
    """Class-based augmentation for 2D slices with intensity and spatial transforms"""
    def __call__(self, image, label):
        # Intensity variations (applied together 50% of the time)
        if torch.rand(1) < 0.5:
            # Contrast
            contrast_factor = torch.FloatTensor(1).uniform_(0.75, 1.25).item()
            image = transForm.adjust_contrast(image, contrast_factor)

            # Brightness
            brightness_factor = torch.FloatTensor(1).uniform_(0.75, 1.25).item()
            image = transForm.adjust_brightness(image, brightness_factor)

        # Random horizontal flip
        if torch.rand(1) < 0.5:
            image = transForm.hflip(image)
            label = transForm.hflip(label)

        return image, label

class SingleSliceDataset(Dataset):
    """Modified dataset with proper 2D augmentations"""
    def __init__(self, root_dir, testing=False, transform=None):
        self.root_dir = root_dir
        self.testing = testing
        self.transform = transform
        self.file_list = [f for f in os.listdir(root_dir) if f.endswith('.npz')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir, self.file_list[idx])
        slice_data = np.load(file_path)

        # Load data and add channel dimension if needed
        image = slice_data['image'][np.newaxis]  # Shape [1, H, W]
        label = slice_data['label']  # Shape [H, W]

        # Convert to tensors first
        image_tensor = torch.from_numpy(image).float()
        label_tensor = torch.from_numpy(label).long()

        # Apply augmentations only during training
        if self.transform is not None:
            image_tensor, label_tensor = self.transform(image_tensor, label_tensor)

        return image_tensor, label_tensor


npz_directory_train = data_path + "/train/npz"
npz_directory_val = data_path + "/val/npz"
npz_directory_test = data_path+"/test/npz"

train_transform = SliceAugmentation()
train_dataset = SingleSliceDataset(npz_directory_train, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=12)

val_dataset = SingleSliceDataset(npz_directory_val, transform=None)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=12)

test_dataset = SingleSliceDataset(npz_directory_test, transform=None)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=12)

# **2. Implementation for Pure Segmentation**

### 2.1 Model & Loss Functions Definition

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=8, features=[64, 128, 256, 512, 1024]):
        super().__init__()

        # Encoder
        self.encoders = nn.ModuleList()
        for feature in features:
            self.encoders.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Pooling layers
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Decoder
        self.upconvs = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for feature in reversed(features):
            self.upconvs.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.decoders.append(DoubleConv(feature * 2, feature))

        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder path
        for encoder in self.encoders:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        skip_connections = skip_connections[::-1]
        for i in range(len(self.decoders)):
            x = self.upconvs[i](x)
            skip_connection = skip_connections[i]
            x = torch.cat((skip_connection, x), dim=1)
            x = self.decoders[i](x)

        return self.final_conv(x)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes, smooth=1e-7):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.num_classes = num_classes

    def forward(self, logits, targets):
        # Ensure labels are in [0, num_classes-1]
        # Convert targets to one-hot: [B, H, W] -> [B, C, H, W]
        num_classes = logits.shape[1]
        targets_onehot = F.one_hot(targets, num_classes=self.num_classes).permute(0, 3, 1, 2).float()

        # Softmax probabilities
        probs = F.softmax(logits, dim=1)

        # Sum over batch, height, width
        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_onehot, dim=dims)
        cardinality = torch.sum(probs + targets_onehot, dim=dims)

        # Dice score and loss
        dice_score = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1.0 - dice_score.mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, logits, targets):
        # Compute focal loss
        ce_loss = F.cross_entropy(logits, targets, reduction='none')  # Shape: [B, H, W]
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

class HybridLoss(nn.Module):
    def __init__(self, num_classes, dice_w = 1.0, focal_w = 0.0, cross_w = 1.0):
        super().__init__()
        self.dice = DiceLoss(num_classes)
        self.focal = FocalLoss()
        self.crosse = nn.CrossEntropyLoss()
        self.dice_w = dice_w
        self.focal_w = focal_w
        self.cross_w = cross_w

    def forward(self, output, target):
        # Ensure output is [B, 8, 256, 256] and target is [B, 256, 256]
        dice_loss = self.dice(output, target)
        focal_loss = self.focal(output, target)
        crosse_loss = self.crosse(output, target)
        return dice_loss*self.dice_w + focal_loss*self.focal_w + crosse_loss*self.cross_w  # Combine losses

### 2.2 Training & Testing

In [None]:
# Initialize model and loss
model_pure = UNet().to(device)
criterion = HybridLoss(num_classes=8)
optimizer = torch.optim.AdamW(model_pure.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

num_epochs=25

for epoch in range(num_epochs):
    # Training
    model_pure.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model_pure(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_dataset)

    # Validation
    model_pure.eval()
    val_loss = 0.0
    dice_score_total = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model_pure(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)

            # Calculate softmax probabilities and one-hot encoded labels
            probs = F.softmax(outputs, dim=1)
            targets_one_hot = F.one_hot(labels, num_classes=8).permute(0, 3, 1, 2).float()

            dims = (0, 2, 3)
            intersection = torch.sum(probs * targets_one_hot, dims)
            cardinality = torch.sum(probs + targets_one_hot, dims)
            dice = (2. * intersection + 1e-7) / (cardinality + 1e-7)
            dice_score_total += dice.mean().item() * images.size(0)

    val_loss /= len(val_dataset)
    val_dice = dice_score_total / len(val_dataset)
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}')
    print('-'*20)

    scheduler.step(val_loss)

# Test Evaluation
model_pure.eval()
test_loss = 0.0
dice_score_total = 0.0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model_pure(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * images.size(0)

        # Calculate softmax probabilities and one-hot encoded labels
        probs = F.softmax(outputs, dim=1)
        targets_one_hot = F.one_hot(labels, num_classes=8).permute(0, 3, 1, 2).float()

        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_one_hot, dims)
        cardinality = torch.sum(probs + targets_one_hot, dims)
        dice = (2. * intersection + 1e-7) / (cardinality + 1e-7)
        dice_score_total += dice.mean().item() * images.size(0)

test_loss /= len(test_dataset)
test_dice = dice_score_total / len(test_dataset)
print(f'Test Loss: {test_loss:.4f} | Test Dice: {test_dice:.4f}')

# Save trained model
torch.save(model_pure.state_dict(), 'unet_segmentation_trained.pth')

# **3. Implementation for Uncertainty-aware Segmentation**

### 3.1 Model & Loss Functions Definition

In [None]:
class DoubleConvUnc(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_p=0.2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_p),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_p)
        )

    def forward(self, x):
        return self.conv(x)

class UNetUncertainity(nn.Module):
    def __init__(self, in_channels=1, out_channels=8, features=[64, 128, 256, 512], dropout_p=0.1):
        super().__init__()
        # Encoder
        self.encoders = nn.ModuleList()
        for feature in features:
            self.encoders.append(DoubleConvUnc(in_channels, feature, dropout_p))
            in_channels = feature

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck = DoubleConvUnc(features[-1], features[-1]*2, dropout_p)

        # Decoder with initialization
        self.upconvs = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for feature in reversed(features):
            self.upconvs.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.decoders.append(DoubleConvUnc(feature*2, feature, dropout_p))

        # Separate heads for logits and variance
        self.logits_head = nn.Conv2d(features[0], out_channels, kernel_size=1)
        self.var_head = nn.Conv2d(features[0], out_channels, kernel_size=1)

        # Initialize variance head to small values
        nn.init.normal_(self.var_head.weight, std=0.01)
        nn.init.constant_(self.var_head.bias, -5.0)  # softplus(-5) ≈ 0.0067

    def forward(self, x):
        skip_connections = []
        # Encoder path
        for encoder in self.encoders:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        for idx in range(len(self.decoders)):
            x = self.upconvs[idx](x)
            x = torch.cat([x, skip_connections[-(idx+1)]], dim=1)
            x = self.decoders[idx](x)

        # Output heads
        logits = self.logits_head(x)
        var = F.softplus(self.var_head(x)) + 1e-7
        return logits, var

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes, smooth=1e-7):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.num_classes = num_classes

    def forward(self, logits, targets):
        # Ensure labels are in [0, num_classes-1]
        # Convert targets to one-hot: [B, H, W] -> [B, C, H, W]
        num_classes = logits.shape[1]
        targets_onehot = F.one_hot(targets, num_classes=self.num_classes).permute(0, 3, 1, 2).float()
        probs = F.softmax(logits, dim=1)

        # Sum over batch, height, width
        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_onehot, dim=dims)
        cardinality = torch.sum(probs + targets_onehot, dim=dims)

        # Dice score and loss
        dice_score = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1.0 - dice_score.mean()

class HybridLossUnc(nn.Module):
    def __init__(self, num_classes, dice_w=1.0, focal_w=0.0, cross_w=1.0, var_reg=10.0):
        super().__init__()
        self.dice = DiceLoss(num_classes)
        self.focal = FocalLoss()
        self.cross = nn.CrossEntropyLoss()
        self.weights = {'dice': dice_w, 'focal': focal_w, 'cross': cross_w}
        self.var_reg = var_reg

    def forward(self, output, target):
        logits, var = output
        std = torch.sqrt(var)

        # Multi-sample uncertainty estimation (3 samples)
        losses = {'dice': 0.0, 'focal': 0.0, 'cross': 0.0}
        for _ in range(3):
            epsilon = torch.randn_like(logits)
            noisy_logits = logits + epsilon * std

            losses['dice'] += self.dice(noisy_logits, target)
            losses['focal'] += self.focal(noisy_logits, target)
            losses['cross'] += self.cross(noisy_logits, target)

        # Average losses over samples + clean logits regularization
        dice_loss = (losses['dice']/3) + self.dice(logits, target)
        focal_loss = losses['focal']/3
        cross_loss = losses['cross']/3

        # Variance regularization with improved stability
        variance_loss = torch.log1p(torch.mean(var))  # Smoother regularization

        total_loss = (
            dice_loss * self.weights['dice'] +
            focal_loss * self.weights['focal'] +
            cross_loss * self.weights['cross'] +
            variance_loss * self.var_reg
        )
        return total_loss

### 3.2 Training & Testing

In [None]:
def enable_dropout(model):
    """Turn on dropout layers while keeping other modules in eval mode"""
    for module in model.modules():
        if isinstance(module, nn.Dropout) or isinstance(module, nn.Dropout2d):
            module.train()

In [None]:
#model definitions
model_unc = UNetUncertainity(in_channels=1, out_channels=8, dropout_p=0.2).to(device)
optimizer = torch.optim.AdamW(model_unc.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
criterion = HybridLossUnc(num_classes=8, var_reg=10.0)

num_epochs=25

for epoch in range(num_epochs):
    # Train
    model_unc.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits, var = model_unc(images)
        loss = criterion((logits, var), labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_dataset)

    # Validation
    model_unc.eval()
    enable_dropout(model_unc)
    val_loss = 0.0
    dice_score_total = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            logits, var = model_unc(images)

            # Loss calculation handles tuple automatically
            loss = criterion((logits, var), labels)
            val_loss += loss.item() * images.size(0)

            # Use only logits for metrics calculation
            probs = F.softmax(logits, dim=1)
            targets_one_hot = F.one_hot(labels, num_classes=8).permute(0, 3, 1, 2).float()

            dims = (0, 2, 3)
            intersection = torch.sum(probs * targets_one_hot, dim=dims)
            cardinality = torch.sum(probs + targets_one_hot, dim=dims)
            dice = (2. * intersection + 1e-7) / (cardinality + 1e-7)
            dice_score_total += dice.mean().item() * images.size(0)

        val_loss /= len(val_dataset)
        val_dice = dice_score_total / len(val_dataset)
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}')
    print('-'*20)

    scheduler.step(val_loss)

# Test Evaluation
model_unc.eval()
enable_dropout(model_unc)
test_loss = 0.0
test_dice = 0.0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        logits, var = model_unc(images)
        loss = criterion((logits, var), labels)
        test_loss += loss.item() * images.size(0)

        # Calculate Dice for test set
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(labels, num_classes=8).permute(0, 3, 1, 2).float()
        intersection = torch.sum(probs * targets_one_hot, dim=(0, 2, 3))
        cardinality = torch.sum(probs + targets_one_hot, dim=(0, 2, 3))
        dice = (2. * intersection + 1e-7) / (cardinality + 1e-7)
        test_dice += dice.mean().item() * images.size(0)

test_loss /= len(test_dataset)
test_dice /= len(test_dataset)
print(f'Test Loss: {test_loss:.4f} | Test Dice: {test_dice:.4f}')

# Save trained model
torch.save(model_unc.state_dict(), 'unet_uncertainity_trained.pth')

# **4. Visualisations**

In [None]:
def visualize_predictions(imgs, lbls, preds, num_samples=16):
    """
    Visualizes test images, ground truth labels, and model predictions.

    Args:
        imgs: Tensor [B, C, H, W] or [B, H, W] - Input images.
        lbls: Tensor [B, H, W] - Ground truth segmentation masks.
        preds: Tensor [B, H, W] - Predicted segmentation masks.
        num_samples: int - Number of samples to visualize.
    """
    num_samples = min(num_samples, imgs.shape[0])
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

    for i in range(num_samples):
        if imgs.dim() == 4:  # If image has channels, convert to grayscale
            image_vis = imgs[i].mean(dim=0).cpu().numpy()  # Average over channels
        else:
            image_vis = imgs[i].cpu().numpy()  # Direct grayscale image

        label_vis = lbls[i].cpu().numpy()  # Ground truth mask
        pred_vis = preds[i].cpu().numpy()  # Model prediction mask

        # Ensure correct shape
        if len(image_vis.shape) == 3:  # If still multi-channel, take middle slice
            image_vis = image_vis[image_vis.shape[0] // 2]

        # Plot Image
        axes[i, 0].imshow(image_vis, cmap='gray')
        axes[i, 0].set_title(f"Test Image {i+1}")
        axes[i, 0].axis("off")

        # Plot Ground Truth
        axes[i, 1].imshow(label_vis, cmap="jet")
        axes[i, 1].set_title(f"Ground Truth {i+1}")
        axes[i, 1].axis("off")

        # Plot Prediction
        axes[i, 2].imshow(pred_vis, cmap="jet")
        axes[i, 2].set_title(f"Prediction {i+1}")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
def plot_segmentations(model, num_batches, double_out=False):
  model.eval()
  with torch.no_grad():
      counter=1
      for batch in test_loader:  # Get a batch of test images
          images, labels = batch
          images, labels = images.to(device), labels.to(device)

          # Get predictions
          if double_out:
            logits, _ = model(images)
          else:
            logits = model(images)

          preds = torch.argmax(logits, dim=1)  # Convert logits to class indices

          # Visualize results
          visualize_predictions(images, labels, preds)

          if counter>num_batches:
            break

          counter+=1


### 4.1 Pure Segmentation Model

In [None]:
plot_segmentations(model_pure, num_batches=1)

### 4.2 Uncertainty-aware Segmentation Model

In [None]:
plot_segmentations(model_unc, num_batches=1, double_out=True)

### 4.3 Uncertainty-aware Segmentation + Uncertainity Visualisation

In [None]:
def visualize_uncertainties(image, true_mask, pred_mask, aleatoric, epistemic):
    """
    Visualize input image, true segmentation, predicted segmentation,
    aleatoric uncertainty, and epistemic uncertainty.
    """
    fig, axs = plt.subplots(1, 5, figsize=(25, 5))

    # Use a discrete colormap with 8 discrete colors (for 8 classes)
    cmap = plt.get_cmap('tab10', 8)  # 8 discrete colors for classes 0-7
    norm = plt.Normalize(vmin=0, vmax=7)

    # Input image
    axs[0].imshow(image, cmap='gray')
    axs[0].set_title('Input Image')
    axs[0].axis('off')

    # True segmentation
    axs[1].imshow(true_mask, cmap=cmap, norm=norm)
    axs[1].set_title('True Segmentation')
    axs[1].axis('off')

    # Predicted segmentation
    im = axs[2].imshow(pred_mask, cmap=cmap, norm=norm)
    axs[2].set_title('Predicted Segmentation')
    axs[2].axis('off')

    # Aleatoric uncertainty
    a = axs[3].imshow(aleatoric, cmap='viridis')
    axs[3].set_title('Aleatoric Uncertainty')
    axs[3].axis('off')
    plt.colorbar(a, ax=axs[3])

    # Epistemic uncertainty
    e = axs[4].imshow(epistemic, cmap='viridis')
    axs[4].set_title('Epistemic Uncertainty')
    axs[4].axis('off')
    plt.colorbar(e, ax=axs[4])

    plt.tight_layout()
    plt.show()

def process_single_sample(model, image, true_mask, num_mc_samples=50):
    """Process a single sample using MC dropout samples to average segmentation."""

    # Add batch dimension if missing.
    if image.dim() == 3:
        image = image.unsqueeze(0).to(device)
    if true_mask.dim() == 2:
        true_mask = true_mask.unsqueeze(0).to(device)

    #eval mode
    model.eval()
    with torch.no_grad():
      logits_a, aleatoric = model(image)
      fake_preds = torch.argmax(logits_a, dim=1).squeeze().cpu().numpy()
      class_indices = torch.argmax(logits_a, dim=1)
      aleatoric = aleatoric.mean(dim=1).squeeze().cpu().numpy()

    #eval with dropout
    enable_dropout(model)

    probs_predictions = []
    with torch.no_grad():
        for _ in range(num_mc_samples):
            logits, _ = model(image)
            prob = F.softmax(logits, dim=1)  # [1, C, H, W]
            probs_predictions.append(prob.cpu())

    stacked_predictions = torch.stack(probs_predictions)
    stacked_predictions = stacked_predictions.squeeze(1)  # [n_samples, 1, C, H, W] -> [n_samples, C, H, W]
    mean_pred = stacked_predictions.mean(dim=0).numpy()     # [C, H, W]
    epistimic = stacked_predictions.var(dim=0).mean(dim=0).cpu().numpy()  # [H, W] (avg over classes)

    image_np = image.squeeze().cpu().numpy()
    true_mask_np = true_mask.squeeze().cpu().numpy()
    pred_mask = np.argmax(mean_pred, axis=0)

    torch.cuda.empty_cache()

    return image_np, true_mask_np, pred_mask, aleatoric, epistimic

In [None]:
#visualise the segmentations and uncertainties
for batch_idx, (images, labels) in enumerate(test_loader):
    for i in range(images.size(0)):
        if batch_idx > 0:
            break

        single_image = images[i].unsqueeze(0).to(device)
        single_label = labels[i].unsqueeze(0).to(device)

        image_np, true_mask, pred_mask, aleatoric, epistemic = process_single_sample(
            model_unc, single_image, single_label, num_mc_samples=50
        )

        visualize_uncertainties(image_np, true_mask, pred_mask, aleatoric, epistemic)

        del single_image, single_label
        torch.cuda.empty_cache()


# **5. Calibration**

### Aleotoric Scaling Tuning

In [None]:
class CalibratedUNetUncertainty(UNetUncertainity):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.aleatoric_scale = nn.Parameter(torch.ones(1))  # Learnable scale

    def forward(self, x):
        logits, var = super().forward(x)
        return logits, F.softplus(self.aleatoric_scale) * var  # Scaled output


# 1. First load the original trained model
original_model = UNetUncertainity().to(device)
original_model.load_state_dict(torch.load('unet_uncertainity_trained.pth'))

# 2. Create calibrated model and copy weights
calib_model = CalibratedUNetUncertainty().to(device)

# Copy all weights EXCEPT the new aleatoric_scale parameter
pretrained_dict = original_model.state_dict()
model_dict = calib_model.state_dict()

# 1. Filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. Overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. Load the new state dict
calib_model.load_state_dict(model_dict)

# Now proceed with calibration
for param in calib_model.parameters():
    param.requires_grad = False
calib_model.aleatoric_scale.requires_grad = True  # Only this will be trained

# Calibrate
calib_model.eval()
optimizer = torch.optim.LBFGS([calib_model.aleatoric_scale], lr=0.01)
for _ in range(50):
    optimizer.zero_grad()
    loss = 0.0
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        logits, var = calib_model(images)
        probs = F.softmax(logits, dim=1)
        targets_onehot = F.one_hot(labels, num_classes=8).permute(0,3,1,2).float()
        loss += 0.5 * ((probs - targets_onehot).pow(2) / var + torch.log(var)).mean()
    loss.backward()
    optimizer.step(lambda: loss)

print(f"Aleatoric scale: {F.softplus(calib_model.aleatoric_scale).item():.4f}")
aleotoric_scale = F.softplus(calib_model.aleatoric_scale).item()

### 5.2 Temperature Scaling Tuning

In [None]:
class TemperatureScaledModel(nn.Module):
    def __init__(self, base_model, temp_init=1.0):
        super().__init__()
        self.base_model = base_model
        self.temperature = nn.Parameter(torch.tensor(temp_init))  # Learnable parameter

    def forward(self, x):
        # Forward pass through base model (with aleatoric calibration)
        logits, var = self.base_model(x)

        # Apply temperature scaling to logits (affects epistemic uncertainty)
        scaled_logits = logits / self.temperature
        return scaled_logits, var  # Return scaled logits + original aleatoric var

In [None]:
def calibrate_temperature(model_scaled, val_loader, device):
    # Wrap model with temperature parameter
    temp_model = TemperatureScaledModel(model_scaled).to(device)
    optimizer = torch.optim.LBFGS([temp_model.temperature], lr=0.01, max_iter=50)
    criterion = nn.CrossEntropyLoss()  # Only CE used here

    def closure():
        optimizer.zero_grad()
        total_loss = 0.0
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass with temperature
            logits, _ = temp_model(images)  # (var unused in calibration)
            loss = criterion(logits, labels)
            total_loss += loss.item()
            loss.backward()
        return total_loss

    optimizer.step(closure)  # Single call is sufficient

    print(f"Optimal temperature: {temp_model.temperature.item():.4f}")
    return temp_model

In [None]:
# Calibrate temperature on validation set
temp_scaled_model = calibrate_temperature(calib_model, val_loader, device)

In [None]:
# Save model path
torch.save(temp_scaled_model.state_dict(), 'unet_uncertainity_calibrated_trained.pth')

In [None]:
#visualise the segmentations and uncertainties
for batch_idx, (images, labels) in enumerate(test_loader):
    for i in range(images.size(0)):
        if batch_idx > 0:
            break

        single_image = images[i].unsqueeze(0).to(device)
        single_label = labels[i].unsqueeze(0).to(device)

        image_np, true_mask, pred_mask, aleatoric, epistemic = process_single_sample(
            temp_scaled_model, single_image, single_label, num_mc_samples=50
        )

        visualize_uncertainties(image_np, true_mask, pred_mask, aleatoric, epistemic)

        del single_image, single_label
        torch.cuda.empty_cache()


### 5.3 Calibration Evaluation

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
from tqdm import tqdm

def compute_uncertainties(model, dataloader, device, mc_samples=50):
    """Compute both types of uncertainties for a dataset.

    Aleatoric uncertainty is computed with dropout disabled (deterministic pass),
    while epistemic uncertainty is estimated via MC dropout.
    """
    # Set model to evaluation mode (dropout off)
    model.eval()

    all_logits = []
    all_aleatoric = []
    all_epistemic = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Computing Uncertainties"):
            images = images.to(device)
            labels = labels.to(device)

            # --- Compute aleatoric uncertainty (and logits) with dropout disabled ---
            logits_single, aleatoric_single = model(images)
            # For segmentation: use softmax on logits from the single pass and get predicted class
            probs_single = torch.softmax(logits_single, dim=1)
            pred_class = probs_single.argmax(dim=1)
            # Gather the aleatoric uncertainty for the predicted class
            aleatoric_single = torch.gather(aleatoric_single, 1, pred_class.unsqueeze(1))

            # --- Compute epistemic uncertainty via MC dropout ---
            # Temporarily enable dropout for MC sampling
            enable_dropout(model)
            mc_logits = []
            for _ in range(mc_samples):
                logits_mc, _ = model(images)  # We ignore aleatoric from these passes
                mc_logits.append(logits_mc)
            mc_logits = torch.stack(mc_logits)  # shape: [mc_samples, B, C, H, W]

            # Restore dropout off for subsequent operations
            model.eval()

            # Compute epistemic uncertainty: variance of probabilities across MC samples
            probs_mc = torch.softmax(mc_logits, dim=2)  # softmax over channel dimension
            epistemic = probs_mc.var(dim=0).mean(dim=1, keepdim=True)  # shape: [B,1,H,W]

            all_logits.append(logits_single)
            all_aleatoric.append(aleatoric_single)
            all_epistemic.append(epistemic)
            all_labels.append(labels)

    return (
        torch.cat(all_logits),
        torch.cat(all_aleatoric),
        torch.cat(all_epistemic),
        torch.cat(all_labels)
    )


def flatten_segmentation_data(*tensors):
    """Flatten batch and spatial dimensions"""
    return [tensor.flatten().cpu().numpy() for tensor in tensors]

# def calibration_curves(logits, aleatoric, epistemic, labels, n_quantiles=10):
#     """Plot calibration curves for both uncertainties using quantile binning"""
#     probs = torch.softmax(logits, dim=1)
#     pred_classes = probs.argmax(dim=1)
#     errors = (pred_classes != labels).float()

#     # Flatten all tensors
#     prob_flat, error_flat, alea_flat, epi_flat = flatten_segmentation_data(
#         probs.max(dim=1)[0], errors, aleatoric.squeeze(), epistemic.squeeze()
#     )

#     # Common quantile bins based on combined uncertainties
#     combined_unc = alea_flat + epi_flat
#     quantiles = np.quantile(combined_unc, np.linspace(0, 1, n_quantiles+1))

#     # Bin indices for both uncertainties
#     alea_bins = np.digitize(alea_flat, quantiles) - 1
#     epi_bins = np.digitize(epi_flat, quantiles) - 1

#     # Calculate calibration metrics per bin
#     alea_acc, alea_conf = [], []
#     epi_acc, epi_conf = [], []

#     for i in range(n_quantiles):
#         # Aleatoric
#         alea_mask = (alea_bins == i)
#         if alea_mask.sum() > 0:
#             alea_acc.append(error_flat[alea_mask].mean())
#             alea_conf.append(alea_flat[alea_mask].mean())

#         # Epistemic
#         epi_mask = (epi_bins == i)
#         if epi_mask.sum() > 0:
#             epi_acc.append(error_flat[epi_mask].mean())
#             epi_conf.append(epi_flat[epi_mask].mean())

#     # Plotting
#     plt.figure(figsize=(10, 5))
#     plt.plot(alea_conf, alea_acc, 'o-', label='Aleatoric')
#     plt.plot(epi_conf, epi_acc, 'o-', label='Epistemic')
#     plt.plot([0, 1], [0, 1], 'k--', label='Perfect')
#     plt.xlabel('Predicted Uncertainty (Quantile Bins)')
#     plt.ylabel('Empirical Error Rate')
#     plt.title('Uncertainty Calibration Curves')
#     plt.legend()
#     plt.show()

def uncertainty_reliability_diagram(logits, aleatoric, epistemic, labels):
    """Combined reliability diagram with uncertainty weighting"""
    probs = torch.softmax(logits, dim=1)
    pred_classes = probs.argmax(dim=1)
    errors = (pred_classes != labels).float()

    # Flatten all data
    prob_flat, error_flat, alea_flat, epi_flat = flatten_segmentation_data(
        probs.max(dim=1)[0], errors, aleatoric.squeeze(), epistemic.squeeze()
    )

    total_unc = alea_flat + epi_flat
    bin_indices = np.digitize(total_unc, np.quantile(total_unc, np.linspace(0, 1, 11)))

    # Calculate metrics per bin
    results = []
    for i in range(10):
        mask = bin_indices == i
        if mask.sum() > 0:
            bin_acc = error_flat[mask].mean()
            bin_unc = total_unc[mask].mean()
            results.append((bin_unc, bin_acc))

    unc_vals, acc_vals = zip(*sorted(results))

    plt.figure(figsize=(10,5))
    plt.plot(unc_vals, acc_vals, 'o-')
    plt.plot([min(unc_vals), max(unc_vals)], [min(acc_vals), max(acc_vals)], 'k--')
    plt.xlabel('Total Uncertainty')
    plt.ylabel('Error Rate')
    plt.title('Uncertainty-Weighted Reliability Diagram')
    plt.show()

def compute_calibration_metrics(logits, aleatoric, epistemic, labels):
    """Compute quantitative calibration metrics"""
    probs = torch.softmax(logits, dim=1)
    pred_classes = probs.argmax(dim=1)
    errors = (pred_classes != labels).float()

    # Flatten all data
    prob_flat, error_flat, alea_flat, epi_flat = flatten_segmentation_data(
        probs.max(dim=1)[0], errors, aleatoric.squeeze(), epistemic.squeeze()
    )

    # Expected Calibration Error
    def ece(uncertainties):
        bin_indices = np.digitize(uncertainties, np.quantile(uncertainties, np.linspace(0, 1, 11)))
        ece = 0
        for i in range(10):
            mask = bin_indices == i
            if mask.sum() > 0:
                acc = error_flat[mask].mean()
                conf = uncertainties[mask].mean()
                ece += np.abs(acc - conf) * mask.sum()
        return ece / len(uncertainties)

    # Brier Score
    brier = np.mean((prob_flat - (1 - error_flat)) ** 2)

    return {
        'brier_score': brier,
        'aleatoric_ece': ece(alea_flat),
        'epistemic_ece': ece(epi_flat),
        'total_ece': ece(alea_flat + epi_flat)
    }

In [None]:
# Compute uncertainties
logits, aleatoric, epistemic, labels = compute_uncertainties(
    model_unc, test_loader, device, mc_samples=50
)

# Quantitative metrics
metrics = compute_calibration_metrics(logits, aleatoric, epistemic, labels)
print("Uncalibrated Model:\n")
print("\nCalibration Metrics:")
print(f"Brier Score: {metrics['brier_score']:.4f}")
print(f"Aleatoric ECE: {metrics['aleatoric_ece']:.4f}")
print(f"Epistemic ECE: {metrics['epistemic_ece']:.4f}")
print(f"Total ECE: {metrics['total_ece']:.4f}")

In [None]:
logits, aleatoric, epistemic, labels = compute_uncertainties(
        temp_scaled_model, test_loader, device, mc_samples=50)

# Quantitative metrics
metrics = compute_calibration_metrics(logits, aleatoric, epistemic, labels)
print("Calibrated Model:\n")
print("\nCalibration Metrics:")
print(f"Brier Score: {metrics['brier_score']:.4f}")
print(f"Aleatoric ECE: {metrics['aleatoric_ece']:.4f}")
print(f"Epistemic ECE: {metrics['epistemic_ece']:.4f}")
print(f"Total ECE: {metrics['total_ece']:.4f}")

In [None]:
plot_segmentations(temp_scaled_model, num_batches=1, double_out=True)