# Calculating Standardized Precipitation Index (SPI)

This notebook demonstrates how to calculate the Standardized Precipitation Index (SPI) using the `precip-index` package.

**What is SPI?**
- Developed by McKee et al. (1993)
- Transforms precipitation to standard normal distribution
- Allows comparison across different climates and time scales
- Negative values indicate drought, positive indicate wet conditions

**Learning Objectives:**
1. Load and prepare precipitation data
2. Calculate SPI for single and multiple time scales
3. Save and reuse gamma fitting parameters
4. Visualize results
5. Perform drought classification and analysis

## 1. Setup and Imports

In [None]:
# Add src directory to Python path
import sys
sys.path.insert(0, '../src')

# Core libraries
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from datetime import datetime

# Import SPI functions
from indices import (
    spi, 
    spi_multi_scale,
    save_fitting_params,
    load_fitting_params,
    save_index_to_netcdf,
    classify_drought,
    get_drought_area_percentage
)

from config import Periodicity

# Plotting settings
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

print("✓ All imports successful!")

## 2. Load Precipitation Data

The package expects data in **CF Convention** format with dimensions: `(time, lat, lon)`

### Option A: Load Your Own Data

In [None]:
# Example: Load precipitation data from NetCDF
# Uncomment and modify the path to your data file

# ds = xr.open_dataset('path/to/your/precipitation.nc')
# precip = ds['precip']  # or 'prcp', 'pr', depending on your variable name
# print(f"Data loaded: {precip.shape}")
# print(f"Dimensions: {precip.dims}")
# print(f"Time range: {precip.time[0].values} to {precip.time[-1].values}")

### Option B: Generate Synthetic Test Data

In [None]:
# Generate synthetic precipitation data for demonstration
# This creates 30 years of monthly data for a small grid

np.random.seed(42)

# Parameters
n_years = 30
n_months = n_years * 12  # 360 months
n_lat = 20
n_lon = 30
data_start_year = 1991

# Create time coordinate
time = xr.cftime_range(
    start=f'{data_start_year}-01-01',
    periods=n_months,
    freq='MS',  # Month start
    calendar='standard'
)

# Create spatial coordinates
lat = np.linspace(-10, 10, n_lat)
lon = np.linspace(30, 50, n_lon)

# Generate synthetic precipitation (gamma-distributed)
# Mean: 100 mm/month, with seasonal cycle and spatial variation
precip_data = np.zeros((n_months, n_lat, n_lon))

for t in range(n_months):
    month = t % 12
    # Seasonal cycle (higher in summer)
    seasonal_factor = 1 + 0.5 * np.sin(2 * np.pi * month / 12)
    
    for i in range(n_lat):
        for j in range(n_lon):
            # Spatial variation
            spatial_factor = 0.5 + 0.5 * (i / n_lat + j / n_lon)
            
            # Gamma-distributed precipitation
            mean_precip = 100 * seasonal_factor * spatial_factor
            alpha = 2.0  # shape parameter
            beta = mean_precip / alpha  # scale parameter
            
            precip_data[t, i, j] = np.random.gamma(alpha, beta)
            
            # Randomly set ~10% of values to zero (dry periods)
            if np.random.rand() < 0.1:
                precip_data[t, i, j] = 0.0

# Create DataArray
precip = xr.DataArray(
    data=precip_data,
    dims=['time', 'lat', 'lon'],
    coords={
        'time': time,
        'lat': lat,
        'lon': lon
    },
    attrs={
        'long_name': 'Monthly precipitation',
        'units': 'mm/month',
        'standard_name': 'precipitation_amount'
    }
)

print("✓ Synthetic data created successfully!")
print(f"  Shape: {precip.shape}")
print(f"  Dimensions: {precip.dims}")
print(f"  Time range: {n_years} years ({data_start_year}-{data_start_year + n_years - 1})")
print(f"  Mean precipitation: {precip.mean().values:.1f} mm/month")
print(f"  Zero values: {(precip == 0).sum().values} ({100 * (precip == 0).sum() / precip.size:.1f}%)")

