# Simplified Where's Waldo Object Detection - Kaggle/Google Colab Execution Version

Here's a fully executable version of the notebook optimized for Google Colab or Kaggle. I've streamlined the code, reduced computational requirements, and made it more robust for cloud execution:



In [None]:
# %% [markdown]
# # Simplified Object Detection: Finding Waldo Characters
# 
# This notebook implements a complete object detection pipeline that:
# 1. Creates a synthetic dataset of Waldo characters on backgrounds
# 2. Builds a custom object detection model with a ResNet backbone
# 3. Trains the model with early stopping and learning rate scheduling
# 4. Evaluates performance using precision, recall, IoU and other metrics
# 5. Fine-tunes a YOLOv8 model on the same dataset for comparison

# %% [markdown]
# ## 1. Install Requirements and Import Libraries

# %%
# Install required packages
# !pip install -q ultralytics torch torchvision matplotlib tqdm

# Import necessary libraries
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import random
import time
from PIL import Image
import requests
import io
from tqdm.notebook import tqdm
from pathlib import Path
import pandas as pd

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# %% [markdown]
# ## 2. Download and Prepare Character Images

# %%
# Set up directories
object_dir = "objects"
background_dir = "backgrounds"
dataset_dir = "dataset"

os.makedirs(object_dir, exist_ok=True)
os.makedirs(background_dir, exist_ok=True)

# Download the Waldo character images
def download_character_images():
    # Waldo character URLs
    character_urls = {
        "waldo": "https://static.wikia.nocookie.net/waldo/images/9/9d/Character.Waldo.jpg",
        "wilma": "https://static.wikia.nocookie.net/waldo/images/8/86/Character.Wilma.jpg",
        "wenda": "https://static.wikia.nocookie.net/waldo/images/3/3e/Character.Wenda.jpg"
    }
    
    object_images = []
    object_names = []
    
    for name, url in character_urls.items():
        try:
            # Download image
            response = requests.get(url)
            if response.status_code != 200:
                print(f"⚠️ Failed to download {name}. Creating fallback.")
                create_fallback_character(name, object_dir)
                continue
                
            # Create character image with transparent background
            img = Image.open(io.BytesIO(response.content)).convert("RGBA")
            
            # Simple background removal (white to transparent)
            data = np.array(img)
            r, g, b, a = data.T
            white_areas = (r > 200) & (g > 200) & (b > 200)
            data[..., 3][white_areas.T] = 0
            
            # Save image
            transparent_img = Image.fromarray(data)
            img_path = os.path.join(object_dir, f"{name}.png")
            transparent_img.save(img_path)
            
            object_images.append(transparent_img)
            object_names.append(name)
            print(f"✅ Downloaded {name}")
            
        except Exception as e:
            print(f"❌ Error processing {name}: {e}")
            create_fallback_character(name, object_dir)
    
    return object_images, object_names

def create_fallback_character(character, output_dir):
    """Create a simple colored character if download fails"""
    colors = {
        "waldo": (255, 0, 0, 255),  # Red
        "wilma": (0, 0, 255, 255),  # Blue
        "wenda": (255, 105, 180, 255)  # Pink
    }
    
    color = colors.get(character, (255, 165, 0, 255))
    
    # Create a character silhouette
    img = Image.new('RGBA', (200, 300), (0, 0, 0, 0))
    
    # Draw simple character
    from PIL import ImageDraw
    draw = ImageDraw.Draw(img)
    
    # Head
    draw.ellipse((75, 30, 125, 80), fill=color)
    
    # Body
    draw.rectangle((85, 80, 115, 180), fill=color)
    
    # Arms
    draw.rectangle((50, 100, 85, 120), fill=color)
    draw.rectangle((115, 100, 150, 120), fill=color)
    
    # Legs
    draw.rectangle((85, 180, 95, 250), fill=color)
    draw.rectangle((105, 180, 115, 250), fill=color)
    
    # Add stripes if it's Waldo
    if character == "waldo":
        stripe_color = (255, 255, 255, 255)
        for y in range(80, 180, 20):
            draw.rectangle((85, y, 115, y+10), fill=stripe_color)
    
    # Save image
    img_path = os.path.join(output_dir, f"{character}.png")
    img.save(img_path)
    print(f"🎨 Created fallback image for {character}")
    
    return img

# Visualize the characters
def visualize_objects(object_images, object_names):
    plt.figure(figsize=(15, 5))
    for i, (img, name) in enumerate(zip(object_images, object_names)):
        plt.subplot(1, len(object_images), i+1)
        plt.imshow(img)
        plt.title(name)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

# Download and visualize characters
object_images, object_names = download_character_images()
visualize_objects(object_images, object_names)

# %% [markdown]
# ## 3. Create Background Images

# %%
# Create procedural backgrounds (to avoid web crawling on Kaggle/Colab)
def create_background_images(num_images=200):
    """Generate procedural background images"""
    print(f"Creating {num_images} background images...")
    
    background_paths = []
    
    for i in range(num_images):
        # Create a procedural background with random patterns
        bg_width, bg_height = 640, 640
        background = Image.new("RGB", (bg_width, bg_height), (255, 255, 255))
        
        # Draw random shapes for more complex backgrounds
        from PIL import ImageDraw
        draw = ImageDraw.Draw(background)
        
        # Add random lines
        for _ in range(random.randint(10, 30)):
            x1 = random.randint(0, bg_width)
            y1 = random.randint(0, bg_height)
            x2 = random.randint(0, bg_width)
            y2 = random.randint(0, bg_height)
            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            width = random.randint(1, 5)
            draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
        
        # Add random rectangles
        for _ in range(random.randint(5, 15)):
            x1 = random.randint(0, bg_width)
            y1 = random.randint(0, bg_height)
            x2 = random.randint(0, bg_width)
            y2 = random.randint(0, bg_height)
            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            draw.rectangle([x1, y1, x2, y2], fill=color)
        
        # Add random circles
        for _ in range(random.randint(5, 15)):
            x1 = random.randint(0, bg_width)
            y1 = random.randint(0, bg_height)
            radius = random.randint(5, 50)
            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            draw.ellipse([x1-radius, y1-radius, x1+radius, y1+radius], fill=color)
        
        # Save the background
        bg_path = os.path.join(background_dir, f"background_{i:03d}.jpg")
        background.save(bg_path)
        background_paths.append(bg_path)
        
        # Show progress
        if (i + 1) % 50 == 0:
            print(f"  Created {i+1}/{num_images} backgrounds")
    
    print(f"✅ Created {num_images} background images")
    return background_paths

# Visualize backgrounds
def visualize_backgrounds(background_paths, num_samples=8):
    plt.figure(figsize=(15, 8))
    samples = random.sample(background_paths, min(num_samples, len(background_paths)))
    
    for i, path in enumerate(samples):
        img = Image.open(path)
        plt.subplot(2, 4, i+1)
        plt.imshow(img)
        plt.title(f"Background {i+1}")
        plt.axis("off")
    
    plt.tight_layout()
    plt.show()

# Generate and visualize backgrounds
background_paths = create_background_images(200)
visualize_backgrounds(background_paths)

# %% [markdown]
# ## 4. Create Synthetic Dataset

# %%
def create_synthetic_dataset(background_paths, object_images, object_names, 
                            output_dir, split, img_size=(640, 640), num_images=500):
    """
    Create a synthetic dataset by placing objects on backgrounds
    
    Parameters:
        background_paths: List of paths to background images
        object_images: List of object images with transparency
        object_names: List of object class names
        output_dir: Root directory to save dataset
        split: Dataset split ('train', 'val', or 'test')
        img_size: Size of output images (width, height)
        num_images: Number of images to generate
    """
    # Create directory structure
    dataset_dir = os.path.join(output_dir, split)
    images_dir = os.path.join(dataset_dir, "images")
    labels_dir = os.path.join(dataset_dir, "labels")
    
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(labels_dir, exist_ok=True)
    
    print(f"🎯 Creating {num_images} synthetic images for {split} set...")
    
    for i in range(num_images):
        # Select random background
        bg_path = random.choice(background_paths)
        try:
            background = Image.open(bg_path).convert("RGB").resize(img_size)
        except Exception as e:
            print(f"⚠️ Error loading background {bg_path}: {e}")
            # Create a simple background as fallback
            background = Image.new("RGB", img_size, (200, 200, 200))
            
        # Select random object
        obj_idx = random.randint(0, len(object_images) - 1)
        obj_image = object_images[obj_idx].copy()
        
        # Resize object to random size
        scale_factor = random.uniform(0.1, 0.3)  # Object will be 10-30% of image size
        obj_width = int(img_size[0] * scale_factor)
        obj_height = int(obj_width * (obj_image.height / obj_image.width))  # Maintain aspect ratio
        
        try:
            # For newer PIL versions
            obj_image = obj_image.resize((obj_width, obj_height), Image.Resampling.LANCZOS)
        except AttributeError:
            try:
                # For older PIL versions
                obj_image = obj_image.resize((obj_width, obj_height), Image.LANCZOS)
            except:
                # Fallback
                obj_image = obj_image.resize((obj_width, obj_height))
                
        # Place object at random position
        max_x = img_size[0] - obj_width
        max_y = img_size[1] - obj_height
        x_pos = random.randint(0, max_x)
        y_pos = random.randint(0, max_y)
        
        # Paste object on background
        background.paste(obj_image, (x_pos, y_pos), obj_image)
        
        # Calculate YOLO format bounding box
        x_center = (x_pos + obj_width / 2) / img_size[0]
        y_center = (y_pos + obj_height / 2) / img_size[1]
        width = obj_width / img_size[0]
        height = obj_height / img_size[1]
        
        # Save image with proper padding in filename
        img_filename = f"{i:05d}.jpg"
        background.save(os.path.join(images_dir, img_filename))
        
        # Save label in YOLO format
        label_filename = f"{i:05d}.txt"
        with open(os.path.join(labels_dir, label_filename), "w") as f:
            f.write(f"{obj_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
            
        # Show progress
        if (i + 1) % 100 == 0 or i == num_images - 1:
            print(f"  Progress: {i+1}/{num_images} images created")
    
    print(f"✅ Created {split} dataset with {num_images} images")
    return images_dir, labels_dir

# Create all datasets
def create_all_datasets(background_paths, object_images, object_names, output_dir="dataset"):
    """Create train, validation, and test datasets"""
    # Reduced dataset sizes for quicker execution in Colab/Kaggle
    train_images, train_labels = create_synthetic_dataset(
        background_paths, object_images, object_names, 
        output_dir, "train", num_images=1000
    )
    
    val_images, val_labels = create_synthetic_dataset(
        background_paths, object_images, object_names, 
        output_dir, "val", num_images=200
    )
    
    test_images, test_labels = create_synthetic_dataset(
        background_paths, object_images, object_names, 
        output_dir, "test", num_images=100
    )
    
    return {
        'train': (train_images, train_labels),
        'val': (val_images, val_labels),
        'test': (test_images, test_labels)
    }

# Create the datasets
os.makedirs(dataset_dir, exist_ok=True)
dataset_paths = create_all_datasets(background_paths, object_images, object_names, dataset_dir)

# %% [markdown]
# ## 5. Define Dataset and Create DataLoaders

# %%
# Define the PyTorch Dataset class
class ObjectDetectionDataset(Dataset):
    def __init__(self, root_dir, split, num_classes, transform=None):
        """
        Dataset for object detection
        
        Parameters:
            root_dir: Root directory of the dataset
            split: 'train', 'val', or 'test'
            num_classes: Number of object classes
            transform: PyTorch transformations to apply
        """
        self.root_dir = root_dir
        self.split = split
        self.num_classes = num_classes
        self.transform = transform
        
        # Get the paths
        self.images_dir = os.path.join(root_dir, split, "images")
        self.labels_dir = os.path.join(root_dir, split, "labels")
        
        # Get image files
        self.image_files = sorted([
            f for f in os.listdir(self.images_dir) 
            if f.endswith((".jpg", ".jpeg", ".png"))
        ])
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        
        # Get corresponding label
        label_path = os.path.join(self.labels_dir, 
                                  os.path.splitext(self.image_files[idx])[0] + ".txt")
        
        # Default values in case label is missing
        class_id = 0
        bbox = torch.tensor([0.5, 0.5, 0.2, 0.2])  # [x_center, y_center, width, height]
        
        # Try to load label
        try:
            with open(label_path, "r") as f:
                label_data = f.readline().strip().split()
                class_id = int(float(label_data[0]))
                bbox = torch.tensor([float(x) for x in label_data[1:5]])
        except Exception as e:
            print(f"⚠️ Error loading label for {self.image_files[idx]}: {e}")
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, bbox, class_id

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset objects
train_dataset = ObjectDetectionDataset(dataset_dir, "train", len(object_names), train_transform)
val_dataset = ObjectDetectionDataset(dataset_dir, "val", len(object_names), val_transform)
test_dataset = ObjectDetectionDataset(dataset_dir, "test", len(object_names), test_transform)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"✅ DataLoaders created")
print(f"  • Train: {len(train_dataset)} images ({len(train_loader)} batches)")
print(f"  • Val: {len(val_dataset)} images ({len(val_loader)} batches)")
print(f"  • Test: {len(test_dataset)} images ({len(test_loader)} batches)")

