# Clock Drawing Segmentation with Deep Learning in the MoCA Test

This notebook implements an advanced segmentation system for clock drawings in the Montreal Cognitive Assessment (MoCA) test using deep learning techniques. The system employs a U-Net++ architecture with SE-ResNet50 backbone to segment different components of clock drawings (entire clock, numbers, hands, and contour) to assist in the automated analysis for cognitive impairment detection.

### Authors and Contact Information
- **Diego Aldahir Tovar Ledesma** - diego.tovar@udem.edu
- **Jorge Rodrigo Gómez Mayo** - jorger.gomez@udem.edu

**Organization:** Universidad de Monterrey  
**First created:** April 2025

### Project Overview
This segmentation model is designed to identify four key components in clock drawings:
- The entire clock face
- Numbers on the clock face
- Clock hands (hour and minute)
- Clock contour/outline

By accurately segmenting these elements, the system provides objective measurements that can help medical professionals detect early signs of cognitive impairment, particularly in conditions like Parkinson's disease and dementia.

### Technical Implementation
- **Architecture:** U-Net++ with SE-ResNet50 encoder (pre-trained on ImageNet)
- **Loss Function:** Combined Dice Loss and Focal Loss with class weighting
- **Data Processing:** Enhanced thin line detection for clock hands and contours
- **Performance Metrics:** IoU (Intersection over Union) for each component

### Usage Instructions
The notebook is structured in sequential sections covering data loading, preprocessing, model definition, training, and evaluation. To run the entire pipeline, execute each cell in order.

## Initial Setup and Imports

This code imports necessary libraries for the project, including data manipulation, image processing, deep learning frameworks, and visualization tools. It also sets up a reproducibility function to ensure consistent results across runs.

In [None]:
# Standard libraries
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import cv2
from tqdm import tqdm
import random

# PyTorch and deep learning libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2

