# Notebook 04: Evaluation and Demo

## Pancreas CT Segmentation using TransUNet

This notebook covers:
1. Loading trained model
2. 3D volume inference pipeline
3. Metric computation (Dice, Hausdorff Distance)
4. Publication-ready visualizations
5. Attention map visualization (Bonus)

## 1. Environment Setup

In [None]:
import os
import sys
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import nibabel as nib
from tqdm import tqdm
from scipy.ndimage import zoom

import torch
import torch.nn.functional as F

import monai
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    EnsureTyped,
)

# Add src to path
sys.path.insert(0, str(Path.cwd()))

# Import custom modules
from src.model import create_transunet
from src.transforms import get_val_transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Paths
PROJECT_ROOT = Path.cwd()
CHECKPOINT_PATH = PROJECT_ROOT / "checkpoints" / "best_metric_model.pth"
SPLITS_PATH = PROJECT_ROOT / "outputs" / "data_splits.json"
OUTPUT_DIR = PROJECT_ROOT / "outputs"

# Configuration
IMG_SIZE = 224
NUM_CLASSES = 2

print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Splits: {SPLITS_PATH}")

## 2. Load Trained Model

In [None]:
# Load checkpoint
if CHECKPOINT_PATH.exists():
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
    print(f"Best validation Dice: {checkpoint['best_dice']:.4f}")
    
    # Get config from checkpoint
    config = checkpoint.get('config', {})
    model_variant = config.get('model_variant', 'small')
else:
    print("Warning: No checkpoint found. Using default model.")
    checkpoint = None
    model_variant = 'small'

In [None]:
# Create model
model = create_transunet(
    img_size=IMG_SIZE,
    in_channels=1,
    out_channels=NUM_CLASSES,
    variant=model_variant,
)

# Load weights
if checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Model weights loaded successfully.")

model = model.to(device)
model.eval()

print(f"\nModel variant: {model_variant}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Load Test Data

In [None]:
# Load data splits
with open(SPLITS_PATH, "r") as f:
    data_splits = json.load(f)

test_data = data_splits["test"]
print(f"Test samples: {len(test_data)}")

In [None]:
# Preprocessing transforms
preprocess = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(
        keys=["image", "label"],
        pixdim=(1.0, 1.0, 1.0),
        mode=("bilinear", "nearest"),
    ),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=-175,
        a_max=250,
        b_min=0.0,
        b_max=1.0,
        clip=True,
    ),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    EnsureTyped(keys=["image", "label"]),
])

## 4. 3D Volume Inference Pipeline

In [None]:
def predict_volume(model, volume, target_size=224, device='cpu'):
    """
    Run inference on a 3D volume by slicing and stacking.
    
    Args:
        model: Trained segmentation model
        volume: 3D numpy array (D, H, W) or (C, D, H, W)
        target_size: Size to resize slices to
        device: Torch device
        
    Returns:
        3D prediction array (D, H, W)
    """
    model.eval()
    
    # Handle channel dimension
    if volume.ndim == 4:
        volume = volume[0]
    
    # Convert to numpy if tensor
    if torch.is_tensor(volume):
        volume = volume.numpy()
    
    D, H, W = volume.shape
    predictions = np.zeros((D, H, W), dtype=np.int64)
    
    with torch.no_grad():
        for slice_idx in range(D):
            # Get slice
            slice_2d = volume[slice_idx, :, :]
            
            # Resize to target size
            zoom_factors = (target_size / H, target_size / W)
            slice_resized = zoom(slice_2d, zoom_factors, order=1)
            
            # Convert to tensor
            slice_tensor = torch.from_numpy(slice_resized).float()
            slice_tensor = slice_tensor.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)
            slice_tensor = slice_tensor.to(device)
            
            # Predict
            output = model(slice_tensor)
            pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
            
            # Resize back to original size
            zoom_factors_back = (H / target_size, W / target_size)
            pred_resized = zoom(pred.astype(float), zoom_factors_back, order=0)
            
            predictions[slice_idx] = pred_resized.astype(np.int64)
    
    return predictions