# %% [markdown]
# ## 6. Visualize Training Samples

# %%
# Visualize dataset samples
def visualize_dataset_sample(dataset, num_samples=4):
    """Visualize samples from the dataset"""
    if len(dataset) == 0:
        print("❌ No images in dataset to visualize")
        return
        
    plt.figure(figsize=(15, 5))
    for i in range(min(num_samples, len(dataset))):
        image, bbox, class_id = dataset[i]
        
        # Denormalize the image
        img = image.permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Extract bounding box
        x_center, y_center, width, height = bbox.numpy()
        
        # Calculate bounding box corners
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        plt.subplot(1, num_samples, i+1)
        plt.imshow(img)
        plt.title(f"Class: {object_names[class_id]}")
        
        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, 
                           fill=False, edgecolor='red', linewidth=2)
        plt.gca().add_patch(rect)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize samples from each dataset
print("🖼️ Visualizing training samples:")
visualize_dataset_sample(train_dataset)

# %% [markdown]
# ## 7. Build Custom Object Detection Model

# %%
# Define the custom object detection model with a pre-trained backbone
class CustomObjectDetectionModel(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super(CustomObjectDetectionModel, self).__init__()
        # Use a pre-trained ResNet18 as the backbone
        resnet = models.resnet18(pretrained=pretrained)
        
        # Remove the final fully connected layer and avgpool
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
        # Feature pyramid to handle multi-scale detection
        self.conv1x1 = nn.Conv2d(512, 256, kernel_size=1)  # Reduce channels
        
        # Add spatial pyramid pooling to handle various object sizes
        self.spp = nn.Sequential(
            nn.AdaptiveMaxPool2d(5),  # Multi-scale features
            nn.Flatten()
        )
        
        # Feature size after SPP and flattening
        feature_size = 5 * 5 * 256
        
        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Add dropout for regularization
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )
        
        # Bounding box regression head
        self.regression_head = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Add dropout for regularization
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 4),  # [x_center, y_center, width, height]
            nn.Sigmoid()  # Bound outputs between 0 and 1 for normalized coordinates
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        # Initialize the weights of our added layers 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Extract features from the backbone
        features = self.backbone(x) 
        
        # Apply 1x1 convolution to reduce channels
        features = self.conv1x1(features)
        
        # Apply spatial pyramid pooling
        features_flat = self.spp(features)
        
        # Process features through the classification and regression heads
        class_logits = self.classification_head(features_flat)
        bbox_coords = self.regression_head(features_flat)
        
        return class_logits, bbox_coords

# Create the model
num_classes = len(object_names)
model = CustomObjectDetectionModel(num_classes=num_classes, pretrained=True).to(device)

# Print model info
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")
print(f"Model created with {len(object_names)} classes")

# %% [markdown]
# ## 8. Define Loss Function and Optimizer

# %%
# Define the loss functions
classification_loss_fn = nn.CrossEntropyLoss()  # For class probabilities
regression_loss_fn = nn.SmoothL1Loss()  # Better choice for bounding box regression than MSE

# Define the optimizer
learning_rate = 0.001
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Learning rate scheduler to reduce LR when training plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=5, verbose=True
)

# Define a custom combined loss function
def object_detection_loss(class_pred, bbox_pred, class_target, bbox_target):
    """
    Calculate combined loss for object detection
    
    Args:
        class_pred: predicted class scores [batch_size, num_classes]
        bbox_pred: predicted bounding boxes [batch_size, 4]
        class_target: ground truth class indices [batch_size]
        bbox_target: ground truth bounding boxes [batch_size, 4]
    
    Returns:
        total_loss: combined classification and regression loss
        cls_loss: classification loss component
        reg_loss: regression loss component
    """
    # Calculate classification loss
    cls_loss = classification_loss_fn(class_pred, class_target)
    
    # Calculate regression loss
    reg_loss = regression_loss_fn(bbox_pred, bbox_target)
    
    # Combine losses - balanced weighting
    total_loss = cls_loss + reg_loss
    
    return total_loss, cls_loss, reg_loss

# %% [markdown]
# ## 9. Train the Model

# %%
# Function to train the object detection model
def train_model(model, train_loader, val_loader, loss_fn, optimizer, scheduler, 
                num_epochs=10, early_stopping_patience=5, device=device):
    """
    Train the custom object detection model
    
    Args:
        model: The model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        loss_fn: Combined loss function
        optimizer: Optimizer for parameter updates
        scheduler: Learning rate scheduler
        num_epochs: Maximum number of epochs to train
        early_stopping_patience: Number of epochs to wait before early stopping
        device: Device to train on (cuda/cpu)
        
    Returns:
        model: Trained model
        history: Training history (losses, metrics)
    """
    # Initialize history dictionary to track metrics
    history = {
        'train_loss': [], 'val_loss': [],
        'train_cls_loss': [], 'val_cls_loss': [],
        'train_reg_loss': [], 'val_reg_loss': []
    }
    
    # Variables for early stopping and best model tracking
    best_val_loss = float('inf')
    early_stopping_counter = 0
    best_model_path = 'best_model.pth'
    
    # Progress bar for epochs
    print(f"🏋️ Starting training for {num_epochs} epochs...")
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_cls_loss = 0.0
        train_reg_loss = 0.0
        
        # Progress bar for training
        train_progress = tqdm(train_loader, desc="Training", leave=False)
        
        for i, (images, bboxes, class_ids) in enumerate(train_progress):
            # Move data to device
            images = images.to(device)
            bboxes = bboxes.to(device)
            class_ids = class_ids.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            class_pred, bbox_pred = model(images)
            
            # Calculate loss
            loss, cls_loss, reg_loss = loss_fn(class_pred, bbox_pred, class_ids, bboxes)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update running losses
            train_loss += loss.item()
            train_cls_loss += cls_loss.item()
            train_reg_loss += reg_loss.item()
            
            # Update progress bar
            train_progress.set_postfix(loss=f"{loss.item():.4f}")
        
        # Calculate average training losses
        avg_train_loss = train_loss / len(train_loader)
        avg_train_cls_loss = train_cls_loss / len(train_loader)
        avg_train_reg_loss = train_reg_loss / len(train_loader)
        
        # Add to history
        history['train_loss'].append(avg_train_loss)
        history['train_cls_loss'].append(avg_train_cls_loss)
        history['train_reg_loss'].append(avg_train_reg_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_cls_loss = 0.0
        val_reg_loss = 0.0
        
        # Progress bar for validation
        val_progress = tqdm(val_loader, desc="Validation", leave=False)
        
        with torch.no_grad():
            for i, (images, bboxes, class_ids) in enumerate(val_progress):
                # Move data to device
                images = images.to(device)
                bboxes = bboxes.to(device)
                class_ids = class_ids.to(device)
                
                # Forward pass
                class_pred, bbox_pred = model(images)
                
                # Calculate loss
                loss, cls_loss, reg_loss = loss_fn(class_pred, bbox_pred, class_ids, bboxes)
                
                # Update running losses
                val_loss += loss.item()
                val_cls_loss += cls_loss.item()
                val_reg_loss += reg_loss.item()
                
                # Update progress bar
                val_progress.set_postfix(loss=f"{loss.item():.4f}")
        
        # Calculate average validation losses
        avg_val_loss = val_loss / len(val_loader)
        avg_val_cls_loss = val_cls_loss / len(val_loader)
        avg_val_reg_loss = val_reg_loss / len(val_loader)
        
        # Add to history
        history['val_loss'].append(avg_val_loss)
        history['val_cls_loss'].append(avg_val_cls_loss)
        history['val_reg_loss'].append(avg_val_reg_loss)
        
        # Update learning rate scheduler
        scheduler.step(avg_val_loss)
        
        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f} (Cls: {avg_train_cls_loss:.4f}, Reg: {avg_train_reg_loss:.4f})")
        print(f"  Val Loss: {avg_val_loss:.4f} (Cls: {avg_val_cls_loss:.4f}, Reg: {avg_val_reg_loss:.4f})")
        
        # Check if this is the best model so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            early_stopping_counter = 0
            
            # Save the best model
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
                'history': history
            }, best_model_path)
            
            print(f"  ✅ Model improved! Saved checkpoint to {best_model_path}")
        else:
            early_stopping_counter += 1
            print(f"  ⚠️ Model did not improve. Early stopping counter: {early_stopping_counter}/{early_stopping_patience}")
            
            # Check if we should stop early
            if early_stopping_counter >= early_stopping_patience:
                print(f"  🛑 Early stopping triggered after {epoch+1} epochs")
                break
    
    # Load the best model
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n✅ Training complete! Best model from epoch {checkpoint['epoch']} loaded (Val Loss: {checkpoint['val_loss']:.4f})")
    
    return model, history