def seed_everything(seed=42):
    """
    Sets random seeds for reproducibility across all random number generators.
    
    Args:
        seed (int, optional): The seed value to use. Defaults to 42.
    
    Returns:
        None
    
    Note:
        This function sets seeds for Python's random module, NumPy, PyTorch CPU and
        GPU operations, and configures CUDA backend for deterministic behavior.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Initialize seed for reproducibility
seed_everything()

## Configuration Setup and Device Checking

This code sets up a version-controlled checkpoint system and defines configuration parameters for the neural network training process, including device selection, training hyperparameters, model architecture, and dataset paths.

In [None]:
def get_next_checkpoint_version(base_dir='checkpoints'):
    """
    Creates a new versioned directory for saving model checkpoints.
    
    Args:
        base_dir (str, optional): Base directory for all checkpoint versions. 
            Defaults to 'checkpoints'.
    
    Returns:
        str: Path to the newly created checkpoint directory with incremented version.
    
    Note:
        This function automatically increments version numbers (v1, v2, etc.)
        based on existing directories.
    """
    os.makedirs(base_dir, exist_ok=True)  # Create base folder if it doesn't exist

    version_pattern = re.compile(r'^v(\d+)$')
    existing_versions = []

    for name in os.listdir(base_dir):
        match = version_pattern.match(name)
        if match:
            existing_versions.append(int(match.group(1)))

    next_version = max(existing_versions, default=0) + 1
    new_path = os.path.join(base_dir, f'v{next_version}')
    os.makedirs(new_path, exist_ok=True)  # Create the new folder

    return new_path

# Create checkpoint path dynamically before defining the class
DYNAMIC_CHECKPOINT_PATH = get_next_checkpoint_version()

class Config:
    """
    Configuration class containing all parameters for model training and evaluation.
    
    Attributes:
        DEVICE (torch.device): Computing device (CPU, MPS for Apple Silicon, or CUDA).
        EPOCHS (int): Number of training epochs.
        BATCH_SIZE (int): Batch size for training.
        LEARNING_RATE (float): Learning rate for optimizer.
        IMG_SIZE (int): Target image size for resizing.
        ENCODER (str): Backbone encoder architecture.
        ENCODER_WEIGHTS (str): Pre-trained weights for encoder.
        CLASSES (list): List of segmentation class names.
        NUM_CLASSES (int): Number of segmentation classes.
        DATASET_PATH (str): Path to training dataset.
        TEST_PATH (str): Path to test dataset.
        CHECKPOINT_PATH (str): Path to save model checkpoints.
        PATIENCE (int): Number of epochs with no improvement for early stopping.
        VAL_FREQUENCY (int): Validation frequency in epochs.
        SAVE_FREQUENCY (int): Model saving frequency in epochs.
        VISUALIZATION_FREQUENCY (int): Visualization frequency in epochs.
    """
    # Device for training
    DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")  # For MacBook with Apple Silicon

    # Training parameters
    EPOCHS = 100
    BATCH_SIZE = 4
    LEARNING_RATE = 1e-4
    IMG_SIZE = 256  # Size for resizing images

    # Model
    ENCODER = 'se_resnet50'  # Backbone encoder architecture
    ENCODER_WEIGHTS = 'imagenet'  # Pre-trained weights
    CLASSES = ['entire', 'numbers', 'hands', 'contour']  # Classes to segment
    NUM_CLASSES = len(CLASSES)

    # Paths
    DATASET_PATH = "/Users/diegotovar/Pictures/MoCA/Images Mask Files Exit/train_aumented_v2"
    TEST_PATH = "/Users/diegotovar/Pictures/MoCA/Images Mask Files Exit/test"
    CHECKPOINT_PATH = DYNAMIC_CHECKPOINT_PATH

    # Additional parameters
    PATIENCE = 10  # For early stopping
    VAL_FREQUENCY = 1  # How often to validate (in epochs)
    SAVE_FREQUENCY = 5  # How often to save models (in epochs)
    VISUALIZATION_FREQUENCY = 5  # How often to visualize results (in epochs)

# Verify available device
print(f"Training device: {Config.DEVICE}")
print(f"MPS available: {torch.backends.mps.is_available()}" if hasattr(torch.backends, 'mps') else "MPS not available")
print(f"Checkpoint path created: {Config.CHECKPOINT_PATH}")

## Data Verification Function

This code implements a thorough verification process for the training dataset, checking that all required files exist and have the expected format. It also visualizes sample masks to help identify potential issues in the segmentation data.

In [None]:
def verify_data():
    """
    Verifies that the dataset exists and has the correct format.
    
    This function checks:
    - If the dataset directory exists
    - If there are folders within the dataset directory
    - If each folder contains a background image
    - If each folder contains the required mask files for each class
    - If mask files have proper alpha channels and content
    
    Returns:
        bool: True if verification succeeds, False otherwise.
    
    Note:
        Sample masks will be saved to the checkpoint directory for visual inspection.
        For the first sample folder, it will save visualizations of all class masks.
    """
    print(f"Verifying data in {Config.DATASET_PATH}")
    
    # Check if directory exists
    if not os.path.exists(Config.DATASET_PATH):
        print(f"ERROR: Directory {Config.DATASET_PATH} does not exist.")
        return False
    
    # Get folders
    folders = [f for f in os.listdir(Config.DATASET_PATH) 
               if os.path.isdir(os.path.join(Config.DATASET_PATH, f))]
    
    if len(folders) == 0:
        print(f"ERROR: No folders found in {Config.DATASET_PATH}")
        return False
    
    print(f"Found {len(folders)} folders.")
    
    # Verify structure in some random folders
    sample_folders = random.sample(folders, min(5, len(folders)))
    
    for folder in sample_folders:
        folder_path = os.path.join(Config.DATASET_PATH, folder)
        print(f"\nVerifying folder: {folder}")
        
        # Look for background image
        background_files = glob(os.path.join(folder_path, f"{folder}_[Bb]ackground.*"))
        if not background_files:
            print(f"  WARNING: No background image found in {folder}")
            continue
        
        background_path = background_files[0]
        print(f"  Background image: {os.path.basename(background_path)}")
        
        # Verify masks
        for class_name in Config.CLASSES:
            mask_path = os.path.join(folder_path, f"{folder}_{class_name}.*")
            mask_files = glob(mask_path)
            
            if not mask_files:
                print(f"  WARNING: No mask found for {class_name}")
                continue
            
            mask_file = mask_files[0]
            # Read mask with alpha channel (transparency)
            mask = cv2.imread(mask_file, cv2.IMREAD_UNCHANGED)
            
            if mask is None:
                print(f"  ERROR: Could not read mask {os.path.basename(mask_file)}")
                continue
            
            # Check if mask has alpha channel
            has_alpha = mask.shape[-1] == 4
            alpha_channel = mask[:,:,3] if has_alpha else None
            
            # Verify mask has data
            mask_sum = np.sum(alpha_channel) if has_alpha else np.sum(mask)
            mask_max = np.max(alpha_channel) if has_alpha else np.max(mask)
            
            print(f"  Mask {class_name}: sum={mask_sum}, max={mask_max}, "
                  f"shape={mask.shape}, has_alpha={has_alpha}")
            
            # Visualize the first mask of each type for the first folder
            if folder == sample_folders[0]:
                plt.figure(figsize=(5, 5))
                # If it has alpha channel, show only that channel
                if has_alpha:
                    plt.imshow(alpha_channel, cmap='gray')
                else:
                    plt.imshow(cv2.cvtColor(mask, cv2.COLOR_BGR2RGB))
                plt.title(f'Mask {class_name}')
                plt.colorbar()
                plt.savefig(os.path.join(Config.CHECKPOINT_PATH, f'sample_mask_{class_name}.png'))
                plt.close()
    
    return True

# Run verification
verify_data()

## Path Collection for Training Data

This code retrieves and validates all image and mask paths from the dataset directory, keeping track of images that have incomplete mask sets while ensuring all files can be properly loaded.

In [None]:
def get_img_paths():
    """
    Gets paths for all images and their corresponding masks, even if they don't have all masks.
    
    This function:
    - Scans all folders in the dataset directory
    - Finds background images and corresponding masks
    - Verifies that the images and masks can be read
    - Tracks images with incomplete mask sets
    
    Returns:
        list: List of dictionaries containing image paths and available mask paths
    """
    img_folders = sorted([f for f in os.listdir(Config.DATASET_PATH) 
                          if os.path.isdir(os.path.join(Config.DATASET_PATH, f))])
    
    img_data = []
    incomplete_mask_count = 0
    
    for folder in tqdm(img_folders, desc="Loading data"):
        folder_path = os.path.join(Config.DATASET_PATH, folder)
        
        # Look for background image (may have different extensions and case)
        background_files = glob(os.path.join(folder_path, f"{folder}_[Bb]ackground.*"))
        
        if background_files:
            background_path = background_files[0]
            
            # Verify image can be read
            img = cv2.imread(background_path)
            if img is None:
                continue
            
            # Check available masks for each class
            masks = {}
            has_at_least_one_mask = False
            
            for class_name in Config.CLASSES:
                mask_files = glob(os.path.join(folder_path, f"{folder}_{class_name}.*"))
                
                if mask_files:
                    mask_path = mask_files[0]
                    # Verify mask can be read
                    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
                    if mask is not None:
                        masks[class_name] = mask_path
                        has_at_least_one_mask = True
            
            # Discard if it doesn't have at least one mask
            if not has_at_least_one_mask:
                continue
                
            # Increment counter if it doesn't have all masks
            if len(masks) < len(Config.CLASSES):
                incomplete_mask_count += 1
            
            img_data.append({
                'img_path': background_path,
                'masks': masks
            })
    
    print(f"Total images found: {len(img_data)}")
    print(f"Images with incomplete masks: {incomplete_mask_count}")
    
    return img_data

# Get image paths
train_img_data = get_img_paths()

## Dataset Creation and Data Processing Pipeline

This code defines a custom dataset class for clock segmentation with special handling for thin lines and transparency. It also implements data augmentation techniques and creates the training and validation data loaders.


In [None]:
class ClockSegmentationDataset(Dataset):
    """
    Dataset for clock segmentation with enhancements for thin lines and transparency support.
    
    This dataset handles multiple mask classes and includes special processing for 
    thin elements like contours and clock hands.
    
    Args:
        img_data (list): List of dictionaries containing image and mask paths.
        transform (albumentations.Compose, optional): Transformations to apply. Defaults to None.
        enhance_masks (bool, optional): Whether to enhance thin lines. Defaults to True.
        dilation_kernel_size (int, optional): Size of dilation kernel. Defaults to 3.
    """
    def __init__(self, img_data, transform=None, enhance_masks=True, dilation_kernel_size=3):
        self.img_data = img_data
        self.transform = transform
        self.enhance_masks = enhance_masks
        self.dilation_kernel_size = dilation_kernel_size
        
    def __len__(self):
        return len(self.img_data)
    
    def enhance_thin_lines(self, mask):
        """
        Enhances thin lines in the mask through dilation and refinement.
        
        Args:
            mask (numpy.ndarray): Input binary mask.
            
        Returns:
            numpy.ndarray: Enhanced binary mask with thicker lines.
        """
        # Create kernel for dilation
        kernel = np.ones((self.dilation_kernel_size, self.dilation_kernel_size), np.uint8)
        
        # Dilate mask to thicken lines
        dilated = cv2.dilate(mask, kernel, iterations=1)
        
        # Optional: Smooth edges with a very light gaussian filter
        # This helps reduce noise while maintaining structure
        smoothed = cv2.GaussianBlur(dilated, (3, 3), 0.5)
        
        # Re-binarize to have a clear mask
        _, enhanced = cv2.threshold(smoothed, 0.5, 1.0, cv2.THRESH_BINARY)
        
        return enhanced
    
    def __getitem__(self, idx):
        sample = self.img_data[idx]
        img_path = sample['img_path']
        mask_paths = sample['masks']
        
        # Load image
        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Could not read image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Scale original image to maintain consistent proportions with masks
        original_height, original_width = image.shape[:2]
        
        # Load masks
        masks = []
        
        for class_name in Config.CLASSES:
            if class_name in mask_paths and os.path.exists(mask_paths[class_name]):
                mask_path = mask_paths[class_name]
                # Read with alpha channel
                mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
                
                if mask is None:
                    # If mask can't be read, use an empty one
                    mask = np.zeros((original_height, original_width), dtype=np.uint8)
                else:
                    # Resize to original image dimensions if different
                    if mask.shape[:2] != (original_height, original_width):
                        mask = cv2.resize(mask, (original_width, original_height), 
                                         interpolation=cv2.INTER_NEAREST)
                    
                    # Extract alpha channel as mask if it exists
                    if mask.shape[-1] == 4:  # RGBA
                        alpha_channel = mask[:, :, 3]
                        # Binarize alpha channel
                        _, binary_mask = cv2.threshold(alpha_channel, 127, 255, cv2.THRESH_BINARY)
                        mask = binary_mask
                    else:  # No alpha channel, convert to grayscale
                        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
                        _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
                    
                    # Normalize mask to values of 0 and 1
                    mask = mask / 255.0
                    
                    # Apply enhancement for thin lines if enabled
                    if self.enhance_masks and class_name in ['contour', 'hands']:  # Classes with thin lines
                        mask = self.enhance_thin_lines(mask)
            else:
                # If mask doesn't exist for this class, use an empty one
                mask = np.zeros((original_height, original_width), dtype=np.float32)
            
            masks.append(mask)
        
        # Create a single mask with channels for each class
        multi_mask = np.stack(masks, axis=-1).astype(np.float32)
        
        # Apply transformations if they exist
        if self.transform:
            try:
                augmented = self.transform(image=image, mask=multi_mask)
                image = augmented['image']
                multi_mask = augmented['mask']
            except Exception as e:
                print(f"Error during transformation: {e}")
                print(f"Image details: {img_path}, shape={image.shape}")
                raise
            
        # Ensure tensors are in the correct format
        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image.transpose(2, 0, 1))
            
        if not isinstance(multi_mask, torch.Tensor):
            multi_mask = torch.from_numpy(multi_mask.transpose(2, 0, 1))
        elif multi_mask.dim() == 3 and multi_mask.shape[0] != Config.NUM_CLASSES:
            multi_mask = multi_mask.permute(2, 0, 1)
        
        return image, multi_mask


def ensure_binary_mask(mask, threshold=0.5):
    """
    Ensures that the mask remains binary (0 or 1) after transformations.
    
    Args:
        mask (numpy.ndarray): Input mask that may have non-binary values.
        threshold (float, optional): Threshold for binarization. Defaults to 0.5.
        
    Returns:
        numpy.ndarray: Binary mask with values 0 or 1.
    """
    return (mask > threshold).astype(np.float32)


def get_transforms(phase):
    """
    Defines enhanced transformations that preserve thin lines.
    
    Args:
        phase (str): Either 'train' or another value (for validation/testing).
        
    Returns:
        albumentations.Compose: Composition of transformations.
    """
    if phase == 'train':
        return A.Compose([
            # Resize with interpolation more suitable for binary masks
            A.Resize(Config.IMG_SIZE, Config.IMG_SIZE, interpolation=cv2.INTER_NEAREST),
            
            # Transformations that maintain thin line structure
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            ], p=0.5),
            
            # Limited noise to avoid distorting thin masks
            A.OneOf([
                A.GaussNoise(var_limit=(10, 30), p=0.5),
                A.GaussianBlur(blur_limit=3, p=0.5),
            ], p=0.3),
            
            # Gentle spatial transformations
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, 
                              interpolation=cv2.INTER_NEAREST, border_mode=cv2.BORDER_CONSTANT, p=0.5),
            
            # Lambda to ensure binary values in masks
            A.Lambda(mask=lambda x, **kwargs: (x > 0.5).astype(np.float32)),
            
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])
    else:  # valid or test
        return A.Compose([
            A.Resize(Config.IMG_SIZE, Config.IMG_SIZE, interpolation=cv2.INTER_NEAREST),
            
            # Also here for consistency
            A.Lambda(mask=lambda x, **kwargs: (x > 0.5).astype(np.float32)),
            
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

# Split into training and validation sets (20% for validation)
train_indices, valid_indices = train_test_split(
    range(len(train_img_data)), test_size=0.2, random_state=42)

# Create subsets
train_data = [train_img_data[i] for i in train_indices]
valid_data = [train_img_data[i] for i in valid_indices]

print(f"Training images: {len(train_data)}")
print(f"Validation images: {len(valid_data)}")

# Create datasets with enhancements for thin lines
train_dataset = ClockSegmentationDataset(
    train_data, 
    transform=get_transforms('train'),
    enhance_masks=True,
    dilation_kernel_size=3  # For thick lines like contours
)
valid_dataset = ClockSegmentationDataset(
    valid_data, 
    transform=get_transforms('valid'),
    enhance_masks=True,
    dilation_kernel_size=3
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=Config.BATCH_SIZE, 
    shuffle=True, num_workers=0, drop_last=True)
valid_loader = DataLoader(
    valid_dataset, batch_size=Config.BATCH_SIZE, 
    shuffle=False, num_workers=0)

## Mask Visualization and Enhancement Analysis

This code implements detailed visualization functions to inspect and compare different processing techniques for segmentation masks, with special focus on thin line enhancement and contour detection.

In [None]:
def visualize_mask_enhancements(dataset, idx=0, show_preprocessing=True):
    """
    Visualizes a sample from the dataset with comparison between original and enhanced masks.
    
    This function shows the original image alongside its masks in different processing stages.
    
    Args:
        dataset (ClockSegmentationDataset): The dataset to visualize from.
        idx (int, optional): Index of the sample to visualize. Defaults to 0.
        show_preprocessing (bool, optional): Whether to show intermediate preprocessing 
            steps. Defaults to True.
    
    Returns:
        tuple: (image, original_mask, enhanced_mask) as numpy arrays.
    """
    # Get original image and masks (without enhancements)
    temp_enhance_setting = dataset.enhance_masks
    dataset.enhance_masks = False
    image_orig, mask_orig = dataset[idx]
    
    # Restore configuration to get enhanced masks
    dataset.enhance_masks = temp_enhance_setting
    image, mask_enhanced = dataset[idx]
    
    img_path = dataset.img_data[idx]['img_path']
    img_name = os.path.basename(img_path).split('_Background')[0]
    
    # Convert tensors to numpy arrays for visualization
    if isinstance(image, torch.Tensor):
        # Denormalize image
        mean = torch.tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
        
        image = image * std + mean
        image = image.permute(1, 2, 0).numpy()
        image = np.clip(image, 0, 1)
        
        # Convert masks to numpy
        if isinstance(mask_orig, torch.Tensor):
            mask_orig = mask_orig.numpy()
        if isinstance(mask_enhanced, torch.Tensor):
            mask_enhanced = mask_enhanced.numpy()
    
    # Prepare visualization in two rows
    if show_preprocessing:
        fig, axes = plt.subplots(3, Config.NUM_CLASSES + 1, figsize=(20, 12))
        fig.suptitle(f"Sample {idx+1}: {img_name} - Mask Comparison", fontsize=16)
        
        # First row: Original image and dilated masks
        axes[0, 0].imshow(image)
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')
        
        # Second row: Original masks
        for i, class_name in enumerate(Config.CLASSES):
            if mask_orig.shape[0] == len(Config.CLASSES):  # If in C,H,W format
                axes[1, i+1].imshow(mask_orig[i], cmap='gray', vmin=0, vmax=1)
                axes[0, i+1].imshow(cv2.dilate(mask_orig[i].astype(np.float32), 
                                             np.ones((3,3), np.uint8), iterations=1), 
                                   cmap='gray', vmin=0, vmax=1)
            else:  # If in H,W,C format
                axes[1, i+1].imshow(mask_orig[:,:,i], cmap='gray', vmin=0, vmax=1)
                axes[0, i+1].imshow(cv2.dilate(mask_orig[:,:,i].astype(np.float32), 
                                             np.ones((3,3), np.uint8), iterations=1),
                                   cmap='gray', vmin=0, vmax=1)
            axes[0, i+1].set_title(f'Dilated mask: {class_name}')
            axes[1, i+1].set_title(f'Original mask: {class_name}')
            axes[0, i+1].axis('off')
            axes[1, i+1].axis('off')
        
        # Third row: Enhanced masks
        axes[1, 0].imshow(image)
        axes[1, 0].set_title('Image for reference')
        axes[1, 0].axis('off')
        
        axes[2, 0].imshow(image)
        axes[2, 0].set_title('Image for reference')
        axes[2, 0].axis('off')
        
        for i, class_name in enumerate(Config.CLASSES):
            if mask_enhanced.shape[0] == len(Config.CLASSES):
                axes[2, i+1].imshow(mask_enhanced[i], cmap='gray', vmin=0, vmax=1)
            else:
                axes[2, i+1].imshow(mask_enhanced[:,:,i], cmap='gray', vmin=0, vmax=1)
            axes[2, i+1].set_title(f'Enhanced mask: {class_name}')
            axes[2, i+1].axis('off')
    else:
        # Simple visualization without showing intermediate steps
        fig, axes = plt.subplots(2, Config.NUM_CLASSES + 1, figsize=(20, 8))
        fig.suptitle(f"Sample {idx+1}: {img_name}", fontsize=16)
        
        # First row: Original image and original masks
        axes[0, 0].imshow(image)
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')
        
        for i, class_name in enumerate(Config.CLASSES):
            if mask_orig.shape[0] == len(Config.CLASSES):
                axes[0, i+1].imshow(mask_orig[i], cmap='gray', vmin=0, vmax=1)
            else:
                axes[0, i+1].imshow(mask_orig[:,:,i], cmap='gray', vmin=0, vmax=1)
            axes[0, i+1].set_title(f'Original: {class_name}')
            axes[0, i+1].axis('off')
            
            # Second row: Enhanced masks
            if mask_enhanced.shape[0] == len(Config.CLASSES):
                axes[1, i+1].imshow(mask_enhanced[i], cmap='gray', vmin=0, vmax=1)
            else:
                axes[1, i+1].imshow(mask_enhanced[:,:,i], cmap='gray', vmin=0, vmax=1)
            axes[1, i+1].set_title(f'Enhanced: {class_name}')
            axes[1, i+1].axis('off')
        
        axes[1, 0].imshow(image)
        axes[1, 0].set_title('Original Image')
        axes[1, 0].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return image, mask_orig, mask_enhanced

def compare_contours(dataset, idx=0):
    """
    Specialized function to analyze contour detection in masks.
    
    This function applies different edge detection and morphological operations
    to analyze and enhance contours in the segmentation masks.
    
    Args:
        dataset (ClockSegmentationDataset): The dataset to visualize from.
        idx (int, optional): Index of the sample to analyze. Defaults to 0.
    """
    # Get the sample
    image, mask = dataset[idx]
    img_path = dataset.img_data[idx]['img_path']
    img_name = os.path.basename(img_path).split('_Background')[0]
    
    # Convert to numpy for processing
    if isinstance(image, torch.Tensor):
        # Denormalize image
        mean = torch.tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
        
        image = image * std + mean
        image = image.permute(1, 2, 0).numpy()
        image = np.clip(image, 0, 1)
        
        if isinstance(mask, torch.Tensor):
            mask = mask.numpy()
    
    # Extract contour mask (assuming it's a specific class)
    contour_idx = Config.CLASSES.index('contour')
    
    if mask.shape[0] == len(Config.CLASSES):  # C,H,W format
        contour_mask = mask[contour_idx]
    else:  # H,W,C format
        contour_mask = mask[:,:,contour_idx]
    
    # Convert for processing with OpenCV
    contour_mask_cv = (contour_mask * 255).astype(np.uint8)
    
    # Different edge detection methods
    edges_canny = cv2.Canny(contour_mask_cv, 50, 150)
    
    # Gentle dilation
    kernel1 = np.ones((3, 3), np.uint8)
    dilated = cv2.dilate(contour_mask_cv, kernel1, iterations=1)
    
    # Stronger dilation
    kernel2 = np.ones((5, 5), np.uint8)
    dilated_more = cv2.dilate(contour_mask_cv, kernel2, iterations=1)
    
    # Morphological filter to enhance lines
    kernel_line = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 5))
    morph_line = cv2.morphologyEx(contour_mask_cv, cv2.MORPH_CLOSE, kernel_line)
    
    # Visualize
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(f"Contour Analysis - {img_name}", fontsize=16)
    
    axes[0, 0].imshow(image)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(contour_mask, cmap='gray', vmin=0, vmax=1)
    axes[0, 1].set_title('Original Mask')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(edges_canny, cmap='gray')
    axes[0, 2].set_title('Canny Edge Detection')
    axes[0, 2].axis('off')
    
    axes[1, 0].imshow(dilated, cmap='gray')
    axes[1, 0].set_title('Gentle Dilation')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(dilated_more, cmap='gray')
    axes[1, 1].set_title('Strong Dilation')
    axes[1, 1].axis('off')
    
    axes[1, 2].imshow(morph_line, cmap='gray')
    axes[1, 2].set_title('Morphological Filter')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize some random samples with applied enhancements
print("\nSamples with enhanced thin line processing:")
for i in range(3):
    idx = random.randint(0, len(valid_dataset)-1)
    img_name = os.path.basename(valid_dataset.img_data[idx]['img_path']).split('_Background')[0]
    print(f"\nSample {i+1}: {img_name}")
    visualize_mask_enhancements(valid_dataset, idx=idx)

## Thin Line Processing and Method Comparison

This code defines advanced techniques for enhancing thin lines in segmentation masks, particularly useful for detecting contours and hands in clock drawings. It also includes a visualization function to compare different processing methods.

In [None]:
def process_thin_lines(mask, method='morphological', params=None):
    """
    Processes thin lines in masks using different techniques.
    
    This function implements various computer vision techniques to enhance thin lines
    in binary masks for improved segmentation performance.
    
    Args:
        mask (numpy.ndarray): Original mask (values between 0-1 or 0-255).
        method (str, optional): Method to use ('dilation', 'morphological', 'thinning', 'adaptive'). 
            Defaults to 'morphological'.
        params (dict, optional): Specific parameters for the chosen method. 
            Defaults to None.
        
    Returns:
        numpy.ndarray: Processed mask with the same value range as the input.
    """
    # Ensure mask is in 0-255 format for OpenCV
    if mask.max() <= 1.0:
        mask_cv = (mask * 255).astype(np.uint8)
    else:
        mask_cv = mask.astype(np.uint8)
    
    # Configure default parameters if not provided
    if params is None:
        params = {}
    
    if method == 'dilation':
        # Simple dilation to thicken lines
        kernel_size = params.get('kernel_size', 3)
        iterations = params.get('iterations', 1)
        
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        processed = cv2.dilate(mask_cv, kernel, iterations=iterations)
        
    elif method == 'morphological':
        # More advanced morphological operations
        kernel_size = params.get('kernel_size', 3)
        close_iterations = params.get('close_iterations', 1)
        open_iterations = params.get('open_iterations', 0)
        
        # Close first to connect nearby lines
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        processed = cv2.morphologyEx(mask_cv, cv2.MORPH_CLOSE, kernel, iterations=close_iterations)
        
        # Optional: open to remove noise if needed
        if open_iterations > 0:
            processed = cv2.morphologyEx(processed, cv2.MORPH_OPEN, kernel, iterations=open_iterations)
            
    elif method == 'thinning':
        # Thin lines to ensure consistency
        # First dilate to ensure connectivity
        kernel_dilate = np.ones((3, 3), np.uint8)
        dilated = cv2.dilate(mask_cv, kernel_dilate, iterations=1)
        
        # Zhang-Suen skeletonization algorithm
        # Simplified implementation - for a full version use skimage.morphology.skeletonize
        thinned = dilated.copy()
        
        # Iterate until no more changes
        prev = np.zeros_like(thinned)
        while not np.array_equal(thinned, prev):
            prev = thinned.copy()
            # Apply controlled erosion
            kernel = np.ones((3, 3), np.uint8)
            eroded = cv2.erode(thinned, kernel, iterations=1)
            # Recover critical lines
            thinned = cv2.bitwise_and(thinned, cv2.bitwise_not(cv2.subtract(thinned, eroded)))
        
        # Ensure visibility by slightly dilating the result
        processed = cv2.dilate(thinned, np.ones((2, 2), np.uint8), iterations=1)
        
    elif method == 'adaptive':
        # Adaptive method that combines different techniques
        # First analyze mask characteristics
        non_zero = np.count_nonzero(mask_cv)
        total_pixels = mask_cv.size
        density = non_zero / total_pixels
        
        # Apply technique according to density
        if density < 0.01:  # Very few lines -> more dilation
            kernel_size = 5
            iterations = 2
        elif density < 0.05:  # Few lines -> moderate dilation
            kernel_size = 3
            iterations = 1
        else:  # Higher density -> gentle dilation
            kernel_size = 2
            iterations = 1
        
        # Apply adaptive dilation
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        processed = cv2.dilate(mask_cv, kernel, iterations=iterations)
        
        # Smooth edges
        processed = cv2.GaussianBlur(processed, (3, 3), 0.5)
        
        # Re-binarize to maintain sharpness
        _, processed = cv2.threshold(processed, 127, 255, cv2.THRESH_BINARY)
        
    else:
        # If no valid method is specified, return the original mask
        processed = mask_cv
    
    # Normalize according to input format
    if mask.max() <= 1.0:
        return processed / 255.0
    else:
        return processed

def compare_methods(image_path, mask_path, methods=['original', 'dilation', 'morphological', 'adaptive']):
    """
    Compares different processing methods on a single mask.
    
    This function loads an image and mask, applies various enhancement methods,
    and visualizes the results side by side for comparison.
    
    Args:
        image_path (str): Path to the original image.
        mask_path (str): Path to the mask to process.
        methods (list, optional): List of methods to apply. 
            Defaults to ['original', 'dilation', 'morphological', 'adaptive'].
            
    Returns:
        dict: Dictionary containing the original and processed masks.
    """
    # Load image and mask
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Load mask with alpha channel
    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
    if mask is None:
        print(f"Could not load mask: {mask_path}")
        return
    
    # Extract alpha channel if it exists
    if mask.shape[-1] == 4:  # RGBA
        alpha_channel = mask[:, :, 3]
        # Binarize alpha channel
        _, binary_mask = cv2.threshold(alpha_channel, 127, 255, cv2.THRESH_BINARY)
        mask = binary_mask
    else:  # No alpha channel, convert to grayscale
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    
    # Normalize mask to 0-1
    mask = mask / 255.0
    
    # Apply different methods
    results = {'original': mask}
    
    for method in methods:
        if method == 'original':
            continue  # Already have it
            
        # Process with corresponding method
        processed = process_thin_lines(mask, method=method)
        results[method] = processed
    
    # Visualize results
    n_methods = len(methods)
    fig, axes = plt.subplots(1, n_methods + 1, figsize=(4*(n_methods + 1), 4))
    
    # Show original image
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Show each method
    for i, method in enumerate(methods):
        axes[i+1].imshow(results[method], cmap='gray', vmin=0, vmax=1)
        axes[i+1].set_title(f'Method: {method}')
        axes[i+1].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, 'method_comparison.png'))
    plt.show()
    
    return results

## Advanced Mask Processing Comparison

This code analyzes different morphological processing techniques on random validation samples to identify optimal methods for enhancing thin lines in clock drawings, especially for contours and clock hands.

In [None]:
# Test different methods on some random samples
sample_indices = random.sample(range(len(valid_dataset)), min(3, len(valid_dataset)))

for i, idx in enumerate(sample_indices):
    print(f"\nAnalyzing sample {i+1}/{len(sample_indices)} (index {idx})")
    
    # Sample metadata
    img_path = valid_dataset.img_data[idx]['img_path']
    img_name = os.path.basename(img_path).split('_Background')[0]
    
    print(f"Image: {img_name}")
    
    # Process contours if available
    if 'contour' in valid_dataset.img_data[idx]['masks']:
        contour_mask_path = valid_dataset.img_data[idx]['masks']['contour']
        print("\nComparing methods for contours:")
        compare_methods(img_path, contour_mask_path, 
                       methods=['original', 'dilation', 'morphological', 'adaptive'])
        
        # Deeper analysis for the first example
        if i == 0:
            compare_contours(valid_dataset, idx)
    
    # Process clock hands if available
    if 'hands' in valid_dataset.img_data[idx]['masks']:
        hands_mask_path = valid_dataset.img_data[idx]['masks']['hands']
        print("\nComparing methods for clock hands:")
        compare_methods(img_path, hands_mask_path, 
                       methods=['original', 'dilation', 'morphological', 'adaptive'])
                       
    # Visualize complete comparison
    print("\nVisualizing complete comparison of all masks:")
    visualize_mask_enhancements(valid_dataset, idx, show_preprocessing=True)

print("\nRecommendation summary:")
print("- For contours: 'adaptive' method with kernel_size=3")
print("- For clock hands: 'morphological' method with kernel_size=2, close_iterations=1")
print("- For both cases, always use INTER_NEAREST interpolation in transformations")

## Loss Functions and Evaluation Metrics

This code defines custom loss functions and evaluation metrics optimized for multi-class segmentation tasks, particularly focused on handling thin structures in clock drawings.

In [None]:
class DiceLoss(nn.Module):
    """
    Implementation of Dice Loss for segmentation with class weighting options.
    
    The Dice coefficient measures the overlap between predicted and ground truth masks.
    This loss is particularly effective for imbalanced segmentation problems.
    
    Args:
        smooth (float, optional): Small constant to prevent division by zero. Defaults to 1e-6.
        weight (torch.Tensor, optional): Class weights tensor. Defaults to None.
    """
    def __init__(self, smooth=1e-6, weight=None):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.weight = weight  # Class weights
        
    def forward(self, inputs, targets):
        # Flatten tensors for per-class calculation
        batch_size = targets.size(0)
        num_classes = targets.size(1)
        
        # If weights provided use them, otherwise all weights are 1
        weight = self.weight if self.weight is not None else torch.ones(num_classes).to(inputs.device)
        
        # Calculate Dice for each class
        dice_scores = []
        for class_idx in range(num_classes):
            inputs_class = inputs[:, class_idx, ...]
            targets_class = targets[:, class_idx, ...]
            
            # Flatten
            inputs_flat = inputs_class.view(batch_size, -1)
            targets_flat = targets_class.view(batch_size, -1)
            
            # Calculate intersection and union
            intersection = (inputs_flat * targets_flat).sum(1)
            union = inputs_flat.sum(1) + targets_flat.sum(1)
            
            # Calculate Dice per instance and average over batch
            dice_score = (2. * intersection + self.smooth) / (union + self.smooth)
            weighted_dice = weight[class_idx] * (1 - dice_score.mean())
            dice_scores.append(weighted_dice)
        
        return torch.stack(dice_scores).mean()

class FocalLoss(nn.Module):
    """
    Implementation of Focal Loss for segmentation with weighting support.
    
    Focal loss addresses class imbalance by down-weighting well-classified examples,
    making the model focus more on difficult cases.
    
    Args:
        alpha (float, optional): Balancing parameter. Defaults to 1.
        gamma (float, optional): Focusing parameter. Defaults to 2.
        reduction (str, optional): Reduction method ('mean', 'sum', 'none'). Defaults to 'mean'.
        weight (torch.Tensor, optional): Class weights tensor. Defaults to None.
    """
    def __init__(self, alpha=1, gamma=2, reduction='mean', weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.weight = weight  # Class weights
        
    def forward(self, inputs, targets):
        # Calculate BCE
        bce_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        
        # Focal term
        pt = torch.exp(-bce_loss)
        focal_term = self.alpha * (1 - pt) ** self.gamma
        
        # Apply class weighting if provided
        if self.weight is not None:
            # Expand weights to multiply by each pixel of each class
            weight = self.weight.view(1, -1, 1, 1).expand_as(targets)
            focal_loss = weight * focal_term * bce_loss
        else:
            focal_loss = focal_term * bce_loss
        
        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:  # 'none'
            return focal_loss

class CombinedLoss(nn.Module):
    """
    Combines Dice Loss and Focal Loss with configurable weighting.
    
    This loss function leverages the strengths of both Dice loss (for shape and overlap)
    and Focal loss (for handling class imbalance).
    
    Args:
        alpha (float, optional): Alpha parameter for Focal Loss. Defaults to 1.0.
        gamma (float, optional): Gamma parameter for Focal Loss. Defaults to 2.0.
        dice_weight (float, optional): Weight for Dice Loss. Defaults to 1.0.
        focal_weight (float, optional): Weight for Focal Loss. Defaults to 1.0.
        class_weights (list or torch.Tensor, optional): Per-class weights. Defaults to None.
    """
    def __init__(self, alpha=1.0, gamma=2.0, dice_weight=1.0, focal_weight=1.0, class_weights=None):
        super(CombinedLoss, self).__init__()
        # Convert weights to tensor if it's a list
        if class_weights is not None and not isinstance(class_weights, torch.Tensor):
            class_weights = torch.tensor(class_weights)
            
        self.dice = DiceLoss(weight=class_weights)
        self.focal = FocalLoss(alpha=alpha, gamma=gamma, weight=class_weights)
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        
    def forward(self, inputs, targets):
        dice_loss = self.dice(inputs, targets)
        focal_loss = self.focal(inputs, targets)
        
        # Combine both losses
        return self.dice_weight * dice_loss + self.focal_weight * focal_loss

def iou_score(outputs, targets, threshold=0.5, smooth=1e-5):
    """
    Calculates IoU (Intersection over Union) for segmentation evaluation.
    
    IoU is a standard metric for evaluating the quality of segmentation masks.
    
    Args:
        outputs (torch.Tensor): Model predictions.
        targets (torch.Tensor): Ground truth masks.
        threshold (float, optional): Threshold for binary prediction. Defaults to 0.5.
        smooth (float, optional): Small constant to prevent division by zero. Defaults to 1e-5.
        
    Returns:
        tuple: (mean_iou, class_iou) containing the mean IoU and per-class IoU values.
    """
    # Convert to binary with threshold
    outputs_bin = (outputs > threshold).float()
    
    # Calculate intersection and union for each batch and class
    intersection = (outputs_bin * targets).sum(dim=(2, 3))
    union = outputs_bin.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) - intersection
    
    # Avoid division by zero
    iou = (intersection + smooth) / (union + smooth)
    
    # IoU per class
    class_iou = iou.mean(dim=0)
    
    # Global mean IoU
    mean_iou = class_iou.mean()
    
    return mean_iou, class_iou

## Model Architecture and Initialization

This code defines the neural network architecture selection and initialization strategy, focusing on proper weight initialization for improved training convergence.

In [None]:
def init_weights(m):
    """
    Initializes model weights appropriately to improve convergence.
    
    This function applies Xavier normal initialization to convolutional and linear layers
    and constant initialization to batch normalization layers.
    
    Args:
        m (nn.Module): Module to initialize.
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

