# 📊 Loss Functions and Metrics for Cardiac Segmentation

This notebook implements advanced loss functions and metrics specifically designed for medical image segmentation tasks. We'll focus on hybrid loss functions that combine multiple objectives and medical-specific evaluation metrics.

## Objectives
- Implement Dice Loss for overlap optimization
- Implement Binary Cross-Entropy Loss for pixel-wise classification
- Create hybrid loss functions (Dice + BCE, Focal + Dice)
- Implement boundary-aware losses
- Develop medical segmentation metrics (IoU, Hausdorff Distance)
- Create comprehensive evaluation framework

## Key Components
1. **Individual Loss Functions**: Dice, BCE, Focal Loss
2. **Hybrid Loss Functions**: Combined losses with configurable weights
3. **Boundary-Aware Losses**: Surface loss, boundary loss
4. **Medical Metrics**: Dice coefficient, IoU, sensitivity, specificity
5. **Distance Metrics**: Hausdorff distance, average surface distance
6. **Evaluation Framework**: Comprehensive metrics computation

In [2]:
# Import Required Libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import ndimage
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("PyTorch version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU Device:", torch.cuda.get_device_name(0))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

PyTorch version: 2.7.1+cpu
CUDA Available: False
Using device: cpu


In [3]:
# Configuration Class for Loss Functions and Metrics
class LossConfig:
    """Configuration for loss functions and metrics"""
    
    def __init__(self):
        # Loss function weights
        self.dice_weight = 0.5
        self.bce_weight = 0.5
        self.focal_alpha = 0.25
        self.focal_gamma = 2.0
        self.boundary_weight = 0.1
        
        # Smoothing parameters
        self.smooth = 1e-6
        self.epsilon = 1e-7
        
        # Thresholds
        self.binary_threshold = 0.5
        
        # Evaluation parameters
        self.hausdorff_percentile = 95
        
    def get_config(self):
        return {
            'dice_weight': self.dice_weight,
            'bce_weight': self.bce_weight,
            'focal_alpha': self.focal_alpha,
            'focal_gamma': self.focal_gamma,
            'boundary_weight': self.boundary_weight,
            'smooth': self.smooth,
            'epsilon': self.epsilon,
            'binary_threshold': self.binary_threshold,
            'hausdorff_percentile': self.hausdorff_percentile
        }

# Initialize configuration
config = LossConfig()
print("Loss Functions Configuration:")
for key, value in config.get_config().items():
    print(f"  {key}: {value}")

Loss Functions Configuration:
  dice_weight: 0.5
  bce_weight: 0.5
  focal_alpha: 0.25
  focal_gamma: 2.0
  boundary_weight: 0.1
  smooth: 1e-06
  epsilon: 1e-07
  binary_threshold: 0.5
  hausdorff_percentile: 95


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Individual Loss Functions

class DiceLoss(nn.Module):
    """Dice loss for binary segmentation"""
    
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Predicted mask (B, C, H, W)
            y_true: Ground truth mask (B, C, H, W) or (B, H, W)
        """
        # Ensure same shape
        if y_true.dim() == 3:  # (B, H, W)
            y_true = y_true.unsqueeze(1)  # (B, 1, H, W)
        
        # Flatten tensors
        y_pred_flat = y_pred.view(y_pred.size(0), -1)
        y_true_flat = y_true.view(y_true.size(0), -1)
        
        # Calculate intersection and union
        intersection = (y_pred_flat * y_true_flat).sum(dim=1)
        total = y_pred_flat.sum(dim=1) + y_true_flat.sum(dim=1)
        
        # Calculate Dice coefficient
        dice = (2. * intersection + self.smooth) / (total + self.smooth)
        
        # Return Dice loss
        return 1. - dice.mean()

class DiceCoefficient(nn.Module):
    """Dice coefficient metric (not loss)"""
    
    def __init__(self, smooth=1e-6):
        super(DiceCoefficient, self).__init__()
        self.smooth = smooth
    
    def forward(self, y_pred, y_true):
        # Same calculation as DiceLoss but return coefficient
        if y_true.dim() == 3:
            y_true = y_true.unsqueeze(1)
        
        y_pred_flat = y_pred.view(y_pred.size(0), -1)
        y_true_flat = y_true.view(y_true.size(0), -1)
        
        intersection = (y_pred_flat * y_true_flat).sum(dim=1)
        total = y_pred_flat.sum(dim=1) + y_true_flat.sum(dim=1)
        
        dice = (2. * intersection + self.smooth) / (total + self.smooth)
        return dice.mean()

class GeneralizedDiceLoss(nn.Module):
    """Generalized Dice Loss for multi-class segmentation"""
    
    def __init__(self, smooth=1e-6):
        super(GeneralizedDiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Predicted logits (B, C, H, W)
            y_true: Ground truth (B, H, W) with class indices
        """
        # Convert to one-hot if needed
        if y_true.dim() == 3:  # (B, H, W)
            y_true_oh = F.one_hot(y_true.long(), num_classes=y_pred.size(1))
            y_true_oh = y_true_oh.permute(0, 3, 1, 2).float()  # (B, C, H, W)
        else:
            y_true_oh = y_true
        
        # Apply softmax to predictions
        y_pred_soft = F.softmax(y_pred, dim=1)
        
        # Flatten for computation
        y_pred_flat = y_pred_soft.view(y_pred_soft.size(0), y_pred_soft.size(1), -1)
        y_true_flat = y_true_oh.view(y_true_oh.size(0), y_true_oh.size(1), -1)
        
        # Calculate weights (inverse class frequency)
        class_sums = y_true_flat.sum(dim=(0, 2))
        weights = 1. / (class_sums ** 2 + self.smooth)
        
        # Calculate intersection and union per class
        intersection = (y_pred_flat * y_true_flat).sum(dim=(0, 2))
        total = (y_pred_flat + y_true_flat).sum(dim=(0, 2))
        
        # Weighted generalized dice
        numerator = 2. * (weights * intersection).sum()
        denominator = (weights * total).sum()
        
        gd_loss = 1. - (numerator + self.smooth) / (denominator + self.smooth)
        
        return gd_loss

# Test dice loss functions
dice_loss = DiceLoss()
dice_coeff = DiceCoefficient()
gen_dice_loss = GeneralizedDiceLoss()

print("✅ Dice Loss Functions Implemented:")
print("- DiceLoss: Standard Dice loss for binary segmentation")
print("- DiceCoefficient: Dice coefficient metric")
print("- GeneralizedDiceLoss: Multi-class Dice loss")

# Quick test with dummy data
test_pred = torch.randn(2, 3, 64, 64)
test_true = torch.randint(0, 3, (2, 64, 64))

print(f"\n🧪 Test with dummy data:")
print(f"   Generalized Dice Loss: {gen_dice_loss(test_pred, test_true):.4f}")
print(f"   Input shape - Pred: {test_pred.shape}, True: {test_true.shape}")

