# Visualization Gallery

This notebook demonstrates all visualization capabilities of the precip-index package. We'll explore:

1. **Basic drought index plots** - Time series with thresholds
2. **Event timeline plots** - Individual events highlighted
3. **Drought evolution (5-panel)** - Comprehensive monitoring view
4. **Spatial statistics maps** - Gridded drought characteristics
5. **Magnitude comparison** - Cumulative vs instantaneous
6. **Period comparison** - Historical vs recent
7. **Custom styling** - Publication-ready figures
8. **Batch processing** - Multiple locations

**Learning Objectives:**
1. Master all plotting functions
2. Customize visualizations for different purposes
3. Create publication-quality figures
4. Understand appropriate plot types for different analyses

## 1. Setup and Imports

In [None]:
# Add src directory to Python path
import sys
sys.path.insert(0, '../src')

# Core libraries
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from datetime import datetime
import os

# Drought analysis
from runtheory import (
    identify_events,
    calculate_timeseries,
    calculate_period_statistics,
    compare_periods
)

# Visualization functions
from visualization import (
    plot_index,
    plot_events,
    plot_event_characteristics,
    plot_event_timeline,
    plot_spatial_stats,
    generate_location_filename
)

# Plotting settings
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 10

# Create output directories
os.makedirs('../output/plots/single', exist_ok=True)
os.makedirs('../output/plots/spatial', exist_ok=True)

print("✓ All imports successful!")
print("✓ Output directories ready")

## 2. Load Data and Prepare Examples

In [None]:
# Load SPI-12
spi_file = '../output/netcdf/spi_12.nc'

if not os.path.exists(spi_file):
    print("❌ SPI file not found. Please run notebook 01 first.")
    raise FileNotFoundError(f"File not found: {spi_file}")

spi = xr.open_dataset(spi_file)['spi_gamma_12_month']
print(f"✓ SPI-12 loaded: {spi.shape}")

# Select sample location
lat_idx = len(spi.lat) // 2
lon_idx = len(spi.lon) // 2
spi_loc = spi.isel(lat=lat_idx, lon=lon_idx)
lat_val = float(spi.lat.values[lat_idx])
lon_val = float(spi.lon.values[lon_idx])

print(f"✓ Sample location: {lat_val:.2f}°N, {lon_val:.2f}°E")

# Calculate drought characteristics
threshold = -1.2
events = identify_events(spi_loc, threshold=threshold, min_duration=3)
ts = calculate_timeseries(spi_loc, threshold=threshold)
stats_2023 = calculate_period_statistics(spi, threshold=threshold,
                                         start_year=2023, end_year=2023)

print(f"✓ Found {len(events)} events")
print(f"✓ Time series: {len(ts)} months")
print(f"✓ 2023 statistics calculated")
print("\n✓ All data ready for visualization!")

---
# PLOT GALLERY
---

## 3. Plot Type 1: Basic Drought Index

**Function:** `plot_index()`

**Shows:**
- SPI/SPEI values over time
- Threshold line
- Color-coded severity
- Drought/wet periods highlighted

**Best for:** Simple time series visualization, initial exploration

In [None]:
# Basic plot
fig = plot_index(spi_loc, threshold=threshold,
                         title=f'SPI-12 at {lat_val:.2f}°N, {lon_val:.2f}°E')

filename = generate_location_filename('plot_basic_index', lat_val, lon_val, 'png')
plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
print(f"✓ Saved: {filename}")

plt.show()

### Variations with Different Thresholds

In [None]:
# Compare different thresholds
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 12), sharex=True)

plot_index(spi_loc, threshold=-1.0, ax=ax1,
                  title='Moderate Drought (Threshold: -1.0)')

plot_index(spi_loc, threshold=-1.5, ax=ax2,
                  title='Severe Drought (Threshold: -1.5)')

plot_index(spi_loc, threshold=-2.0, ax=ax3,
                  title='Extreme Drought (Threshold: -2.0)')

plt.suptitle('Threshold Sensitivity Comparison', fontsize=14, y=0.995)
plt.tight_layout()

filename = generate_location_filename('threshold_comparison', lat_val, lon_val, 'png')
plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
print(f"✓ Saved: {filename}")

plt.show()

## 4. Plot Type 2: Drought Events Timeline

**Function:** `plot_events()`

