# SPI-6 from CONUS404-BA (monthly precipitation)

This notebook computes the **Standardized Precipitation Index (SPI)** at a **6‑month accumulation (SPI‑6)** from the **CONUS404 bias‑adjusted (CONUS404‑BA)** precipitation field, and then (optionally) summarizes SPI‑6 over a region (e.g., Iowa) and compares it to an external SPI product.

**High-level steps**
1. Load daily CONUS404‑BA precipitation.
2. Aggregate to monthly totals.
3. Clip to a polygon region of interest.
4. Compute 6‑month accumulated precipitation.
5. Fit a **Gamma distribution per calendar month** (handling zeros) and transform to standard normal → SPI.
6. Save SPI‑6 as NetCDF and create a statewide/area-average time series plot.


## 0. Environment and imports

If you are working with the cloud-hosted HyTEST/OSN catalogs, you need internet access and the required Python packages installed.


In [None]:
# Core
import os
import warnings

import numpy as np
import pandas as pd
import xarray as xr

# Catalog + distributed compute (optional but recommended for large grids)
import intake

# Geospatial I/O and clipping
import geopandas as gpd
from shapely.geometry import mapping
import rioxarray  # noqa: F401 (adds .rio accessor)

# SPI math
from scipy.stats import gamma, norm

# Plotting
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

warnings.filterwarnings("ignore")

## 1. Parameters

Update paths for your machine. The `SHP_FILE` should be a polygon boundary (e.g., Iowa).  
The catalog dataset id below matches the original notebook (`conus404-daily-ba-osn`), but you can change it to any compatible CONUS404 dataset that contains `RAIN`.


In [None]:
# --- Region of interest (polygon) ---
SHP_FILE = r"C:\path\to\Iowa_State_Boundary.shp"  # <-- change

# --- Output paths ---
OUT_DIR = r"C:\path\to\outputs"  # <-- change
OUT_PNG = os.path.join(OUT_DIR, "SPI6_Iowa_CONUS404BA_1980_2022.png")
OUT_NC  = os.path.join(OUT_DIR, "CONUS404BA_SPI6_1980_2022.nc")

# --- Data selection ---
DATASET_ID = "conus404-daily-ba-osn"
TIME_START = "1980-01-01"
TIME_END   = "2022-09-30"  # match original notebook window

## 2. Open the HyTEST catalog and load CONUS404‑BA

This uses HyTEST's intake catalog to access CONUS404 on OSN.  
If your environment cannot reach GitHub or OSN, replace this with your local file paths / zarr store.


In [None]:
HYTEST_CATALOG_URL = "https://raw.githubusercontent.com/hytest-org/hytest/main/dataset_catalog/hytest_intake_catalog.yml"

cat_root = intake.open_catalog(HYTEST_CATALOG_URL)
conus_cat = cat_root["conus404-catalog"]

print("Available CONUS404 datasets:", list(conus_cat)[:10], "...")

ds = conus_cat[DATASET_ID].to_dask()
ds

### 2.1 Select precipitation and the historical analysis window

CONUS404 provides multiple variables; we use `RAIN` (daily rainfall amount).  
We subset to the historical window used in your original notebook.


In [None]:
# Some catalog entries may already be CF-compliant; metpy parsing is optional.
# If you previously relied on metpy.parse_cf(), keep it; otherwise skip to keep dependencies light.
try:
    import metpy  # noqa: F401
    ds = ds.metpy.parse_cf()
except Exception:
    pass

# Subset the analysis window
ds = ds.sel(time=slice(TIME_START, TIME_END))

# Keep only rainfall (daily)
if "RAIN" not in ds:
    raise KeyError(f"Expected variable 'RAIN' in dataset. Variables found: {list(ds.data_vars)}")

rain_daily = ds["RAIN"]
rain_daily

## 3. Aggregate daily precipitation to monthly totals

SPI‑6 is commonly computed from **monthly** precipitation totals.  
We resample to month-start frequency (`MS`) and sum all days in each month.