def build_model():
    """
    Builds the segmentation model with improved initialization.
    
    This function creates a segmentation model based on the configuration settings,
    with properly initialized weights for better training convergence.
    
    Returns:
        nn.Module: The constructed segmentation model.
    """
    # Test with different architectures
    model_type = "Unet++"  # Alternatives: "Unet", "FPN", "PSPNet", "LinkNet"
    
    if model_type == "Unet":
        model = smp.Unet(
            encoder_name=Config.ENCODER,
            encoder_weights=Config.ENCODER_WEIGHTS,
            classes=Config.NUM_CLASSES,
            activation='sigmoid'
        )
    elif model_type == "Unet++":
        model = smp.UnetPlusPlus(
            encoder_name=Config.ENCODER,
            encoder_weights=Config.ENCODER_WEIGHTS,
            classes=Config.NUM_CLASSES,
            activation='sigmoid'
        )
    elif model_type == "Linknet":
        model = smp.Linknet(
            encoder_name=Config.ENCODER,
            encoder_weights=Config.ENCODER_WEIGHTS,
            classes=Config.NUM_CLASSES,
            activation='sigmoid'
        )
    else:
        raise ValueError(f"Model type not supported: {model_type}")
    
    # Initialize decoder weights (encoder is already pre-trained)
    # This is crucial to prevent the model from getting stuck predicting only zeros or ones
    # Only apply to layers that are not part of the encoder
    for name, module in model.named_children():
        if name != 'encoder':
            module.apply(init_weights)
    
    print(f"Model built: {model_type} with encoder {Config.ENCODER}")
    return model