**Shows:**
- Individual events with unique colors
- Peak markers
- Event boundaries
- Event shading

**Best for:** Event identification, comparing event characteristics

In [None]:
# Events timeline
fig = plot_events(spi_loc, events, threshold=threshold,
                          title=f'Drought Events at {lat_val:.2f}°N, {lon_val:.2f}°E')

filename = generate_location_filename('plot_events', lat_val, lon_val, 'png')
plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
print(f"✓ Saved: {filename}")

plt.show()

print(f"\nShowing {len(events)} discrete drought events")
print("Each event shown in different color with peak markers")

### Add Custom Annotations

In [None]:
# Plot with custom annotations
fig, ax = plt.subplots(figsize=(16, 6))
plot_events(spi_loc, events, threshold=threshold, ax=ax)

# Add duration labels at peaks
for idx, event in events.iterrows():
    ax.text(event['peak_date'], event['peak'] - 0.2,
            f"D={int(event['duration'])}m",
            ha='center', fontsize=8,
            bbox=dict(boxstyle='round,pad=0.3', facecolor='wheat', alpha=0.7))

ax.set_title('Drought Events with Duration Annotations', fontsize=13)

filename = generate_location_filename('events_annotated', lat_val, lon_val, 'png')
plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
print(f"✓ Saved: {filename}")

plt.show()

## 5. Plot Type 3: Event Characteristics Analysis

**Function:** `plot_event_characteristics()`

**Shows:**
- Multi-panel analysis
- Distribution histograms
- Relationship scatter plots
- Time evolution

**Best for:** Understanding event patterns, comparing characteristics

In [None]:
# Characteristics analysis
if len(events) > 0:
    fig = plot_event_characteristics(events, characteristic='magnitude')
    
    filename = generate_location_filename('characteristics', lat_val, lon_val, 'png')
    plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {filename}")
    
    plt.show()
else:
    print("No events to analyze characteristics")

## 6. Plot Type 4: Drought Evolution Timeline (5-Panel)

**Function:** `plot_event_timeline()`

**Shows 5 panels:**
1. Index value (SPI/SPEI)
2. Duration (current drought length)
3. Magnitude - Cumulative (blue, monotonic)
4. Magnitude - Instantaneous (red, variable)
5. Intensity (magnitude/duration)

**Best for:** Real-time monitoring, understanding drought evolution

In [None]:
# 5-panel drought evolution
fig = plot_event_timeline(ts, title=f'Drought Evolution at {lat_val:.2f}°N, {lon_val:.2f}°E')

filename = generate_location_filename('timeline_5panel', lat_val, lon_val, 'png')
plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
print(f"✓ Saved: {filename}")

plt.show()

print("\n5-Panel Breakdown:")
print("  Panel 1: SPI-12 with threshold")
print("  Panel 2: Current drought duration")
print("  Panel 3: Cumulative magnitude (blue, always increasing)")
print("  Panel 4: Instantaneous magnitude (red, NDVI-like)")
print("  Panel 5: Intensity (cumulative/duration)")

### Focus on Recent Period

In [None]:
# Zoom into recent years
recent_ts = ts[ts['time'] >= '2020-01-01']

if len(recent_ts) > 0:
    fig = plot_event_timeline(recent_ts, 
                                title=f'Recent Drought Evolution (2020-present) at {lat_val:.2f}°N, {lon_val:.2f}°E')
    
    filename = generate_location_filename('timeline_recent', lat_val, lon_val, 'png')
    plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {filename}")
    
    plt.show()
else:
    print("No recent data available")

## 7. Plot Type 5: Spatial Drought Statistics

**Function:** `plot_spatial_stats()`

**Shows:**
- Maps of drought statistics
- 9 available variables
- Customizable colormaps

**Best for:** Regional analysis, spatial patterns, decision support

### Map 1: Number of Events

In [None]:
# Event count map
fig = plot_spatial_stats(stats_2023, variable='num_events',
                                 title='Number of Drought Events in 2023',
                                 cmap='YlOrRd')

plt.savefig('../output/plots/spatial/map_num_events_2023.png', dpi=300, bbox_inches='tight')
print("✓ Saved: map_num_events_2023.png")

plt.show()

### Map 2: Worst Severity

