In [None]:
import os
import re
import torch
import shutil
import cv2
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from datetime import datetime
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [None]:
imgs_list = os.listdir('data/images')

In [None]:
# Example, Marks, Time, Date, Address, Closing, Author, Source, Matrices
# Path to your labels folder
labels_folder = 'data/labels'

# Classes to be removed
classes_to_remove = [
    'Example', 'Marks', 'Time', 'Date', 'Address',
    'Closing', 'Author', 'Source', 'Matrices'
]

# Create a backup folder for removed files
backup_folder = 'removed_labels'
if not os.path.exists(backup_folder):
    os.makedirs(backup_folder)

# This pattern matches the tag format in your files
# Example: task-34-annotation-10-by-1-tag-Author-0.png
tag_pattern = re.compile(r'task-\d+-annotation-\d+-by-\d+-tag-([A-Za-z\s-]+)-\d+\.png')


def should_remove(filename):
    """Check if the file has a tag that should be removed"""
    match = tag_pattern.match(filename)
    if not match:
        return False

    # Extract the tag name (like "Author" or "Font - Bold")
    tag_name = match.group(1)

    for class_name in classes_to_remove:
        # Check if tag equals the class or starts with the class
        # This handles both "Author" and "Author-0" cases
        if (tag_name == class_name or
                tag_name.startswith(class_name + ' ') or
                tag_name.startswith(class_name + '-')):
            return True
    return False


# Track statistics
total_files = 0
removed_files = 0
files_kept = 0

# Store removed files for reporting
removed_file_list = []

# Process all files in the labels folder
for filename in os.listdir(labels_folder):
    if filename.endswith('.png'):
        total_files += 1
        file_path = os.path.join(labels_folder, filename)

        if should_remove(filename):
            removed_files += 1
            removed_file_list.append(filename)

            print(f"Moving to backup: {filename}")
            shutil.move(file_path, os.path.join(backup_folder, filename))
        else:
            files_kept += 1

print(f"\nSummary:")
print(f"Total files processed: {total_files}")
print(f"Files removed: {removed_files}")
print(f"Files kept: {files_kept}")

if removed_files > 0:
    print("\nFiles that were moved to backup:")
    for file in sorted(removed_file_list):
        print(f"  - {file}")

## Data Loader