In [None]:
def evaluate_volume(prediction, ground_truth):
    """
    Calculate Dice Score and Hausdorff Distance for a volume.
    
    Args:
        prediction: 3D prediction array
        ground_truth: 3D ground truth array
        
    Returns:
        Dictionary with metrics
    """
    # Binarize (merge tumor with pancreas)
    pred_binary = (prediction > 0).astype(np.float32)
    gt_binary = (ground_truth > 0).astype(np.float32)
    
    # Dice Score
    intersection = np.sum(pred_binary * gt_binary)
    union = np.sum(pred_binary) + np.sum(gt_binary)
    
    if union == 0:
        dice = 1.0 if np.sum(pred_binary) == 0 else 0.0
    else:
        dice = 2.0 * intersection / union
    
    # Hausdorff Distance using MONAI
    pred_tensor = torch.from_numpy(pred_binary).unsqueeze(0).unsqueeze(0)
    gt_tensor = torch.from_numpy(gt_binary).unsqueeze(0).unsqueeze(0)
    
    hd_metric = HausdorffDistanceMetric(include_background=False, percentile=95)
    
    try:
        hd_metric(y_pred=pred_tensor, y=gt_tensor)
        hd = hd_metric.aggregate().item()
    except:
        hd = float('inf')  # If one of the masks is empty
    
    return {
        "dice": dice,
        "hausdorff_95": hd,
        "pred_volume": np.sum(pred_binary),
        "gt_volume": np.sum(gt_binary),
    }

## 5. Run Evaluation on Test Set

In [None]:
# Evaluate on test set
results = []
predictions_list = []
ground_truths_list = []
images_list = []

print("Evaluating on test set...")
print("="*60)

for i, sample in enumerate(tqdm(test_data[:5], desc="Processing volumes")):  # Limit for demo
    # Load and preprocess
    data = preprocess(sample)
    image = data["image"]
    label = data["label"]
    
    # Handle dimensions
    if image.ndim == 4:
        image = image[0]
    if label.ndim == 4:
        label = label[0]
    
    # Convert to numpy
    if torch.is_tensor(image):
        image = image.numpy()
    if torch.is_tensor(label):
        label = label.numpy()
    
    # Predict
    prediction = predict_volume(model, image, target_size=IMG_SIZE, device=device)
    
    # Evaluate
    metrics = evaluate_volume(prediction, label)
    metrics["sample_id"] = i
    results.append(metrics)
    
    # Store for visualization
    predictions_list.append(prediction)
    ground_truths_list.append(label)
    images_list.append(image)
    
    print(f"Sample {i+1}: Dice={metrics['dice']:.4f}, HD95={metrics['hausdorff_95']:.2f}mm")

print("="*60)

In [None]:
# Summary statistics
dice_scores = [r["dice"] for r in results]
hd_scores = [r["hausdorff_95"] for r in results if r["hausdorff_95"] != float('inf')]

print("\nTest Set Performance Summary:")
print("="*60)
print(f"Dice Score:")
print(f"  Mean: {np.mean(dice_scores):.4f}")
print(f"  Std:  {np.std(dice_scores):.4f}")
print(f"  Min:  {np.min(dice_scores):.4f}")
print(f"  Max:  {np.max(dice_scores):.4f}")

if hd_scores:
    print(f"\nHausdorff Distance (95th percentile):")
    print(f"  Mean: {np.mean(hd_scores):.2f} mm")
    print(f"  Std:  {np.std(hd_scores):.2f} mm")

## 6. Publication-Ready Visualizations

In [None]:
def create_comparison_figure(image, ground_truth, prediction, patient_idx=0, slice_idx=None):
    """
    Create a publication-ready comparison figure.
    
    Args:
        image: 3D image array
        ground_truth: 3D ground truth array
        prediction: 3D prediction array
        patient_idx: Patient index for title
        slice_idx: Optional specific slice index
    """
    # Find optimal slice (max pancreas area)
    if slice_idx is None:
        gt_area = np.sum(ground_truth > 0, axis=(1, 2))
        slice_idx = np.argmax(gt_area)
    
    # Extract slices
    img_slice = image[slice_idx]
    gt_slice = ground_truth[slice_idx]
    pred_slice = prediction[slice_idx]
    
    # Create figure
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # 1. Original CT
    axes[0].imshow(img_slice, cmap="gray")
    axes[0].set_title("Original CT", fontsize=14, fontweight="bold")
    axes[0].axis("off")
    
    # 2. Ground Truth (Green)
    axes[1].imshow(img_slice, cmap="gray")
    gt_mask = np.ma.masked_where(gt_slice == 0, gt_slice)
    axes[1].imshow(gt_mask, cmap="Greens", alpha=0.7, vmin=0, vmax=2)
    axes[1].set_title("Ground Truth", fontsize=14, fontweight="bold")
    axes[1].axis("off")
    
    # 3. Prediction (Red)
    axes[2].imshow(img_slice, cmap="gray")
    pred_mask = np.ma.masked_where(pred_slice == 0, pred_slice)
    axes[2].imshow(pred_mask, cmap="Reds", alpha=0.7, vmin=0, vmax=2)
    axes[2].set_title("Model Prediction", fontsize=14, fontweight="bold")
    axes[2].axis("off")
    
    # 4. Overlay comparison
    axes[3].imshow(img_slice, cmap="gray")
    
    # True Positive (Green), False Positive (Red), False Negative (Blue)
    gt_binary = gt_slice > 0
    pred_binary = pred_slice > 0
    
    tp = gt_binary & pred_binary
    fp = ~gt_binary & pred_binary
    fn = gt_binary & ~pred_binary
    
    overlay = np.zeros((*img_slice.shape, 4))
    overlay[tp] = [0, 1, 0, 0.5]  # Green
    overlay[fp] = [1, 0, 0, 0.5]  # Red
    overlay[fn] = [0, 0, 1, 0.5]  # Blue
    
    axes[3].imshow(overlay)
    axes[3].set_title("Comparison\n(TP: Green, FP: Red, FN: Blue)", fontsize=14, fontweight="bold")
    axes[3].axis("off")
    
    # Calculate metrics for this slice
    dice = 2 * np.sum(tp) / (np.sum(gt_binary) + np.sum(pred_binary) + 1e-8)
    
    plt.suptitle(
        f"Patient {patient_idx + 1} - Slice {slice_idx} | Dice Score: {dice:.4f}",
        fontsize=16, fontweight="bold", y=1.02
    )
    
    plt.tight_layout()
    return fig