# Build the model
model = build_model()
model.to(Config.DEVICE)

# Print model summary
print(f"Total number of parameters: {sum(p.numel() for p in model.parameters())}")
print(f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

## Batch Visualization for Model Evaluation

This code implements a comprehensive visualization function to monitor model predictions during training, showing both combined class visualizations and individual class visualizations for detailed analysis.

In [None]:
def visualize_batch(model, data_loader, device, epoch):
    """
    Visualizes predictions on a batch for analysis.
    
    This function creates two visualization plots:
    1. A comparison of original images, ground truth masks, and model predictions
    2. A detailed view of each segmentation class for both ground truth and predictions
    
    Args:
        model (nn.Module): The trained segmentation model.
        data_loader (DataLoader): DataLoader containing validation/test data.
        device (torch.device): Device to run the model on.
        epoch (int): Current epoch number for labeling the saved visualizations.
    """
    model.eval()
    images, masks = next(iter(data_loader))
    batch_size = min(3, images.size(0))
    
    with torch.no_grad():
        outputs = model(images[:batch_size].to(device))
    
    # Denormalization
    mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
    
    # Create figure
    fig, axes = plt.subplots(batch_size, 3, figsize=(15, 5*batch_size))
    
    for i in range(batch_size):
        # Original image
        img = images[i].cpu().detach()
        
        # Check tensor shape
        if img.dim() == 4:  # If it has shape [B, C, H, W]
            img = img.squeeze(0)  # Remove batch dimension
        
        # Continue with denormalization
        img = img * std.squeeze(0) + mean.squeeze(0)
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        
        # Ground truth mask (combine channels for visualization)
        mask = masks[i].cpu().numpy()
        mask_combined = np.zeros((mask.shape[1], mask.shape[2], 3))
        for c in range(min(3, mask.shape[0])):  # RGB channels
            mask_combined[:, :, c] = mask[c]
        
        # Prediction
        pred = outputs[i].cpu().numpy()
        pred_combined = np.zeros((pred.shape[1], pred.shape[2], 3))
        for c in range(min(3, pred.shape[0])):  # RGB channels
            pred_combined[:, :, c] = pred[c]
        
        # If there's only one image, axes is not a 2D array
        if batch_size == 1:
            axes[0].imshow(img)
            axes[0].set_title(f'Original Image')
            axes[0].axis('off')
            
            axes[1].imshow(mask_combined)
            axes[1].set_title(f'Ground Truth Mask')
            axes[1].axis('off')
            
            axes[2].imshow(pred_combined)
            axes[2].set_title(f'Prediction')
            axes[2].axis('off')
        else:
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f'Original Image')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(mask_combined)
            axes[i, 1].set_title(f'Ground Truth Mask')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred_combined)
            axes[i, 2].set_title(f'Prediction')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, f'viz_epoch_{epoch+1}.png'))
    plt.close()
    
    # Additionally, visualize each class separately for the first example
    fig, axes = plt.subplots(2, Config.NUM_CLASSES, figsize=(4*Config.NUM_CLASSES, 8))
    
    for c in range(Config.NUM_CLASSES):
        # Ground truth mask for class c
        axes[0, c].imshow(masks[0][c].cpu().numpy(), cmap='viridis')
        axes[0, c].set_title(f'GT: {Config.CLASSES[c]}')
        axes[0, c].axis('off')
        
        # Prediction for class c
        axes[1, c].imshow(outputs[0][c].cpu().numpy(), cmap='viridis')
        axes[1, c].set_title(f'Pred: {Config.CLASSES[c]}')
        axes[1, c].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, f'viz_classes_epoch_{epoch+1}.png'))
    plt.close()

