# Cell 1 - Title and description
# PyroQ Data Exploration
# This notebook explores thermal satellite imagery data for wildfire detection.


# Cell 2 - Imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys
sys.path.append('../')

from src.data.dataset import ThermalAnomalyDataset
from src.classical.preprocessing import ThermalPreprocessor

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette('viridis')


# Cell 3 - Load and Examine Data
# Load dataset
data_path = '../data/patches/'
if Path(data_path).exists():
    dataset = ThermalAnomalyDataset(data_path, split='train')
    print(f"Dataset size: {len(dataset)}")
    print(f"Image shape: {dataset.images[0].shape}")
    print(f"Label distribution: {np.bincount(dataset.labels)}")
else:
    print("Data not found. Creating sample data...")
    from src.data.modis_api import create_sample_data
    create_sample_data('../data/raw', num_samples=1000)
    dataset = ThermalAnomalyDataset('../data/raw', split='train')

# Cell 4 - Visualize Sample Images
# Plot sample images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Fire samples
fire_indices = np.where(dataset.labels == 1)[0][:5]
for i, idx in enumerate(fire_indices):
    image = dataset.images[idx]
    if len(image.shape) == 3:
        image = image[0]  # Take first channel
    axes[0, i].imshow(image, cmap='hot')
    axes[0, i].set_title(f'Fire Sample {i+1}')
    axes[0, i].axis('off')

# No-fire samples
no_fire_indices = np.where(dataset.labels == 0)[0][:5]
for i, idx in enumerate(no_fire_indices):
    image = dataset.images[idx]
    if len(image.shape) == 3:
        image = image[0]  # Take first channel
    axes[1, i].imshow(image, cmap='hot')
    axes[1, i].set_title(f'No-Fire Sample {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

# Cell 5 - Analyze Temperature Distributions
# Calculate statistics
fire_temps = []
no_fire_temps = []

for i in range(len(dataset)):
    image = dataset.images[i]
    if len(image.shape) == 3:
        image = image[0]
    
    if dataset.labels[i] == 1:
        fire_temps.extend(image.flatten())
    else:
        no_fire_temps.extend(image.flatten())

# Plot distributions
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(fire_temps, bins=50, alpha=0.7, label='Fire', density=True)
plt.hist(no_fire_temps, bins=50, alpha=0.7, label='No Fire', density=True)
plt.xlabel('Temperature Value')
plt.ylabel('Density')
plt.title('Temperature Distributions')
plt.legend()

plt.subplot(1, 2, 2)
plt.boxplot([fire_temps, no_fire_temps], labels=['Fire', 'No Fire'])
plt.ylabel('Temperature Value')
plt.title('Temperature Box Plot')

plt.tight_layout()
plt.show()

print(f"Fire mean temperature: {np.mean(fire_temps):.2f}")
print(f"No-fire mean temperature: {np.mean(no_fire_temps):.2f}")