# YOLO Image Explorer

Explore your image database and experiment with YOLO object detection.

This notebook connects to your MariaDB database and lets you:
- Query images by rating, location, camera, etc.
- Run YOLO detection with various parameters
- Analyze correlations between detections and your ratings
- Experiment with different YOLO models and confidence thresholds

## Setup

In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from ultralytics import YOLO
from sqlalchemy import create_engine, text, func
from sqlalchemy.orm import sessionmaker

# Find project root by looking for .git directory
def find_project_root():
    """Walk up directory tree until we find .git directory."""
    current = Path.cwd()
    
    # Check current directory and all parents
    for path in [current] + list(current.parents):
        if (path / '.git').exists():
            return path
    
    # Fallback to current directory if no .git found
    print("⚠ Warning: Could not find .git directory")
    return current

project_root = find_project_root()
python_src = project_root / 'src' / 'python'

if python_src.exists():
    sys.path.insert(0, str(python_src))
    print(f"✓ Project root: {project_root}")
    print(f"✓ Python source: {python_src}")
else:
    print(f"⚠ Warning: Could not find src/python at {python_src}")
    print(f"  Using project root: {project_root}")

from home_media_ai.media import Media, MediaType, Base

print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Connect to database
db_uri = os.getenv('HOME_MEDIA_AI_URI')
if not db_uri:
    raise ValueError("Set HOME_MEDIA_AI_URI environment variable")

engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)
session = Session()

print("✓ Connected to database")

In [None]:
# Initialize YOLO model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO('yolov8n.pt')  # Start with nano model
model.to(device)

print(f"✓ YOLO model loaded on {device}")
print(f"Model can detect these classes:")
print(', '.join(list(model.names.values())[:20]) + '...')

## Database Overview

In [None]:
# Get database statistics
total_media = session.query(Media).count()
rated_media = session.query(Media).filter(Media.rating.isnot(None)).count()
with_gps = session.query(Media).filter(Media.gps_latitude.isnot(None)).count()

print(f"Total media files: {total_media:,}")
print(f"Rated files: {rated_media:,} ({rated_media/total_media*100:.1f}%)")
print(f"With GPS: {with_gps:,} ({with_gps/total_media*100:.1f}%)")

# Rating distribution - correct aggregation
rating_dist = session.query(
    Media.rating,
    func.count(Media.id)
).group_by(Media.rating).all()

print("\nRating distribution:")
# Sort with None values handled - they go first
for rating, count in sorted(rating_dist, key=lambda x: (x[0] is None, x[0])):
    if rating is not None:
        stars = '★' * rating
        print(f"  {stars:6} {count:6,} files")
    else:
        print(f"  {'(none)':6} {count:6,} files")

In [None]:
# Camera distribution - corrected query
camera_counts = session.query(
    Media.camera_make,
    Media.camera_model,
    func.count(Media.id).label('count')
).filter(
    Media.camera_make.isnot(None),
    Media.camera_model.isnot(None)
).group_by(
    Media.camera_make, 
    Media.camera_model
).order_by(
    func.count(Media.id).desc()
).limit(10).all()

print("\nTop cameras in your collection:")
for make, cam_model, count in camera_counts:
    print(f"  {make} {cam_model}: {count:,} images")

## Helper Functions

In [None]:
def display_image_with_metadata(media_obj, figsize=(12, 8)):
    """Display image with its metadata."""
    fig, ax = plt.subplots(figsize=figsize)
    
    try:
        img = Image.open(media_obj.file_path)
        ax.imshow(img)
        ax.axis('off')
        
        # Build title with metadata
        title_parts = [Path(media_obj.file_path).name]
        
        if media_obj.rating:
            title_parts.append(f"{'★' * media_obj.rating}")
        
        if media_obj.camera_make and media_obj.camera_model:
            title_parts.append(f"{media_obj.camera_make} {media_obj.camera_model}")
        
        if media_obj.width and media_obj.height:
            mp = (media_obj.width * media_obj.height) / 1_000_000
            title_parts.append(f"{media_obj.width}×{media_obj.height} ({mp:.1f}MP)")
        
        ax.set_title(' | '.join(title_parts), fontsize=10, pad=10)
        plt.tight_layout()
        return fig, img
        
    except Exception as e:
        print(f"Error loading image: {e}")
        return None, None