# Set training parameters - reduced for Colab/Kaggle
num_epochs = 10  # Use 20-30 if running on more powerful hardware
early_stopping_patience = 3

# Start training
model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=object_detection_loss,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    early_stopping_patience=early_stopping_patience
)

# %% [markdown]
# ## 10. Visualize Training Metrics

# %%
# Plot the training history
def plot_training_metrics(history):
    """Plot training and validation metrics with analysis"""
    # Create a figure with 2 subplots
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 10), sharex=True)
    
    # Plot total loss (main metric)
    axes[0].plot(history['train_loss'], label="Training Loss", color="blue", marker="o")
    axes[0].plot(history['val_loss'], label="Validation Loss", color="orange", marker="o")
    axes[0].set_ylabel("Total Loss")
    axes[0].set_title("Training and Validation Loss Over Time")
    axes[0].legend()
    axes[0].grid(True)
    
    # Add annotations for best model
    best_epoch = np.argmin(history['val_loss'])
    best_val_loss = history['val_loss'][best_epoch]
    axes[0].axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5)
    axes[0].scatter(best_epoch, best_val_loss, s=100, c='red', label=f'Best Model (Epoch {best_epoch+1})')
    
    # Plot component losses (classification and regression)
    axes[1].plot(history['train_cls_loss'], label="Train Classification Loss", color="blue", linestyle="-")
    axes[1].plot(history['val_cls_loss'], label="Val Classification Loss", color="blue", linestyle="--")
    axes[1].plot(history['train_reg_loss'], label="Train Regression Loss", color="green", linestyle="-")
    axes[1].plot(history['val_reg_loss'], label="Val Regression Loss", color="green", linestyle="--")
    axes[1].set_xlabel("Epochs")
    axes[1].set_ylabel("Component Losses")
    axes[1].set_title("Classification and Regression Loss Components")
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Analyze convergence and provide text report
    print("📊 Model Convergence Analysis:")
    
    # Check if the model has converged
    min_loss_epoch = np.argmin(history['val_loss'])
    last_epoch = len(history['val_loss']) - 1
    
    # Calculate training and validation loss reduction
    initial_train_loss = history['train_loss'][0]
    final_train_loss = history['train_loss'][last_epoch]
    train_reduction = ((initial_train_loss - final_train_loss) / initial_train_loss) * 100
    
    initial_val_loss = history['val_loss'][0]
    final_val_loss = history['val_loss'][last_epoch]
    best_val_loss = history['val_loss'][min_loss_epoch]
    val_reduction = ((initial_val_loss - best_val_loss) / initial_val_loss) * 100
    
    # Check if loss is still decreasing at the end of training
    if min_loss_epoch == last_epoch:
        print(f"  • The model was STILL IMPROVING when training stopped at epoch {last_epoch+1}")
        print(f"  • Consider training for more epochs to potentially achieve better performance")
    elif min_loss_epoch < last_epoch - 2:
        print(f"  • The model CONVERGED around epoch {min_loss_epoch+1} (best validation loss)")
        print(f"  • Early stopping prevented overfitting by loading the best model")
    else:
        print(f"  • The model appears to have CONVERGED near the end of training (best at epoch {min_loss_epoch+1})")
    
    print(f"\n  • Training loss reduced by {train_reduction:.2f}% (from {initial_train_loss:.4f} to {final_train_loss:.4f})")
    print(f"  • Validation loss reduced by {val_reduction:.2f}% (from {initial_val_loss:.4f} to {best_val_loss:.4f})")
    
    # Check for overfitting
    if final_train_loss < final_val_loss * 0.7:
        print("\n  ⚠️ OVERFITTING DETECTED: The training loss is much lower than validation loss")
    else:
        print("\n  ✅ HEALTHY CONVERGENCE: Training and validation losses decreased together")
        print("  • The model appears to generalize well to unseen data")

# Plot training metrics
plot_training_metrics(history)

# %% [markdown]
# ## 11. Evaluate the Custom Model

# %%
# Function to calculate Intersection over Union (IoU)
def calculate_iou(pred_box, true_box):
    """Calculate IoU between predicted and ground truth boxes in YOLO format"""
    # Extract coordinates (convert from center format to corner format)
    pred_x1 = pred_box[0] - pred_box[2] / 2
    pred_y1 = pred_box[1] - pred_box[3] / 2
    pred_x2 = pred_box[0] + pred_box[2] / 2
    pred_y2 = pred_box[1] + pred_box[3] / 2

    true_x1 = true_box[0] - true_box[2] / 2
    true_y1 = true_box[1] - true_box[3] / 2
    true_x2 = true_box[0] + true_box[2] / 2
    true_y2 = true_box[1] + true_box[3] / 2

    # Calculate intersection area
    inter_x1 = max(pred_x1, true_x1)
    inter_y1 = max(pred_y1, true_y1)
    inter_x2 = min(pred_x2, true_x2)
    inter_y2 = min(pred_y2, true_y2)

    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

    # Calculate union area
    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
    true_area = (true_x2 - true_x1) * (true_y2 - true_y1)
    union_area = pred_area + true_area - inter_area

    # Avoid division by zero
    if union_area == 0:
        return 0.0

    return inter_area / union_area

# Function to evaluate the model on the test set
def evaluate_model(model, test_loader, device, iou_threshold=0.5):
    """
    Evaluate the model on test data with multiple metrics
    """
    model.eval()  # Set model to evaluation mode
    
    # Initialize metrics
    metrics = {
        'precision': [], 'recall': [], 'f1': [], 'iou': [],
        'class_accuracy': [], 'inference_times': []
    }
    
    # Class-specific metrics
    class_metrics = {class_name: {'correct': 0, 'total': 0} 
                    for class_name in object_names}
    
    with torch.no_grad():
        for images, true_boxes, true_classes in test_loader:
            images = images.to(device)
            true_boxes = true_boxes.to(device)
            true_classes = true_classes.to(device)
            
            # Measure inference time
            start_time = time.time()
            class_logits, pred_boxes = model(images)
            end_time = time.time()
            
            # Calculate inference time per image
            batch_inference_time = (end_time - start_time) / len(images)
            metrics['inference_times'].append(batch_inference_time)
            
            # Get predicted classes
            _, pred_classes = torch.max(class_logits, 1)
            
            # Calculate metrics for each image in the batch
            for i in range(len(images)):
                pred_box = pred_boxes[i].cpu().numpy()
                true_box = true_boxes[i].cpu().numpy()
                
                # Calculate IoU
                iou = calculate_iou(pred_box, true_box)
                metrics['iou'].append(iou)
                
                # Class prediction accuracy
                pred_class = pred_classes[i].item()
                true_class = true_classes[i].item()
                class_correct = (pred_class == true_class)
                metrics['class_accuracy'].append(float(class_correct))
                
                # Update class-specific metrics
                class_name = object_names[true_class]
                class_metrics[class_name]['total'] += 1
                if class_correct:
                    class_metrics[class_name]['correct'] += 1
                
                # Calculate precision, recall, and F1-score
                # Detection is correct if IoU > threshold AND class is correct
                correct_detection = (iou > iou_threshold) and class_correct
                
                if correct_detection:
                    precision = 1.0
                    recall = 1.0
                else:
                    precision = 0.0
                    recall = 0.0
                
                f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
                
                metrics['precision'].append(precision)
                metrics['recall'].append(recall)
                metrics['f1'].append(f1)
    
    # Calculate mean metrics
    mean_metrics = {
        'precision': np.mean(metrics['precision']),
        'recall': np.mean(metrics['recall']),
        'f1': np.mean(metrics['f1']),
        'iou': np.mean(metrics['iou']),
        'class_accuracy': np.mean(metrics['class_accuracy']),
        'inference_time': np.mean(metrics['inference_times'])
    }
    
    # Calculate class-specific accuracy
    for class_name in class_metrics:
        total = class_metrics[class_name]['total']
        if total > 0:
            class_metrics[class_name]['accuracy'] = class_metrics[class_name]['correct'] / total
        else:
            class_metrics[class_name]['accuracy'] = 0.0
    
    # Calculate model size
    model_size_bytes = sum(p.nelement() * p.element_size() for p in model.parameters())
    model_size_mb = model_size_bytes / (1024 * 1024)
    
    # Print metrics summary
    print("\n📊 Model Evaluation Metrics:")
    print(f"  • Mean Precision: {mean_metrics['precision']:.4f}")
    print(f"  • Mean Recall: {mean_metrics['recall']:.4f}")
    print(f"  • Mean F1-Score: {mean_metrics['f1']:.4f}")
    print(f"  • Mean IoU: {mean_metrics['iou']:.4f}")
    print(f"  • Class Prediction Accuracy: {mean_metrics['class_accuracy']:.4f}")
    print(f"  • Average Inference Time: {mean_metrics['inference_time']*1000:.2f} ms per image")
    print(f"  • Model Size: {model_size_mb:.2f} MB")
    
    # Print class-specific metrics
    print("\n📊 Class-Specific Metrics:")
    for class_name in class_metrics:
        accuracy = class_metrics[class_name]['accuracy']
        total = class_metrics[class_name]['total']
        print(f"  • {class_name}: Accuracy = {accuracy:.4f} (from {total} samples)")
    
    return mean_metrics, class_metrics

# Run evaluation on the test set
mean_metrics, class_metrics = evaluate_model(model, test_loader, device)

# %% [markdown]
# ## 12. Visualize Custom Model Predictions