In [None]:
class DocumentSegmentationDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None, target_size=(512, 512)):
        """
        Args:
            images_dir (str): Path to the directory with document images
            labels_dir (str): Path to the directory with label mask images
            transform (callable, optional): Optional transform to be applied on images
            target_size (tuple): Size to resize images and masks to
        """
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.target_size = target_size

        # Get all document image files
        self.image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')])

        # Extract document IDs from image filenames (e.g., "doc_31.jpg" -> "31")
        self.doc_ids = []
        for img_file in self.image_files:
            match = re.search(r'doc_(\d+)', img_file)
            if match:
                self.doc_ids.append(match.group(1))

        # Define class names and assign index to each class
        self.class_names = [
            'Background',  # 0
            'Header',  # 1
            'Paragraph',  # 2
            'Page Number',  # 3
            'Footnotes',  # 4
            'Font - Bold',  # 5
            'Topic',  # 6
            'Caption',  # 7
            'Image',  # 8
            'Topic - Level1',  # 9
            'Topic - Level2',  # 10
            'Topic - Level3',  # 11
            'Diagram',  # 12
            'Table',  # 13
            'Mathematical Expression',  # 14
            'List',  # 15
            'Footer',  # 16
            'Font - Italic',  # 17
            'Poem',  # 18
            'Charts',  # 19
        ]

        # Exclude the classes we want to remove
        self.excluded_classes = [
            'Example', 'Marks', 'Time', 'Date', 'Address',
            'Closing', 'Author', 'Source', 'Matrices'
        ]

        # Create a mapping from class name to index
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.class_names)}
        self.num_classes = len(self.class_names)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load document image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Resize image to target size
        image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_AREA)

        # Get corresponding document ID
        doc_id = self.doc_ids[idx]

        # Create empty segmentation mask (initialized with zeros for background)
        mask = np.zeros((*self.target_size, 1), dtype=np.uint8)

        # Find all label masks for this document
        label_pattern = f"task-*-annotation-*-by-*-tag-*-*.png"
        all_labels = os.listdir(self.labels_dir)
        doc_labels = []

        # Find the annotation number for this document (assuming consistent annotation numbers)
        annotation_numbers = set()
        for label_file in all_labels:
            # Look for any label files related to this document
            if f"-{doc_id}-" in label_file:
                match = re.search(r'task-(\d+)-annotation', label_file)
                if match:
                    annotation_numbers.add(match.group(1))

        # If we found annotation numbers, use them to find all related labels
        if annotation_numbers:
            for annotation_num in annotation_numbers:
                # Get all masks for this annotation
                doc_labels.extend([f for f in all_labels if f.startswith(f"task-{annotation_num}-annotation")])

        # Process each label mask
        for label_file in doc_labels:
            # Extract class name from label file (e.g., "task-31-annotation-4-by-1-tag-Header-0.png" -> "Header")
            match = re.search(r'tag-([A-Za-z\s-]+)-\d+', label_file)
            if not match:
                continue

            class_name = match.group(1)

            # Skip excluded classes
            if any(excluded in class_name for excluded in self.excluded_classes):
                continue

            # Check if this class is in our mapping
            if class_name not in self.class_to_idx:
                # Add new class if not already present
                self.class_names.append(class_name)
                self.class_to_idx[class_name] = len(self.class_names) - 1
                self.num_classes = len(self.class_names)

            # Get class index
            class_idx = self.class_to_idx[class_name]

            # Load label mask
            label_path = os.path.join(self.labels_dir, label_file)
            label_mask = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

            if label_mask is None:
                continue

            # Resize label mask to match target size
            label_mask = cv2.resize(label_mask, self.target_size, interpolation=cv2.INTER_NEAREST)

            # Binary threshold to ensure mask is binary (0 or 255)
            _, label_mask = cv2.threshold(label_mask, 127, 255, cv2.THRESH_BINARY)

            # Set corresponding pixels in the combined mask to class index
            # For pixels where the label mask is > 0, set the class value
            mask[label_mask > 0] = class_idx

        # Convert image and mask to tensors
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).long()

        # Apply additional transformations if specified
        if self.transform:
            image = self.transform(image)

        return {
            'image': image,
            'mask': mask.squeeze(),
            'doc_id': doc_id,
            'image_path': img_path
        }

    def get_class_names(self):
        return self.class_names

In [None]:
def visualize_sample(dataset, idx):
    """Visualize a sample from the dataset"""
    sample = dataset[idx]
    image = sample['image']
    mask = sample['mask']
    doc_id = sample['doc_id']

    # Convert tensor to numpy for visualization
    image_np = image.permute(1, 2, 0).numpy()
    mask_np = mask.numpy()

    # Create a color-coded mask for visualization
    cmap = plt.cm.get_cmap('tab20', dataset.num_classes)
    colored_mask = cmap(mask_np)

    # Plot
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(image_np)
    plt.title(f"Document {doc_id}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(colored_mask)
    plt.title(f"Segmentation Mask")
    plt.axis('off')

    # Add color bar with class names
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, dataset.num_classes - 1))
    cbar = plt.colorbar(sm, ax=plt.gca())
    cbar.set_ticks(np.arange(dataset.num_classes) + 0.5)
    cbar.set_ticklabels(dataset.class_names)

    plt.tight_layout()
    plt.show()

In [None]:
# Define directories
images_dir = 'data/images'  # Update with your actual path
labels_dir = 'data/labels'  # Update with your actual path

# Create dataset
dataset = DocumentSegmentationDataset(images_dir, labels_dir, target_size=(256, 256))

# Create dataloader
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Print dataset info
print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Class names: {dataset.class_names}")

# Visualize a sample
visualize_sample(dataset, 0)

# Iterate through batches
for batch_idx, batch in enumerate(dataloader):
    images = batch['image']
    masks = batch['mask']
    doc_ids = batch['doc_id']

    print(f"Batch {batch_idx}:")
    print(f"  Image shape: {images.shape}")
    print(f"  Mask shape: {masks.shape}")
    print(f"  Document IDs: {doc_ids}")

    # Process only one batch for demonstration
    if batch_idx == 0:
        break

## U-Net Components