In [None]:
rain_monthly = rain_daily.resample(time="MS").sum()
rain_monthly.name = "RAIN"
rain_monthly

## 4. Clip monthly precipitation to the region of interest

We reproject the shapefile to the dataset CRS and clip using `rioxarray`.


In [None]:
# Ensure rioxarray knows which dims are spatial (CONUS404 uses x/y)
rain_monthly = rain_monthly.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)

# Attach CRS if present in dataset metadata
# HyTEST CONUS404 entries often include metpy_crs; if not, you may need to set it manually.
if "metpy_crs" in ds.variables:
    crs = ds["metpy_crs"].metpy.pyproj_crs
    rain_monthly = rain_monthly.rio.write_crs(crs.to_wkt(), inplace=False)

# Read and reproject ROI polygon
gdf = gpd.read_file(SHP_FILE)
gdf = gdf.to_crs(rain_monthly.rio.crs)

geoms = [mapping(geom) for geom in gdf.geometry if geom is not None and not geom.is_empty]
if len(geoms) == 0:
    raise ValueError("No valid geometries found in SHP_FILE.")

rain_roi = rain_monthly.rio.clip(geoms, rain_monthly.rio.crs, drop=True, from_disk=True)
rain_roi

## 5. Build 6‑month accumulated precipitation

For each month `t`, the 6‑month accumulation is:

\[
P_6(t) = \sum_{k=0}^{5} P(t-k)
\]

The first 5 months are undefined (insufficient history) and are dropped later.


In [None]:
p6 = rain_roi.rolling(time=6, min_periods=6).sum()
month_index = p6["time"].dt.month
p6

## 6. Compute SPI from accumulated precipitation

### 6.1 Gamma fit per calendar month + zero handling

SPI is computed separately for each calendar month (Jan…Dec) to remove seasonality:

1. For a given pixel and calendar month, fit a **Gamma(a, scale)** distribution to the non‑zero values.
2. Let **q0** be the fraction of zeros for that month.
3. Convert an observation `x` to probability

\[
p = q0 + (1-q0)\,F_\Gamma(x)
\]

4. Transform to standard normal quantile:

\[
\mathrm{SPI} = \Phi^{-1}(p)
\]

This implementation is vectorized over the grid with `xr.apply_ufunc`.


In [None]:
def spi_gamma_by_month_1d(series, month_idx):
    """SPI for a 1D time series using Gamma CDF per calendar month with zero handling.

    Parameters
    ----------
    series : array-like, shape (T,)
        Accumulated precipitation values (e.g., 6-month totals).
    month_idx : array-like, shape (T,)
        Calendar month for each element (1..12).

    Returns
    -------
    out : np.ndarray, shape (T,)
        SPI values.
    """
    x = np.asarray(series, dtype=np.float64)
    months = np.asarray(month_idx, dtype=np.int16)
    out = np.full_like(x, np.nan, dtype=np.float64)

    for m in range(1, 13):
        idx = np.where(months == m)[0]
        if idx.size == 0:
            continue

        xm = x[idx]
        ok = np.isfinite(xm)
        xm_ok = xm[ok]

        if xm_ok.size < 10:  # not enough samples to fit robustly
            continue

        # zeros: separate mass at zero
        q0 = np.mean(xm_ok <= 0.0)
        pos = xm_ok[xm_ok > 0.0]
        if pos.size < 10:
            continue

        try:
            # Fit Gamma with loc fixed at 0 (standard for precip)
            a, loc, scale = gamma.fit(pos, floc=0.0)

            cdf = np.zeros_like(xm, dtype=np.float64)
            mask = xm > 0
            cdf[mask] = gamma.cdf(xm[mask], a, loc=0.0, scale=scale)

            p = np.clip(q0 + (1.0 - q0) * cdf, 1e-6, 1.0 - 1e-6)
            out[idx] = norm.ppf(p)
        except Exception:
            # If a fit fails for a pixel/month, leave NaNs
            pass

    return out