# %%
# Visualize predictions on test images
def visualize_predictions(model, test_loader, device, num_images=8):
    """
    Visualize model predictions vs ground truth with detailed metrics
    """
    model.eval()  # Set model to evaluation mode
    
    # Get a batch from the test loader
    data_iter = iter(test_loader)
    images, true_boxes, true_classes = next(data_iter)
    
    # Ensure we don't try to visualize more images than we have
    num_images = min(num_images, len(images))
    
    # Make predictions
    images = images.to(device)
    with torch.no_grad():
        class_logits, pred_boxes = model(images)
        _, pred_classes = torch.max(class_logits, 1)
    
    # Set up the plot
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    for i in range(num_images):
        # Get image and convert for display
        img = images[i].cpu()
        img = img.permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Get ground truth box and class
        true_box = true_boxes[i].cpu().numpy()
        true_class = true_classes[i].item()
        true_class_name = object_names[true_class]
        
        # Get predicted box and class
        pred_box = pred_boxes[i].cpu().numpy()
        pred_class = pred_classes[i].item()
        pred_class_name = object_names[pred_class]
        
        # Calculate IoU
        iou = calculate_iou(pred_box, true_box)
        
        # Plot the image
        axes[i].imshow(img)
        
        # Draw ground truth box (green)
        x_center, y_center, width, height = true_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        rect = patches.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min,
            linewidth=2, edgecolor='green', facecolor='none', label='True'
        )
        axes[i].add_patch(rect)
        
        # Draw predicted box (red)
        x_center, y_center, width, height = pred_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        rect = patches.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min,
            linewidth=2, edgecolor='red', facecolor='none', label='Pred'
        )
        axes[i].add_patch(rect)
        
        # Add detailed title with metrics
        axes[i].set_title(
            f"True: {true_class_name}, Pred: {pred_class_name}\nIoU: {iou:.2f}",
            fontsize=10
        )
        axes[i].axis('off')
        axes[i].legend()
    
    # Hide any unused subplots
    for i in range(num_images, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.suptitle("Model Predictions vs Ground Truth", fontsize=16, y=1.02)
    plt.show()

# Visualize predictions
visualize_predictions(model, test_loader, device)

# Visualize predictions on test images with detailed metrics
def visualize_predictions_with_metrics(model, test_loader, iou_threshold=0.5):
    """
    Visualize model predictions on test images with detailed metrics
    """
    model.eval()  # Set model to evaluation mode
    
    # Get a batch of images from the test loader
    images, true_boxes, true_classes = next(iter(test_loader))
    images, true_boxes, true_classes = images.to(device), true_boxes.to(device), true_classes.to(device)

    # Make predictions
    with torch.no_grad():
        class_logits, pred_boxes = model(images)
        _, pred_classes = torch.max(class_logits, 1)

    # Set up figure for visualization
    num_images = min(8, len(images))
    fig, axes = plt.subplots(2, 4, figsize=(15, 10))
    axes = axes.flatten()
    
    for i in range(num_images):
        # Get current image and denormalize it properly
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Display the image
        axes[i].imshow(img)

        # Get ground truth and prediction information
        true_box = true_boxes[i].cpu().numpy()
        pred_box = pred_boxes[i].cpu().numpy()
        true_class = true_classes[i].item()
        pred_class = pred_classes[i].item()
        
        # Get class names
        true_class_name = object_names[true_class]
        pred_class_name = object_names[pred_class]

        # Calculate ground truth box coordinates
        x_center, y_center, width, height = true_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        # Add ground truth box
        gt_rect = plt.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min, 
            linewidth=2, edgecolor="green", facecolor="none", 
            label=f"True: {true_class_name}"
        )
        axes[i].add_patch(gt_rect)

        # Calculate predicted box coordinates
        x_center, y_center, width, height = pred_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        # Add predicted box
        pred_rect = plt.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min, 
            linewidth=2, edgecolor="red", facecolor="none", 
            label=f"Pred: {pred_class_name}"
        )
        axes[i].add_patch(pred_rect)

        # Calculate metrics
        iou = calculate_iou(pred_box, true_box)
        
        # Calculate precision and recall based on IoU threshold
        class_correct = (pred_class == true_class)
        detection_correct = (iou > iou_threshold) and class_correct
        
        precision = 1.0 if detection_correct else 0.0
        recall = 1.0 if detection_correct else 0.0
        
        # Add metrics annotation
        axes[i].set_title(
            f"IoU: {iou:.2f} | Precision: {precision:.1f} | Recall: {recall:.1f}\n"
            f"True: {true_class_name} | Pred: {pred_class_name}", 
            fontsize=9
        )
        axes[i].axis("off")
        axes[i].legend(loc="upper right", fontsize=8)
    
    # Hide any unused subplots
    for i in range(num_images, len(axes)):
        axes[i].axis("off")
        
    plt.tight_layout()
    plt.suptitle("Model Predictions vs Ground Truth with Metrics", fontsize=14, y=1.02)
    plt.show()

# Run the detailed visualization
visualize_predictions_with_metrics(model, test_loader)

# %% [markdown]
# ## 13. Fine-tune YOLOv8 on the Same Dataset

# %%
# Prepare the data.yaml file for YOLO training
train_images_path = os.path.join(dataset_dir, "train", "images")
val_images_path = os.path.join(dataset_dir, "val", "images")
test_images_path = os.path.join(dataset_dir, "test", "images")

# Create the data.yaml file in the dataset directory
data_yaml_path = os.path.join(dataset_dir, "data.yaml")
with open(data_yaml_path, "w") as f:
    f.write(f"path: {dataset_dir}\n")
    f.write(f"train: {train_images_path}\n")
    f.write(f"val: {val_images_path}\n")
    f.write(f"test: {test_images_path}\n")
    f.write(f"nc: {len(object_names)}\n")  # Number of classes from object_names
    
    # Write class names dynamically from object_names
    f.write("names:\n")
    for i, name in enumerate(object_names):
        f.write(f"  {i}: {name}\n")

print(f"✅ data.yaml file created at {data_yaml_path}")
print("📄 Content preview:")
with open(data_yaml_path, "r") as f:
    print(f.read())

# Import YOLO from ultralytics
try:
    from ultralytics import YOLO
    
    # Load a pre-trained YOLO model
    yolo_model = YOLO("yolov8n.pt")  # Using YOLOv8 nano version
    
    # Print model information
    print(f"\n🔍 Using pre-trained model: YOLOv8n")
    print(f"  • Architecture: YOLOv8 (Detection)")
    print(f"  • Size: Nano (compact)")
    print(f"  • Pre-trained on: COCO dataset (80 classes)")
    
    # Fine-tune the YOLO model on the custom dataset
    print("\n🏋️ Starting YOLO model fine-tuning...")
    results = yolo_model.train(
        data=data_yaml_path,
        epochs=5,  # Reduced for Colab/Kaggle
        imgsz=640,
        batch=8,  # Reduced for Colab/Kaggle
        patience=3,  # Early stopping patience
        save=True,  # Save best model
        device=0 if torch.cuda.is_available() else "cpu",
        project="yolo_training",
        name="run1"
    )
    
    # Save the fine-tuned YOLO model
    fine_tuned_model_path = "yolo_fine_tuned.pt"
    yolo_model.export(format="pytorch")  # Export to PyTorch format
    
    print(f"✅ Fine-tuned YOLO model saved")
    print(f"  • Best model path: {yolo_model.best}")
    print(f"  • Exported model: {fine_tuned_model_path}")
    
except Exception as e:
    print(f"❌ Error during YOLO training: {str(e)}")
    print("You can continue with the rest of the notebook.")

# %% [markdown]
# ## 14. Evaluate YOLO Model

