# Image Explorer - Interactive Media Processing

This notebook provides an interactive environment for selecting and processing images from your media database.

## Features

- Query images by various criteria (rating, GPS location, camera, date range)
- Display images with metadata overlays
- Basic image processing (histogram analysis, color analysis, edge detection)
- Side-by-side RAW vs JPEG comparison
- Batch processing capabilities

In [None]:
# Standard imports
import os
import sys
from pathlib import Path
from typing import List, Dict, Optional
import random

# Scientific computing
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
from PIL import Image
import cv2

# Database
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

# Add project to path
sys.path.insert(0, str(Path.cwd().parent))
from home_media_ai.media import Media, MediaType
from home_media_ai.exif_extractor import ExifExtractor

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline

print("✓ Imports complete")


In [None]:
# Database connection
DATABASE_URI = os.getenv('HOME_MEDIA_AI_URI')
if not DATABASE_URI:
    raise ValueError("Set HOME_MEDIA_AI_URI environment variable")

engine = create_engine(DATABASE_URI)
Session = sessionmaker(bind=engine)
session = Session()

print(f"✓ Connected to database: {engine.url.database}")


## Helper Functions

In [None]:
def query_images(rating: Optional[int] = None,
                 camera_make: Optional[str] = None,
                 has_gps: bool = False,
                 year: Optional[int] = None,
                 limit: int = 10,
                 random_sample: bool = True) -> List[Media]:
    """Query images from database with filters."""
    query = session.query(Media).filter(Media.is_original == True)
    
    if rating is not None:
        query = query.filter(Media.rating == rating)
    
    if camera_make:
        query = query.filter(Media.camera_make.ilike(f'%{camera_make}%'))
    
    if has_gps:
        query = query.filter(
            Media.gps_latitude.isnot(None),
            Media.gps_longitude.isnot(None)
        )
    
    if year:
        query = query.filter(text(f'YEAR(created) = {year}'))
    
    if random_sample:
        query = query.order_by(text('RAND()'))
    else:
        query = query.order_by(Media.created.desc())
    
    return query.limit(limit).all()


def load_image(file_path: str) -> Optional[np.ndarray]:
    """Load image as numpy array (RGB)."""
    try:
        img = cv2.imread(file_path)
        if img is None:
            return None
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    except Exception as e:
        print(f"Failed to load {file_path}: {e}")
        return None