In [None]:
spi6 = xr.apply_ufunc(
    spi_gamma_by_month_1d,
    p6,
    month_index,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,                 # broadcast over (y, x)
    dask="parallelized",
    output_dtypes=[np.float64],
)

# Drop the first 5 months (warm-up) so SPI6 aligns with valid 6-month accumulations
spi6 = spi6.isel(time=slice(5, None)).rename("SPI6")
spi6

## 7. Region-average SPI‑6 time series

We compute a simple mean over the clipped grid cells.  
If you want an **area-weighted mean**, you can weight by cell area or latitude-based weights (if lat/lon are available).


In [None]:
spi6_mean = spi6.mean(dim=[d for d in spi6.dims if d != "time"], skipna=True)
spi6_mean

## 8. Plot SPI‑6 categories and save a figure

In [None]:
# Drought/wetness bins (matching your original categories)
breaks = np.array([-np.inf, -2.0, -1.6, -1.3, -0.8, -0.5, 0.5, 0.8, 1.3, 1.6, 2.0, np.inf])
labels = [
    "Exceptional drought","Extreme drought","Severe drought","Moderate drought","Abnormally dry",
    "Normal",
    "Abnormally wet","Moderate wetness","Severe wetness","Extreme wetness","Exceptional wetness"
]

# Colors (optional)
cmap = {
    "Exceptional drought":"#8b0000","Extreme drought":"#d7301f","Severe drought":"#fc8d59",
    "Moderate drought":"#fdd49e","Abnormally dry":"#d9d9d9","Normal":"#bfbfbf",
    "Abnormally wet":"#b7e4c7","Moderate wetness":"#66c2a5","Severe wetness":"#56b1f7",
    "Extreme wetness":"#253494","Exceptional wetness":"#5e3c99",
}

# Convert to pandas for plotting
s = spi6_mean.to_series()
cats = pd.cut(s.values, bins=breaks, labels=labels)

# Plot
os.makedirs(OUT_DIR, exist_ok=True)

fig, ax = plt.subplots(figsize=(12, 3))
ax.axhline(0, linewidth=1)

# Color bars by category
for i, (t, v, cat) in enumerate(zip(s.index, s.values, cats)):
    if pd.isna(v) or pd.isna(cat):
        continue
    ax.bar(t, v, width=25, align="center", color=cmap.get(str(cat), "gray"))

ax.set_title("SPI-6 (CONUS404-BA) – Region mean")
ax.set_ylabel("SPI (unitless)")
ax.xaxis.set_major_locator(mdates.YearLocator(base=5))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.set_xlim([s.index.min(), s.index.max()])

# Legend
from matplotlib.patches import Patch
handles = [Patch(facecolor=cmap[l], label=l) for l in labels]
ax.legend(handles=handles, ncol=3, frameon=False, bbox_to_anchor=(1.02, 1), loc="upper left")

plt.tight_layout()
plt.savefig(OUT_PNG, dpi=200, bbox_inches="tight")
plt.show()

OUT_PNG

## 9. Save SPI‑6 grid to NetCDF

The output NetCDF stores the SPI‑6 grid (time, y, x) clipped to your region.


In [None]:
spi6_out = spi6.astype("float32")
spi6_out.attrs.update({
    "long_name": "Standardized Precipitation Index (6-month)",
    "description": "Gamma CDF per calendar month with zero handling; inverse-normal transform. Computed from CONUS404-BA precipitation.",
    "units": "1",
})

# Clean up coordinate attributes that sometimes upset NetCDF writers
for c in ("time", "y", "x", "lat", "lon", "south_north", "west_east"):
    if c in spi6_out.coords and "grid_mapping" in spi6_out[c].attrs:
        spi6_out[c].attrs.pop("grid_mapping", None)

ds_out = spi6_out.to_dataset()

# Preserve CRS if present
try:
    ds_out = ds_out.rio.write_crs(rain_roi.rio.crs, inplace=False)
except Exception:
    pass

ds_out.to_netcdf(OUT_NC)
OUT_NC