# Test the function with some data
with torch.no_grad():
    visualize_batch(model, valid_loader, Config.DEVICE, 0)

## Training Loop Implementation

This code implements a robust training procedure with advanced features like class weighting, gradient clipping, and comprehensive monitoring of training and validation metrics.

In [None]:
def train_model():
    """
    Trains the segmentation model with optimized parameters.
    
    This function implements a complete training pipeline including:
    - Class weighting for handling imbalanced data
    - Adaptive learning rate scheduling
    - Gradient clipping
    - Early stopping
    - Periodic model checkpointing
    - Comprehensive metrics tracking
    - Visualization of predictions during training
    
    Returns:
        tuple: (model, history) - Trained model and training history dictionary
    """
    # Define class weights to give more importance to thin lines
    class_weights = torch.tensor([1.0, 1.0, 2.0, 2.0]).to(Config.DEVICE)  # Higher weight for 'hands' and 'contour'
    
    # Define optimizer, loss function and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)  # Reduced learning rate
    criterion = CombinedLoss(
        alpha=1.0, gamma=2.0,
        dice_weight=1.0, focal_weight=1.0,
        class_weights=class_weights
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    # For storing losses and metrics
    history = {
        'train_loss': [],
        'valid_loss': [],
        'valid_iou': [],
        'class_iou': {class_name: [] for class_name in Config.CLASSES},
        'learning_rate': []
    }
    
    best_valid_loss = float('inf')
    counter = 0  # For early stopping
    
    # Training loop
    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
        
        # ===== TRAINING =====
        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc="Training")
        
        for images, masks in progress_bar:
            images = images.to(Config.DEVICE)
            masks = masks.to(Config.DEVICE)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            
            # Clip gradients to prevent explosion
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Update loss and progress bar
            train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
        
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Save current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        history['learning_rate'].append(current_lr)
        
        # ===== VALIDATION =====
        if (epoch + 1) % Config.VAL_FREQUENCY == 0:
            model.eval()
            valid_loss = 0
            all_ious = []
            all_class_ious = {class_name: [] for class_name in Config.CLASSES}
            
            progress_bar = tqdm(valid_loader, desc="Validation")
            with torch.no_grad():
                for images, masks in progress_bar:
                    images = images.to(Config.DEVICE)
                    masks = masks.to(Config.DEVICE)
                    
                    # Forward pass
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    valid_loss += loss.item()
                    
                    # Calculate IoU
                    mean_iou, class_ious = iou_score(outputs, masks)
                    all_ious.append(mean_iou.item())
                    
                    # Store IoU per class
                    for i, class_name in enumerate(Config.CLASSES):
                        all_class_ious[class_name].append(class_ious[i].item())
            
            # Average metrics
            valid_loss /= len(valid_loader)
            history['valid_loss'].append(valid_loss)
            
            valid_iou = np.mean(all_ious)
            history['valid_iou'].append(valid_iou)
            
            for class_name in Config.CLASSES:
                class_iou = np.mean(all_class_ious[class_name])
                history['class_iou'][class_name].append(class_iou)
            
            # Update scheduler
            scheduler.step(valid_loss)
            
            # Print metrics
            print(f"Epoch {epoch+1}")
            print(f"Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")
            print(f"Valid IoU: {valid_iou:.4f}")
            for class_name in Config.CLASSES:
                class_iou = history['class_iou'][class_name][-1]
                print(f"  - IoU {class_name}: {class_iou:.4f}")
            
            # Visualize predictions for analysis
            if (epoch + 1) % Config.VISUALIZATION_FREQUENCY == 0:
                visualize_batch(model, valid_loader, Config.DEVICE, epoch)
            
            # Check if it's the best model
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                counter = 0
                torch.save(model.state_dict(), 
                           os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth'))
                print(f"New best model saved! (Loss: {valid_loss:.4f})")
            else:
                counter += 1
                if counter >= Config.PATIENCE:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        
        # Save checkpoint periodically
        if (epoch + 1) % Config.SAVE_FREQUENCY == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'valid_loss': valid_loss if 'valid_loss' in locals() else None,
                'valid_iou': valid_iou if 'valid_iou' in locals() else None,
                'history': history
            }, os.path.join(Config.CHECKPOINT_PATH, f'checkpoint_epoch_{epoch+1}.pth'))
    
    # Save final model
    torch.save(model.state_dict(), 
               os.path.join(Config.CHECKPOINT_PATH, 'final_model.pth'))
    
    # Plot training history
    plt.figure(figsize=(16, 12))
    
    # Loss plot
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['valid_loss'], label='Valid Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Evolution')
    
    # Global IoU plot
    plt.subplot(2, 2, 2)
    plt.plot(history['valid_iou'], label='Valid IoU')
    plt.xlabel('Epochs')
    plt.ylabel('IoU Score')
    plt.legend()
    plt.title('IoU Evolution')
    
    # Class IoU plot
    plt.subplot(2, 2, 3)
    for class_name in Config.CLASSES:
        plt.plot(history['class_iou'][class_name], 
                 label=f'IoU {class_name}')
    plt.xlabel('Epochs')
    plt.ylabel('IoU Score')
    plt.legend()
    plt.title('Class IoU Evolution')
    
    # Learning rate plot
    plt.subplot(2, 2, 4)
    plt.plot(history['learning_rate'])
    plt.xlabel('Epochs')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Evolution')
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, 'training_history.png'))
    plt.show()
    
    return model, history