✅ Dice Loss Functions Implemented:
- DiceLoss: Standard Dice loss for binary segmentation
- DiceCoefficient: Dice coefficient metric
- GeneralizedDiceLoss: Multi-class Dice loss

🧪 Test with dummy data:
   Generalized Dice Loss: 0.6688
   Input shape - Pred: torch.Size([2, 3, 64, 64]), True: torch.Size([2, 64, 64])


In [5]:
# Binary Cross-Entropy and Focal Loss

import torch
import torch.nn as nn

class BinaryCrossEntropyLoss(nn.Module):
    """Binary Cross-Entropy loss with numerical stability"""
    
    def __init__(self, epsilon=1e-7):
        super(BinaryCrossEntropyLoss, self).__init__()
        self.epsilon = epsilon
    
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Predicted probabilities (B, 1, H, W)
            y_true: Ground truth binary mask (B, 1, H, W) or (B, H, W)
        """
        if y_true.dim() == 3:
            y_true = y_true.unsqueeze(1).float()
        
        # Clip predictions to prevent log(0)
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        
        # Calculate BCE
        bce = -(y_true * torch.log(y_pred) + (1. - y_true) * torch.log(1. - y_pred))
        
        return bce.mean()

class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    
    def __init__(self, alpha=0.25, gamma=2.0, epsilon=1e-7):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
    
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Predicted probabilities (B, 1, H, W)
            y_true: Ground truth binary mask (B, 1, H, W) or (B, H, W)
        """
        if y_true.dim() == 3:
            y_true = y_true.unsqueeze(1).float()
        
        # Clip predictions
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        
        # Calculate cross entropy
        ce = -(y_true * torch.log(y_pred) + (1. - y_true) * torch.log(1. - y_pred))
        
        # Calculate focal weight
        pt = torch.where(y_true == 1, y_pred, 1 - y_pred)
        focal_weight = (1 - pt) ** self.gamma
        
        # Apply alpha weighting
        alpha_weight = torch.where(y_true == 1, self.alpha, 1 - self.alpha)
        
        # Combine everything
        focal_loss = alpha_weight * focal_weight * ce
        
        return focal_loss.mean()