In [None]:
# Worst peak map
fig = plot_spatial_stats(stats_2023, variable='worst_peak',
                                 title='Worst Drought Severity in 2023',
                                 cmap='RdYlBu_r')

plt.savefig('../output/plots/spatial/map_worst_peak_2023.png', dpi=300, bbox_inches='tight')
print("✓ Saved: map_worst_peak_2023.png")

plt.show()

### Map 3: Total Magnitude

In [None]:
# Total magnitude map
fig = plot_spatial_stats(stats_2023, variable='total_magnitude',
                                 title='Total Drought Magnitude in 2023',
                                 cmap='YlOrRd')

plt.savefig('../output/plots/spatial/map_total_magnitude_2023.png', dpi=300, bbox_inches='tight')
print("✓ Saved: map_total_magnitude_2023.png")

plt.show()

### Map 4: Percent Time in Drought

In [None]:
# Percent time map
fig = plot_spatial_stats(stats_2023, variable='pct_time_in_drought',
                                 title='% Time in Drought (2023)',
                                 cmap='Reds')

plt.savefig('../output/plots/spatial/map_pct_drought_2023.png', dpi=300, bbox_inches='tight')
print("✓ Saved: map_pct_drought_2023.png")

plt.show()

### Multi-Variable Panel

In [None]:
# 2x2 panel of key statistics
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

variables = ['num_events', 'worst_peak', 'total_magnitude', 'pct_time_in_drought']
titles = ['Event Count', 'Worst Severity', 'Total Magnitude', '% Time in Drought']
cmaps = ['YlOrRd', 'RdYlBu_r', 'YlOrRd', 'Reds']

for ax, var, title, cmap in zip(axes.flat, variables, titles, cmaps):
    stats_2023[var].plot(ax=ax, cmap=cmap, cbar_kwargs={'label': title})
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')

plt.suptitle('Drought Statistics Summary (2023)', fontsize=14, y=0.995)
plt.tight_layout()

plt.savefig('../output/plots/spatial/map_panel_2023.png', dpi=300, bbox_inches='tight')
print("✓ Saved: map_panel_2023.png")

plt.show()

## 8. Magnitude Comparison Plots

Visualize both magnitude types to understand their differences.

### Stacked Comparison

In [None]:
# Dual magnitude - stacked
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

# Cumulative (blue)
drought_periods = ts[ts['is_event']]
if len(drought_periods) > 0:
    drought_periods.plot(x='time', y='magnitude_cumulative', ax=ax1,
                        color='steelblue', linewidth=2, label='Cumulative')
    ax1.fill_between(drought_periods['time'], 0, drought_periods['magnitude_cumulative'],
                     alpha=0.3, color='blue')
    ax1.set_ylabel('Cumulative Magnitude', fontsize=11)
    ax1.set_title('Cumulative Magnitude (Total Deficit - Monotonic)', fontsize=12)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Instantaneous (red)
    drought_periods.plot(x='time', y='magnitude_instantaneous', ax=ax2,
                        color='darkred', linewidth=2, label='Instantaneous')
    ax2.fill_between(drought_periods['time'], 0, drought_periods['magnitude_instantaneous'],
                     alpha=0.3, color='red')
    ax2.set_ylabel('Instantaneous Magnitude', fontsize=11)
    ax2.set_title('Instantaneous Magnitude (Current Severity - Variable)', fontsize=12)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle('Magnitude Types Comparison', fontsize=14, y=0.995)
    plt.tight_layout()
    
    filename = generate_location_filename('magnitude_stacked', lat_val, lon_val, 'png')
    plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {filename}")
    
    plt.show()
else:
    print("No drought periods to plot")

### Twin-Axis Comparison