## Model Training

This code executes the training procedure defined in the "train_model()" function, initiating the segmentation model training process and storing both the trained model and its training history.

In [None]:
# Train the model
model, history = train_model()

## Model Evaluation

This code defines and executes an evaluation function for the trained segmentation model, measuring its performance on a dataset using multiple metrics.

In [None]:
def evaluate_model(model_path, data_loader, device):
    """
    Evaluates the model on a dataset.
    
    This function loads a trained model and evaluates its performance on the provided
    dataset using loss and IoU metrics, both overall and per-class.
    
    Args:
        model_path (str): Path to the saved model weights.
        data_loader (DataLoader): DataLoader containing the evaluation dataset.
        device (torch.device): Device to run the model on.
        
    Returns:
        tuple: (eval_loss, valid_iou, class_iou_values) - Loss, average IoU, and per-class IoU values.
    """
    # Load the model
    model = build_model()
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    
    # Define criterion and class weights
    class_weights = torch.tensor([1.0, 1.0, 2.0, 2.0]).to(Config.DEVICE)
    criterion = CombinedLoss(
        alpha=1.0, gamma=2.0,
        dice_weight=1.0, focal_weight=1.0,
        class_weights=class_weights
    )
    
    # Metrics
    eval_loss = 0
    all_ious = []
    all_class_ious = {class_name: [] for class_name in Config.CLASSES}
    
    # Evaluation
    with torch.no_grad():
        for images, masks in tqdm(data_loader, desc="Evaluating"):
            images = images.to(device)
            masks = masks.to(device)
            
            # Prediction
            outputs = model(images)
            loss = criterion(outputs, masks)
            eval_loss += loss.item()
            
            # Calculate IoU
            mean_iou, class_ious = iou_score(outputs, masks)
            all_ious.append(mean_iou.item())
            
            # Store IoU per class
            for i, class_name in enumerate(Config.CLASSES):
                all_class_ious[class_name].append(class_ious[i].item())
    
    # Average metrics
    eval_loss /= len(data_loader)
    valid_iou = np.mean(all_ious)
    
    class_iou_values = {}
    for class_name in Config.CLASSES:
        class_iou_values[class_name] = np.mean(all_class_ious[class_name])
    
    # Print results
    print(f"Evaluation:")
    print(f"Loss: {eval_loss:.4f}")
    print(f"Average IoU: {valid_iou:.4f}")
    for class_name, iou in class_iou_values.items():
        print(f"  - IoU {class_name}: {iou:.4f}")
    
    return eval_loss, valid_iou, class_iou_values

# Evaluate the best model on the validation set
best_model_path = os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth')

if os.path.exists(best_model_path):
    print("Evaluating the best model on the validation set...")
    eval_loss, avg_iou, class_ious = evaluate_model(best_model_path, valid_loader, Config.DEVICE)
else:
    print("Best model not found. Make sure you've trained the model first.")

## Prediction and Visualization Function

This code implements a comprehensive prediction and visualization system to analyze model performance using various thresholds, and displays results with color-coded segmentation masks overlaid on the original images.

In [None]:
def predict_and_visualize(model_path, image_path, threshold=0.5):
    """
    Makes a prediction and visualizes the results with different thresholds.
    
    This function loads a trained model, processes an input image, and creates
    two visualization plots:
    1. A comparison of prediction probabilities and binary masks at different thresholds
    2. Color-coded overlay visualizations of each segmentation class
    
    Args:
        model_path (str): Path to the saved model weights.
        image_path (str): Path to the input image.
        threshold (float, optional): Default threshold for binarization. Defaults to 0.5.
        
    Returns:
        tuple: (pred_np, pred_binary) - Raw prediction probabilities and binary masks.
    """
    # Load the model
    model = build_model()
    model.load_state_dict(torch.load(model_path))
    model.to(Config.DEVICE)
    model.eval()
    
    # Load the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Could not read image: {image_path}")
        return None
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Transform the image
    transform = A.Compose([
        A.Resize(Config.IMG_SIZE, Config.IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.Lambda(mask=lambda x, **kwargs: (x > 0.5).astype(np.float32)),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(Config.DEVICE)
    
    # Make prediction
    with torch.no_grad():
        pred = model(image_tensor)
    
    # Convert to numpy and test different thresholds
    pred_np = pred.squeeze().cpu().numpy()
    
    # Create a range of thresholds to test
    thresholds = [0.2, 0.3, 0.4, 0.5, 0.6]
    pred_binary_per_threshold = []
    
    for thresh in thresholds:
        pred_binary = (pred_np > thresh).astype(np.uint8)
        pred_binary_per_threshold.append(pred_binary)
    
    # Visualize results with different thresholds
    fig, axes = plt.subplots(len(thresholds) + 1, Config.NUM_CLASSES + 1, 
                           figsize=(5*(Config.NUM_CLASSES + 1), 5*(len(thresholds) + 1)))
    
    # Overall title
    fig.suptitle('Predictions with Different Thresholds', fontsize=20)
    
    # First row: Original image and probabilities
    axes[0, 0].imshow(image)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    for i, class_name in enumerate(Config.CLASSES):
        # Show probability map
        axes[0, i+1].imshow(pred_np[i], cmap='viridis')
        axes[0, i+1].set_title(f'Prob: {class_name}')
        axes[0, i+1].axis('off')
    
    # Following rows: Binary masks with different thresholds
    for t_idx, threshold in enumerate(thresholds):
        axes[t_idx + 1, 0].imshow(image)
        axes[t_idx + 1, 0].set_title(f'Threshold: {threshold}')
        axes[t_idx + 1, 0].axis('off')
        
        for i, class_name in enumerate(Config.CLASSES):
            # Show binary mask for this threshold
            axes[t_idx + 1, i+1].imshow(pred_binary_per_threshold[t_idx][i], cmap='gray')
            axes[t_idx + 1, i+1].set_title(f'{class_name} (threshold={threshold})')
            axes[t_idx + 1, i+1].axis('off')
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjustment for general title
    
    # Save visualization
    output_name = os.path.basename(image_path).split('.')[0]
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, f'pred_{output_name}_thresholds.png'))
    plt.show()
    
    # Also visualize overlay of best predictions (threshold 0.5)
    fig, axes = plt.subplots(1, Config.NUM_CLASSES + 1, figsize=(5*(Config.NUM_CLASSES + 1), 5))
    
    # Overlay on original image
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    pred_binary = pred_binary_per_threshold[thresholds.index(0.5)]
    
    # Colors for each class
    colors = ['red', 'green', 'blue', 'yellow']
    
    for i, class_name in enumerate(Config.CLASSES):
        # Create a copy of the image
        overlay = image.copy()
        
        # Binary mask
        mask = pred_binary[i]
        
        # Create a color map for the mask
        colored_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
        
        if colors[i] == 'red':
            colored_mask[mask == 1] = [255, 0, 0, 128]  # RGBA semi-transparent red
        elif colors[i] == 'green':
            colored_mask[mask == 1] = [0, 255, 0, 128]  # RGBA semi-transparent green
        elif colors[i] == 'blue':
            colored_mask[mask == 1] = [0, 0, 255, 128]  # RGBA semi-transparent blue
        elif colors[i] == 'yellow':
            colored_mask[mask == 1] = [255, 255, 0, 128]  # RGBA semi-transparent yellow
        
        # Overlay the colored mask
        axes[i+1].imshow(image)
        axes[i+1].imshow(colored_mask, alpha=0.5)
        axes[i+1].set_title(f'Overlay: {class_name}')
        axes[i+1].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, f'pred_{output_name}_overlay.png'))
    plt.show()
    
    return pred_np, pred_binary

# Test prediction on some test images if they exist
test_images = glob(os.path.join(Config.TEST_PATH, '*/*_Background.*'))
if test_images:
    num_samples = min(3, len(test_images))
    sample_images = random.sample(test_images, num_samples)
    
    for img_path in sample_images:
        print(f"\nPrediction for: {os.path.basename(img_path)}")
        pred_np, pred_binary = predict_and_visualize(
            os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth'),
            img_path
        )

