In [None]:
import os
import sys
from pathlib import Path

# Database
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
import rawpy
import numpy as np
from skimage.transform import rescale

# Add project to path
sys.path.insert(0, str(Path.cwd().parent))
from home_media_ai.models import Media, MediaType
from home_media_ai.exif_extractor import ExifExtractor
from home_media_ai.media_query import MediaQuery

import HIP

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline


## Connect

In [None]:
# Database connection
DATABASE_URI = os.getenv('HOME_MEDIA_AI_URI')
if not DATABASE_URI:
    raise ValueError("Set HOME_MEDIA_AI_URI environment variable")

engine = create_engine(DATABASE_URI)
Session = sessionmaker(bind=engine)
session = Session()

print(f"✓ Connected to database: {engine.url.database}")

query = MediaQuery(session)


## Learn

In [None]:
images = query.canon().raw().rating_min(4).year(2024).all()
print(len(images))


In [None]:
with rawpy.imread(images[0].file_path) as raw:
    # Process RAW to RGB
    # use_camera_wb=True uses camera white balance
    # output_bps=8 for 8-bit, output_bps=16 for 16-bit
    rgb = raw.postprocess(
        use_camera_wb=True,
        output_bps=16,  # Keep 16-bit for quality
        no_auto_bright=False,
        output_color=rawpy.ColorSpace.sRGB,
    )

print(type(rgb))
print(f"Image shape: {rgb.shape}")
print(f"Data type: {rgb.dtype}")
print(f"Memory: {rgb.nbytes / 1024**2:.1f} MB")

# Rescale for display
rgb_small = rescale(rgb, 0.25, channel_axis=2)  # 25% of original size
rgb_display = rgb_small / 65535.0

print(type(rgb_display))
print(f"Image shape: {rgb_display.shape}")
print(f"Data type: {rgb_display.dtype}")
print(f"Memory: {rgb_display.nbytes / 1024**2:.1f} MB")


In [None]:
# Instead of subplots
try:
    fig1 = plt.figure(figsize=(8, 8))
    plt.imshow(rgb / 65535.0)
    plt.axis('off')
    plt.show()
except Exception as e:
    print(f"Error: {e}")
# Then separately
fig2 = plt.figure(figsize=(8, 8))
for i, color in enumerate(['r', 'g', 'b']):
    plt.hist(rgb[:,:,i].flatten(), bins=256, range=(0, 65535),
             color=color, alpha=0.5, label=color.upper())
plt.legend()
plt.show()


## Cleanup!

In [None]:
# Close database connection
session.close()
print("✓ Database connection closed")