In [None]:
# Dual magnitude - twin axes
if len(drought_periods) > 0:
    fig, ax1 = plt.subplots(figsize=(14, 6))
    
    # Cumulative (left)
    color1 = 'steelblue'
    ax1.set_xlabel('Time', fontsize=11)
    ax1.set_ylabel('Cumulative Magnitude', color=color1, fontsize=11)
    ax1.plot(drought_periods['time'], drought_periods['magnitude_cumulative'],
             color=color1, linewidth=2, label='Cumulative')
    ax1.tick_params(axis='y', labelcolor=color1)
    ax1.grid(True, alpha=0.3)
    
    # Instantaneous (right)
    ax2 = ax1.twinx()
    color2 = 'darkred'
    ax2.set_ylabel('Instantaneous Magnitude', color=color2, fontsize=11)
    ax2.plot(drought_periods['time'], drought_periods['magnitude_instantaneous'],
             color=color2, linewidth=2, linestyle='--', label='Instantaneous')
    ax2.tick_params(axis='y', labelcolor=color2)
    
    plt.title('Dual Magnitude Evolution (Twin Axes)', fontsize=13)
    
    # Combined legend
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
    
    fig.tight_layout()
    
    filename = generate_location_filename('magnitude_twin', lat_val, lon_val, 'png')
    plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {filename}")
    
    plt.show()
    
    print("\nKey Observation:")
    print("  Blue (cumulative) always increases during drought")
    print("  Red (instantaneous) varies with SPI pattern (peaks and valleys)")
    print("  See docs/user-guide/magnitude.md for detailed explanation")

## 9. Period Comparison Visualizations

In [None]:
# Calculate comparison
print("Comparing historical vs recent periods...")
comparison = compare_periods(
    spi,
    periods=[(1991, 2020), (2021, 2024)],
    period_names=['Historical (1991-2020)', 'Recent (2021-2024)'],
    threshold=threshold,
    min_duration=3
)
print("✓ Comparison calculated")

### Side-by-Side Comparison

In [None]:
# Plot both periods
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Historical
comparison.sel(period='Historical (1991-2020)').num_events.plot(
    ax=ax1, cmap='YlOrRd', vmin=0, vmax=15,
    cbar_kwargs={'label': 'Events'}
)
ax1.set_title('Historical Period (1991-2020)', fontsize=12)

# Recent
comparison.sel(period='Recent (2021-2024)').num_events.plot(
    ax=ax2, cmap='YlOrRd', vmin=0, vmax=15,
    cbar_kwargs={'label': 'Events'}
)
ax2.set_title('Recent Period (2021-2024)', fontsize=12)

plt.suptitle('Drought Event Count Comparison', fontsize=14, y=0.98)
plt.tight_layout()

plt.savefig('../output/plots/spatial/comparison_sidebyside.png', dpi=300, bbox_inches='tight')
print("✓ Saved: comparison_sidebyside.png")

plt.show()

### Difference Map

In [None]:
# Calculate and plot difference
diff = comparison.sel(period='Recent (2021-2024)') - comparison.sel(period='Historical (1991-2020)')

fig, ax = plt.subplots(figsize=(12, 6))

diff.num_events.plot(ax=ax, cmap='RdBu_r', center=0,
                     cbar_kwargs={'label': 'Change in Events'})
ax.set_title('Change in Drought Events (Recent - Historical)', fontsize=13)

plt.tight_layout()
plt.savefig('../output/plots/spatial/comparison_difference.png', dpi=300, bbox_inches='tight')
print("✓ Saved: comparison_difference.png")

plt.show()

# Print summary
mean_change = float(diff.num_events.mean().values)
print(f"\nAverage change in events: {mean_change:+.2f}")
if mean_change > 0:
    print("  → More drought events in recent period")
elif mean_change < 0:
    print("  → Fewer drought events in recent period")
else:
    print("  → No change in average events")

## 10. Custom Styling and Publication Quality

### Publication Settings

In [None]:
# Set publication-quality parameters
plt.rcParams.update({
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'font.family': 'serif',
    'font.size': 11,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 14
})

print("✓ Publication settings activated")

In [None]:
# Create publication-ready figure
fig = plot_events(spi_loc, events, threshold=threshold,
                          title=f'Drought Events Analysis\n{lat_val:.2f}°N, {lon_val:.2f}°E')

# Save in multiple formats
base = generate_location_filename('publication', lat_val, lon_val, '').rstrip('.')
plt.savefig(f'../output/plots/single/{base}.png', dpi=300, bbox_inches='tight')
plt.savefig(f'../output/plots/single/{base}.pdf', bbox_inches='tight')  # Vector
plt.savefig(f'../output/plots/single/{base}.svg', bbox_inches='tight')  # Vector

print(f"✓ Saved publication figures:")
print(f"  {base}.png (raster, 300 dpi)")
print(f"  {base}.pdf (vector)")
print(f"  {base}.svg (vector)")

plt.show()

In [None]:
# Reset to default
plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 10
print("✓ Reset to default settings")