## Prediction Export for Clinical Use

This code implements a function to save model predictions as transparent PNG masks for clinical use, with special processing for thin structures like contours and clock hands.

In [None]:
def save_predictions(model_path, output_dir, threshold=0.5, test_dir=None):
    """
    Saves model predictions as transparent PNG masks.
    
    This function processes test images through the model and saves each class
    prediction as a separate PNG file with transparency, optimized for clinical use
    and further analysis.
    
    Args:
        model_path (str): Path to the saved model weights.
        output_dir (str): Directory to save prediction masks.
        threshold (float, optional): Threshold for binarization. Defaults to 0.5.
        test_dir (str, optional): Directory containing test images. Defaults to None.
    """
    # If test directory is not specified, use the configured one
    if test_dir is None:
        test_dir = Config.TEST_PATH
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load the model
    model = build_model()
    model.load_state_dict(torch.load(model_path))
    model.to(Config.DEVICE)
    model.eval()
    
    # Find test images
    test_images = glob(os.path.join(test_dir, '*/*_Background.*'))
    
    if not test_images:
        print("No test images to process.")
        return
    
    # Transformation for model input
    transform = A.Compose([
        A.Resize(Config.IMG_SIZE, Config.IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.Lambda(mask=lambda x, **kwargs: (x > 0.5).astype(np.float32)),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    # Process each image
    for img_path in tqdm(test_images, desc="Saving predictions"):
        # Base name for predictions
        base_name = os.path.basename(img_path).split('_')[0]
        
        # Create directory for this set of predictions
        image_output_dir = os.path.join(output_dir, base_name)
        os.makedirs(image_output_dir, exist_ok=True)
        
        # Load the image
        image = cv2.imread(img_path)
        if image is None:
            print(f"Could not read image: {img_path}")
            continue
        
        # Save the original image
        cv2.imwrite(os.path.join(image_output_dir, f"{base_name}_background.png"), image)
        
        # Preprocess for the model
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        transformed = transform(image=image_rgb)
        image_tensor = transformed['image'].unsqueeze(0).to(Config.DEVICE)
        
        # Make prediction
        with torch.no_grad():
            pred = model(image_tensor)
        
        # Process the prediction
        pred_np = pred.squeeze().cpu().numpy()
        
        # Resize to original size
        h, w = image.shape[:2]
        
        # Save each class as a transparent RGBA image
        for i, class_name in enumerate(Config.CLASSES):
            # Resize and apply threshold
            pred_class = cv2.resize(pred_np[i], (w, h), interpolation=cv2.INTER_NEAREST)
            pred_binary = (pred_class > threshold).astype(np.uint8) * 255
            
            # Apply enhancement for thin lines if it's a class with thin lines
            if class_name in ['contour', 'hands']:
                # Use morphological operations to enhance lines
                if class_name == 'contour':
                    kernel_size = 3
                else:  # hands
                    kernel_size = 2
                    
                kernel = np.ones((kernel_size, kernel_size), np.uint8)
                pred_binary = cv2.morphologyEx(pred_binary, cv2.MORPH_CLOSE, kernel)
            
            # Create an RGBA image with transparency
            # Use white for color (255, 255, 255) and binary value for alpha
            rgba = np.zeros((h, w, 4), dtype=np.uint8)
            rgba[..., 0:3] = 255  # White
            rgba[..., 3] = pred_binary  # Alpha channel
            
            # Save the mask as PNG with transparency
            output_path = os.path.join(image_output_dir, f"{base_name}_{class_name}.png")
            cv2.imwrite(output_path, rgba)
    
    print(f"Predictions saved to: {output_dir}")

# Save predictions
output_predictions_dir = os.path.join(Config.CHECKPOINT_PATH, 'predictions')
save_predictions(
    os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth'),
    output_predictions_dir
)

## Error Analysis and Post-Processing

This code implements detailed error analysis and post-processing techniques for segmentation predictions, providing visual feedback on model performance and enhancing the quality of predicted masks.

In [None]:
def analyze_prediction_errors(model_path, data_loader, device, num_samples=5):
    """
    Analyzes specific errors in predictions.
    
    This function evaluates model performance on random samples by calculating false 
    positives, false negatives, and IoU scores, with visualizations to identify error patterns.
    
    Args:
        model_path (str): Path to the saved model weights.
        data_loader (DataLoader): DataLoader containing evaluation data.
        device (torch.device): Device to run the model on.
        num_samples (int, optional): Number of random samples to analyze. Defaults to 5.
    """
    # Load the model
    model = build_model()
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    
    # Get some random samples
    indices = random.sample(range(len(data_loader.dataset)), min(num_samples, len(data_loader.dataset)))
    
    # For each sample, calculate specific errors
    with torch.no_grad():
        for idx in indices:
            # Get image and mask
            image, mask = data_loader.dataset[idx]
            image_tensor = image.unsqueeze(0).to(device)
            mask_tensor = mask.unsqueeze(0).to(device)
            
            # Make prediction
            pred = model(image_tensor)
            
            # Calculate errors for each class
            for c, class_name in enumerate(Config.CLASSES):
                gt_mask = mask_tensor[0, c].cpu().numpy()
                pred_mask = (pred[0, c] > 0.5).cpu().numpy()
                
                # Calculate metrics
                intersection = np.logical_and(gt_mask, pred_mask).sum()
                union = np.logical_or(gt_mask, pred_mask).sum()
                iou = intersection / union if union > 0 else 0
                
                # False positives and negatives
                false_positives = np.logical_and(pred_mask, np.logical_not(gt_mask)).sum()
                false_negatives = np.logical_and(np.logical_not(pred_mask), gt_mask).sum()
                
                print(f"Sample {idx}, Class {class_name}:")
                print(f"  IoU: {iou:.4f}")
                print(f"  False positives: {false_positives} pixels")
                print(f"  False negatives: {false_negatives} pixels")
                
                # Visualize comparison
                fig, axes = plt.subplots(1, 4, figsize=(20, 5))
                
                # Denormalize image
                img_np = image.permute(1, 2, 0).cpu().numpy()
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img_np = std * img_np + mean
                img_np = np.clip(img_np, 0, 1)
                
                # Visualize original image
                axes[0].imshow(img_np)
                axes[0].set_title('Original Image')
                axes[0].axis('off')
                
                # Ground truth mask
                axes[1].imshow(gt_mask, cmap='gray')
                axes[1].set_title(f'Ground Truth: {class_name}')
                axes[1].axis('off')
                
                # Prediction
                axes[2].imshow(pred_mask, cmap='gray')
                axes[2].set_title(f'Prediction: {class_name}')
                axes[2].axis('off')
                
                # Errors: Green = false positives, Red = false negatives
                error_map = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
                error_map[np.logical_and(pred_mask, np.logical_not(gt_mask))] = [0, 1, 0]  # Green for FP
                error_map[np.logical_and(np.logical_not(pred_mask), gt_mask)] = [1, 0, 0]  # Red for FN
                
                axes[3].imshow(error_map)
                axes[3].set_title('Errors (Green: FP, Red: FN)')
                axes[3].axis('off')
                
                plt.tight_layout()
                plt.show()

# Analyze errors in some samples
best_model_path = os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth')
if os.path.exists(best_model_path):
    print("Analyzing prediction errors...")
    analyze_prediction_errors(best_model_path, valid_loader, Config.DEVICE, num_samples=3)
else:
    print("Best model not found. Make sure you've trained the model first.")

def apply_post_processing(mask, class_name):
    """
    Applies additional post-processing to predicted masks.
    
    This function uses class-specific morphological operations and filtering
    to enhance prediction quality and remove artifacts.
    
    Args:
        mask (numpy.ndarray): Input binary mask (values 0-1 or 0-255).
        class_name (str): Class name to determine appropriate processing.
        
    Returns:
        numpy.ndarray: Processed mask with the same value range as the input.
    """
    # Convert mask to uint8 format for OpenCV
    mask_cv = (mask * 255).astype(np.uint8)
    
    if class_name == 'contour':
        # For contours, apply morphological closing
        kernel = np.ones((3, 3), np.uint8)
        processed = cv2.morphologyEx(mask_cv, cv2.MORPH_CLOSE, kernel, iterations=1)
        
        # Remove small components (noise)
        num_labels, labels = cv2.connectedComponents(processed)
        min_size = 20  # Threshold to consider noise
        
        # Keep only components larger than threshold
        for label in range(1, num_labels):
            component_mask = (labels == label).astype(np.uint8)
            if np.sum(component_mask) < min_size:
                processed[component_mask == 1] = 0
                
    elif class_name == 'hands':
        # For hands, emphasize thin lines
        kernel_line = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 5))
        processed = cv2.morphologyEx(mask_cv, cv2.MORPH_CLOSE, kernel_line)
        
        # Also apply median filter to smooth
        processed = cv2.medianBlur(processed, 3)
        
    elif class_name == 'numbers':
        # For numbers, preserve more details
        kernel = np.ones((2, 2), np.uint8)
        processed = cv2.morphologyEx(mask_cv, cv2.MORPH_OPEN, kernel, iterations=1)
        
    else:  # 'entire'
        # For the entire figure, use opening to remove noise
        kernel = np.ones((3, 3), np.uint8)
        processed = cv2.morphologyEx(mask_cv, cv2.MORPH_OPEN, kernel, iterations=1)
    
    # Re-binarize to ensure values 0 and 255
    _, processed = cv2.threshold(processed, 127, 255, cv2.THRESH_BINARY)
    
    # Normalize to 0-1 range if input was in that range
    if mask.max() <= 1.0:
        return processed / 255.0
    else:
        return processed

def show_post_processing_comparison(model_path, image_path):
    """
    Shows a comparison between original and post-processed predictions.
    
    This function visualizes the effect of class-specific post-processing
    techniques on model predictions.
    
    Args:
        model_path (str): Path to the saved model weights.
        image_path (str): Path to the input image.
        
    Returns:
        tuple: (pred_binary, processed_masks) - Original binary predictions and
            post-processed masks.
    """
    # Load the model
    model = build_model()
    model.load_state_dict(torch.load(model_path))
    model.to(Config.DEVICE)
    model.eval()
    
    # Load the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Could not read image: {image_path}")
        return None
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Transform the image
    transform = A.Compose([
        A.Resize(Config.IMG_SIZE, Config.IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.Lambda(mask=lambda x, **kwargs: (x > 0.5).astype(np.float32)),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(Config.DEVICE)
    
    # Make prediction
    with torch.no_grad():
        pred = model(image_tensor)
    
    # Convert to numpy
    pred_np = pred.squeeze().cpu().numpy()
    
    # Apply threshold for binary predictions
    pred_binary = (pred_np > 0.5).astype(np.float32)
    
    # Apply post-processing
    processed_masks = []
    for i, class_name in enumerate(Config.CLASSES):
        processed = apply_post_processing(pred_binary[i], class_name)
        processed_masks.append(processed)
    
    # Visualize comparison
    fig, axes = plt.subplots(2, Config.NUM_CLASSES, figsize=(20, 10))
    
    for i, class_name in enumerate(Config.CLASSES):
        # Original prediction
        axes[0, i].imshow(pred_binary[i], cmap='gray')
        axes[0, i].set_title(f'Original: {class_name}')
        axes[0, i].axis('off')
        
        # Post-processed prediction
        axes[1, i].imshow(processed_masks[i], cmap='gray')
        axes[1, i].set_title(f'Post-processed: {class_name}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    
    # Save visualization
    output_name = os.path.basename(image_path).split('.')[0]
    plt.savefig(os.path.join(Config.CHECKPOINT_PATH, f'post_processing_{output_name}.png'))
    plt.show()
    
    return pred_binary, processed_masks

# Test post-processing on some images
test_images = glob(os.path.join(Config.TEST_PATH, '*/*_Background.*'))
if test_images:
    num_samples = min(2, len(test_images))
    sample_images = random.sample(test_images, num_samples)
    
    for img_path in sample_images:
        print(f"\nPost-processing for: {os.path.basename(img_path)}")
        pred_binary, processed_masks = show_post_processing_comparison(
            os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth'),
            img_path
        )

## Model Report Generation

This code creates a comprehensive report documenting the segmentation model's architecture, training parameters, performance metrics, and preprocessing techniques, providing valuable documentation for the project.

In [None]:
def generate_model_report(model_path, output_path=None, training_duration=None):
    """
    Generates a text report with all relevant information about the model.
    
    This function creates a comprehensive document containing details about the model
    architecture, training process, performance metrics, and preprocessing techniques.
    
    Args:
        model_path (str): Path to the trained model.
        output_path (str, optional): Path to save the report. 
            Defaults to CHECKPOINT_PATH/model_report.txt.
        training_duration (float, optional): Duration of training in seconds, if known.
            
    Returns:
        str: Content of the generated report.
    """
    if output_path is None:
        output_path = os.path.join(Config.CHECKPOINT_PATH, 'model_report.txt')
    
    # Check if checkpoint exists to load additional data
    checkpoint_exists = False
    checkpoint_data = {}
    
    # Find the latest checkpoint
    checkpoint_files = glob(os.path.join(Config.CHECKPOINT_PATH, 'checkpoint_epoch_*.pth'))
    if checkpoint_files:
        latest_checkpoint = sorted(checkpoint_files)[-1]
        try:
            checkpoint_data = torch.load(latest_checkpoint, map_location='cpu')
            checkpoint_exists = True
        except:
            print(f"Could not load checkpoint {latest_checkpoint}")
    
    # Check if we can load training history
    history = {}
    if checkpoint_exists and 'history' in checkpoint_data:
        history = checkpoint_data['history']
    
    # Load the model to get information
    model_info = {}
    try:
        model = build_model()
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        model_type = model.__class__.__name__
        num_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        model_info = {
            'model_type': model_type,
            'num_params': num_params,
            'trainable_params': trainable_params
        }
    except:
        print(f"Could not load model {model_path} for analysis")
    
    # Generate final metrics if history exists
    final_metrics = {}
    if history:
        if 'train_loss' in history and history['train_loss']:
            final_metrics['final_train_loss'] = history['train_loss'][-1]
        if 'valid_loss' in history and history['valid_loss']:
            final_metrics['final_valid_loss'] = history['valid_loss'][-1]
        if 'valid_iou' in history and history['valid_iou']:
            final_metrics['final_valid_iou'] = history['valid_iou'][-1]
        if 'class_iou' in history:
            for class_name, iou_values in history['class_iou'].items():
                if iou_values:
                    final_metrics[f'final_iou_{class_name}'] = iou_values[-1]
    
    # Function to format duration
    def format_duration(seconds):
        if seconds is None:
            return "Not available"
        
        hours, remainder = divmod(seconds, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        if hours > 0:
            return f"{int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"
        elif minutes > 0:
            return f"{int(minutes)} minutes, {int(seconds)} seconds"
        else:
            return f"{int(seconds)} seconds"
    
    # Function to safely format numeric values
    def format_number(value, decimal_places=6):
        if isinstance(value, (int, float)):
            return f"{value:.{decimal_places}f}"
        return "Not available"
    
    # Write the report
    with open(output_path, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("CLOCK DRAWING SEGMENTATION MODEL REPORT\n")
        f.write("=" * 80 + "\n\n")
        
        # General information
        f.write("GENERAL INFORMATION\n")
        f.write("-" * 80 + "\n")
        f.write(f"Generation date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Model path: {model_path}\n\n")
        
        # Model architecture
        f.write("MODEL ARCHITECTURE\n")
        f.write("-" * 80 + "\n")
        f.write(f"Model type: {model_info.get('model_type', 'Not available')}\n")
        f.write(f"Encoder: {Config.ENCODER}\n")
        f.write(f"Pre-trained weights: {Config.ENCODER_WEIGHTS}\n")
        f.write(f"Number of classes: {Config.NUM_CLASSES}\n")
        f.write(f"Classes: {', '.join(Config.CLASSES)}\n")
        f.write(f"Total parameters: {model_info.get('num_params', 'Not available'):,}\n")
        f.write(f"Trainable parameters: {model_info.get('trainable_params', 'Not available'):,}\n\n")
        
        # Training parameters
        f.write("TRAINING PARAMETERS\n")
        f.write("-" * 80 + "\n")
        f.write(f"Epochs: {Config.EPOCHS}\n")
        
        # Get carefully the number of completed epochs
        epochs_completed = "Not available"
        if checkpoint_exists and 'epoch' in checkpoint_data:
            epochs_completed = checkpoint_data['epoch'] + 1
        f.write(f"Completed epochs: {epochs_completed}\n")
        
        f.write(f"Batch Size: {Config.BATCH_SIZE}\n")
        f.write(f"Initial Learning Rate: {Config.LEARNING_RATE}\n")
        
        # Get carefully the final learning rate
        lr_final = "Not available"
        if checkpoint_exists and 'optimizer_state_dict' in checkpoint_data:
            optimizer_dict = checkpoint_data['optimizer_state_dict']
            if 'param_groups' in optimizer_dict and optimizer_dict['param_groups']:
                if 'lr' in optimizer_dict['param_groups'][0]:
                    lr_final = format_number(optimizer_dict['param_groups'][0]['lr'])
        f.write(f"Final Learning Rate: {lr_final}\n")
        
        f.write(f"Training duration: {format_duration(training_duration)}\n\n")
        
        # Final metrics
        f.write("FINAL METRICS\n")
        f.write("-" * 80 + "\n")
        f.write(f"Final training loss: {format_number(final_metrics.get('final_train_loss', None))}\n")
        f.write(f"Final validation loss: {format_number(final_metrics.get('final_valid_loss', None))}\n")
        f.write(f"Average validation IoU: {format_number(final_metrics.get('final_valid_iou', None))}\n\n")
        
        # IoU by class
        f.write("IoU BY CLASS\n")
        for class_name in Config.CLASSES:
            iou_key = f'final_iou_{class_name}'
            iou_value = final_metrics.get(iou_key, None)
            f.write(f"- {class_name}: {format_number(iou_value)}\n")
        f.write("\n")
        
        # Preprocessing
        f.write("APPLIED PREPROCESSING\n")
        f.write("-" * 80 + "\n")
        f.write("- Alpha channel extraction from transparent PNG masks\n")
        f.write("- Strict mask binarization (values 0 or 1)\n")
        f.write("- Thin line enhancement through controlled dilation for contours and hands\n")
        f.write("- INTER_NEAREST interpolation for resizing with fine detail preservation\n")
        f.write("- Image normalization with mean=(0.485, 0.456, 0.406) and std=(0.229, 0.224, 0.225)\n\n")
        
        # Other observations
        f.write("ADDITIONAL OBSERVATIONS\n")
        f.write("-" * 80 + "\n")
        f.write("- Combined loss function (Dice Loss + Focal Loss) with class weighting\n")
        f.write("- Higher weighting for thin line classes (hands, contour)\n")
        f.write("- Xavier weight initialization to improve convergence\n")
        f.write("- Early stopping with patience of 10 epochs\n")
        f.write("- Class-specific post-processing to improve final results\n")
        
    print(f"Model report generated at: {output_path}")
    
    # Open the file for visualization
    with open(output_path, 'r') as f:
        report_content = f.read()
    
    return report_content

# Generate the report
best_model_path = os.path.join(Config.CHECKPOINT_PATH, 'best_model.pth')
if os.path.exists(best_model_path):
    # If you have information about the training duration, you can pass it here
    # For example: training_duration = end_time - start_time
    report = generate_model_report(best_model_path)
    print("\nReport content:")
    print(report)
else:
    print("Best model not found. Make sure you've trained the model first.")