# %%
# Function to evaluate the YOLO model on the test set
def evaluate_yolo_model(yolo_model, test_loader):
    """
    Evaluate YOLO model on test data and calculate standard metrics
    """
    yolo_model.eval()  # Set YOLO model to evaluation mode
    
    # Initialize metrics
    metrics = {
        'precision': [], 'recall': [], 'f1': [], 'iou': [], 
        'inference_times': []
    }

    with torch.no_grad():
        for images, true_boxes, true_classes in test_loader:
            # Process each image individually to avoid batching issues
            for i in range(len(images)):
                # Get single image and convert to numpy with correct format
                image = images[i].cpu().numpy().transpose(1, 2, 0)  # CHW to HWC
                
                # Denormalize the image from [-1,1] to [0,255]
                image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                image = np.clip(image, 0, 1) * 255
                image = image.astype(np.uint8)
                
                # Get ground truth
                true_box = true_boxes[i].cpu().numpy()
                true_class = true_classes[i].item()
                
                # Measure inference time for single image
                start_time = time.time()
                # Run inference with YOLO - process one image at a time
                try:
                    result = yolo_model(image, conf=0.25, iou=0.45)[0]  # Get first (only) result
                    end_time = time.time()
                    inference_time = end_time - start_time
                    metrics['inference_times'].append(inference_time)
                    
                    # Get YOLO predictions (boxes in xywh normalized format)
                    if len(result.boxes) > 0:
                        # Convert to xywh normalized format
                        pred_box = result.boxes.xywhn[0].cpu().numpy()
                        pred_class = int(result.boxes.cls[0].item())
                        
                        # Calculate IoU between predicted and ground truth box
                        iou = calculate_iou(pred_box, true_box)
                        metrics['iou'].append(iou)
                        
                        # Determine if detection is correct (IoU > 0.5 and correct class)
                        class_correct = (pred_class == true_class)
                        detection_correct = (iou > 0.5) and class_correct
                        
                        # Calculate precision and recall
                        precision = 1.0 if detection_correct else 0.0
                        recall = 1.0 if detection_correct else 0.0
                        
                        # Calculate F1 score
                        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
                    else:
                        # No detection
                        iou = 0.0
                        precision = 0.0
                        recall = 0.0
                        f1 = 0.0
                        metrics['iou'].append(iou)
                    
                    # Store metrics
                    metrics['precision'].append(precision)
                    metrics['recall'].append(recall)
                    metrics['f1'].append(f1)
                    
                except Exception as e:
                    print(f"Error processing image {i}: {str(e)}")
                    continue
    
    # Calculate mean metrics if we have data
    if len(metrics['iou']) > 0:
        mean_metrics = {
            'precision': np.mean(metrics['precision']),
            'recall': np.mean(metrics['recall']),
            'f1': np.mean(metrics['f1']),
            'iou': np.mean(metrics['iou']),
            'inference_time': np.mean(metrics['inference_times'])
        }
        
        # Print detailed evaluation results
        print("\n📊 YOLO Model Evaluation Results:")
        print(f"  • Mean Precision: {mean_metrics['precision']:.4f}")
        print(f"  • Mean Recall: {mean_metrics['recall']:.4f}")
        print(f"  • Mean F1-Score: {mean_metrics['f1']:.4f}")
        print(f"  • Mean IoU: {mean_metrics['iou']:.4f}")
        print(f"  • Average Inference Time:# %% [markdown]
# # Simplified Object Detection: Finding Waldo Characters
# 
# This notebook implements a complete object detection pipeline that:
# 1. Creates a synthetic dataset of Waldo characters on backgrounds
# 2. Builds a custom object detection model with a ResNet backbone
# 3. Trains the model with early stopping and learning rate scheduling
# 4. Evaluates performance using precision, recall, IoU and other metrics
# 5. Fine-tunes a YOLOv8 model on the same dataset for comparison

# %% [markdown]
# ## 1. Install Requirements and Import Libraries

# %%
# Install required packages
!pip install -q ultralytics torch torchvision matplotlib tqdm

# Import necessary libraries
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import random
import time
from PIL import Image
import requests
import io
from tqdm.notebook import tqdm
from pathlib import Path
import pandas as pd

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# %% [markdown]
# ## 2. Download and Prepare Character Images

# %%
# Set up directories
object_dir = "objects"
background_dir = "backgrounds"
dataset_dir = "dataset"

os.makedirs(object_dir, exist_ok=True)
os.makedirs(background_dir, exist_ok=True)

# Download the Waldo character images
def download_character_images():
    # Waldo character URLs
    character_urls = {
        "waldo": "https://static.wikia.nocookie.net/waldo/images/9/9d/Character.Waldo.jpg",
        "wilma": "https://static.wikia.nocookie.net/waldo/images/8/86/Character.Wilma.jpg",
        "wenda": "https://static.wikia.nocookie.net/waldo/images/3/3e/Character.Wenda.jpg"
    }
    
    object_images = []
    object_names = []
    
    for name, url in character_urls.items():
        try:
            # Download image
            response = requests.get(url)
            if response.status_code != 200:
                print(f"⚠️ Failed to download {name}. Creating fallback.")
                create_fallback_character(name, object_dir)
                continue
                
            # Create character image with transparent background
            img = Image.open(io.BytesIO(response.content)).convert("RGBA")
            
            # Simple background removal (white to transparent)
            data = np.array(img)
            r, g, b, a = data.T
            white_areas = (r > 200) & (g > 200) & (b > 200)
            data[..., 3][white_areas.T] = 0
            
            # Save image
            transparent_img = Image.fromarray(data)
            img_path = os.path.join(object_dir, f"{name}.png")
            transparent_img.save(img_path)
            
            object_images.append(transparent_img)
            object_names.append(name)
            print(f"✅ Downloaded {name}")
            
        except Exception as e:
            print(f"❌ Error processing {name}: {e}")
            create_fallback_character(name, object_dir)
    
    return object_images, object_names

def create_fallback_character(character, output_dir):
    """Create a simple colored character if download fails"""
    colors = {
        "waldo": (255, 0, 0, 255),  # Red
        "wilma": (0, 0, 255, 255),  # Blue
        "wenda": (255, 105, 180, 255)  # Pink
    }
    
    color = colors.get(character, (255, 165, 0, 255))
    
    # Create a character silhouette
    img = Image.new('RGBA', (200, 300), (0, 0, 0, 0))
    
    # Draw simple character
    from PIL import ImageDraw
    draw = ImageDraw.Draw(img)
    
    # Head
    draw.ellipse((75, 30, 125, 80), fill=color)
    
    # Body
    draw.rectangle((85, 80, 115, 180), fill=color)
    
    # Arms
    draw.rectangle((50, 100, 85, 120), fill=color)
    draw.rectangle((115, 100, 150, 120), fill=color)
    
    # Legs
    draw.rectangle((85, 180, 95, 250), fill=color)
    draw.rectangle((105, 180, 115, 250), fill=color)
    
    # Add stripes if it's Waldo
    if character == "waldo":
        stripe_color = (255, 255, 255, 255)
        for y in range(80, 180, 20):
            draw.rectangle((85, y, 115, y+10), fill=stripe_color)
    
    # Save image
    img_path = os.path.join(output_dir, f"{character}.png")
    img.save(img_path)
    print(f"🎨 Created fallback image for {character}")
    
    return img

# Visualize the characters
def visualize_objects(object_images, object_names):
    plt.figure(figsize=(15, 5))
    for i, (img, name) in enumerate(zip(object_images, object_names)):
        plt.subplot(1, len(object_images), i+1)
        plt.imshow(img)
        plt.title(name)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

# Download and visualize characters
object_images, object_names = download_character_images()
visualize_objects(object_images, object_names)

# %% [markdown]
# ## 3. Create Background Images

# %%
# Create procedural backgrounds (to avoid web crawling on Kaggle/Colab)
def create_background_images(num_images=200):
    """Generate procedural background images"""
    print(f"Creating {num_images} background images...")
    
    background_paths = []
    
    for i in range(num_images):
        # Create a procedural background with random patterns
        bg_width, bg_height = 640, 640
        background = Image.new("RGB", (bg_width, bg_height), (255, 255, 255))
        
        # Draw random shapes for more complex backgrounds
        from PIL import ImageDraw
        draw = ImageDraw.Draw(background)
        
        # Add random lines
        for _ in range(random.randint(10, 30)):
            x1 = random.randint(0, bg_width)
            y1 = random.randint(0, bg_height)
            x2 = random.randint(0, bg_width)
            y2 = random.randint(0, bg_height)
            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            width = random.randint(1, 5)
            draw.line([(x1, y1), (x2, y2)], fill=color, width=width)
        
        # Add random rectangles
        for _ in range(random.randint(5, 15)):
            x1 = random.randint(0, bg_width)
            y1 = random.randint(0, bg_height)
            x2 = random.randint(0, bg_width)
            y2 = random.randint(0, bg_height)
            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            draw.rectangle([x1, y1, x2, y2], fill=color)
        
        # Add random circles
        for _ in range(random.randint(5, 15)):
            x1 = random.randint(0, bg_width)
            y1 = random.randint(0, bg_height)
            radius = random.randint(5, 50)
            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            draw.ellipse([x1-radius, y1-radius, x1+radius, y1+radius], fill=color)
        
        # Save the background
        bg_path = os.path.join(background_dir, f"background_{i:03d}.jpg")
        background.save(bg_path)
        background_paths.append(bg_path)
        
        # Show progress
        if (i + 1) % 50 == 0:
            print(f"  Created {i+1}/{num_images} backgrounds")
    
    print(f"✅ Created {num_images} background images")
    return background_paths

# Visualize backgrounds
def visualize_backgrounds(background_paths, num_samples=8):
    plt.figure(figsize=(15, 8))
    samples = random.sample(background_paths, min(num_samples, len(background_paths)))
    
    for i, path in enumerate(samples):
        img = Image.open(path)
        plt.subplot(2, 4, i+1)
        plt.imshow(img)
        plt.title(f"Background {i+1}")
        plt.axis("off")
    
    plt.tight_layout()
    plt.show()

# Generate and visualize backgrounds
background_paths = create_background_images(200)
visualize_backgrounds(background_paths)

# %% [markdown]
# ## 4. Create Synthetic Dataset

# %%
def create_synthetic_dataset(background_paths, object_images, object_names, 
                            output_dir, split, img_size=(640, 640), num_images=500):
    """
    Create a synthetic dataset by placing objects on backgrounds
    
    Parameters:
        background_paths: List of paths to background images
        object_images: List of object images with transparency
        object_names: List of object class names
        output_dir: Root directory to save dataset
        split: Dataset split ('train', 'val', or 'test')
        img_size: Size of output images (width, height)
        num_images: Number of images to generate
    """
    # Create directory structure
    dataset_dir = os.path.join(output_dir, split)
    images_dir = os.path.join(dataset_dir, "images")
    labels_dir = os.path.join(dataset_dir, "labels")
    
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(labels_dir, exist_ok=True)
    
    print(f"🎯 Creating {num_images} synthetic images for {split} set...")
    
    for i in range(num_images):
        # Select random background
        bg_path = random.choice(background_paths)
        try:
            background = Image.open(bg_path).convert("RGB").resize(img_size)
        except Exception as e:
            print(f"⚠️ Error loading background {bg_path}: {e}")
            # Create a simple background as fallback
            background = Image.new("RGB", img_size, (200, 200, 200))
            
        # Select random object
        obj_idx = random.randint(0, len(object_images) - 1)
        obj_image = object_images[obj_idx].copy()
        
        # Resize object to random size
        scale_factor = random.uniform(0.1, 0.3)  # Object will be 10-30% of image size
        obj_width = int(img_size[0] * scale_factor)
        obj_height = int(obj_width * (obj_image.height / obj_image.width))  # Maintain aspect ratio
        
        try:
            # For newer PIL versions
            obj_image = obj_image.resize((obj_width, obj_height), Image.Resampling.LANCZOS)
        except AttributeError:
            try:
                # For older PIL versions
                obj_image = obj_image.resize((obj_width, obj_height), Image.LANCZOS)
            except:
                # Fallback
                obj_image = obj_image.resize((obj_width, obj_height))
                
        # Place object at random position
        max_x = img_size[0] - obj_width
        max_y = img_size[1] - obj_height
        x_pos = random.randint(0, max_x)
        y_pos = random.randint(0, max_y)
        
        # Paste object on background
        background.paste(obj_image, (x_pos, y_pos), obj_image)
        
        # Calculate YOLO format bounding box
        x_center = (x_pos + obj_width / 2) / img_size[0]
        y_center = (y_pos + obj_height / 2) / img_size[1]
        width = obj_width / img_size[0]
        height = obj_height / img_size[1]
        
        # Save image with proper padding in filename
        img_filename = f"{i:05d}.jpg"
        background.save(os.path.join(images_dir, img_filename))
        
        # Save label in YOLO format
        label_filename = f"{i:05d}.txt"
        with open(os.path.join(labels_dir, label_filename), "w") as f:
            f.write(f"{obj_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
            
        # Show progress
        if (i + 1) % 100 == 0 or i == num_images - 1:
            print(f"  Progress: {i+1}/{num_images} images created")
    
    print(f"✅ Created {split} dataset with {num_images} images")
    return images_dir, labels_dir

# Create all datasets
def create_all_datasets(background_paths, object_images, object_names, output_dir="dataset"):
    """Create train, validation, and test datasets"""
    # Reduced dataset sizes for quicker execution in Colab/Kaggle
    train_images, train_labels = create_synthetic_dataset(
        background_paths, object_images, object_names, 
        output_dir, "train", num_images=1000
    )
    
    val_images, val_labels = create_synthetic_dataset(
        background_paths, object_images, object_names, 
        output_dir, "val", num_images=200
    )
    
    test_images, test_labels = create_synthetic_dataset(
        background_paths, object_images, object_names, 
        output_dir, "test", num_images=100
    )
    
    return {
        'train': (train_images, train_labels),
        'val': (val_images, val_labels),
        'test': (test_images, test_labels)
    }

# Create the datasets
os.makedirs(dataset_dir, exist_ok=True)
dataset_paths = create_all_datasets(background_paths, object_images, object_names, dataset_dir)

# %% [markdown]
# ## 5. Define Dataset and Create DataLoaders

# %%
# Define the PyTorch Dataset class
class ObjectDetectionDataset(Dataset):
    def __init__(self, root_dir, split, num_classes, transform=None):
        """
        Dataset for object detection
        
        Parameters:
            root_dir: Root directory of the dataset
            split: 'train', 'val', or 'test'
            num_classes: Number of object classes
            transform: PyTorch transformations to apply
        """
        self.root_dir = root_dir
        self.split = split
        self.num_classes = num_classes
        self.transform = transform
        
        # Get the paths
        self.images_dir = os.path.join(root_dir, split, "images")
        self.labels_dir = os.path.join(root_dir, split, "labels")
        
        # Get image files
        self.image_files = sorted([
            f for f in os.listdir(self.images_dir) 
            if f.endswith((".jpg", ".jpeg", ".png"))
        ])
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        
        # Get corresponding label
        label_path = os.path.join(self.labels_dir, 
                                  os.path.splitext(self.image_files[idx])[0] + ".txt")
        
        # Default values in case label is missing
        class_id = 0
        bbox = torch.tensor([0.5, 0.5, 0.2, 0.2])  # [x_center, y_center, width, height]
        
        # Try to load label
        try:
            with open(label_path, "r") as f:
                label_data = f.readline().strip().split()
                class_id = int(float(label_data[0]))
                bbox = torch.tensor([float(x) for x in label_data[1:5]])
        except Exception as e:
            print(f"⚠️ Error loading label for {self.image_files[idx]}: {e}")
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, bbox, class_id

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset objects
train_dataset = ObjectDetectionDataset(dataset_dir, "train", len(object_names), train_transform)
val_dataset = ObjectDetectionDataset(dataset_dir, "val", len(object_names), val_transform)
test_dataset = ObjectDetectionDataset(dataset_dir, "test", len(object_names), test_transform)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"✅ DataLoaders created")
print(f"  • Train: {len(train_dataset)} images ({len(train_loader)} batches)")
print(f"  • Val: {len(val_dataset)} images ({len(val_loader)} batches)")
print(f"  • Test: {len(test_dataset)} images ({len(test_loader)} batches)")

# %% [markdown]
# ## 6. Visualize Training Samples

# %%
# Visualize dataset samples
def visualize_dataset_sample(dataset, num_samples=4):
    """Visualize samples from the dataset"""
    if len(dataset) == 0:
        print("❌ No images in dataset to visualize")
        return
        
    plt.figure(figsize=(15, 5))
    for i in range(min(num_samples, len(dataset))):
        image, bbox, class_id = dataset[i]
        
        # Denormalize the image
        img = image.permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Extract bounding box
        x_center, y_center, width, height = bbox.numpy()
        
        # Calculate bounding box corners
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        plt.subplot(1, num_samples, i+1)
        plt.imshow(img)
        plt.title(f"Class: {object_names[class_id]}")
        
        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, 
                           fill=False, edgecolor='red', linewidth=2)
        plt.gca().add_patch(rect)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize samples from each dataset