def run_yolo_on_media(media_obj, yolo_model, device, conf=0.25, show_boxes=True, filter_classes=None):
    """Run YOLO detection on a media object.
    
    Args:
        media_obj: SQLAlchemy Media object
        yolo_model: YOLO model instance
        device: Device to run on ('cuda' or 'cpu')
        conf: Confidence threshold (0.0-1.0)
        show_boxes: Display image with bounding boxes
        filter_classes: List of class names to filter (e.g., ['person', 'plant'])
    
    Returns:
        List of detection dictionaries
    """
    results = yolo_model(media_obj.file_path, conf=conf, device=device, verbose=False)
    result = results[0]
    
    # Filter classes if specified
    detections = []
    for box in result.boxes:
        cls_id = int(box.cls[0])
        cls_name = result.names[cls_id]
        confidence = float(box.conf[0])
        
        if filter_classes is None or cls_name in filter_classes:
            detections.append({
                'class': cls_name,
                'confidence': confidence,
                'box': box.xyxy[0].cpu().numpy()
            })
    
    print(f"\nDetected {len(detections)} objects:")
    for det in sorted(detections, key=lambda x: x['confidence'], reverse=True):
        print(f"  - {det['class']:<15} {det['confidence']:.2%}")
    
    if show_boxes and detections:
        plt.figure(figsize=(12, 8))
        img_with_boxes = result.plot()
        plt.imshow(img_with_boxes)
        plt.axis('off')
        plt.title(f"{Path(media_obj.file_path).name} - {len(detections)} objects")
        plt.tight_layout()
        plt.show()
    
    return detections


def compare_models(media_obj, device, models_to_test=['yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt']):
    """Compare detection results across different YOLO model sizes.
    
    Args:
        media_obj: SQLAlchemy Media object
        device: Device to run on ('cuda' or 'cpu')
        models_to_test: List of YOLO model names to compare
    
    Returns:
        List of result dictionaries
    """
    import time
    
    results_comparison = []
    
    for model_name in models_to_test:
        print(f"\nTesting {model_name}...")
        test_model = YOLO(model_name)
        test_model.to(device)
        
        start = time.time()
        results = test_model(media_obj.file_path, conf=0.25, device=device, verbose=False)
        elapsed = time.time() - start
        
        result = results[0]
        detections = [(result.names[int(box.cls[0])], float(box.conf[0])) 
                     for box in result.boxes]
        
        results_comparison.append({
            'model': model_name,
            'time_ms': elapsed * 1000,
            'num_detections': len(detections),
            'detections': detections
        })
        
        print(f"  Time: {elapsed*1000:.1f}ms, Detections: {len(detections)}")
    
    return results_comparison

print("✓ Helper functions loaded")

## Query and Explore Images

In [None]:
# Get some highly-rated images
top_rated = session.query(Media).filter(
    Media.rating >= 4,
    Media.file_path.like('%.jpg')
).order_by(Media.rating.desc()).limit(10).all()

print(f"Found {len(top_rated)} images with rating ≥ 4 stars")
print("\nFirst few:")
for i, img in enumerate(top_rated[:5], 1):
    print(f"{i}. {Path(img.file_path).name} - {'★' * img.rating}")

In [None]:
# Pick one image to analyze in detail
# Change index to explore different images
test_image_idx = 1
test_img = top_rated[test_image_idx]

print(f"Selected image: {Path(test_img.file_path).name}")
print(f"Rating: {'★' * test_img.rating if test_img.rating else 'Unrated'}")
print(f"Camera: {test_img.camera_make} {test_img.camera_model}")
print(f"Size: {test_img.width}×{test_img.height}")
if test_img.gps_latitude and test_img.gps_longitude:
    print(f"GPS: {test_img.gps_latitude:.6f}, {test_img.gps_longitude:.6f}")

# Display the image
display_image_with_metadata(test_img)
plt.show()

## Run YOLO Detection

Now let's see what YOLO detects in this image.

In [None]:
# Basic detection with default confidence threshold
detections = run_yolo_on_media(test_img, model, device, conf=0.25, show_boxes=True)

In [None]:
# Try different confidence thresholds
print("Comparing different confidence thresholds:\n")

for conf_threshold in [0.1, 0.25, 0.5, 0.75]:
    results = model(test_img.file_path, conf=conf_threshold, device=device, verbose=False)
    num_detections = len(results[0].boxes)
    print(f"Confidence ≥ {conf_threshold:.2f}: {num_detections} detections")

In [None]:
# Filter for nature-related objects only
nature_classes = ['bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 
                 'bear', 'zebra', 'giraffe', 'potted plant', 'vase', 'bottle']

print("Filtering for nature-related objects:\n")
nature_detections = run_yolo_on_media(
    test_img,
    model,
    device,
    conf=0.25, 
    show_boxes=True, 
    filter_classes=nature_classes
)

## Batch Analysis

Let's analyze multiple images and look for patterns.

In [None]:
# Analyze top 20 rated images
analysis_results = []