### Quick Data Visualization

In [None]:
# Plot time series at a sample point
sample_lat_idx = n_lat // 2
sample_lon_idx = n_lon // 2

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8))

# Time series
precip[:, sample_lat_idx, sample_lon_idx].plot(ax=ax1, linewidth=0.8)
ax1.set_title(f'Precipitation Time Series (lat={lat[sample_lat_idx]:.1f}°, lon={lon[sample_lon_idx]:.1f}°)')
ax1.set_ylabel('Precipitation (mm/month)')
ax1.grid(True, alpha=0.3)

# Spatial pattern (mean over time)
precip.mean(dim='time').plot(ax=ax2, cmap='YlGnBu')
ax2.set_title('Mean Precipitation (mm/month)')

plt.tight_layout()
plt.show()

## 3. Calculate SPI for a Single Time Scale

Let's calculate SPI-12 (12-month accumulation period)

In [None]:
# Calculate SPI-12
print("Calculating SPI-12...")

spi_12, params = spi(
    precip,
    scale=12,
    periodicity='monthly',
    calibration_start_year=1991,
    calibration_end_year=2020,
    return_params=True  # Return fitting parameters for saving
)

print("✓ SPI-12 calculation complete!")
print(f"  Output shape: {spi_12.shape}")
print(f"  Valid range: [{spi_12.min().values:.2f}, {spi_12.max().values:.2f}]")
print(f"  Mean: {spi_12.mean().values:.3f}")
print(f"  Std: {spi_12.std().values:.3f}")

### Visualize SPI-12 Results

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 9))

# Time series at sample point
spi_12[:, sample_lat_idx, sample_lon_idx].plot(ax=ax1, linewidth=0.8, color='steelblue')
ax1.axhline(y=0, color='k', linestyle='-', linewidth=0.8, alpha=0.3)
ax1.axhline(y=-1.0, color='orange', linestyle='--', linewidth=0.8, alpha=0.5, label='Moderate drought')
ax1.axhline(y=-2.0, color='red', linestyle='--', linewidth=0.8, alpha=0.5, label='Severe drought')
ax1.axhline(y=1.0, color='green', linestyle='--', linewidth=0.8, alpha=0.5, label='Moderate wet')
ax1.fill_between(spi_12.time, -10, 0, alpha=0.1, color='red', label='Drought')
ax1.fill_between(spi_12.time, 0, 10, alpha=0.1, color='blue', label='Wet')
ax1.set_title(f'SPI-12 Time Series (lat={lat[sample_lat_idx]:.1f}°, lon={lon[sample_lon_idx]:.1f}°)')
ax1.set_ylabel('SPI-12')
ax1.set_ylim(-3, 3)
ax1.grid(True, alpha=0.3)
ax1.legend(loc='upper right')

# Spatial pattern (mean over time)
spi_12.mean(dim='time').plot(ax=ax2, cmap='RdYlBu', vmin=-1, vmax=1, cbar_kwargs={'label': 'Mean SPI-12'})
ax2.set_title('Mean SPI-12 (entire period)')

plt.tight_layout()
plt.show()

## 4. Save Fitting Parameters for Reuse

Gamma fitting parameters can be saved and reused to speed up calculations on new data with the same calibration period.

In [None]:
# Save parameters
param_file = 'spi_gamma_params_12_month.nc'

save_fitting_params(
    params,
    param_file,
    scale=12,
    periodicity='monthly',
    index_type='spi',
    calibration_start_year=1991,
    calibration_end_year=2020,
    coords={'lat': precip.lat, 'lon': precip.lon}
)

print(f"✓ Parameters saved to: {param_file}")

# Load parameters
loaded_params = load_fitting_params(param_file, scale=12, periodicity='monthly')
print(f"✓ Parameters loaded successfully")
print(f"  Alpha shape: {loaded_params['alpha'].shape}")
print(f"  Beta shape: {loaded_params['beta'].shape}")
print(f"  Prob_zero shape: {loaded_params['prob_zero'].shape}")

### Use Pre-computed Parameters (Faster)