print("🖼️ Visualizing training samples:")
visualize_dataset_sample(train_dataset)

# %% [markdown]
# ## 7. Build Custom Object Detection Model

# %%
# Define the custom object detection model with a pre-trained backbone
class CustomObjectDetectionModel(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super(CustomObjectDetectionModel, self).__init__()
        # Use a pre-trained ResNet18 as the backbone
        resnet = models.resnet18(pretrained=pretrained)
        
        # Remove the final fully connected layer and avgpool
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
        # Feature pyramid to handle multi-scale detection
        self.conv1x1 = nn.Conv2d(512, 256, kernel_size=1)  # Reduce channels
        
        # Add spatial pyramid pooling to handle various object sizes
        self.spp = nn.Sequential(
            nn.AdaptiveMaxPool2d(5),  # Multi-scale features
            nn.Flatten()
        )
        
        # Feature size after SPP and flattening
        feature_size = 5 * 5 * 256
        
        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Add dropout for regularization
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )
        
        # Bounding box regression head
        self.regression_head = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Add dropout for regularization
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 4),  # [x_center, y_center, width, height]
            nn.Sigmoid()  # Bound outputs between 0 and 1 for normalized coordinates
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        # Initialize the weights of our added layers 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Extract features from the backbone
        features = self.backbone(x) 
        
        # Apply 1x1 convolution to reduce channels
        features = self.conv1x1(features)
        
        # Apply spatial pyramid pooling
        features_flat = self.spp(features)
        
        # Process features through the classification and regression heads
        class_logits = self.classification_head(features_flat)
        bbox_coords = self.regression_head(features_flat)
        
        return class_logits, bbox_coords

# Create the model
num_classes = len(object_names)
model = CustomObjectDetectionModel(num_classes=num_classes, pretrained=True).to(device)

# Print model info
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")
print(f"Model created with {len(object_names)} classes")

# %% [markdown]
# ## 8. Define Loss Function and Optimizer

# %%
# Define the loss functions
classification_loss_fn = nn.CrossEntropyLoss()  # For class probabilities
regression_loss_fn = nn.SmoothL1Loss()  # Better choice for bounding box regression than MSE

# Define the optimizer
learning_rate = 0.001
weight_decay = 1e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Learning rate scheduler to reduce LR when training plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=5, verbose=True
)

# Define a custom combined loss function
def object_detection_loss(class_pred, bbox_pred, class_target, bbox_target):
    """
    Calculate combined loss for object detection
    
    Args:
        class_pred: predicted class scores [batch_size, num_classes]
        bbox_pred: predicted bounding boxes [batch_size, 4]
        class_target: ground truth class indices [batch_size]
        bbox_target: ground truth bounding boxes [batch_size, 4]
    
    Returns:
        total_loss: combined classification and regression loss
        cls_loss: classification loss component
        reg_loss: regression loss component
    """
    # Calculate classification loss
    cls_loss = classification_loss_fn(class_pred, class_target)
    
    # Calculate regression loss
    reg_loss = regression_loss_fn(bbox_pred, bbox_target)
    
    # Combine losses - balanced weighting
    total_loss = cls_loss + reg_loss
    
    return total_loss, cls_loss, reg_loss

# %% [markdown]
# ## 9. Train the Model

# %%
# Function to train the object detection model
def train_model(model, train_loader, val_loader, loss_fn, optimizer, scheduler, 
                num_epochs=10, early_stopping_patience=5, device=device):
    """
    Train the custom object detection model
    
    Args:
        model: The model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        loss_fn: Combined loss function
        optimizer: Optimizer for parameter updates
        scheduler: Learning rate scheduler
        num_epochs: Maximum number of epochs to train
        early_stopping_patience: Number of epochs to wait before early stopping
        device: Device to train on (cuda/cpu)
        
    Returns:
        model: Trained model
        history: Training history (losses, metrics)
    """
    # Initialize history dictionary to track metrics
    history = {
        'train_loss': [], 'val_loss': [],
        'train_cls_loss': [], 'val_cls_loss': [],
        'train_reg_loss': [], 'val_reg_loss': []
    }
    
    # Variables for early stopping and best model tracking
    best_val_loss = float('inf')
    early_stopping_counter = 0
    best_model_path = 'best_model.pth'
    
    # Progress bar for epochs
    print(f"🏋️ Starting training for {num_epochs} epochs...")
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_cls_loss = 0.0
        train_reg_loss = 0.0
        
        # Progress bar for training
        train_progress = tqdm(train_loader, desc="Training", leave=False)
        
        for i, (images, bboxes, class_ids) in enumerate(train_progress):
            # Move data to device
            images = images.to(device)
            bboxes = bboxes.to(device)
            class_ids = class_ids.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            class_pred, bbox_pred = model(images)
            
            # Calculate loss
            loss, cls_loss, reg_loss = loss_fn(class_pred, bbox_pred, class_ids, bboxes)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update running losses
            train_loss += loss.item()
            train_cls_loss += cls_loss.item()
            train_reg_loss += reg_loss.item()
            
            # Update progress bar
            train_progress.set_postfix(loss=f"{loss.item():.4f}")
        
        # Calculate average training losses
        avg_train_loss = train_loss / len(train_loader)
        avg_train_cls_loss = train_cls_loss / len(train_loader)
        avg_train_reg_loss = train_reg_loss / len(train_loader)
        
        # Add to history
        history['train_loss'].append(avg_train_loss)
        history['train_cls_loss'].append(avg_train_cls_loss)
        history['train_reg_loss'].append(avg_train_reg_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_cls_loss = 0.0
        val_reg_loss = 0.0
        
        # Progress bar for validation
        val_progress = tqdm(val_loader, desc="Validation", leave=False)
        
        with torch.no_grad():
            for i, (images, bboxes, class_ids) in enumerate(val_progress):
                # Move data to device
                images = images.to(device)
                bboxes = bboxes.to(device)
                class_ids = class_ids.to(device)
                
                # Forward pass
                class_pred, bbox_pred = model(images)
                
                # Calculate loss
                loss, cls_loss, reg_loss = loss_fn(class_pred, bbox_pred, class_ids, bboxes)
                
                # Update running losses
                val_loss += loss.item()
                val_cls_loss += cls_loss.item()
                val_reg_loss += reg_loss.item()
                
                # Update progress bar
                val_progress.set_postfix(loss=f"{loss.item():.4f}")
        
        # Calculate average validation losses
        avg_val_loss = val_loss / len(val_loader)
        avg_val_cls_loss = val_cls_loss / len(val_loader)
        avg_val_reg_loss = val_reg_loss / len(val_loader)
        
        # Add to history
        history['val_loss'].append(avg_val_loss)
        history['val_cls_loss'].append(avg_val_cls_loss)
        history['val_reg_loss'].append(avg_val_reg_loss)
        
        # Update learning rate scheduler
        scheduler.step(avg_val_loss)
        
        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f} (Cls: {avg_train_cls_loss:.4f}, Reg: {avg_train_reg_loss:.4f})")
        print(f"  Val Loss: {avg_val_loss:.4f} (Cls: {avg_val_cls_loss:.4f}, Reg: {avg_val_reg_loss:.4f})")
        
        # Check if this is the best model so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            early_stopping_counter = 0
            
            # Save the best model
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
                'history': history
            }, best_model_path)
            
            print(f"  ✅ Model improved! Saved checkpoint to {best_model_path}")
        else:
            early_stopping_counter += 1
            print(f"  ⚠️ Model did not improve. Early stopping counter: {early_stopping_counter}/{early_stopping_patience}")
            
            # Check if we should stop early
            if early_stopping_counter >= early_stopping_patience:
                print(f"  🛑 Early stopping triggered after {epoch+1} epochs")
                break
    
    # Load the best model
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n✅ Training complete! Best model from epoch {checkpoint['epoch']} loaded (Val Loss: {checkpoint['val_loss']:.4f})")
    
    return model, history