def display_image_with_metadata(media: Media, figsize=(12, 8)):
    """Display image with metadata overlay."""
    img = load_image(media.get_full_path())
    if img is None:
        print(f"Could not load: {media.get_full_path()}")
        return
    
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(img)
    ax.axis('off')
    
    metadata_text = []
    metadata_text.append(f"File: {Path(media.file_path).name}")
    
    if media.rating:
        stars = '●' * media.rating + '○' * (5 - media.rating)
        metadata_text.append(f"Rating: {stars}")
    
    if media.camera_make and media.camera_model:
        metadata_text.append(f"Camera: {media.camera_make} {media.camera_model}")
    
    if media.lens_model:
        metadata_text.append(f"Lens: {media.lens_model}")
    
    if media.width and media.height:
        mp = (media.width * media.height) / 1_000_000
        metadata_text.append(f"Size: {media.width}×{media.height} ({mp:.1f} MP)")
    
    if media.gps_latitude and media.gps_longitude:
        metadata_text.append(f"GPS: ({media.gps_latitude:.6f}, {media.gps_longitude:.6f})")
    
    if media.created:
        metadata_text.append(f"Date: {media.created.strftime('%Y-%m-%d %H:%M:%S')}")
    
    text_str = '\n'.join(metadata_text)
    ax.text(0.02, 0.98, text_str, 
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment='top',
            fontfamily='FiraCode Nerd Font Mono',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()


print("✓ Helper functions defined")


## Database Overview

In [None]:
# Get database statistics
total_files = session.query(Media).count()
original_files = session.query(Media).filter(Media.is_original == True).count()
rated_files = session.query(Media).filter(Media.rating.isnot(None)).count()
gps_files = session.query(Media).filter(
    Media.gps_latitude.isnot(None),
    Media.gps_longitude.isnot(None)
).count()

print("Database Statistics:")
print("="*50)
print(f"Total files:          {total_files:,}")
print(f"Original files:       {original_files:,}")
print(f"Derivative files:     {total_files - original_files:,}")
print(f"Files with ratings:   {rated_files:,} ({rated_files/total_files*100:.1f}%)")
print(f"Files with GPS:       {gps_files:,} ({gps_files/total_files*100:.1f}%)")
print("="*50)


In [None]:
# Rating distribution
rating_query = """
SELECT rating, COUNT(*) as count
FROM media
WHERE rating IS NOT NULL
GROUP BY rating
ORDER BY rating
"""

with engine.connect() as conn:
    result = conn.execute(text(rating_query))
    ratings_df = pd.DataFrame(result.fetchall(), columns=['rating', 'count'])

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(ratings_df['rating'], ratings_df['count'])
ax.set_xlabel('Rating (stars)')
ax.set_ylabel('Number of files')
ax.set_title('Image Rating Distribution')
ax.set_xticks([0, 1, 2, 3, 4, 5])
plt.tight_layout()
plt.show()


## Image Selection and Display

In [None]:
# Example 1: Get random 5-star images
images = query_images(rating=5, limit=3)

print(f"Found {len(images)} images with 5-star rating\n")

for media in images:
    display_image_with_metadata(media)


In [None]:
# Example 2: Get images from specific camera
images = query_images(camera_make='Canon', has_gps=True, limit=2)

print(f"Found {len(images)} Canon images with GPS\n")

for media in images:
    display_image_with_metadata(media)


## Image Analysis Functions

In [None]:
def analyze_image_histogram(img: np.ndarray, title: str = "Image Analysis"):
    """Display image with RGB histogram and statistics."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    axes[0, 0].imshow(img)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    colors = ('r', 'g', 'b')
    for i, color in enumerate(colors):
        hist = cv2.calcHist([img], [i], None, [256], [0, 256])
        axes[0, 1].plot(hist, color=color, alpha=0.7, label=color.upper())
    axes[0, 1].set_title('RGB Histogram')
    axes[0, 1].set_xlabel('Pixel Value')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    axes[1, 0].hist(gray.ravel(), bins=256, range=[0, 256], color='gray', alpha=0.7)
    axes[1, 0].set_title('Grayscale Histogram')
    axes[1, 0].set_xlabel('Pixel Value')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].grid(True, alpha=0.3)
    
    stats_text = [
        "Image Statistics:",
        f"Shape: {img.shape}",
        f"Dtype: {img.dtype}",
        "",
        "Channel Means:",
        f"  R: {img[:,:,0].mean():.2f}",
        f"  G: {img[:,:,1].mean():.2f}",
        f"  B: {img[:,:,2].mean():.2f}",
        "",
        "Channel Std Dev:",
        f"  R: {img[:,:,0].std():.2f}",
        f"  G: {img[:,:,1].std():.2f}",
        f"  B: {img[:,:,2].std():.2f}",
        "",
        f"Grayscale Mean: {gray.mean():.2f}",
        f"Grayscale Std:  {gray.std():.2f}",
    ]
    
    axes[1, 1].text(0.1, 0.5, '\n'.join(stats_text),
                    transform=axes[1, 1].transAxes,
                    fontsize=11,
                    verticalalignment='center',
                    fontfamily='monospace')
    axes[1, 1].axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def analyze_edges(img: np.ndarray, low_threshold: int = 50, high_threshold: int = 150):
    """Detect and display edges using Canny edge detection."""
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, low_threshold, high_threshold)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(gray, cmap='gray')
    axes[1].set_title('Grayscale')
    axes[1].axis('off')
    
    axes[2].imshow(edges, cmap='gray')
    axes[2].set_title(f'Edges (Canny: {low_threshold}, {high_threshold})')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()


def analyze_color_distribution(img: np.ndarray):
    """Analyze and visualize color distribution in LAB color space."""
    lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    axes[0, 0].imshow(img)
    axes[0, 0].set_title('Original RGB Image')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(lab[:,:,0], cmap='gray')
    axes[0, 1].set_title('L* (Lightness)')
    axes[0, 1].axis('off')
    
    a_channel = lab[:,:,1].flatten()
    b_channel = lab[:,:,2].flatten()
    
    sample_size = min(10000, len(a_channel))
    indices = np.random.choice(len(a_channel), sample_size, replace=False)
    
    axes[1, 0].hexbin(a_channel[indices], b_channel[indices], 
                      gridsize=50, cmap='YlOrRd', mincnt=1)
    axes[1, 0].set_xlabel('a* (green-red)')
    axes[1, 0].set_ylabel('b* (blue-yellow)')
    axes[1, 0].set_title('Color Distribution (LAB space)')
    
    pixels = img.reshape(-1, 3)
    
    from sklearn.cluster import KMeans
    n_colors = 5
    kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
    kmeans.fit(pixels)
    
    labels = kmeans.labels_
    label_counts = np.bincount(labels)
    percentages = label_counts / len(labels) * 100
    
    order = np.argsort(percentages)[::-1]
    dominant_colors = kmeans.cluster_centers_[order] / 255.0
    dominant_percentages = percentages[order]
    
    axes[1, 1].imshow([dominant_colors])
    axes[1, 1].set_aspect('auto')
    axes[1, 1].axis('off')
    axes[1, 1].set_title('Dominant Colors')
    
    for i, (color, pct) in enumerate(zip(dominant_colors, dominant_percentages)):
        axes[1, 1].text(i / n_colors + 0.1, 0.5, f'{pct:.1f}%',
                       transform=axes[1, 1].transData,
                       ha='center', va='center',
                       fontsize=10, fontweight='bold',
                       color='white' if color.mean() < 0.5 else 'black')
    
    plt.tight_layout()
    plt.show()


print("✓ Image analysis functions defined")


## Example: Analyze a Single Image

In [None]:
# Select an image to analyze
test_images = query_images(rating=4, limit=1)

if test_images:
    media = test_images[0]
    print(f"Analyzing: {Path(media.file_path).name}\n")
    
    img = load_image(media.file_path)
    
    if img is not None:
        analyze_image_histogram(img, title=Path(media.file_path).name)
        analyze_edges(img)
        analyze_color_distribution(img)
else:
    print("No images found matching criteria")


## Example: Compare RAW and JPEG

In [None]:
# Find RAW files with derivatives
raw_with_derivatives_query = """
SELECT m1.id, m1.file_path as raw_path, m2.file_path as jpeg_path
FROM media m1
JOIN media m2 ON m2.origin_id = m1.id
WHERE m1.is_original = TRUE
  AND m1.file_ext IN ('.dng', '.cr2', '.nef', '.arw')
  AND m2.file_ext IN ('.jpg', '.jpeg')
ORDER BY RAND()
LIMIT 1
"""

with engine.connect() as conn:
    result = conn.execute(text(raw_with_derivatives_query))
    pair = result.fetchone()

if pair:
    raw_path, jpeg_path = pair[1], pair[2]
    print(f"RAW:  {Path(raw_path).name}")
    print(f"JPEG: {Path(jpeg_path).name}\n")
    
    jpeg_img = load_image(jpeg_path)
    
    if jpeg_img is not None:
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        
        axes[0].imshow(jpeg_img)
        axes[0].set_title(f'JPEG: {Path(jpeg_path).name}')
        axes[0].axis('off')
        
        for i, color in enumerate(['r', 'g', 'b']):
            hist = cv2.calcHist([jpeg_img], [i], None, [256], [0, 256])
            axes[1].plot(hist, color=color, alpha=0.7, label=f'{color.upper()} channel')
        
        axes[1].set_title('JPEG Histogram')
        axes[1].set_xlabel('Pixel Value')
        axes[1].set_ylabel('Frequency')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

    raw_img = load_image(raw_path)
    if raw_img is not None:
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        
        axes[0].imshow(raw_img)
        axes[0].set_title(f'JPEG: {Path(jpeg_path).name}')
        axes[0].axis('off')
        
        for i, color in enumerate(['r', 'g', 'b']):
            hist = cv2.calcHist([raw_img], [i], None, [256], [0, 256])
            axes[1].plot(hist, color=color, alpha=0.7, label=f'{color.upper()} channel')
        
        axes[1].set_title('JPEG Histogram')
        axes[1].set_xlabel('Pixel Value')
        axes[1].set_ylabel('Frequency')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
else:
    print("No RAW/JPEG pairs found")


## Example: Batch Analysis

In [None]:
# Analyze brightness distribution across multiple images
sample_images = query_images(rating=4, limit=20)

brightness_data = []

for media in sample_images:
    img = load_image(media.file_path)
    if img is not None:
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mean_brightness = gray.mean()
        
        brightness_data.append({
            'filename': Path(media.file_path).name,
            'brightness': mean_brightness,
            'rating': media.rating
        })

df = pd.DataFrame(brightness_data)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(df['brightness'], bins=20, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Mean Brightness')
axes[0].set_ylabel('Number of Images')
axes[0].set_title('Brightness Distribution')
axes[0].grid(True, alpha=0.3)

axes[1].scatter(df.index, df['brightness'], c=df['rating'], cmap='YlOrRd', s=100)
axes[1].set_xlabel('Image Index')
axes[1].set_ylabel('Mean Brightness')
axes[1].set_title('Brightness vs Rating')
axes[1].grid(True, alpha=0.3)
cbar = plt.colorbar(axes[1].collections[0], ax=axes[1])
cbar.set_label('Rating')

plt.tight_layout()
plt.show()

print(f"\nBrightness Statistics:")
print(df['brightness'].describe())


## Custom Query Examples

In [None]:
# Example: Find images from a specific location
location_query = """
SELECT id, file_path, gps_latitude, gps_longitude, rating
FROM media
WHERE gps_latitude BETWEEN 43.0 AND 43.5
  AND gps_longitude BETWEEN -89.5 AND -89.0
  AND is_original = TRUE
ORDER BY rating DESC
LIMIT 5
"""

with engine.connect() as conn:
    result = conn.execute(text(location_query))
    location_results = pd.DataFrame(result.fetchall(), 
                                   columns=['id', 'file_path', 'lat', 'lon', 'rating'])

print(f"Found {len(location_results)} images in specified area:")
print(location_results[['file_path', 'lat', 'lon', 'rating']])


## Export Results

In [None]:
# Example: Export analysis results to CSV
output_dir = Path.cwd() / 'analysis_results'
output_dir.mkdir(exist_ok=True)

# Save brightness analysis
if 'df' in locals() and not df.empty:
    output_file = output_dir / 'brightness_analysis.csv'
    df.to_csv(output_file, index=False)
    print(f"Saved brightness analysis to: {output_file}")

# Save location query results
if 'location_results' in locals() and not location_results.empty:
    output_file = output_dir / 'location_query.csv'
    location_results.to_csv(output_file, index=False)
    print(f"Saved location query to: {output_file}")


## Cleanup

In [None]:
# Close database connection
session.close()
print("✓ Database connection closed")