In [None]:
# Calculate SPI using pre-computed parameters
# This is much faster when processing multiple datasets
print("Calculating SPI-12 with pre-computed parameters...")

spi_12_fast = spi(
    precip,
    scale=12,
    periodicity='monthly',
    fitting_params=loaded_params
)

print("✓ Calculation complete!")

# Verify results are identical
diff = np.abs(spi_12.values - spi_12_fast.values)
print(f"  Max difference from original: {np.nanmax(diff):.10f}")
print("  Results are identical!" if np.nanmax(diff) < 1e-6 else "  Warning: Results differ!")

## 5. Calculate SPI for Multiple Time Scales

Different time scales capture different drought types:
- SPI-1: Meteorological drought (short-term)
- SPI-3: Agricultural drought (seasonal)
- SPI-6: Hydrological drought (medium-term)
- SPI-12: Long-term water resources

In [None]:
# Calculate multiple scales
print("Calculating SPI for scales: 1, 3, 6, 12 months...")

spi_multi = spi_multi_scale(
    precip,
    scales=[1, 3, 6, 12],
    periodicity='monthly',
    calibration_start_year=1991,
    calibration_end_year=2020
)

print("✓ Multi-scale SPI calculation complete!")
print(f"  Variables: {list(spi_multi.data_vars)}")

### Compare Different Time Scales

In [None]:
fig, axes = plt.subplots(4, 1, figsize=(14, 12))

scales = [1, 3, 6, 12]
colors = ['skyblue', 'steelblue', 'navy', 'darkblue']

for i, (scale, color) in enumerate(zip(scales, colors)):
    var_name = f'spi_gamma_{scale}_month'
    spi_data = spi_multi[var_name]
    
    # Plot time series
    spi_data[:, sample_lat_idx, sample_lon_idx].plot(
        ax=axes[i], 
        linewidth=0.8, 
        color=color
    )
    
    # Add reference lines
    axes[i].axhline(y=0, color='k', linestyle='-', linewidth=0.8, alpha=0.3)
    axes[i].axhline(y=-1.0, color='orange', linestyle='--', linewidth=0.6, alpha=0.4)
    axes[i].axhline(y=-2.0, color='red', linestyle='--', linewidth=0.6, alpha=0.4)
    axes[i].fill_between(spi_data.time, -10, 0, alpha=0.1, color='red')
    axes[i].fill_between(spi_data.time, 0, 10, alpha=0.1, color='blue')
    
    axes[i].set_title(f'SPI-{scale}')
    axes[i].set_ylabel(f'SPI-{scale}')
    axes[i].set_ylim(-3, 3)
    axes[i].grid(True, alpha=0.3)
    
    if i < 3:
        axes[i].set_xlabel('')

plt.suptitle(f'SPI Time Series Comparison (lat={lat[sample_lat_idx]:.1f}°, lon={lon[sample_lon_idx]:.1f}°)', 
             fontsize=14, y=0.995)
plt.tight_layout()
plt.show()

print("\nObservations:")
print("  - SPI-1 shows high frequency variability (month-to-month)")
print("  - SPI-12 shows smoother, longer-term trends")
print("  - Drought events persist longer at longer time scales")

## 6. Drought Classification

McKee et al. (1993) classification scheme

In [None]:
# Classify SPI-12 into drought categories
drought_categories = classify_drought(spi_12, classification='mckee')

print("Drought Classification:")
print("  -2: Extremely dry")
print("  -1: Severely dry")
print("   0: Moderately dry")
print("   1: Near normal")
print("   2: Moderately wet")
print("   3: Very wet")
print("   4: Extremely wet")

# Count occurrences
unique, counts = np.unique(drought_categories.values[~np.isnan(drought_categories.values)], return_counts=True)
print("\nFrequency Distribution:")
for cat, count in zip(unique, counts):
    pct = 100 * count / np.sum(counts)
    print(f"  Category {int(cat):2d}: {count:6d} ({pct:5.1f}%)")

### Visualize Drought Categories

In [None]:
# Plot drought category map for a specific time
time_idx = -12  # Last 12 months