# Set training parameters - reduced for Colab/Kaggle
num_epochs = 10  # Use 20-30 if running on more powerful hardware
early_stopping_patience = 3

# Start training
model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=object_detection_loss,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    early_stopping_patience=early_stopping_patience
)

# %% [markdown]
# ## 10. Visualize Training Metrics

# %%
# Plot the training history
def plot_training_metrics(history):
    """Plot training and validation metrics with analysis"""
    # Create a figure with 2 subplots
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 10), sharex=True)
    
    # Plot total loss (main metric)
    axes[0].plot(history['train_loss'], label="Training Loss", color="blue", marker="o")
    axes[0].plot(history['val_loss'], label="Validation Loss", color="orange", marker="o")
    axes[0].set_ylabel("Total Loss")
    axes[0].set_title("Training and Validation Loss Over Time")
    axes[0].legend()
    axes[0].grid(True)
    
    # Add annotations for best model
    best_epoch = np.argmin(history['val_loss'])
    best_val_loss = history['val_loss'][best_epoch]
    axes[0].axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5)
    axes[0].scatter(best_epoch, best_val_loss, s=100, c='red', label=f'Best Model (Epoch {best_epoch+1})')
    
    # Plot component losses (classification and regression)
    axes[1].plot(history['train_cls_loss'], label="Train Classification Loss", color="blue", linestyle="-")
    axes[1].plot(history['val_cls_loss'], label="Val Classification Loss", color="blue", linestyle="--")
    axes[1].plot(history['train_reg_loss'], label="Train Regression Loss", color="green", linestyle="-")
    axes[1].plot(history['val_reg_loss'], label="Val Regression Loss", color="green", linestyle="--")
    axes[1].set_xlabel("Epochs")
    axes[1].set_ylabel("Component Losses")
    axes[1].set_title("Classification and Regression Loss Components")
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Analyze convergence and provide text report
    print("📊 Model Convergence Analysis:")
    
    # Check if the model has converged
    min_loss_epoch = np.argmin(history['val_loss'])
    last_epoch = len(history['val_loss']) - 1
    
    # Calculate training and validation loss reduction
    initial_train_loss = history['train_loss'][0]
    final_train_loss = history['train_loss'][last_epoch]
    train_reduction = ((initial_train_loss - final_train_loss) / initial_train_loss) * 100
    
    initial_val_loss = history['val_loss'][0]
    final_val_loss = history['val_loss'][last_epoch]
    best_val_loss = history['val_loss'][min_loss_epoch]
    val_reduction = ((initial_val_loss - best_val_loss) / initial_val_loss) * 100
    
    # Check if loss is still decreasing at the end of training
    if min_loss_epoch == last_epoch:
        print(f"  • The model was STILL IMPROVING when training stopped at epoch {last_epoch+1}")
        print(f"  • Consider training for more epochs to potentially achieve better performance")
    elif min_loss_epoch < last_epoch - 2:
        print(f"  • The model CONVERGED around epoch {min_loss_epoch+1} (best validation loss)")
        print(f"  • Early stopping prevented overfitting by loading the best model")
    else:
        print(f"  • The model appears to have CONVERGED near the end of training (best at epoch {min_loss_epoch+1})")
    
    print(f"\n  • Training loss reduced by {train_reduction:.2f}% (from {initial_train_loss:.4f} to {final_train_loss:.4f})")
    print(f"  • Validation loss reduced by {val_reduction:.2f}% (from {initial_val_loss:.4f} to {best_val_loss:.4f})")
    
    # Check for overfitting
    if final_train_loss < final_val_loss * 0.7:
        print("\n  ⚠️ OVERFITTING DETECTED: The training loss is much lower than validation loss")
    else:
        print("\n  ✅ HEALTHY CONVERGENCE: Training and validation losses decreased together")
        print("  • The model appears to generalize well to unseen data")

# Plot training metrics
plot_training_metrics(history)

# %% [markdown]
# ## 11. Evaluate the Custom Model

# %%
# Function to calculate Intersection over Union (IoU)
def calculate_iou(pred_box, true_box):
    """Calculate IoU between predicted and ground truth boxes in YOLO format"""
    # Extract coordinates (convert from center format to corner format)
    pred_x1 = pred_box[0] - pred_box[2] / 2
    pred_y1 = pred_box[1] - pred_box[3] / 2
    pred_x2 = pred_box[0] + pred_box[2] / 2
    pred_y2 = pred_box[1] + pred_box[3] / 2

    true_x1 = true_box[0] - true_box[2] / 2
    true_y1 = true_box[1] - true_box[3] / 2
    true_x2 = true_box[0] + true_box[2] / 2
    true_y2 = true_box[1] + true_box[3] / 2

    # Calculate intersection area
    inter_x1 = max(pred_x1, true_x1)
    inter_y1 = max(pred_y1, true_y1)
    inter_x2 = min(pred_x2, true_x2)
    inter_y2 = min(pred_y2, true_y2)

    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

    # Calculate union area
    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
    true_area = (true_x2 - true_x1) * (true_y2 - true_y1)
    union_area = pred_area + true_area - inter_area

    # Avoid division by zero
    if union_area == 0:
        return 0.0

    return inter_area / union_area

# Function to evaluate the model on the test set
def evaluate_model(model, test_loader, device, iou_threshold=0.5):
    """
    Evaluate the model on test data with multiple metrics
    """
    model.eval()  # Set model to evaluation mode
    
    # Initialize metrics
    metrics = {
        'precision': [], 'recall': [], 'f1': [], 'iou': [],
        'class_accuracy': [], 'inference_times': []
    }
    
    # Class-specific metrics
    class_metrics = {class_name: {'correct': 0, 'total': 0} 
                    for class_name in object_names}
    
    with torch.no_grad():
        for images, true_boxes, true_classes in test_loader:
            images = images.to(device)
            true_boxes = true_boxes.to(device)
            true_classes = true_classes.to(device)
            
            # Measure inference time
            start_time = time.time()
            class_logits, pred_boxes = model(images)
            end_time = time.time()
            
            # Calculate inference time per image
            batch_inference_time = (end_time - start_time) / len(images)
            metrics['inference_times'].append(batch_inference_time)
            
            # Get predicted classes
            _, pred_classes = torch.max(class_logits, 1)
            
            # Calculate metrics for each image in the batch
            for i in range(len(images)):
                pred_box = pred_boxes[i].cpu().numpy()
                true_box = true_boxes[i].cpu().numpy()
                
                # Calculate IoU
                iou = calculate_iou(pred_box, true_box)
                metrics['iou'].append(iou)
                
                # Class prediction accuracy
                pred_class = pred_classes[i].item()
                true_class = true_classes[i].item()
                class_correct = (pred_class == true_class)
                metrics['class_accuracy'].append(float(class_correct))
                
                # Update class-specific metrics
                class_name = object_names[true_class]
                class_metrics[class_name]['total'] += 1
                if class_correct:
                    class_metrics[class_name]['correct'] += 1
                
                # Calculate precision, recall, and F1-score
                # Detection is correct if IoU > threshold AND class is correct
                correct_detection = (iou > iou_threshold) and class_correct
                
                if correct_detection:
                    precision = 1.0
                    recall = 1.0
                else:
                    precision = 0.0
                    recall = 0.0
                
                f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
                
                metrics['precision'].append(precision)
                metrics['recall'].append(recall)
                metrics['f1'].append(f1)
    
    # Calculate mean metrics
    mean_metrics = {
        'precision': np.mean(metrics['precision']),
        'recall': np.mean(metrics['recall']),
        'f1': np.mean(metrics['f1']),
        'iou': np.mean(metrics['iou']),
        'class_accuracy': np.mean(metrics['class_accuracy']),
        'inference_time': np.mean(metrics['inference_times'])
    }
    
    # Calculate class-specific accuracy
    for class_name in class_metrics:
        total = class_metrics[class_name]['total']
        if total > 0:
            class_metrics[class_name]['accuracy'] = class_metrics[class_name]['correct'] / total
        else:
            class_metrics[class_name]['accuracy'] = 0.0
    
    # Calculate model size
    model_size_bytes = sum(p.nelement() * p.element_size() for p in model.parameters())
    model_size_mb = model_size_bytes / (1024 * 1024)
    
    # Print metrics summary
    print("\n📊 Model Evaluation Metrics:")
    print(f"  • Mean Precision: {mean_metrics['precision']:.4f}")
    print(f"  • Mean Recall: {mean_metrics['recall']:.4f}")
    print(f"  • Mean F1-Score: {mean_metrics['f1']:.4f}")
    print(f"  • Mean IoU: {mean_metrics['iou']:.4f}")
    print(f"  • Class Prediction Accuracy: {mean_metrics['class_accuracy']:.4f}")
    print(f"  • Average Inference Time: {mean_metrics['inference_time']*1000:.2f} ms per image")
    print(f"  • Model Size: {model_size_mb:.2f} MB")
    
    # Print class-specific metrics
    print("\n📊 Class-Specific Metrics:")
    for class_name in class_metrics:
        accuracy = class_metrics[class_name]['accuracy']
        total = class_metrics[class_name]['total']
        print(f"  • {class_name}: Accuracy = {accuracy:.4f} (from {total} samples)")
    
    return mean_metrics, class_metrics

# Run evaluation on the test set
mean_metrics, class_metrics = evaluate_model(model, test_loader, device)

# %% [markdown]
# ## 12. Visualize Custom Model Predictions