for img in top_rated[:20]:
    try:
        results = model(img.file_path, conf=0.25, device=device, verbose=False)
        result = results[0]
        
        detections = [result.names[int(box.cls[0])] for box in result.boxes]
        
        analysis_results.append({
            'filename': Path(img.file_path).name,
            'rating': img.rating,
            'num_objects': len(detections),
            'objects': detections,
            'has_person': 'person' in detections,
            'has_plant': 'potted plant' in detections,
        })
    except Exception as e:
        print(f"Skipping {img.file_path}: {e}")

df = pd.DataFrame(analysis_results)
print(f"\nAnalyzed {len(df)} images")
df.head(10)

In [None]:
# Correlation between rating and number of detected objects
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Rating vs number of objects
rating_groups = df.groupby('rating')['num_objects'].mean()
axes[0].bar(rating_groups.index, rating_groups.values)
axes[0].set_xlabel('Rating (stars)')
axes[0].set_ylabel('Average number of detected objects')
axes[0].set_title('Object Detection vs Rating')
axes[0].grid(axis='y', alpha=0.3)

# Most common detected objects
all_objects = [obj for objs in df['objects'] for obj in objs]
object_counts = pd.Series(all_objects).value_counts().head(10)
axes[1].barh(range(len(object_counts)), object_counts.values)
axes[1].set_yticks(range(len(object_counts)))
axes[1].set_yticklabels(object_counts.index)
axes[1].set_xlabel('Count')
axes[1].set_title('Most Detected Objects (Top 10)')
axes[1].grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Summary statistics
print("Summary Statistics:\n")
print(f"Average objects per image: {df['num_objects'].mean():.1f}")
print(f"Images with people: {df['has_person'].sum()} ({df['has_person'].mean()*100:.1f}%)")
print(f"Images with plants: {df['has_plant'].sum()} ({df['has_plant'].mean()*100:.1f}%)")

print("\nBy rating:")
for rating in sorted(df['rating'].unique()):
    subset = df[df['rating'] == rating]
    print(f"  {'★' * rating}: {len(subset)} images, {subset['num_objects'].mean():.1f} avg objects")

## Model Comparison

Compare detection quality and speed across different YOLO model sizes.

In [None]:
# Compare models on one image
comparison = compare_models(test_img, device, models_to_test=['yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt'])

# Visualize comparison
models = [c['model'] for c in comparison]
times = [c['time_ms'] for c in comparison]
num_dets = [c['num_detections'] for c in comparison]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].bar(models, times)
axes[0].set_ylabel('Time (ms)')
axes[0].set_title('Inference Speed Comparison')
axes[0].grid(axis='y', alpha=0.3)

axes[1].bar(models, num_dets)
axes[1].set_ylabel('Number of Detections')
axes[1].set_title('Detection Count Comparison')
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## Custom Queries

Build your own queries to explore specific subsets.

In [None]:
# Example: Images from a specific camera with high ratings
camera_images = session.query(Media).filter(
    Media.camera_make == 'Canon',  # Change to your camera
    Media.rating >= 3,
    Media.file_path.like('%.cr2')
).limit(5).all()

print(f"Found {len(camera_images)} images")
for img in camera_images:
    print(f"  {Path(img.file_path).name} - {'★' * img.rating}")

In [None]:
# Example: Images from a specific location (GPS bounding box)
# Madison, WI area
lat_min, lat_max = 43.0, 43.2
lon_min, lon_max = -89.5, -89.2

location_images = session.query(Media).filter(
    Media.gps_latitude.between(lat_min, lat_max),
    Media.gps_longitude.between(lon_min, lon_max),
    Media.file_path.like('%.jpg')
).limit(5).all()

print(f"Found {len(location_images)} images in specified area")
for img in location_images:
    print(f"  {Path(img.file_path).name} - ({img.gps_latitude:.4f}, {img.gps_longitude:.4f})")

In [None]:
# Example: Random sample for exploration
import random

all_images = session.query(Media).filter(
    Media.file_path.like('%.jpg')
).all()

random_sample = random.sample(all_images, min(10, len(all_images)))

print(f"Random sample of {len(random_sample)} images:")
for img in random_sample[:5]:
    rating_str = f"{'★' * img.rating}" if img.rating else "unrated"
    print(f"  {Path(img.file_path).name} - {rating_str}")

## Experiments to Try

**Ideas for exploration:**

1. **Rating predictor**: Do images with more detected objects tend to have higher ratings?
2. **Subject focus**: Are your highly-rated images more likely to have a single dominant subject?
3. **Composition analysis**: Use bounding boxes to check rule-of-thirds placement
4. **Filtering workflow**: Could YOLO help pre-filter images before manual review?
5. **Context detection**: Do certain objects (people, equipment) correlate with habitat vs specimen shots?

**Try modifying the code above to:**
- Analyze images by date range
- Compare different cameras or lenses
- Find images with specific object combinations
- Build a simple auto-tagger based on detections

In [None]:
# Cleanup
session.close()
print("Session closed")