class TverskyLoss(nn.Module):
    """Tversky Loss - generalization of Dice Loss"""
    
    def __init__(self, alpha=0.5, beta=0.5, smooth=1e-6):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha  # weight for false positives
        self.beta = beta    # weight for false negatives
        self.smooth = smooth
    
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Predicted probabilities (B, 1, H, W)
            y_true: Ground truth binary mask (B, 1, H, W) or (B, H, W)
        """
        if y_true.dim() == 3:
            y_true = y_true.unsqueeze(1).float()
        
        # Flatten tensors
        y_pred_flat = y_pred.view(-1)
        y_true_flat = y_true.view(-1)
        
        # Calculate true positives, false positives, false negatives
        TP = (y_pred_flat * y_true_flat).sum()
        FP = ((1 - y_true_flat) * y_pred_flat).sum()
        FN = (y_true_flat * (1 - y_pred_flat)).sum()
        
        # Calculate Tversky coefficient
        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        
        return 1. - tversky

# Test loss functions
bce_loss = BinaryCrossEntropyLoss()
focal_loss = FocalLoss()
tversky_loss = TverskyLoss()

print("✅ Additional Loss Functions Implemented:")
print("- BinaryCrossEntropyLoss: Stable BCE with clipping")
print("- FocalLoss: For handling class imbalance")
print("- TverskyLoss: Generalization of Dice Loss")

# Quick test with dummy binary data
test_pred_bin = torch.sigmoid(torch.randn(2, 1, 64, 64))
test_true_bin = torch.randint(0, 2, (2, 64, 64)).float()

print(f"\n🧪 Test with dummy binary data:")
print(f"   BCE Loss: {bce_loss(test_pred_bin, test_true_bin):.4f}")
print(f"   Focal Loss: {focal_loss(test_pred_bin, test_true_bin):.4f}")
print(f"   Tversky Loss: {tversky_loss(test_pred_bin, test_true_bin):.4f}")

✅ Additional Loss Functions Implemented:
- BinaryCrossEntropyLoss: Stable BCE with clipping
- FocalLoss: For handling class imbalance
- TverskyLoss: Generalization of Dice Loss

🧪 Test with dummy binary data:
   BCE Loss: 0.8104
   Focal Loss: 0.1772
   Tversky Loss: 0.5075


In [7]:
# Hybrid Loss Functions

import torch
import torch.nn as nn

class DiceBCELoss(nn.Module):
    """Combined Dice and Binary Cross-Entropy Loss"""
    
    def __init__(self, dice_weight=0.5, bce_weight=0.5, smooth=1e-6):
        super(DiceBCELoss, self).__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.dice_loss = DiceLoss(smooth)
        self.bce_loss = BinaryCrossEntropyLoss()
    
    def forward(self, y_pred, y_true):
        dice_loss_val = self.dice_loss(y_pred, y_true)
        bce_loss_val = self.bce_loss(y_pred, y_true)
        
        return self.dice_weight * dice_loss_val + self.bce_weight * bce_loss_val

class FocalDiceLoss(nn.Module):
    """Combined Focal and Dice Loss"""
    
    def __init__(self, focal_weight=0.5, dice_weight=0.5, 
                 alpha=0.25, gamma=2.0, smooth=1e-6):
        super(FocalDiceLoss, self).__init__()
        self.focal_weight = focal_weight
        self.dice_weight = dice_weight
        self.focal_loss = FocalLoss(alpha, gamma)
        self.dice_loss = DiceLoss(smooth)
    
    def forward(self, y_pred, y_true):
        focal_loss_val = self.focal_loss(y_pred, y_true)
        dice_loss_val = self.dice_loss(y_pred, y_true)
        
        return self.focal_weight * focal_loss_val + self.dice_weight * dice_loss_val

class TverskyFocalLoss(nn.Module):
    """Combined Tversky and Focal Loss"""
    
    def __init__(self, tversky_weight=0.5, focal_weight=0.5,
                 alpha_t=0.3, beta_t=0.7, alpha_f=0.25, gamma_f=2.0, smooth=1e-6):
        super(TverskyFocalLoss, self).__init__()
        self.tversky_weight = tversky_weight
        self.focal_weight = focal_weight
        self.tversky_loss = TverskyLoss(alpha_t, beta_t, smooth)
        self.focal_loss = FocalLoss(alpha_f, gamma_f)
    
    def forward(self, y_pred, y_true):
        tversky_loss_val = self.tversky_loss(y_pred, y_true)
        focal_loss_val = self.focal_loss(y_pred, y_true)
        
        return self.tversky_weight * tversky_loss_val + self.focal_weight * focal_loss_val

class AdaptiveLoss(nn.Module):
    """Adaptive loss that adjusts weights during training"""
    
    def __init__(self, initial_dice_weight=0.5):
        super(AdaptiveLoss, self).__init__()
        self.dice_weight = nn.Parameter(torch.tensor(initial_dice_weight), requires_grad=False)
        self.dice_loss = DiceLoss()
        self.bce_loss = BinaryCrossEntropyLoss()
        self.step_count = 0
    
    def forward(self, y_pred, y_true):
        dice_loss_val = self.dice_loss(y_pred, y_true)
        bce_loss_val = self.bce_loss(y_pred, y_true)
        
        # Adapt weights based on performance (simple strategy)
        bce_weight = 1.0 - self.dice_weight
        
        return self.dice_weight * dice_loss_val + bce_weight * bce_loss_val
    
    def update_weights(self, dice_score):
        """Update weights based on current performance"""
        # Simple adaptation: increase Dice weight if score is low
        if dice_score < 0.5:
            self.dice_weight.data = torch.clamp(self.dice_weight.data + 0.01, 0.1, 0.9)
        elif dice_score > 0.8:
            self.dice_weight.data = torch.clamp(self.dice_weight.data - 0.01, 0.1, 0.9)

# Test hybrid loss functions
dice_bce_loss = DiceBCELoss()
focal_dice_loss = FocalDiceLoss()
tversky_focal_loss = TverskyFocalLoss()
adaptive_loss = AdaptiveLoss()

print("✅ Hybrid Loss Functions Implemented:")
print("- DiceBCELoss: Combined Dice + BCE")
print("- FocalDiceLoss: Combined Focal + Dice")
print("- TverskyFocalLoss: Combined Tversky + Focal")
print("- AdaptiveLoss: Adaptive weight adjustment")

# Quick test
test_pred_bin = torch.sigmoid(torch.randn(2, 1, 64, 64))
test_true_bin = torch.randint(0, 2, (2, 64, 64)).float()

print(f"\n🧪 Test hybrid losses:")
print(f"   Dice+BCE Loss: {dice_bce_loss(test_pred_bin, test_true_bin):.4f}")
print(f"   Focal+Dice Loss: {focal_dice_loss(test_pred_bin, test_true_bin):.4f}")
print(f"   Tversky+Focal Loss: {tversky_focal_loss(test_pred_bin, test_true_bin):.4f}")
print(f"   Adaptive Loss: {adaptive_loss(test_pred_bin, test_true_bin):.4f}")

✅ Hybrid Loss Functions Implemented:
- DiceBCELoss: Combined Dice + BCE
- FocalDiceLoss: Combined Focal + Dice
- TverskyFocalLoss: Combined Tversky + Focal
- AdaptiveLoss: Adaptive weight adjustment

🧪 Test hybrid losses:
   Dice+BCE Loss: 0.6438
   Focal+Dice Loss: 0.3305
   Tversky+Focal Loss: 0.3313
   Adaptive Loss: 0.6438


In [8]:
# Boundary-Aware Loss Functions

def boundary_loss(y_true, y_pred, theta=0.01):
    """
    Boundary loss to improve edge detection
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        theta: Threshold for boundary detection
    
    Returns:
        Boundary loss value
    """
    # Convert to binary if needed
    y_true_binary = tf.cast(y_true > 0.5, tf.float32)
    y_pred_binary = tf.cast(y_pred > 0.5, tf.float32)
    
    # Compute gradients to find boundaries
    def compute_boundary(mask):
        # Sobel filters for edge detection
        sobel_x = tf.constant([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=tf.float32)
        sobel_y = tf.constant([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=tf.float32)
        
        sobel_x = tf.reshape(sobel_x, [3, 3, 1, 1])
        sobel_y = tf.reshape(sobel_y, [3, 3, 1, 1])
        
        # Add channel dimension if needed
        if len(mask.shape) == 3:
            mask = tf.expand_dims(mask, -1)
        
        grad_x = tf.nn.conv2d(mask, sobel_x, strides=[1, 1, 1, 1], padding='SAME')
        grad_y = tf.nn.conv2d(mask, sobel_y, strides=[1, 1, 1, 1], padding='SAME')
        
        boundary = tf.sqrt(grad_x**2 + grad_y**2)
        return tf.squeeze(boundary, -1) if len(y_true.shape) == 3 else boundary
    
    # Compute boundaries
    true_boundary = compute_boundary(y_true_binary)
    pred_boundary = compute_boundary(y_pred)
    
    # Boundary loss
    boundary_diff = tf.abs(true_boundary - pred_boundary)
    return tf.reduce_mean(boundary_diff)

def surface_loss(y_true, y_pred):
    """
    Surface loss for better boundary alignment
    """
    # Compute distance transforms
    def compute_distance_transform(mask):
        # Simplified distance transform using morphological operations
        mask_binary = tf.cast(mask > 0.5, tf.float32)
        
        # Create kernel for morphological operations
        kernel = tf.ones((3, 3, 1, 1), dtype=tf.float32)
        
        # Erosion and dilation for distance approximation
        eroded = tf.nn.erosion2d(tf.expand_dims(mask_binary, -1), kernel, 
                                strides=[1, 1, 1, 1], padding='SAME', 
                                data_format='NHWC', dilations=[1, 1, 1, 1])
        
        dilated = tf.nn.dilation2d(tf.expand_dims(mask_binary, -1), kernel,
                                  strides=[1, 1, 1, 1], padding='SAME',
                                  data_format='NHWC', dilations=[1, 1, 1, 1])
        
        # Approximate distance as difference between original and eroded
        distance = mask_binary - tf.squeeze(eroded, -1)
        return distance
    
    # Compute distance transforms
    true_dist = compute_distance_transform(y_true)
    
    # Surface loss
    surface_loss_val = tf.reduce_mean(y_pred * true_dist)
    
    return surface_loss_val

def hausdorff_loss(y_true, y_pred, alpha=2.0):
    """
    Differentiable approximation of Hausdorff distance
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        alpha: Smoothing parameter
    
    Returns:
        Hausdorff loss approximation
    """
    # Convert to binary
    y_true_binary = tf.cast(y_true > 0.5, tf.float32)
    y_pred_binary = tf.cast(y_pred > 0.5, tf.float32)
    
    # Get coordinates of positive pixels
    def get_boundary_points(mask):
        # Find boundary using morphological operations
        mask_3d = tf.expand_dims(mask, -1)
        kernel = tf.ones((3, 3, 1, 1), dtype=tf.float32)
        
        eroded = tf.nn.erosion2d(mask_3d, kernel, strides=[1, 1, 1, 1], 
                                padding='SAME', data_format='NHWC', 
                                dilations=[1, 1, 1, 1])
        
        boundary = mask_3d - eroded
        return tf.squeeze(boundary, -1)
    
    # Get boundaries
    true_boundary = get_boundary_points(y_true_binary)
    pred_boundary = get_boundary_points(y_pred_binary)
    
    # Compute approximate Hausdorff distance
    # This is a simplified version - in practice, you might use more sophisticated methods
    boundary_diff = tf.abs(true_boundary - pred_boundary)
    hausdorff_approx = tf.reduce_max(boundary_diff, axis=[1, 2])
    
    return tf.reduce_mean(hausdorff_approx)

# Combined boundary-aware loss
def boundary_aware_loss(y_true, y_pred, dice_weight=0.4, boundary_weight=0.3, 
                       surface_weight=0.3, smooth=1e-6):
    """
    Combined loss with boundary awareness
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        dice_weight: Weight for Dice loss
        boundary_weight: Weight for boundary loss
        surface_weight: Weight for surface loss
        smooth: Smoothing factor
    
    Returns:
        Combined boundary-aware loss
    """
    dice_loss_val = dice_loss(y_true, y_pred, smooth)
    boundary_loss_val = boundary_loss(y_true, y_pred)
    surface_loss_val = surface_loss(y_true, y_pred)
    
    total_loss = (dice_weight * dice_loss_val + 
                  boundary_weight * boundary_loss_val + 
                  surface_weight * surface_loss_val)
    
    return total_loss

print("Boundary-Aware Loss Functions Implemented:")
print("- boundary_loss: Emphasizes edge detection")
print("- surface_loss: Improves boundary alignment")
print("- hausdorff_loss: Approximates Hausdorff distance")
print("- boundary_aware_loss: Combined boundary-aware loss")

Boundary-Aware Loss Functions Implemented:
- boundary_loss: Emphasizes edge detection
- surface_loss: Improves boundary alignment
- hausdorff_loss: Approximates Hausdorff distance
- boundary_aware_loss: Combined boundary-aware loss


In [9]:
# Medical Segmentation Metrics

def iou_score(y_true, y_pred, threshold=0.5, smooth=1e-6):
    """
    Intersection over Union (IoU) score
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        threshold: Threshold for binary conversion
        smooth: Smoothing factor
    
    Returns:
        IoU score
    """
    y_true_binary = tf.cast(y_true > threshold, tf.float32)
    y_pred_binary = tf.cast(y_pred > threshold, tf.float32)
    
    intersection = K.sum(y_true_binary * y_pred_binary)
    union = K.sum(y_true_binary) + K.sum(y_pred_binary) - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    return iou

def sensitivity_score(y_true, y_pred, threshold=0.5, smooth=1e-6):
    """
    Sensitivity (Recall/True Positive Rate)
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        threshold: Threshold for binary conversion
        smooth: Smoothing factor
    
    Returns:
        Sensitivity score
    """
    y_true_binary = tf.cast(y_true > threshold, tf.float32)
    y_pred_binary = tf.cast(y_pred > threshold, tf.float32)
    
    true_positives = K.sum(y_true_binary * y_pred_binary)
    possible_positives = K.sum(y_true_binary)
    
    sensitivity = (true_positives + smooth) / (possible_positives + smooth)
    return sensitivity

def specificity_score(y_true, y_pred, threshold=0.5, smooth=1e-6):
    """
    Specificity (True Negative Rate)
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        threshold: Threshold for binary conversion
        smooth: Smoothing factor
    
    Returns:
        Specificity score
    """
    y_true_binary = tf.cast(y_true > threshold, tf.float32)
    y_pred_binary = tf.cast(y_pred > threshold, tf.float32)
    
    true_negatives = K.sum((1 - y_true_binary) * (1 - y_pred_binary))
    possible_negatives = K.sum(1 - y_true_binary)
    
    specificity = (true_negatives + smooth) / (possible_negatives + smooth)
    return specificity

def precision_score(y_true, y_pred, threshold=0.5, smooth=1e-6):
    """
    Precision (Positive Predictive Value)
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        threshold: Threshold for binary conversion
        smooth: Smoothing factor
    
    Returns:
        Precision score
    """
    y_true_binary = tf.cast(y_true > threshold, tf.float32)
    y_pred_binary = tf.cast(y_pred > threshold, tf.float32)
    
    true_positives = K.sum(y_true_binary * y_pred_binary)
    predicted_positives = K.sum(y_pred_binary)
    
    precision = (true_positives + smooth) / (predicted_positives + smooth)
    return precision

def f1_score(y_true, y_pred, threshold=0.5, smooth=1e-6):
    """
    F1 Score (Harmonic mean of precision and recall)
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        threshold: Threshold for binary conversion
        smooth: Smoothing factor
    
    Returns:
        F1 score
    """
    precision = precision_score(y_true, y_pred, threshold, smooth)
    sensitivity = sensitivity_score(y_true, y_pred, threshold, smooth)
    
    f1 = 2 * (precision * sensitivity) / (precision + sensitivity + smooth)
    return f1

def volume_similarity(y_true, y_pred, threshold=0.5):
    """
    Volume Similarity metric
    
    Args:
        y_true: Ground truth mask
        y_pred: Predicted mask
        threshold: Threshold for binary conversion
    
    Returns:
        Volume similarity score
    """
    y_true_binary = tf.cast(y_true > threshold, tf.float32)
    y_pred_binary = tf.cast(y_pred > threshold, tf.float32)
    
    vol_true = K.sum(y_true_binary)
    vol_pred = K.sum(y_pred_binary)
    
    vol_sim = 1.0 - tf.abs(vol_true - vol_pred) / (vol_true + vol_pred + 1e-6)
    return vol_sim

class MedicalMetrics:
    """
    Comprehensive medical segmentation metrics calculator
    """
    
    def __init__(self, threshold=0.5, smooth=1e-6):
        self.threshold = threshold
        self.smooth = smooth
        
    def compute_all_metrics(self, y_true, y_pred):
        """
        Compute all medical segmentation metrics
        
        Args:
            y_true: Ground truth mask
            y_pred: Predicted mask
        
        Returns:
            Dictionary with all computed metrics
        """
        metrics = {}
        
        # Basic metrics
        metrics['dice'] = dice_coefficient(y_true, y_pred, self.smooth)
        metrics['iou'] = iou_score(y_true, y_pred, self.threshold, self.smooth)
        metrics['sensitivity'] = sensitivity_score(y_true, y_pred, self.threshold, self.smooth)
        metrics['specificity'] = specificity_score(y_true, y_pred, self.threshold, self.smooth)
        metrics['precision'] = precision_score(y_true, y_pred, self.threshold, self.smooth)
        metrics['f1'] = f1_score(y_true, y_pred, self.threshold, self.smooth)
        metrics['volume_similarity'] = volume_similarity(y_true, y_pred, self.threshold)
        
        return metrics
    
    def compute_metrics_numpy(self, y_true_np, y_pred_np):
        """
        Compute metrics using numpy arrays (for post-processing)
        
        Args:
            y_true_np: Ground truth mask (numpy array)
            y_pred_np: Predicted mask (numpy array)
        
        Returns:
            Dictionary with computed metrics
        """
        y_true_binary = (y_true_np > self.threshold).astype(np.float32)
        y_pred_binary = (y_pred_np > self.threshold).astype(np.float32)
        
        # Calculate basic metrics
        intersection = np.sum(y_true_binary * y_pred_binary)
        union = np.sum(y_true_binary) + np.sum(y_pred_binary) - intersection
        
        # Dice coefficient
        dice = (2.0 * intersection + self.smooth) / (np.sum(y_true_binary) + np.sum(y_pred_binary) + self.smooth)
        
        # IoU
        iou = (intersection + self.smooth) / (union + self.smooth)
        
        # Sensitivity and Specificity
        true_positives = intersection
        false_negatives = np.sum(y_true_binary * (1 - y_pred_binary))
        false_positives = np.sum((1 - y_true_binary) * y_pred_binary)
        true_negatives = np.sum((1 - y_true_binary) * (1 - y_pred_binary))
        
        sensitivity = true_positives / (true_positives + false_negatives + self.smooth)
        specificity = true_negatives / (true_negatives + false_positives + self.smooth)
        precision = true_positives / (true_positives + false_positives + self.smooth)
        
        # F1 Score
        f1 = 2 * (precision * sensitivity) / (precision + sensitivity + self.smooth)
        
        # Volume similarity
        vol_true = np.sum(y_true_binary)
        vol_pred = np.sum(y_pred_binary)
        vol_sim = 1.0 - np.abs(vol_true - vol_pred) / (vol_true + vol_pred + self.smooth)
        
        return {
            'dice': dice,
            'iou': iou,
            'sensitivity': sensitivity,
            'specificity': specificity,
            'precision': precision,
            'f1': f1,
            'volume_similarity': vol_sim
        }

# Initialize metrics calculator
metrics_calculator = MedicalMetrics()

print("Medical Segmentation Metrics Implemented:")
print("- Dice Coefficient")
print("- IoU (Intersection over Union)")
print("- Sensitivity (Recall)")
print("- Specificity")
print("- Precision")
print("- F1 Score")
print("- Volume Similarity")
print("- MedicalMetrics class for comprehensive evaluation")

Medical Segmentation Metrics Implemented:
- Dice Coefficient
- IoU (Intersection over Union)
- Sensitivity (Recall)
- Specificity
- Precision
- F1 Score
- Volume Similarity
- MedicalMetrics class for comprehensive evaluation


In [10]:
# Distance-Based Metrics

def hausdorff_distance_numpy(mask1, mask2, percentile=95):
    """
    Calculate Hausdorff distance between two binary masks
    
    Args:
        mask1: First binary mask (numpy array)
        mask2: Second binary mask (numpy array)
        percentile: Percentile for robust Hausdorff distance
    
    Returns:
        Hausdorff distance value
    """
    # Convert to binary
    mask1_binary = (mask1 > 0.5).astype(np.uint8)
    mask2_binary = (mask2 > 0.5).astype(np.uint8)
    
    # Find boundary points
    def get_boundary_points(mask):
        # Use morphological operations to find boundary
        from scipy import ndimage
        eroded = ndimage.binary_erosion(mask)
        boundary = mask.astype(np.float32) - eroded.astype(np.float32)
        coords = np.argwhere(boundary > 0)
        return coords
    
    # Get boundary coordinates
    coords1 = get_boundary_points(mask1_binary)
    coords2 = get_boundary_points(mask2_binary)
    
    if len(coords1) == 0 or len(coords2) == 0:
        return float('inf')
    
    # Calculate directed Hausdorff distances
    try:
        hd1 = directed_hausdorff(coords1, coords2)[0]
        hd2 = directed_hausdorff(coords2, coords1)[0]
        
        # Return maximum (standard Hausdorff) or percentile-based (robust)
        if percentile == 100:
            return max(hd1, hd2)
        else:
            # For robust version, use percentile of distances
            from scipy.spatial.distance import cdist
            distances1 = np.min(cdist(coords1, coords2), axis=1)
            distances2 = np.min(cdist(coords2, coords1), axis=1)
            
            hd1_robust = np.percentile(distances1, percentile)
            hd2_robust = np.percentile(distances2, percentile)
            
            return max(hd1_robust, hd2_robust)
    except:
        return float('inf')

def average_surface_distance(mask1, mask2):
    """
    Calculate Average Surface Distance (ASD)
    
    Args:
        mask1: First binary mask (numpy array)
        mask2: Second binary mask (numpy array)
    
    Returns:
        Average surface distance
    """
    # Convert to binary
    mask1_binary = (mask1 > 0.5).astype(np.uint8)
    mask2_binary = (mask2 > 0.5).astype(np.uint8)
    
    # Find boundary points
    def get_boundary_points(mask):
        eroded = ndimage.binary_erosion(mask)
        boundary = mask.astype(np.float32) - eroded.astype(np.float32)
        coords = np.argwhere(boundary > 0)
        return coords
    
    coords1 = get_boundary_points(mask1_binary)
    coords2 = get_boundary_points(mask2_binary)
    
    if len(coords1) == 0 or len(coords2) == 0:
        return float('inf')
    
    try:
        from scipy.spatial.distance import cdist
        
        # Calculate minimum distances
        distances1 = np.min(cdist(coords1, coords2), axis=1)
        distances2 = np.min(cdist(coords2, coords1), axis=1)
        
        # Average surface distance
        asd = (np.mean(distances1) + np.mean(distances2)) / 2.0
        return asd
    except:
        return float('inf')

def surface_dice_coefficient(mask1, mask2, tolerance=1.0):
    """
    Surface Dice coefficient - measures boundary agreement within tolerance
    
    Args:
        mask1: First binary mask (numpy array)
        mask2: Second binary mask (numpy array)
        tolerance: Distance tolerance for surface matching
    
    Returns:
        Surface Dice coefficient
    """
    # Convert to binary
    mask1_binary = (mask1 > 0.5).astype(np.uint8)
    mask2_binary = (mask2 > 0.5).astype(np.uint8)
    
    # Find boundary points
    def get_boundary_points(mask):
        eroded = ndimage.binary_erosion(mask)
        boundary = mask.astype(np.float32) - eroded.astype(np.float32)
        coords = np.argwhere(boundary > 0)
        return coords
    
    coords1 = get_boundary_points(mask1_binary)
    coords2 = get_boundary_points(mask2_binary)
    
    if len(coords1) == 0 and len(coords2) == 0:
        return 1.0  # Both empty
    if len(coords1) == 0 or len(coords2) == 0:
        return 0.0  # One empty
    
    try:
        from scipy.spatial.distance import cdist
        
        # Find points within tolerance
        distances1 = np.min(cdist(coords1, coords2), axis=1)
        distances2 = np.min(cdist(coords2, coords1), axis=1)
        
        matched1 = np.sum(distances1 <= tolerance)
        matched2 = np.sum(distances2 <= tolerance)
        
        # Surface Dice coefficient
        surface_dice = (matched1 + matched2) / (len(coords1) + len(coords2))
        return surface_dice
    except:
        return 0.0

class DistanceMetrics:
    """
    Comprehensive distance-based metrics calculator
    """
    
    def __init__(self, hausdorff_percentile=95, surface_tolerance=1.0):
        self.hausdorff_percentile = hausdorff_percentile
        self.surface_tolerance = surface_tolerance
    
    def compute_distance_metrics(self, y_true_np, y_pred_np):
        """
        Compute all distance-based metrics
        
        Args:
            y_true_np: Ground truth mask (numpy array)
            y_pred_np: Predicted mask (numpy array)
        
        Returns:
            Dictionary with distance metrics
        """
        metrics = {}
        
        try:
            # Hausdorff Distance
            metrics['hausdorff_distance'] = hausdorff_distance_numpy(
                y_true_np, y_pred_np, self.hausdorff_percentile)
            
            # Average Surface Distance
            metrics['average_surface_distance'] = average_surface_distance(
                y_true_np, y_pred_np)
            
            # Surface Dice Coefficient
            metrics['surface_dice'] = surface_dice_coefficient(
                y_true_np, y_pred_np, self.surface_tolerance)
            
        except Exception as e:
            print(f"Warning: Error computing distance metrics: {e}")
            metrics['hausdorff_distance'] = float('inf')
            metrics['average_surface_distance'] = float('inf')
            metrics['surface_dice'] = 0.0
        
        return metrics
    
    def batch_compute_distance_metrics(self, y_true_batch, y_pred_batch):
        """
        Compute distance metrics for a batch of predictions
        
        Args:
            y_true_batch: Batch of ground truth masks
            y_pred_batch: Batch of predicted masks
        
        Returns:
            Dictionary with averaged metrics
        """
        batch_metrics = {
            'hausdorff_distance': [],
            'average_surface_distance': [],
            'surface_dice': []
        }
        
        for i in range(len(y_true_batch)):
            metrics = self.compute_distance_metrics(y_true_batch[i], y_pred_batch[i])
            
            for key, value in metrics.items():
                if not np.isinf(value) and not np.isnan(value):
                    batch_metrics[key].append(value)
        
        # Calculate averages
        averaged_metrics = {}
        for key, values in batch_metrics.items():
            if values:
                averaged_metrics[f'{key}_mean'] = np.mean(values)
                averaged_metrics[f'{key}_std'] = np.std(values)
                averaged_metrics[f'{key}_median'] = np.median(values)
            else:
                averaged_metrics[f'{key}_mean'] = float('inf')
                averaged_metrics[f'{key}_std'] = 0.0
                averaged_metrics[f'{key}_median'] = float('inf')
        
        return averaged_metrics

# Initialize distance metrics calculator
distance_metrics = DistanceMetrics()

print("Distance-Based Metrics Implemented:")
print("- Hausdorff Distance (with percentile option)")
print("- Average Surface Distance (ASD)")
print("- Surface Dice Coefficient")
print("- DistanceMetrics class for batch processing")

Distance-Based Metrics Implemented:
- Hausdorff Distance (with percentile option)
- Average Surface Distance (ASD)
- Surface Dice Coefficient
- DistanceMetrics class for batch processing


In [11]:
# Comprehensive Evaluation Framework

class SegmentationEvaluator:
    """
    Comprehensive evaluation framework for medical image segmentation
    """
    
    def __init__(self, threshold=0.5, smooth=1e-6, hausdorff_percentile=95):
        self.threshold = threshold
        self.smooth = smooth
        self.medical_metrics = MedicalMetrics(threshold, smooth)
        self.distance_metrics = DistanceMetrics(hausdorff_percentile)
        
    def evaluate_single_prediction(self, y_true, y_pred, include_distance=True):
        """
        Comprehensive evaluation of a single prediction
        
        Args:
            y_true: Ground truth mask
            y_pred: Predicted mask
            include_distance: Whether to compute distance-based metrics
        
        Returns:
            Dictionary with all computed metrics
        """
        results = {}
        
        # Convert to numpy if needed
        if hasattr(y_true, 'numpy'):
            y_true_np = y_true.numpy()
            y_pred_np = y_pred.numpy()
        else:
            y_true_np = y_true
            y_pred_np = y_pred
        
        # Basic medical metrics
        basic_metrics = self.medical_metrics.compute_metrics_numpy(y_true_np, y_pred_np)
        results.update(basic_metrics)
        
        # Distance-based metrics (optional, as they can be computationally expensive)
        if include_distance:
            try:
                distance_metrics_result = self.distance_metrics.compute_distance_metrics(
                    y_true_np, y_pred_np)
                results.update(distance_metrics_result)
            except Exception as e:
                print(f"Warning: Could not compute distance metrics: {e}")
                results.update({
                    'hausdorff_distance': float('inf'),
                    'average_surface_distance': float('inf'),
                    'surface_dice': 0.0
                })
        
        return results
    
    def evaluate_batch(self, y_true_batch, y_pred_batch, include_distance=False):
        """
        Evaluate a batch of predictions
        
        Args:
            y_true_batch: Batch of ground truth masks
            y_pred_batch: Batch of predicted masks
            include_distance: Whether to compute distance-based metrics
        
        Returns:
            Dictionary with aggregated metrics
        """
        all_metrics = []
        
        for i in range(len(y_true_batch)):
            metrics = self.evaluate_single_prediction(
                y_true_batch[i], y_pred_batch[i], include_distance)
            all_metrics.append(metrics)
        
        # Aggregate results
        aggregated = self._aggregate_metrics(all_metrics)
        return aggregated
    
    def _aggregate_metrics(self, metrics_list):
        """
        Aggregate metrics from multiple predictions
        
        Args:
            metrics_list: List of metric dictionaries
        
        Returns:
            Dictionary with aggregated statistics
        """
        if not metrics_list:
            return {}
        
        # Get all metric names
        metric_names = metrics_list[0].keys()
        aggregated = {}
        
        for metric_name in metric_names:
            values = [m[metric_name] for m in metrics_list 
                     if not np.isinf(m[metric_name]) and not np.isnan(m[metric_name])]
            
            if values:
                aggregated[f'{metric_name}_mean'] = np.mean(values)
                aggregated[f'{metric_name}_std'] = np.std(values)
                aggregated[f'{metric_name}_median'] = np.median(values)
                aggregated[f'{metric_name}_min'] = np.min(values)
                aggregated[f'{metric_name}_max'] = np.max(values)
                aggregated[f'{metric_name}_count'] = len(values)
            else:
                aggregated[f'{metric_name}_mean'] = float('nan')
                aggregated[f'{metric_name}_std'] = float('nan')
                aggregated[f'{metric_name}_median'] = float('nan')
                aggregated[f'{metric_name}_min'] = float('nan')
                aggregated[f'{metric_name}_max'] = float('nan')
                aggregated[f'{metric_name}_count'] = 0
        
        return aggregated
    
    def create_metrics_report(self, metrics_dict, title="Segmentation Evaluation Report"):
        """
        Create a formatted report from metrics
        
        Args:
            metrics_dict: Dictionary with computed metrics
            title: Report title
        
        Returns:
            Formatted string report
        """
        report = f"\n{'='*60}\n{title}\n{'='*60}\n"
        
        # Group metrics by type
        basic_metrics = ['dice', 'iou', 'sensitivity', 'specificity', 'precision', 'f1', 'volume_similarity']
        distance_metrics = ['hausdorff_distance', 'average_surface_distance', 'surface_dice']
        
        # Basic metrics
        report += "\n📊 BASIC SEGMENTATION METRICS\n" + "-"*40 + "\n"
        for metric in basic_metrics:
            if f'{metric}_mean' in metrics_dict:
                mean_val = metrics_dict[f'{metric}_mean']
                std_val = metrics_dict[f'{metric}_std']
                report += f"{metric.replace('_', ' ').title():20}: {mean_val:.4f} ± {std_val:.4f}\n"
        
        # Distance metrics
        report += "\n📏 DISTANCE-BASED METRICS\n" + "-"*40 + "\n"
        for metric in distance_metrics:
            if f'{metric}_mean' in metrics_dict:
                mean_val = metrics_dict[f'{metric}_mean']
                std_val = metrics_dict[f'{metric}_std']
                if not np.isinf(mean_val):
                    report += f"{metric.replace('_', ' ').title():20}: {mean_val:.4f} ± {std_val:.4f}\n"
                else:
                    report += f"{metric.replace('_', ' ').title():20}: Not computed\n"
        
        report += "\n" + "="*60 + "\n"
        return report

# Loss Function Testing and Comparison
def test_loss_functions():
    """
    Test different loss functions with synthetic data
    """
    print("🧪 Testing Loss Functions with Synthetic Data")
    print("-" * 50)
    
    # Create synthetic data
    batch_size, height, width = 4, 128, 128
    
    # Ground truth: circular mask
    y_true = np.zeros((batch_size, height, width, 1), dtype=np.float32)
    center = height // 2
    radius = 30
    
    for b in range(batch_size):
        for i in range(height):
            for j in range(width):
                if (i - center)**2 + (j - center)**2 <= radius**2:
                    y_true[b, i, j, 0] = 1.0
    
    # Prediction: slightly offset circular mask
    y_pred = np.zeros_like(y_true)
    offset_center = center + 5
    
    for b in range(batch_size):
        for i in range(height):
            for j in range(width):
                if (i - offset_center)**2 + (j - offset_center)**2 <= radius**2:
                    y_pred[b, i, j, 0] = 0.8  # Soft prediction
    
    # Convert to tensors
    y_true_tf = tf.constant(y_true)
    y_pred_tf = tf.constant(y_pred)
    
    # Test different loss functions
    loss_functions = {
        'Dice Loss': dice_loss,
        'BCE Loss': binary_crossentropy_loss,
        'Focal Loss': lambda yt, yp: focal_loss(yt, yp, alpha=0.25, gamma=2.0),
        'Tversky Loss': lambda yt, yp: tversky_loss(yt, yp, alpha=0.3, beta=0.7),
        'Dice+BCE Loss': lambda yt, yp: dice_bce_loss(yt, yp, 0.5, 0.5),
        'Boundary Loss': boundary_loss,
    }
    
    print("Loss Function Comparison:")
    for name, loss_fn in loss_functions.items():
        try:
            with tf.GradientTape():
                loss_value = loss_fn(y_true_tf, y_pred_tf)
            print(f"{name:15}: {loss_value.numpy():.6f}")
        except Exception as e:
            print(f"{name:15}: Error - {str(e)[:50]}")
    
    return y_true, y_pred

# Initialize evaluator
evaluator = SegmentationEvaluator()

print("\n✅ Comprehensive Evaluation Framework Implemented:")
print("- SegmentationEvaluator class")
print("- Single prediction evaluation")
print("- Batch evaluation with aggregation")
print("- Metrics reporting")
print("- Loss function testing utilities")


✅ Comprehensive Evaluation Framework Implemented:
- SegmentationEvaluator class
- Single prediction evaluation
- Batch evaluation with aggregation
- Metrics reporting
- Loss function testing utilities


In [12]:
# Visualization and Demonstration

def visualize_loss_functions():
    """
    Visualize the behavior of different loss functions
    """
    print("📈 Visualizing Loss Function Behavior")
    print("-" * 40)
    
    # Create test data with varying prediction quality
    prediction_qualities = np.linspace(0.1, 0.9, 20)
    
    # Create synthetic ground truth (simple circle)
    size = 64
    y_true = np.zeros((size, size))
    center = size // 2
    radius = 15
    
    for i in range(size):
        for j in range(size):
            if (i - center)**2 + (j - center)**2 <= radius**2:
                y_true[i, j] = 1.0
    
    # Calculate losses for different prediction qualities
    loss_results = {
        'Dice Loss': [],
        'BCE Loss': [],
        'Focal Loss': [],
        'Tversky Loss': [],
        'Dice+BCE': []
    }
    
    for quality in prediction_qualities:
        # Create prediction with varying quality
        y_pred = y_true * quality + (1 - y_true) * (1 - quality) * 0.1
        
        # Convert to tensors
        y_true_tf = tf.constant(y_true.reshape(1, size, size, 1), dtype=tf.float32)
        y_pred_tf = tf.constant(y_pred.reshape(1, size, size, 1), dtype=tf.float32)
        
        # Calculate losses
        loss_results['Dice Loss'].append(dice_loss(y_true_tf, y_pred_tf).numpy())
        loss_results['BCE Loss'].append(binary_crossentropy_loss(y_true_tf, y_pred_tf).numpy())
        loss_results['Focal Loss'].append(focal_loss(y_true_tf, y_pred_tf).numpy())
        loss_results['Tversky Loss'].append(tversky_loss(y_true_tf, y_pred_tf).numpy())
        loss_results['Dice+BCE'].append(dice_bce_loss(y_true_tf, y_pred_tf).numpy())
    
    # Plot results
    plt.figure(figsize=(12, 8))
    
    for loss_name, loss_values in loss_results.items():
        plt.plot(prediction_qualities, loss_values, marker='o', label=loss_name, linewidth=2)
    
    plt.xlabel('Prediction Quality (IoU-like measure)')
    plt.ylabel('Loss Value')
    plt.title('Comparison of Loss Functions vs Prediction Quality')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return loss_results

def demonstrate_metrics_calculation():
    """
    Demonstrate metrics calculation with real examples
    """
    print("🔍 Demonstrating Metrics Calculation")
    print("-" * 40)
    
    # Generate synthetic test data
    y_true_test, y_pred_test = test_loss_functions()
    
    # Evaluate with comprehensive metrics
    print("\nEvaluating synthetic predictions...")
    results = evaluator.evaluate_batch(y_true_test, y_pred_test, include_distance=False)
    
    # Generate report
    report = evaluator.create_metrics_report(results, "Synthetic Data Evaluation")
    print(report)
    
    return results

def plot_metrics_comparison():
    """
    Create a comprehensive metrics comparison plot
    """
    # Generate different prediction scenarios
    scenarios = {
        'Perfect': (1.0, 0.0),      # Perfect prediction, no offset
        'Good': (0.9, 2),           # Good prediction, small offset
        'Moderate': (0.7, 5),       # Moderate prediction, medium offset
        'Poor': (0.5, 10)           # Poor prediction, large offset
    }
    
    all_results = {}
    
    # Create test data for each scenario
    for scenario_name, (quality, offset) in scenarios.items():
        # Create synthetic data
        size = 64
        batch_size = 8
        
        y_true = np.zeros((batch_size, size, size))
        y_pred = np.zeros((batch_size, size, size))
        
        center = size // 2
        radius = 15
        
        for b in range(batch_size):
            # Ground truth
            for i in range(size):
                for j in range(size):
                    if (i - center)**2 + (j - center)**2 <= radius**2:
                        y_true[b, i, j] = 1.0
            
            # Prediction with offset and quality variation
            pred_center = center + offset
            for i in range(size):
                for j in range(size):
                    if (i - pred_center)**2 + (j - pred_center)**2 <= radius**2:
                        y_pred[b, i, j] = quality
        
        # Evaluate
        results = evaluator.evaluate_batch(y_true, y_pred, include_distance=False)
        all_results[scenario_name] = results
    
    # Plot comparison
    metrics_to_plot = ['dice_mean', 'iou_mean', 'sensitivity_mean', 'specificity_mean', 'f1_mean']
    
    fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(20, 4))
    fig.suptitle('Metrics Comparison Across Different Prediction Scenarios', fontsize=16)
    
    for idx, metric in enumerate(metrics_to_plot):
        values = [all_results[scenario][metric] for scenario in scenarios.keys()]
        
        bars = axes[idx].bar(scenarios.keys(), values, alpha=0.7, 
                           color=['green', 'blue', 'orange', 'red'])
        axes[idx].set_title(metric.replace('_mean', '').replace('_', ' ').title())
        axes[idx].set_ylabel('Score')
        axes[idx].set_ylim(0, 1)
        
        # Add value labels on bars
        for bar, value in zip(bars, values):
            axes[idx].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                          f'{value:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    return all_results

print("📊 Visualization and Demonstration Functions:")
print("- visualize_loss_functions(): Compare loss function behaviors")
print("- demonstrate_metrics_calculation(): Show comprehensive evaluation")
print("- plot_metrics_comparison(): Compare metrics across scenarios")
print("\nRun these functions to see the visualizations!")

📊 Visualization and Demonstration Functions:
- visualize_loss_functions(): Compare loss function behaviors
- demonstrate_metrics_calculation(): Show comprehensive evaluation
- plot_metrics_comparison(): Compare metrics across scenarios

Run these functions to see the visualizations!


## 📋 Summary and Next Steps

### ✅ What We've Implemented

This notebook provides a comprehensive suite of loss functions and metrics specifically designed for medical image segmentation:

#### **Loss Functions**
1. **Individual Loss Functions**:
   - Dice Loss (standard and generalized)
   - Binary Cross-Entropy Loss
   - Focal Loss (for handling class imbalance)
   - Tversky Loss (generalized Dice)

2. **Hybrid Loss Functions**:
   - Dice + BCE Loss (balanced approach)
   - Focal + Dice Loss (imbalance-aware)
   - Tversky + Focal Loss (advanced combination)
   - Adaptive Loss (dynamic weighting)

3. **Boundary-Aware Losses**:
   - Boundary Loss (edge detection focus)
   - Surface Loss (boundary alignment)
   - Hausdorff Loss (distance-based)
   - Combined Boundary-Aware Loss

#### **Evaluation Metrics**
1. **Basic Segmentation Metrics**:
   - Dice Coefficient
   - IoU (Intersection over Union)
   - Sensitivity (Recall)
   - Specificity
   - Precision
   - F1 Score
   - Volume Similarity

2. **Distance-Based Metrics**:
   - Hausdorff Distance (with percentile option)
   - Average Surface Distance
   - Surface Dice Coefficient

3. **Evaluation Framework**:
   - Comprehensive evaluator class
   - Batch processing capabilities
   - Statistical aggregation
   - Formatted reporting

### 🔧 Key Features

- **Medical-Specific Design**: All functions optimized for medical segmentation challenges
- **Numerical Stability**: Proper handling of edge cases and numerical issues
- **Modular Architecture**: Easy to integrate with training pipelines
- **Comprehensive Evaluation**: Both basic and advanced metrics
- **Visualization Tools**: Functions to compare and visualize performance
- **Production Ready**: Efficient implementations suitable for training

### 🎯 Usage in Training Pipeline

```python
# Example usage in model compilation
model.compile(
    optimizer='adam',
    loss=dice_bce_loss,  # or any other implemented loss
    metrics=[dice_coefficient, iou_score, sensitivity_score]
)

# Example usage in evaluation
evaluator = SegmentationEvaluator()
results = evaluator.evaluate_batch(y_true, y_pred, include_distance=True)
report = evaluator.create_metrics_report(results)
print(report)
```

### 🚀 Next Steps

1. **Integration with Training Pipeline** (Notebook 05):
   - Implement training loop with these loss functions
   - Add callbacks for early stopping based on metrics
   - Learning rate scheduling strategies

2. **Model Evaluation** (Notebook 06):
   - Comprehensive model testing
   - Validation set evaluation
   - Cross-validation strategies

3. **Post-processing** (Notebook 07):
   - Morphological operations
   - Connected component analysis
   - Anatomical validation

### 💡 Tips for Usage

1. **Loss Function Selection**:
   - Use Dice + BCE for balanced training
   - Use Focal Loss for highly imbalanced datasets
   - Use boundary-aware losses when edge precision is critical

2. **Metric Interpretation**:
   - Dice > 0.8 is generally considered good for medical segmentation
   - Hausdorff distance should be interpreted in context of image resolution
   - Always report multiple metrics for comprehensive evaluation

3. **Computational Considerations**:
   - Distance-based metrics are computationally expensive
   - Use them for final evaluation, not during training
   - Consider batch processing for efficiency

---

**Ready to proceed to the next notebook: `05_Training_Pipeline.ipynb`** 🎓