# Emergency Vehicle Detection Testing

This notebook tests the trained YOLOv8 model on various scenarios.

In [None]:
import sys
sys.path.append('..')

from utils.detection import VehicleDetector
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import time
from tqdm.notebook import tqdm

%matplotlib inline

## 1. Load Model

In [None]:
# Initialize detector with custom model
detector = VehicleDetector(model_path='../models/yolov8_custom.pt')
print(f"Model loaded on device: {detector.device}")

## 2. Test on Single Images

In [None]:
def visualize_detection(image_path):
    """Visualize detections on a single image"""
    # Read image
    img = cv2.imread(str(image_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Get detections
    detections = detector.detect_frame(img)
    
    # Draw detections
    for det in detections:
        bbox = det.bbox.astype(int)
        label = f"{det.class_name} {det.confidence:.2f}"
        
        # Draw box
        cv2.rectangle(
            img, 
            (bbox[0], bbox[1]), 
            (bbox[2], bbox[3]),
            (0, 255, 0), 2
        )
        
        # Draw label
        cv2.putText(
            img, label,
            (bbox[0], bbox[1] - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5, (0, 255, 0), 2
        )
    
    # Display
    plt.figure(figsize=(12, 8))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f'Detections in {Path(image_path).name}')
    plt.show()
    
    return detections

# Test on sample images from test set
test_images = list(Path('../Dataset/prepared/test/images').glob('*.jpg'))[:5]
for img_path in test_images:
    detections = visualize_detection(img_path)
    print(f"\nDetections in {img_path.name}:")
    for det in detections:
        print(f"- {det.class_name}: {det.confidence:.2f}")

## 3. Performance Analysis

In [None]:
def analyze_performance(num_images=100):
    """Analyze detection performance on test set"""
    test_images = list(Path('../Dataset/prepared/test/images').glob('*.jpg'))[:num_images]
    
    # Performance metrics
    inference_times = []
    confidence_scores = []
    class_distributions = {}
    
    for img_path in tqdm(test_images, desc='Analyzing performance'):
        img = cv2.imread(str(img_path))
        
        # Measure inference time
        start_time = time.time()
        detections = detector.detect_frame(img)
        inference_time = time.time() - start_time
        
        inference_times.append(inference_time)
        
        for det in detections:
            confidence_scores.append(det.confidence)
            class_distributions[det.class_name] = class_distributions.get(det.class_name, 0) + 1
    
    # Plot results
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle('Detection Performance Analysis')
    
    # Inference time distribution
    axes[0].hist(inference_times, bins=20)
    axes[0].set_title('Inference Time Distribution')
    axes[0].set_xlabel('Time (seconds)')
    axes[0].set_ylabel('Count')
    
    # Confidence score distribution
    axes[1].hist(confidence_scores, bins=20)
    axes[1].set_title('Confidence Score Distribution')
    axes[1].set_xlabel('Confidence')
    axes[1].set_ylabel('Count')
    
    # Class distribution
    classes = list(class_distributions.keys())
    counts = list(class_distributions.values())
    axes[2].bar(classes, counts)
    axes[2].set_title('Detected Class Distribution')
    axes[2].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\nPerformance Summary:")
    print(f"Average inference time: {np.mean(inference_times):.3f}s")
    print(f"Average confidence score: {np.mean(confidence_scores):.3f}")
    print("\nClass Distribution:")
    for cls, count in class_distributions.items():
        print(f"- {cls}: {count}")

# Run performance analysis
analyze_performance()

## 4. Video Processing Test

In [None]:
def process_test_video(video_path, output_path=None):
    """Process a test video and visualize detections"""
    if output_path:
        detector.process_video_with_visualization(video_path, output_path)
    else:
        # Process without saving
        frame_detections = []
        for detections in detector.process_video(video_path):
            frame_detections.append(len(detections))
            
        # Plot detection counts over time
        plt.figure(figsize=(12, 4))
        plt.plot(frame_detections)
        plt.title('Number of Detections per Frame')
        plt.xlabel('Frame Number')
        plt.ylabel('Number of Detections')
        plt.grid(True)
        plt.show()

# Test on a sample video if available
video_path = Path('../Dataset/test_videos/traffic.mp4')
if video_path.exists():
    process_test_video(
        str(video_path),
        output_path='../Dataset/test_videos/traffic_detected.mp4'
    )

## 5. Real-time Detection Test

In [None]:
# Test real-time detection (press 'q' to quit)
detector.start_realtime_detection()