# %%
# Visualize predictions on test images
def visualize_predictions(model, test_loader, device, num_images=8):
    """
    Visualize model predictions vs ground truth with detailed metrics
    """
    model.eval()  # Set model to evaluation mode
    
    # Get a batch from the test loader
    data_iter = iter(test_loader)
    images, true_boxes, true_classes = next(data_iter)
    
    # Ensure we don't try to visualize more images than we have
    num_images = min(num_images, len(images))
    
    # Make predictions
    images = images.to(device)
    with torch.no_grad():
        class_logits, pred_boxes = model(images)
        _, pred_classes = torch.max(class_logits, 1)
    
    # Set up the plot
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    for i in range(num_images):
        # Get image and convert for display
        img = images[i].cpu()
        img = img.permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Get ground truth box and class
        true_box = true_boxes[i].cpu().numpy()
        true_class = true_classes[i].item()
        true_class_name = object_names[true_class]
        
        # Get predicted box and class
        pred_box = pred_boxes[i].cpu().numpy()
        pred_class = pred_classes[i].item()
        pred_class_name = object_names[pred_class]
        
        # Calculate IoU
        iou = calculate_iou(pred_box, true_box)
        
        # Plot the image
        axes[i].imshow(img)
        
        # Draw ground truth box (green)
        x_center, y_center, width, height = true_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        rect = patches.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min,
            linewidth=2, edgecolor='green', facecolor='none', label='True'
        )
        axes[i].add_patch(rect)
        
        # Draw predicted box (red)
        x_center, y_center, width, height = pred_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        rect = patches.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min,
            linewidth=2, edgecolor='red', facecolor='none', label='Pred'
        )
        axes[i].add_patch(rect)
        
        # Add detailed title with metrics
        axes[i].set_title(
            f"True: {true_class_name}, Pred: {pred_class_name}\nIoU: {iou:.2f}",
            fontsize=10
        )
        axes[i].axis('off')
        axes[i].legend()
    
    # Hide any unused subplots
    for i in range(num_images, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.suptitle("Model Predictions vs Ground Truth", fontsize=16, y=1.02)
    plt.show()

# Visualize predictions
visualize_predictions(model, test_loader, device)

# Visualize predictions on test images with detailed metrics
def visualize_predictions_with_metrics(model, test_loader, iou_threshold=0.5):
    """
    Visualize model predictions on test images with detailed metrics
    """
    model.eval()  # Set model to evaluation mode
    
    # Get a batch of images from the test loader
    images, true_boxes, true_classes = next(iter(test_loader))
    images, true_boxes, true_classes = images.to(device), true_boxes.to(device), true_classes.to(device)

    # Make predictions
    with torch.no_grad():
        class_logits, pred_boxes = model(images)
        _, pred_classes = torch.max(class_logits, 1)

    # Set up figure for visualization
    num_images = min(8, len(images))
    fig, axes = plt.subplots(2, 4, figsize=(15, 10))
    axes = axes.flatten()
    
    for i in range(num_images):
        # Get current image and denormalize it properly
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Display the image
        axes[i].imshow(img)

        # Get ground truth and prediction information
        true_box = true_boxes[i].cpu().numpy()
        pred_box = pred_boxes[i].cpu().numpy()
        true_class = true_classes[i].item()
        pred_class = pred_classes[i].item()
        
        # Get class names
        true_class_name = object_names[true_class]
        pred_class_name = object_names[pred_class]

        # Calculate ground truth box coordinates
        x_center, y_center, width, height = true_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        # Add ground truth box
        gt_rect = plt.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min, 
            linewidth=2, edgecolor="green", facecolor="none", 
            label=f"True: {true_class_name}"
        )
        axes[i].add_patch(gt_rect)

        # Calculate predicted box coordinates
        x_center, y_center, width, height = pred_box
        x_min = (x_center - width/2) * img.shape[1]
        y_min = (y_center - height/2) * img.shape[0]
        x_max = (x_center + width/2) * img.shape[1]
        y_max = (y_center + height/2) * img.shape[0]
        
        # Add predicted box
        pred_rect = plt.Rectangle(
            (x_min, y_min), x_max-x_min, y_max-y_min, 
            linewidth=2, edgecolor="red", facecolor="none", 
            label=f"Pred: {pred_class_name}"
        )
        axes[i].add_patch(pred_rect)

        # Calculate metrics
        iou = calculate_iou(pred_box, true_box)
        
        # Calculate precision and recall based on IoU threshold
        class_correct = (pred_class == true_class)
        detection_correct = (iou > iou_threshold) and class_correct
        
        precision = 1.0 if detection_correct else 0.0
        recall = 1.0 if detection_correct else 0.0
        
        # Add metrics annotation
        axes[i].set_title(
            f"IoU: {iou:.2f} | Precision: {precision:.1f} | Recall: {recall:.1f}\n"
            f"True: {true_class_name} | Pred: {pred_class_name}", 
            fontsize=9
        )
        axes[i].axis("off")
        axes[i].legend(loc="upper right", fontsize=8)
    
    # Hide any unused subplots
    for i in range(num_images, len(axes)):
        axes[i].axis("off")
        
    plt.tight_layout()
    plt.suptitle("Model Predictions vs Ground Truth with Metrics", fontsize=14, y=1.02)
    plt.show()

# Run the detailed visualization
visualize_predictions_with_metrics(model, test_loader)

# %% [markdown]
# ## 13. Fine-tune YOLOv8 on the Same Dataset

# %%
# Prepare the data.yaml file for YOLO training
train_images_path = os.path.join(dataset_dir, "train", "images")
val_images_path = os.path.join(dataset_dir, "val", "images")
test_images_path = os.path.join(dataset_dir, "test", "images")

# Create the data.yaml file in the dataset directory
data_yaml_path = os.path.join(dataset_dir, "data.yaml")
with open(data_yaml_path, "w") as f:
    f.write(f"path: {dataset_dir}\n")
    f.write(f"train: {train_images_path}\n")
    f.write(f"val: {val_images_path}\n")
    f.write(f"test: {test_images_path}\n")
    f.write(f"nc: {len(object_names)}\n")  # Number of classes from object_names
    
    # Write class names dynamically from object_names
    f.write("names:\n")
    for i, name in enumerate(object_names):
        f.write(f"  {i}: {name}\n")

print(f"✅ data.yaml file created at {data_yaml_path}")
print("📄 Content preview:")
with open(data_yaml_path, "r") as f:
    print(f.read())

# Import YOLO from ultralytics
try:
    from ultralytics import YOLO
    
    # Load a pre-trained YOLO model
    yolo_model = YOLO("yolov8n.pt")  # Using YOLOv8 nano version
    
    # Print model information
    print(f"\n🔍 Using pre-trained model: YOLOv8n")
    print(f"  • Architecture: YOLOv8 (Detection)")
    print(f"  • Size: Nano (compact)")
    print(f"  • Pre-trained on: COCO dataset (80 classes)")
    
    # Fine-tune the YOLO model on the custom dataset
    print("\n🏋️ Starting YOLO model fine-tuning...")
    results = yolo_model.train(
        data=data_yaml_path,
        epochs=5,  # Reduced for Colab/Kaggle
        imgsz=640,
        batch=8,  # Reduced for Colab/Kaggle
        patience=3,  # Early stopping patience
        save=True,  # Save best model
        device=0 if torch.cuda.is_available() else "cpu",
        project="yolo_training",
        name="run1"
    )
    
    # Save the fine-tuned YOLO model
    fine_tuned_model_path = "yolo_fine_tuned.pt"
    yolo_model.export(format="pytorch")  # Export to PyTorch format
    
    print(f"✅ Fine-tuned YOLO model saved")
    print(f"  • Best model path: {yolo_model.best}")
    print(f"  • Exported model: {fine_tuned_model_path}")
    
except Exception as e:
    print(f"❌ Error during YOLO training: {str(e)}")
    print("You can continue with the rest of the notebook.")

# %% [markdown]
# ## 14. Evaluate YOLO Model

# %%
# Function to evaluate the YOLO model on the test set
def evaluate_yolo_model(yolo_model, test_loader):
    """
    Evaluate YOLO model on test data and calculate standard metrics
    """
    yolo_model.eval()  # Set YOLO model to evaluation mode
    
    # Initialize metrics
    metrics = {
        'precision': [], 'recall': [], 'f1': [], 'iou': [], 
        'inference_times': []
    }

    with torch.no_grad():
        for images, true_boxes, true_classes in test_loader:
            # Process each image individually to avoid batching issues
            for i in range(len(images)):
                # Get single image and convert to numpy with correct format
                image = images[i].cpu().numpy().transpose(1, 2, 0)  # CHW to HWC
                
                # Denormalize the image from [-1,1] to [0,255]
                image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                image = np.clip(image, 0, 1) * 255
                image = image.astype(np.uint8)
                
                # Get ground truth
                true_box = true_boxes[i].cpu().numpy()
                true_class = true_classes[i].item()
                
                # Measure inference time for single image
                start_time = time.time()
                # Run inference with YOLO - process one image at a time
                try:
                    result = yolo_model(image, conf=0.25, iou=0.45)[0]  # Get first (only) result
                    end_time = time.time()
                    inference_time = end_time - start_time
                    metrics['inference_times'].append(inference_time)
                    
                    # Get YOLO predictions (boxes in xywh normalized format)
                    if len(result.boxes) > 0:
                        # Convert to xywh normalized format
                        pred_box = result.boxes.xywhn[0].cpu().numpy()
                        pred_class = int(result.boxes.cls[0].item())
                        
                        # Calculate IoU between predicted and ground truth box
                        iou = calculate_iou(pred_box, true_box)
                        metrics['iou'].append(iou)
                        
                        # Determine if detection is correct (IoU > 0.5 and correct class)
                        class_correct = (pred_class == true_class)
                        detection_correct = (iou > 0.5) and class_correct
                        
                        # Calculate precision and recall
                        precision = 1.0 if detection_correct else 0.0
                        recall = 1.0 if detection_correct else 0.0
                        
                        # Calculate F1 score
                        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
                    else:
                        # No detection
                        iou = 0.0
                        precision = 0.0
                        recall = 0.0
                        f1 = 0.0
                        metrics['iou'].append(iou)
                    
                    # Store metrics
                    metrics['precision'].append(precision)
                    metrics['recall'].append(recall)
                    metrics['f1'].append(f1)
                    
                except Exception as e:
                    print(f"Error processing image {i}: {str(e)}")
                    continue
    
    # Calculate mean metrics if we have data
    if len(metrics['iou']) > 0:
        mean_metrics = {
            'precision': np.mean(metrics['precision']),
            'recall': np.mean(metrics['recall']),
            'f1': np.mean(metrics['f1']),
            'iou': np.mean(metrics['iou']),
            'inference_time': np.mean(metrics['inference_times'])
        }
        
        # Print detailed evaluation results
        print("\n📊 YOLO Model Evaluation Results:")
        print(f"  • Mean Precision: {mean_metrics['precision']:.4f}")
        print(f"  • Mean Recall: {mean_metrics['recall']:.4f}")
        print(f"  • Mean F1-Score: {mean_metrics['f1']:.4f}")
        print(f"  • Mean IoU: {mean_metrics['iou']:.4f}")
        print(f"  • Average Inference Time: {mean_metrics['inference_time']*1000:.2f} ms per image")
    else:
        print("❌ No detections made by the YOLO model")
        mean_metrics = None 