In [None]:
class DoubleConv(nn.Module):
    """(Conv => BN => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Encoder path
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        # Decoder path with skip connections
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)

        # Final output layer
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        # Output layer
        logits = self.outc(x)
        return logits

## Training Loop

In [None]:
def train_model(model, train_loader, val_loader, device, epochs=50, lr=0.001):
    # Criterion (loss function) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)

    # Track losses
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    # Create directory for model checkpoints
    checkpoint_dir = os.path.join('checkpoints', datetime.now().strftime('%Y%m%d_%H%M%S'))
    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")

        # Training phase
        model.train()
        train_loss = 0

        with tqdm(train_loader, unit="batch", desc="Training") as train_pbar:
            for batch in train_pbar:
                images = batch['image'].to(device)
                masks = batch['mask'].to(device)

                optimizer.zero_grad()

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, masks)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                train_pbar.set_postfix(loss=loss.item())

        # Calculate average training loss
        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        # Validation phase - SIMPLIFIED
        model.eval()
        val_loss = 0

        # Add tqdm progress bar for validation
        with torch.no_grad():
            with tqdm(val_loader, unit="batch", desc="Validation") as val_pbar:
                for batch in val_pbar:
                    images = batch['image'].to(device)
                    masks = batch['mask'].to(device)

                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    val_loss += loss.item()

                    # Update progress bar with current loss
                    batch_loss = loss.item()
                    val_pbar.set_postfix(loss=batch_loss)

        # Calculate average validation loss
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        # Update learning rate
        scheduler.step(val_loss)

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, os.path.join(checkpoint_dir, 'best_model.pth'))
            print(f"✅ Saved best model with validation loss: {val_loss:.4f}")

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth'))
            print(f"💾 Saved checkpoint at epoch {epoch + 1}")

        # Print summary of losses only
        print(f'Summary - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    # Plot only training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, 'loss_plot.png'))
    plt.show()

    return model, {
        'train_loss': train_losses,
        'val_loss': val_losses
    }

In [None]:
def visualize_predictions(model, dataset, indices, device, save_dir=None):
    """
    Visualize model predictions for multiple samples.

    Args:
        model: The trained UNet model
        dataset: The dataset containing images and masks
        indices: List of sample indices to visualize
        device: The device to run inference on
        save_dir: Directory to save visualizations (optional)
    """
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model.eval()
    class_names = dataset.get_class_names()
    n_classes = len(class_names)

    # Create a colormap
    cmap = plt.cm.get_cmap('tab20', n_classes)

    for idx in indices:
        sample = dataset[idx]
        image = sample['image'].unsqueeze(0).to(device)
        true_mask = sample['mask'].numpy()
        doc_id = sample['doc_id']

        # Get model prediction
        with torch.no_grad():
            output = model(image)
            pred_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

        # Convert image tensor to numpy
        image_np = image.squeeze().cpu().permute(1, 2, 0).numpy()

        # Create figure with subplots
        fig, ax = plt.subplots(2, 2, figsize=(15, 12))

        # Original Image
        ax[0, 0].imshow(image_np)
        ax[0, 0].set_title(f"Document {doc_id}")
        ax[0, 0].axis('off')

        # Ground Truth Mask
        ax[0, 1].imshow(cmap(true_mask))
        ax[0, 1].set_title("Ground Truth")
        ax[0, 1].axis('off')

        # Prediction Mask
        ax[1, 0].imshow(cmap(pred_mask))
        ax[1, 0].set_title("Prediction")
        ax[1, 0].axis('off')

        # Difference Mask
        diff_mask = (pred_mask != true_mask).astype(np.int32)
        ax[1, 1].imshow(image_np)
        ax[1, 1].imshow(np.ma.masked_where(diff_mask == 0, diff_mask),
                        alpha=0.7, cmap='cool')
        ax[1, 1].set_title("Errors (Misclassified Pixels)")
        ax[1, 1].axis('off')

        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, n_classes - 1))
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cbar = fig.colorbar(sm, cax=cbar_ax)
        cbar.set_ticks(np.arange(n_classes) + 0.5)
        cbar.set_ticklabels(class_names)

        # Calculate per-class metrics for this sample
        class_metrics = {}
        for c in range(n_classes):
            # Calculate metrics for this class
            true_c = (true_mask == c)
            pred_c = (pred_mask == c)

            if np.sum(true_c) == 0:
                # Skip if class not present in ground truth
                continue

            intersection = np.logical_and(true_c, pred_c)
            union = np.logical_or(true_c, pred_c)

            tp = np.sum(intersection)
            fp = np.sum(pred_c) - tp
            fn = np.sum(true_c) - tp

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            iou = tp / np.sum(union) if np.sum(union) > 0 else 0

            class_metrics[class_names[c]] = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'iou': iou,
                'pixels': np.sum(true_c)
            }

        # Add a text box with metrics
        metrics_text = "Per-class metrics:\n"
        for cls, metrics in class_metrics.items():
            metrics_text += f"{cls}: IoU={metrics['iou']:.2f}, F1={metrics['f1']:.2f}\n"

        # Add text box with metrics
        props = dict(boxstyle='round', alpha=0.5)
        fig.text(0.02, 0.02, metrics_text, fontsize=10,
                 verticalalignment='bottom', bbox=props)

        plt.tight_layout(rect=[0, 0, 0.9, 1])
        plt.suptitle(f"Document Segmentation Analysis (Sample {idx})", fontsize=16, y=0.98)

        if save_dir:
            plt.savefig(os.path.join(save_dir, f"pred_{idx}.png"))
            plt.close()
        else:
            plt.show()


In [None]:
def analyze_class_distribution(dataset):
    """
    Analyze the class distribution in the dataset.

    Args:
        dataset: The dataset to analyze
    """
    class_names = dataset.get_class_names()
    n_classes = len(class_names)

    # Count pixels per class
    class_counts = np.zeros(n_classes)

    for i in tqdm(range(len(dataset)), desc="Analyzing classes"):
        mask = dataset[i]['mask'].numpy()
        for c in range(n_classes):
            class_counts[c] += np.sum(mask == c)

    # Plot class distribution
    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, class_counts)
    plt.xlabel('Classes')
    plt.ylabel('Pixel Count')
    plt.title('Class Distribution in Dataset')
    plt.xticks(rotation=45, ha='right')

    # Add counts on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2., height + 0.1,
                 f'{int(height)}', ha='center', va='bottom', rotation=0)

    plt.tight_layout()
    plt.show()

    # Calculate class weights for potential weighted loss
    total_pixels = np.sum(class_counts)
    class_weights = 1.0 / (class_counts + 1e-10)  # Add small epsilon to avoid division by zero
    class_weights = class_weights / np.sum(class_weights) * n_classes  # Normalize

    print("\nClass weights for weighted loss:")
    for i, (name, weight) in enumerate(zip(class_names, class_weights)):
        print(f"{name}: {weight:.4f}")

    return class_counts, class_weights

In [None]:
def run_training(images_dir, labels_dir, target_size=(512, 512), batch_size=4, epochs=50):
    # Create dataset and split to train/val
    full_dataset = DocumentSegmentationDataset(images_dir, labels_dir, target_size=target_size)

    # Define train/val split ratio
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size

    # Split dataset
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size])

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create model
    model = UNet(n_channels=3, n_classes=full_dataset.num_classes)
    model.to(device)

    # Print model summary
    print(f"Model: UNet")
    print(f"Input channels: 3")
    print(f"Output classes: {full_dataset.num_classes}")
    print(f"Class names: {full_dataset.get_class_names()}")

    # Train model
    trained_model, _ = train_model(
        model, train_loader, val_loader, device, epochs=epochs)

    # Visualize some predictions
    # for i in range(3):  # Show 3 random examples
    #     idx = np.random.randint(0, len(val_dataset))
    #     visualize_predictions(trained_model, full_dataset, idx, device)

    return trained_model

In [None]:
trained_model = run_training('data/images', 'data/labels', target_size=(256, 256), batch_size=12, epochs=10)

In [None]:
model = UNet(n_channels=3, n_classes=dataset.num_classes)
model.to("cpu")

checkpoint_path = 'checkpoints/20250316_221703/best_model.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
full_dataset = DocumentSegmentationDataset('data/images', 'data/labels', target_size=(256, 256))
visualize_predictions(model, full_dataset, [0, 1, 2], 'cpu')

In [None]:
def evaluate_model(model, test_loader, device, save_dir=None):
    """
    Evaluate model on test set and generate detailed metrics.

    Args:
        model: The trained UNet model
        test_loader: DataLoader for test data
        device: Device to run evaluation on
        save_dir: Directory to save results (optional)
    """
    model.eval()
    n_classes = model.n_classes

    # Initialize confusion matrix
    confusion_matrix = torch.zeros((n_classes, n_classes), device=device)

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            # Update confusion matrix
            for t, p in zip(masks.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

    # Calculate per-class metrics
    precision = torch.zeros(n_classes, device=device)
    recall = torch.zeros(n_classes, device=device)
    f1 = torch.zeros(n_classes, device=device)
    iou = torch.zeros(n_classes, device=device)

    for i in range(n_classes):
        # True Positives: diagonal elements
        tp = confusion_matrix[i, i]
        # False Positives: sum of column i - true positives
        fp = confusion_matrix[:, i].sum() - tp
        # False Negatives: sum of row i - true positives
        fn = confusion_matrix[i, :].sum() - tp

        # Precision: TP / (TP + FP)
        precision[i] = tp / (tp + fp + 1e-10)

        # Recall: TP / (TP + FN)
        recall[i] = tp / (tp + fn + 1e-10)

        # F1 Score: 2 * (precision * recall) / (precision + recall)
        f1[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i] + 1e-10)

        # IoU / Jaccard Index: TP / (TP + FP + FN)
        iou[i] = tp / (tp + fp + fn + 1e-10)

    # Calculate overall metrics
    mean_precision = precision[1:].mean() if n_classes > 1 else precision.mean()
    mean_recall = recall[1:].mean() if n_classes > 1 else recall.mean()
    mean_f1 = f1[1:].mean() if n_classes > 1 else f1.mean()
    mean_iou = iou[1:].mean() if n_classes > 1 else iou.mean()

    # Calculate pixel accuracy
    pixel_accuracy = torch.diag(confusion_matrix).sum() / confusion_matrix.sum()

    # Print results
    print(f"Overall Metrics:")
    print(f"  Pixel Accuracy: {pixel_accuracy:.4f}")
    print(f"  Mean IoU: {mean_iou:.4f}")
    print(f"  Mean Precision: {mean_precision:.4f}")
    print(f"  Mean Recall: {mean_recall:.4f}")
    print(f"  Mean F1 Score: {mean_f1:.4f}")

    print("\nPer-class Metrics:")
    class_names = test_loader.dataset.get_class_names() if hasattr(test_loader.dataset, 'get_class_names') else [f"Class {i}" for i in range(n_classes)]

    for i in range(n_classes):
        print(f"  {class_names[i]}:")
        print(f"    IoU: {iou[i]:.4f}")
        print(f"    Precision: {precision[i]:.4f}")
        print(f"    Recall: {recall[i]:.4f}")
        print(f"    F1 Score: {f1[i]:.4f}")

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    confusion_normalized = confusion_matrix.cpu().numpy()
    confusion_normalized = confusion_normalized / (confusion_normalized.sum(axis=1, keepdims=True) + 1e-10)

    plt.imshow(confusion_normalized, cmap='Blues')
    plt.colorbar()

    # Add class names to axes
    plt.xticks(np.arange(n_classes), class_names, rotation=45, ha='right')
    plt.yticks(np.arange(n_classes), class_names)

    # Add text annotations
    thresh = confusion_normalized.max() / 2.0
    for i in range(n_classes):
        for j in range(n_classes):
            plt.text(j, i, f"{confusion_normalized[i, j]:.2f}",
                    ha="center", va="center",
                    color="white" if confusion_normalized[i, j] > thresh else "black")

    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix (Normalized)')
    plt.tight_layout()

    if save_dir:
        plt.savefig(os.path.join(save_dir, "confusion_matrix.png"))

    plt.show()

    # Prepare results to return
    results = {
        'precision': precision.cpu().numpy(),
        'recall': recall.cpu().numpy(),
        'f1': f1.cpu().numpy(),
        'iou': iou.cpu().numpy(),
        'pixel_accuracy': pixel_accuracy.item(),
        'mean_precision': mean_precision.item(),
        'mean_recall': mean_recall.item(),
        'mean_f1': mean_f1.item(),
        'mean_iou': mean_iou.item(),
        'confusion_matrix': confusion_matrix.cpu().numpy()
    }

    return results