In [None]:
# Generate comparison figures for all test samples
print("Generating comparison figures...")

for i in range(len(predictions_list)):
    fig = create_comparison_figure(
        images_list[i],
        ground_truths_list[i],
        predictions_list[i],
        patient_idx=i
    )
    
    # Save figure
    fig.savefig(
        OUTPUT_DIR / f"comparison_patient_{i+1}.png",
        dpi=150,
        bbox_inches="tight",
        facecolor="white"
    )
    plt.show()

In [None]:
def create_multi_slice_figure(image, ground_truth, prediction, patient_idx=0, num_slices=6):
    """
    Create a figure showing multiple slices.
    """
    # Find slices with pancreas
    gt_area = np.sum(ground_truth > 0, axis=(1, 2))
    pancreas_slices = np.where(gt_area > 0)[0]
    
    if len(pancreas_slices) < num_slices:
        selected = pancreas_slices
    else:
        indices = np.linspace(0, len(pancreas_slices)-1, num_slices, dtype=int)
        selected = pancreas_slices[indices]
    
    fig, axes = plt.subplots(3, len(selected), figsize=(3*len(selected), 9))
    
    for col, slice_idx in enumerate(selected):
        img_slice = image[slice_idx]
        gt_slice = ground_truth[slice_idx]
        pred_slice = prediction[slice_idx]
        
        # Row 0: CT
        axes[0, col].imshow(img_slice, cmap="gray")
        axes[0, col].set_title(f"Slice {slice_idx}", fontsize=10)
        axes[0, col].axis("off")
        
        # Row 1: Ground Truth
        axes[1, col].imshow(img_slice, cmap="gray")
        gt_mask = np.ma.masked_where(gt_slice == 0, gt_slice)
        axes[1, col].imshow(gt_mask, cmap="Greens", alpha=0.7, vmin=0, vmax=2)
        axes[1, col].axis("off")
        
        # Row 2: Prediction
        axes[2, col].imshow(img_slice, cmap="gray")
        pred_mask = np.ma.masked_where(pred_slice == 0, pred_slice)
        axes[2, col].imshow(pred_mask, cmap="Reds", alpha=0.7, vmin=0, vmax=2)
        axes[2, col].axis("off")
    
    # Row labels
    axes[0, 0].set_ylabel("CT Image", fontsize=12, fontweight="bold", rotation=0, ha="right", va="center")
    axes[1, 0].set_ylabel("Ground Truth", fontsize=12, fontweight="bold", rotation=0, ha="right", va="center")
    axes[2, 0].set_ylabel("Prediction", fontsize=12, fontweight="bold", rotation=0, ha="right", va="center")
    
    dice = results[patient_idx]["dice"]
    plt.suptitle(
        f"Patient {patient_idx + 1} - Multi-Slice View | Overall Dice: {dice:.4f}",
        fontsize=14, fontweight="bold"
    )
    
    plt.tight_layout()
    return fig

# Generate multi-slice figure for first patient
if len(predictions_list) > 0:
    fig = create_multi_slice_figure(
        images_list[0],
        ground_truths_list[0],
        predictions_list[0],
        patient_idx=0
    )
    fig.savefig(OUTPUT_DIR / "multi_slice_view.png", dpi=150, bbox_inches="tight", facecolor="white")
    plt.show()

## 7. Attention Map Visualization (Bonus)