fig, ax = plt.subplots(figsize=(12, 6))

# Custom colormap for drought categories
from matplotlib.colors import ListedColormap, BoundaryNorm

colors = ['#8B0000', '#CD5C5C', '#FFA07A', '#FFFACD', '#90EE90', '#20B2AA', '#00008B']
cmap = ListedColormap(colors)
bounds = [-2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]
norm = BoundaryNorm(bounds, cmap.N)

im = ax.pcolormesh(
    drought_categories.lon,
    drought_categories.lat,
    drought_categories[time_idx, :, :],
    cmap=cmap,
    norm=norm,
    shading='auto'
)

cbar = plt.colorbar(im, ax=ax, ticks=[-2, -1, 0, 1, 2, 3, 4])
cbar.set_label('Drought Category')

ax.set_title(f'Drought Classification Map ({drought_categories.time[time_idx].values})')
ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')

plt.tight_layout()
plt.show()

## 7. Drought Area Percentage

Calculate the percentage of area experiencing drought over time

In [None]:
# Calculate drought area percentage time series
# Threshold -1.0 = moderate drought or worse
drought_pct = get_drought_area_percentage(spi_12, threshold=-1.0)

print(f"Drought Area Statistics (SPI <= -1.0):")
print(f"  Mean: {drought_pct.mean().values:.1f}%")
print(f"  Max: {drought_pct.max().values:.1f}%")
print(f"  Min: {drought_pct.min().values:.1f}%")

In [None]:
# Plot drought area percentage over time
fig, ax = plt.subplots(figsize=(14, 5))

drought_pct.plot(ax=ax, linewidth=1.2, color='darkred')
ax.axhline(y=drought_pct.mean(), color='k', linestyle='--', linewidth=0.8, alpha=0.5, label='Mean')
ax.fill_between(drought_pct.time, 0, drought_pct, alpha=0.3, color='red')

ax.set_title('Drought Area Percentage Over Time (SPI-12 ≤ -1.0)', fontsize=12)
ax.set_ylabel('Area under drought (%)')
ax.set_ylim(0, 100)
ax.grid(True, alpha=0.3)
ax.legend()

plt.tight_layout()
plt.show()

## 8. Save Results to NetCDF

In [None]:
# Save single-scale SPI
output_file_single = 'spi_12_output.nc'
save_index_to_netcdf(spi_12, output_file_single, compress=True, complevel=5)
print(f"✓ SPI-12 saved to: {output_file_single}")

# Save multi-scale SPI
output_file_multi = 'spi_multi_scale_output.nc'
save_index_to_netcdf(spi_multi, output_file_multi, compress=True, complevel=5)
print(f"✓ Multi-scale SPI saved to: {output_file_multi}")

# Save drought area percentage
drought_pct.to_netcdf('drought_area_percentage.nc')
print(f"✓ Drought area percentage saved to: drought_area_percentage.nc")

## 9. Summary and Best Practices

### Key Takeaways:
1. ✅ Always ensure data follows CF Convention: (time, lat, lon)
2. ✅ Use appropriate calibration period (WMO recommends 30 years: 1991-2020)
3. ✅ Save fitting parameters for reuse on similar datasets
4. ✅ Choose time scales based on drought type of interest
5. ✅ Validate results visually before using in analysis

### Recommended Workflow:
```python
# 1. Load data
ds = xr.open_dataset('precip.nc')
precip = ds['precip']

# 2. Calculate SPI with parameter saving
spi_12, params = spi(precip, scale=12, return_params=True)

# 3. Save parameters
save_fitting_params(params, 'params.nc', scale=12, periodicity='monthly')

# 4. Reuse parameters on new data
params = load_fitting_params('params.nc', scale=12, periodicity='monthly')
spi_12_new = spi(new_precip, scale=12, fitting_params=params)

# 5. Analyze results
drought_pct = get_drought_area_percentage(spi_12, threshold=-1.0)
```

### Next Steps:
- See `02_calculate_spei.ipynb` for SPEI calculation
- Explore advanced visualizations with Cartopy
- Apply to your own climate datasets