## 11. Batch Processing Multiple Locations

In [None]:
# Select multiple sample locations
n_lat = len(spi.lat)
n_lon = len(spi.lon)

locations = [
    (n_lat // 4, n_lon // 4, 'Northwest'),
    (n_lat // 4, 3 * n_lon // 4, 'Northeast'),
    (3 * n_lat // 4, n_lon // 4, 'Southwest'),
    (3 * n_lat // 4, 3 * n_lon // 4, 'Southeast'),
]

print(f"Generating plots for {len(locations)} locations...")
print()

for lat_i, lon_i, name in locations:
    # Extract location
    loc_spi = spi.isel(lat=lat_i, lon=lon_i)
    loc_lat = float(spi.lat.values[lat_i])
    loc_lon = float(spi.lon.values[lon_i])
    
    # Calculate events
    loc_events = identify_events(loc_spi, threshold=threshold, min_duration=3)
    
    # Plot
    fig = plot_events(loc_spi, loc_events, threshold=threshold,
                              title=f'{name}: {loc_lat:.2f}°N, {loc_lon:.2f}°E')
    
    # Save
    filename = generate_location_filename(f'batch_{name.lower()}', loc_lat, loc_lon, 'png')
    plt.savefig(f'../output/plots/single/{filename}', dpi=300, bbox_inches='tight')
    plt.close()  # Close to save memory
    
    print(f"✓ {name:12s}: {len(loc_events)} events → {filename}")

print()
print(f"✓ All {len(locations)} location plots generated!")

## 12. Summary

### All Plot Types Covered

| # | Plot Type | Function | Best For |
|---|-----------|----------|----------|
| 1 | Basic Index | `plot_index()` | Simple time series |
| 2 | Event Timeline | `plot_events()` | Event identification |
| 3 | Characteristics | `plot_event_characteristics()` | Event analysis |
| 4 | Evolution (5-panel) | `plot_event_timeline()` | Real-time monitoring |
| 5 | Spatial Maps | `plot_spatial_stats()` | Regional patterns |

### Additional Techniques

- ✅ Magnitude comparison (cumulative vs instantaneous)
- ✅ Period comparison (historical vs recent)
- ✅ Multi-panel layouts
- ✅ Custom annotations
- ✅ Publication-quality settings
- ✅ Batch processing
- ✅ Multiple output formats (PNG, PDF, SVG)

### Best Practices

1. **DPI**: 300 for publications, 150 for presentations, 72 for web
2. **Format**: PNG for raster, PDF/SVG for vector graphics
3. **Bbox**: Always use `bbox_inches='tight'` to avoid cropping
4. **Filenames**: Use `generate_location_filename()` for consistency
5. **Memory**: Use `plt.close()` in loops to free memory
6. **Colormaps**: 
   - YlOrRd, Reds for counts/magnitudes
   - RdYlBu_r, RdBu_r for diverging (peaks, changes)

### Output Structure

```
output/plots/
├── single/           # Location-specific plots (lat/lon in filename)
│   ├── *_lat*.##_lon*.##.png
│   ├── *_lat*.##_lon*.##.pdf
│   └── *_lat*.##_lon*.##.svg
└── spatial/          # Regional maps
    ├── map_*.png
    └── comparison_*.png
```

### Next Steps

- Apply these visualizations to your own data
- Customize colormaps for your region
- Create figure panels for reports
- Explore interactive plots with Plotly (optional)
- See `docs/user-guide/visualization.md` for more details

In [None]:
# Final summary
import glob

print("\n" + "="*60)
print("VISUALIZATION GALLERY COMPLETE")
print("="*60)

# Count outputs
single_plots = len(glob.glob('../output/plots/single/*'))
spatial_plots = len(glob.glob('../output/plots/spatial/*'))

print(f"\n✓ Generated {single_plots} single-location plots")
print(f"✓ Generated {spatial_plots} spatial maps")
print(f"\nAll outputs saved to ../output/plots/")
print("\nPlot types demonstrated:")
print("  1. Basic drought index")
print("  2. Event timeline")
print("  3. Event characteristics")
print("  4. 5-panel evolution")
print("  5. Spatial statistics")
print("  + Magnitude comparison")
print("  + Period comparison")
print("  + Publication quality")
print("  + Batch processing")