In [None]:
def visualize_attention(model, image_slice, device='cpu'):
    """
    Extract and visualize attention maps from the Transformer.
    
    Args:
        model: TransUNet model
        image_slice: 2D image slice (H, W)
        device: Torch device
        
    Returns:
        Attention heatmap
    """
    model.eval()
    
    # Resize and prepare input
    from scipy.ndimage import zoom
    
    H, W = image_slice.shape
    zoom_factors = (IMG_SIZE / H, IMG_SIZE / W)
    slice_resized = zoom(image_slice, zoom_factors, order=1)
    
    input_tensor = torch.from_numpy(slice_resized).float()
    input_tensor = input_tensor.unsqueeze(0).unsqueeze(0).to(device)
    
    # Forward pass
    with torch.no_grad():
        output = model(input_tensor)
    
    # Get attention maps
    attention_maps = model.get_attention_maps()
    
    if attention_maps and attention_maps[0] is not None:
        # Use attention from last layer
        attn = attention_maps[-1]  # (B, num_heads, N, N)
        
        # Average over heads
        attn = attn.mean(dim=1)  # (B, N, N)
        
        # Get attention from CLS token or average
        attn = attn[0].mean(dim=0)  # (N,)
        
        # Reshape to spatial
        feature_size = int(np.sqrt(attn.shape[0]))
        attn_map = attn.cpu().numpy().reshape(feature_size, feature_size)
        
        # Resize to input size
        zoom_back = (IMG_SIZE / feature_size, IMG_SIZE / feature_size)
        attn_map = zoom(attn_map, zoom_back, order=1)
        
        return attn_map, slice_resized, output
    else:
        return None, slice_resized, output

In [None]:
# Visualize attention for a sample
if len(images_list) > 0:
    # Get a slice with good pancreas coverage
    gt_area = np.sum(ground_truths_list[0] > 0, axis=(1, 2))
    best_slice = np.argmax(gt_area)
    
    image_slice = images_list[0][best_slice]
    
    attn_map, resized_img, output = visualize_attention(model, image_slice, device)
    
    if attn_map is not None:
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        # Original image
        axes[0].imshow(resized_img, cmap="gray")
        axes[0].set_title("Input Image", fontsize=14, fontweight="bold")
        axes[0].axis("off")
        
        # Attention map
        im = axes[1].imshow(attn_map, cmap="hot")
        axes[1].set_title("Attention Map", fontsize=14, fontweight="bold")
        axes[1].axis("off")
        plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
        
        # Attention overlay
        axes[2].imshow(resized_img, cmap="gray")
        axes[2].imshow(attn_map, cmap="hot", alpha=0.5)
        axes[2].set_title("Attention Overlay", fontsize=14, fontweight="bold")
        axes[2].axis("off")
        
        # Prediction
        pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
        axes[3].imshow(resized_img, cmap="gray")
        pred_mask = np.ma.masked_where(pred == 0, pred)
        axes[3].imshow(pred_mask, cmap="Reds", alpha=0.7)
        axes[3].set_title("Segmentation Output", fontsize=14, fontweight="bold")
        axes[3].axis("off")
        
        plt.suptitle("Transformer Attention Visualization", fontsize=16, fontweight="bold")
        plt.tight_layout()
        plt.savefig(OUTPUT_DIR / "attention_visualization.png", dpi=150, bbox_inches="tight")
        plt.show()
    else:
        print("Attention maps not available.")

## 8. Save Results

In [None]:
# Save evaluation results
results_summary = {
    "num_samples": len(results),
    "dice_mean": float(np.mean(dice_scores)),
    "dice_std": float(np.std(dice_scores)),
    "dice_min": float(np.min(dice_scores)),
    "dice_max": float(np.max(dice_scores)),
    "individual_results": results,
}

if hd_scores:
    results_summary["hd95_mean"] = float(np.mean(hd_scores))
    results_summary["hd95_std"] = float(np.std(hd_scores))

# Save to JSON
results_path = OUTPUT_DIR / "evaluation_results.json"
with open(results_path, "w") as f:
    json.dump(results_summary, f, indent=2)

print(f"Results saved to: {results_path}")

## Summary

This notebook completed the evaluation pipeline:

1. **Model Loading**: Loaded trained TransUNet from checkpoint

2. **3D Volume Inference**:
   - Slice-by-slice prediction
   - Resize to model input size
   - Stack predictions back to 3D

3. **Metrics**:
   - Dice Score: Overlap measure
   - Hausdorff Distance (95th percentile): Surface distance measure

4. **Visualizations**:
   - Side-by-side comparison (CT, Ground Truth, Prediction)
   - Error analysis (TP, FP, FN overlay)
   - Multi-slice view

5. **Attention Maps**:
   - Extracted from Transformer bottleneck
   - Overlaid on input to show model focus areas

**Outputs**:
- `outputs/evaluation_results.json`: Quantitative metrics
- `outputs/comparison_patient_*.png`: Per-patient visualizations
- `outputs/multi_slice_view.png`: Multi-slice overview
- `outputs/attention_visualization.png`: Attention map visualization