## 10. Optional: compare CONUS404‑BA SPI‑6 to another SPI product

If you have a second SPI‑6 NetCDF (e.g., a CMIP6-derived SPI for the same region), you can:
- compute region means for both,
- align time stamps,
- compute metrics (bias, RMSE, correlation, NSE, KGE),
- and plot the paired series.

This block is **optional** and can be skipped if you only need CONUS404‑BA SPI.


In [None]:
# --- OPTIONAL COMPARISON CONFIG ---
COMPARE = False

CMIP6_SPI_PATH = r"C:\path\to\SPI6_Iowa_1980_2021.nc"  # <-- change if COMPARE=True
OUT_PAIRED_CSV = os.path.join(OUT_DIR, "SPI6_Iowa_CONUS404BA_vs_other_paired.csv")
OUT_METRICS_CSV = os.path.join(OUT_DIR, "SPI6_Iowa_CONUS404BA_vs_other_metrics.csv")
OUT_COMPARE_PNG = os.path.join(OUT_DIR, "SPI6_CONUS404BA_vs_other.png")

def nse(obs, sim):
    obs = np.asarray(obs); sim = np.asarray(sim)
    return 1.0 - np.sum((sim-obs)**2) / np.sum((obs - np.mean(obs))**2)

def kge(obs, sim):
    obs = np.asarray(obs); sim = np.asarray(sim)
    r = np.corrcoef(obs, sim)[0, 1]
    alpha = np.std(sim) / np.std(obs)
    beta  = np.mean(sim) / np.mean(obs)
    return 1.0 - np.sqrt((r-1)**2 + (alpha-1)**2 + (beta-1)**2)

if COMPARE:
    conus = xr.open_dataset(OUT_NC)["SPI6"]
    other = xr.open_dataset(CMIP6_SPI_PATH)

    # Try to find the SPI variable in the other dataset
    other_var = None
    for cand in ["SPI6", "spi6", "SPI", "spi"]:
        if cand in other:
            other_var = cand
            break
    if other_var is None:
        raise KeyError(f"Could not find an SPI variable in {CMIP6_SPI_PATH}. Variables: {list(other.data_vars)}")

    other_spi = other[other_var]

    conus_m = conus.mean(dim=[d for d in conus.dims if d != "time"], skipna=True).to_series()
    other_m = other_spi.mean(dim=[d for d in other_spi.dims if d != "time"], skipna=True).to_series()

    df = pd.concat(
        [conus_m.rename("SPI6_CONUS404BA"), other_m.rename("SPI6_OTHER")],
        axis=1
    ).dropna()

    df.to_csv(OUT_PAIRED_CSV, index_label="time")

    # Metrics
    A = df["SPI6_CONUS404BA"].values
    B = df["SPI6_OTHER"].values

    metrics = {
        "n": len(df),
        "bias (OTHER - CONUS404BA)": float(np.mean(B - A)),
        "rmse": float(np.sqrt(np.mean((B - A)**2))),
        "r": float(np.corrcoef(A, B)[0, 1]),
        "nse": float(nse(A, B)),
        "kge": float(kge(A, B)),
    }
    pd.DataFrame([metrics]).to_csv(OUT_METRICS_CSV, index=False)

    # Plot
    fig, ax = plt.subplots(figsize=(12, 3))
    ax.plot(df.index, df["SPI6_CONUS404BA"], label="CONUS404-BA")
    ax.plot(df.index, df["SPI6_OTHER"], label="Other SPI")
    ax.axhline(0, linewidth=1)
    ax.set_title("SPI-6 comparison (region mean)")
    ax.set_ylabel("SPI (unitless)")
    ax.xaxis.set_major_locator(mdates.YearLocator(base=5))
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
    ax.legend(frameon=False)
    plt.tight_layout()
    plt.savefig(OUT_COMPARE_PNG, dpi=200, bbox_inches="tight")
    plt.show()

    OUT_PAIRED_CSV, OUT_METRICS_CSV, OUT_COMPARE_PNG