# LSTM States Spatial Analysis (v2)

## Scientific Context

This notebook analyzes how LSTM hidden states (`h_n`) and cell states (`c_n`) internally represent
**meteorological processes NOT included as model inputs**:
- Evaporation (GLEAM)
- Snow Water Equivalent - SWE (ERA5-Land)
- Snow Depth (ERA5-Land)
- Subsurface Runoff (ERA5-Land)

## LSTM Theory: h_n vs c_n

| State | Name | Theoretical Role | Expected Behavior |
|-------|------|------------------|-------------------|
| **c_n** | Cell State | Long-term memory | Captures **slowly-varying, persistent** patterns (seasonal cycles) |
| **h_n** | Hidden State | Short-term output | More **responsive to recent inputs** (quick fluctuations) |

**Hypothesis**: 
- `h_n` should correlate better with quick/variable processes (evaporation, subsurface)
- `c_n` should correlate better with seasonal/persistent patterns (SWE, snow_depth)

## Data Structure

**States array**: `(996 gauges, 730 timesteps, 256 hidden units)`

Each gauge has its **own** state trajectory over 730 days (2019-2020 test period).
This allows proper per-gauge correlation analysis.

## Key Methodological Notes

1. **Per-gauge correlation**: Each gauge's state trajectory is correlated with its OWN meteo data
2. **Signed correlations**: Preserves positive/negative relationships (cells can inhibit or activate)
3. **Z-score ranking**: Ensures all processes get cell representation (prevents snow dominance)
4. **Hybrid cluster comparison**: Tests if cell-process mappings align with hydrological regimes

In [None]:
"""
LSTM States Spatial Analysis (v2 - Per-Gauge Analysis)
=======================================================
Analyze how LSTM hidden states (h_n) and cell states (c_n) internally represent
meteorological processes NOT included as model inputs.

Key Features:
- Per-gauge state trajectories: Each gauge's states correlated with its own meteo
- Preserves correlation SIGN (positive vs negative groups)
- Analyzes BOTH h_n and c_n with comparison
- Z-score ranking for balanced process representation
"""
from pathlib import Path
import sys

import cartopy.crs as ccrs
import geopandas as gpd
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm

sys.path.append("../")

from src.plots.hex_maps import hexes_plots_n
from src.readers.geom_reader import load_geodata
from src.utils.logger import setup_logger

plt.rcParams["font.family"] = "DeJavu Serif"
plt.rcParams["font.serif"] = ["Times New Roman"]
log = setup_logger("chapter_three_v2", log_file="../logs/chapter_three_v2.log")

# Output directories
table_dir = Path("../res/chapter_three/tables")
table_dir.mkdir(parents=True, exist_ok=True)
image_dir = Path("../res/chapter_three/images")
image_dir.mkdir(parents=True, exist_ok=True)

print("Imports complete. Output directories ready.")

## 1. Data Loading

Load:
- **LSTM states**: `(996, 730, 256)` - per-gauge state trajectories
- **Gauge geometries**: For spatial visualization
- **Hybrid clusters**: Pre-computed hydrological regime classifications

In [None]:
# =============================================================================
# Load Data: LSTM states, gauge geometries, hybrid clusters
# =============================================================================

# Load gauge geometries
ws, gauges = load_geodata(folder_depth="../")
basemap_data = gpd.read_file("../data/geometry/basemap_2023.gpkg")

# Load hybrid cluster assignments
gauge_mapping = pd.read_csv(
    "../res/chapter_one/gauge_hybrid_mapping.csv",
    index_col="gauge_id",
    dtype={"gauge_id": str},
)
hybrid_clusters = gauge_mapping["hybrid_class"]

# Load LSTM states - NEW FORMAT: (n_gauges, n_timesteps, hidden_size)
states_path = Path("../data/optimization/lstm_states_per_gauge/all_gauges_states.npz")
print(f"Loading LSTM states from: {states_path}")
states = np.load(states_path, allow_pickle=True)
print(f"Available keys: {list(states.keys())}")

# Extract state arrays
# Shape: (n_gauges, n_timesteps, hidden_size) = (996, 730, 256)
h_states_all = states["h_n"]  # Hidden states for all gauges
c_states_all = states["c_n"]  # Cell states for all gauges

# Get gauge IDs from states file (ordering matches array indices)
if "gauge_ids" in states.keys():
    state_gauge_ids = [str(g) for g in states["gauge_ids"]]
else:
    # Fallback: use common_index order
    state_gauge_ids = gauges.index.to_list()

n_gauges, n_timesteps, hidden_size = h_states_all.shape

print(f"\nh_n shape: {h_states_all.shape}")
print(f"c_n shape: {c_states_all.shape}")
print(f"Gauges in states: {n_gauges}")
print(f"Timesteps per gauge: {n_timesteps}")
print(f"Hidden size: {hidden_size}")
print(f"Hybrid clusters: {hybrid_clusters.nunique()}")

## 2. Meteorological Parameters

These are processes the LSTM **never saw during training**, but may have learned to represent internally:

| Parameter | Source | Physical Meaning |
|-----------|--------|------------------|
| Evaporation | GLEAM | Water loss to atmosphere |
| SWE | ERA5-Land | Snow water storage |
| Snow Depth | ERA5-Land | Snow accumulation |
| Subsurface | ERA5-Land | Groundwater/baseflow contribution |

In [None]:
# =============================================================================
# Define meteorological parameters (NOT used as model inputs)
# =============================================================================
METEO_PARAMS = {
    "evaporation": {
        "path": Path("../data/meteo_grids_2024/gleam/E"),
        "column": "E",
        "description": "Evaporation (GLEAM)",
        "temporal_scale": "short-term",  # Quick/variable process
    },
    "swe": {
        "path": Path(
            "../data/meteo_grids_2024/snow_and_subsurface/era5_land/snow_depth_water_equivalent"
        ),
        "column": "swe_e5l",
        "description": "Snow Water Equivalent (ERA5-Land)",
        "temporal_scale": "seasonal",  # Slow/persistent process
    },
    "snow_depth": {
        "path": Path("../data/meteo_grids_2024/snow_and_subsurface/era5_land/snow_depth"),
        "column": None,
        "description": "Snow Depth (ERA5-Land)",
        "temporal_scale": "seasonal",  # Slow/persistent process
    },
    "subsurface": {
        "path": Path(
            "../data/meteo_grids_2024/snow_and_subsurface/era5_land/sub_surface_runoff"
        ),
        "column": None,
        "description": "Subsurface Runoff (ERA5-Land)",
        "temporal_scale": "short-term",  # Quick/variable process
    },
}

# Color scheme for signed process groups
PROCESS_COLORS = {
    "evaporation_pos": "#27ae60",  # Green
    "evaporation_neg": "#1e8449",  # Dark Green
    "swe_pos": "#3498db",  # Blue
    "swe_neg": "#1a5276",  # Dark Blue
    "snow_depth_pos": "#9b59b6",  # Purple
    "snow_depth_neg": "#6c3483",  # Dark Purple
    "subsurface_pos": "#e67e22",  # Orange
    "subsurface_neg": "#a04000",  # Dark Orange
    "inactive": "#bdc3c7",  # Gray
}

# Minimum correlation threshold
MIN_CORRELATION = 0.3

# Verify paths
for param, info in METEO_PARAMS.items():
    exists = info["path"].exists()
    n_files = len(list(info["path"].glob("*.csv"))) if exists else 0
    print(f"{param}: exists={exists}, files={n_files}")

## 3. Per-Gauge Correlation Analysis

**Critical improvement over previous version**:

- **Before**: One state trajectory correlated with 996 different gauges' meteo data
- **Now**: Each gauge's state trajectory correlated with its OWN meteo data

For each gauge:
1. Get gauge's state trajectory: `(730 timesteps, 256 cells)`
2. Load gauge's meteo data: `(730 timesteps,)`
3. Compute correlation for each cell: `r = corr(cell_trajectory, meteo_trajectory)`

Result: Per-gauge correlation matrix `(n_gauges, n_cells)` for each meteo parameter.

In [None]:
# =============================================================================
# Core correlation functions (preserving sign)
# =============================================================================


def compute_cell_correlations(
    cell_states: np.ndarray,  # Shape: (n_timesteps, hidden_size)
    meteo_array: np.ndarray,  # Shape: (n_timesteps,)
) -> np.ndarray:
    """
    Compute SIGNED Pearson correlation between each LSTM cell and meteo data.

    Args:
        cell_states: LSTM states for one gauge, shape (n_timesteps, hidden_size)
        meteo_array: Meteo values for same gauge, shape (n_timesteps,)

    Returns:
        Array of shape (hidden_size,) with signed correlation coefficients.
    """
    # Align lengths
    min_len = min(len(cell_states), len(meteo_array))
    cell_states = cell_states[:min_len]
    meteo_array = meteo_array[:min_len]

    # Remove NaN values
    valid_mask = ~np.isnan(meteo_array)
    if valid_mask.sum() < 30:  # Minimum samples for meaningful correlation
        return np.full(cell_states.shape[1], np.nan)

    cell_states = cell_states[valid_mask]
    meteo_array = meteo_array[valid_mask]

    # Compute correlation for each cell
    n_cells = cell_states.shape[1]
    correlations = np.zeros(n_cells)

    for i in range(n_cells):
        try:
            corr_matrix = np.corrcoef(cell_states[:, i], meteo_array)
            correlations[i] = corr_matrix[0, 1]  # SIGNED correlation
        except Exception:
            correlations[i] = np.nan

    return correlations


def load_meteo_data(
    gauge_id: str,
    param_name: str,
    start_date: str = "2019-01-01",
    end_date: str = "2020-12-31",
) -> np.ndarray:
    """
    Load meteorological data for a specific gauge.
    Filters to test period dates to match LSTM states.
    """
    info = METEO_PARAMS[param_name]
    file_path = info["path"] / f"{gauge_id}.csv"

    if not file_path.exists():
        return np.array([])

    try:
        df = pd.read_csv(file_path, index_col="date", parse_dates=["date"])

        # Filter to test period
        df = df.loc[start_date:end_date]

        col = info["column"] if info["column"] else df.columns[0]
        if col not in df.columns:
            col = df.columns[0]
        return df[col].values
    except Exception as e:
        return np.array([])


print("Correlation functions defined (sign-preserving).")

In [None]:
# =============================================================================
# Compute per-gauge correlations for h_n and c_n
# =============================================================================


def compute_all_gauge_correlations(states_array, state_name="h_n"):
    """
    Compute correlation matrices for all gauges and all meteo params.

    Args:
        states_array: Shape (n_gauges, n_timesteps, hidden_size)
        state_name: For logging

    Returns:
        Dict[param_name -> DataFrame(gauge_id x cell_id)]
    """
    correlation_matrices = {}

    for param_name in METEO_PARAMS.keys():
        print(f"\n{state_name} - Processing: {param_name}")

        # Initialize correlation DataFrame
        corr_df = pd.DataFrame(
            index=state_gauge_ids,
            columns=range(hidden_size),
            dtype=float,
        )

        success_count = 0
        for gauge_idx, gauge_id in enumerate(tqdm(state_gauge_ids, desc=param_name)):
            # Get THIS gauge's state trajectory
            gauge_states = states_array[gauge_idx]  # Shape: (n_timesteps, hidden_size)

            # Load THIS gauge's meteo data
            meteo_data = load_meteo_data(gauge_id, param_name)

            if len(meteo_data) == 0:
                continue

            # Compute correlations between gauge's states and gauge's meteo
            correlations = compute_cell_correlations(gauge_states, meteo_data)

            if not np.all(np.isnan(correlations)):
                corr_df.loc[gauge_id] = correlations
                success_count += 1

        correlation_matrices[param_name] = corr_df.dropna(how="all")
        print(
            f"  Valid gauges: {success_count}, Shape: {correlation_matrices[param_name].shape}"
        )

    return correlation_matrices


# Compute for h_n (hidden state)
print("=" * 70)
print("Computing SIGNED correlations for h_n (hidden state)")
print("Each gauge's states correlated with its OWN meteo data")
print("=" * 70)
h_correlation_matrices = compute_all_gauge_correlations(h_states_all, "h_n")

# Compute for c_n (cell state)
print("\n" + "=" * 70)
print("Computing SIGNED correlations for c_n (cell state)")
print("Each gauge's states correlated with its OWN meteo data")
print("=" * 70)
c_correlation_matrices = compute_all_gauge_correlations(c_states_all, "c_n")

print("\nCorrelation computation complete!")

## 4. Cell Assignment to Process Groups

Assign each of the 256 LSTM cells to a "signed process group" based on which
meteorological process it best represents.

**Z-score ranking approach**:
- Problem: Snow processes often have higher raw correlations, dominating assignments
- Solution: Normalize correlations within each process using Z-scores
- This ensures weaker-correlation processes (subsurface, evaporation) still get cells

Groups: `{process}_pos` (positive correlation) or `{process}_neg` (negative correlation)

In [None]:
# =============================================================================
# Assign cells to SIGNED process groups using Z-SCORE ranking
# =============================================================================


def assign_cells_to_signed_groups(correlation_matrices, state_name="h_n"):
    """
    Assign each cell to a signed process group using Z-score ranking.

    This method normalizes correlations within each process so that
    processes with inherently weaker correlations still get cells assigned.
    """
    all_params = list(METEO_PARAMS.keys())

    # Compute mean correlation per cell across all gauges
    cell_corr_matrix = pd.DataFrame(index=range(hidden_size))

    for param_name in all_params:
        corr_df = correlation_matrices[param_name]
        mean_corr = corr_df.mean()  # Mean across gauges, preserves sign!
        cell_corr_matrix[param_name] = mean_corr.values

    # Compute Z-scores for each cell within each process
    # This normalizes so we compare "how good is this cell for this process"
    z_scores = pd.DataFrame(index=range(hidden_size))

    for param in all_params:
        col = cell_corr_matrix[param]
        # Z-score for positive correlations
        z_scores[f"{param}_pos"] = (col.clip(lower=0) - col.clip(lower=0).mean()) / (
            col.clip(lower=0).std() + 1e-8
        )
        # Z-score for negative correlations (use absolute value)
        z_scores[f"{param}_neg"] = (
            col.clip(upper=0).abs() - col.clip(upper=0).abs().mean()
        ) / (col.clip(upper=0).abs().std() + 1e-8)

    # For each cell, find which signed group has highest Z-score
    cell_assignment = []

    for cell_id in range(hidden_size):
        cell_zscores = z_scores.loc[cell_id]
        cell_corrs = cell_corr_matrix.loc[cell_id]

        # Find best group by Z-score
        best_group = cell_zscores.idxmax()
        best_zscore = cell_zscores.max()

        # Get the actual correlation for this group
        base_process = best_group.rsplit("_", 1)[0]
        is_positive = best_group.endswith("_pos")

        if is_positive:
            actual_r = cell_corrs[base_process]
        else:
            actual_r = -abs(cell_corrs[base_process])

        # Check if correlation meets minimum threshold
        if abs(actual_r) >= MIN_CORRELATION:
            primary = best_group
            primary_r = actual_r
        else:
            # Fall back to absolute best if Z-score winner doesn't meet threshold
            all_corrs = cell_corrs.abs()
            if all_corrs.max() >= MIN_CORRELATION:
                best_param = all_corrs.idxmax()
                best_r = cell_corrs[best_param]
                primary = f"{best_param}_pos" if best_r > 0 else f"{best_param}_neg"
                primary_r = best_r
            else:
                primary = "inactive"
                primary_r = 0

        cell_assignment.append(
            {
                "cell_id": cell_id,
                "primary_group": primary,
                "primary_r": primary_r,
                "zscore": best_zscore,
            }
        )

    cell_assignment_df = pd.DataFrame(cell_assignment)

    # Create cell groups dictionary
    cell_groups = {}
    for group in cell_assignment_df["primary_group"].unique():
        cell_groups[group] = cell_assignment_df[
            cell_assignment_df["primary_group"] == group
        ]["cell_id"].tolist()

    return cell_assignment_df, cell_groups, cell_corr_matrix


# Apply to h_n
print("=" * 70)
print("Assigning h_n cells to SIGNED process groups (Z-score ranking)")
print("=" * 70)
h_cell_assignment, h_cell_groups, h_cell_corr_matrix = assign_cells_to_signed_groups(
    h_correlation_matrices, "h_n"
)

# Print mean |r| per process
print("\nMean |r| per process across all cells:")
for param in METEO_PARAMS.keys():
    mean_abs_r = h_cell_corr_matrix[param].abs().mean()
    print(f"  {param}: {mean_abs_r:.3f}")

print("\nh_n Cell Group Distribution:")
print(h_cell_assignment["primary_group"].value_counts())

## 5. Cell Grid Visualization

Visualize the 256 LSTM cells as a 16x16 grid, colored by their assigned process group.

This shows how the LSTM internally organized its representations across cells.

In [None]:
# =============================================================================
# 16x16 Grid Visualization of Cell Assignments
# =============================================================================


def plot_cell_grid(cell_assignment_df, title_suffix=""):
    """Plot 16x16 grid of cells colored by process group."""
    fig, ax = plt.subplots(figsize=(14, 12))

    grid_size = int(np.sqrt(hidden_size))
    cell_grid = np.zeros((grid_size, grid_size, 3))

    # Convert hex colors to RGB
    def hex_to_rgb(hex_color):
        hex_color = hex_color.lstrip("#")
        return tuple(int(hex_color[i : i + 2], 16) / 255 for i in (0, 2, 4))

    for _, row in cell_assignment_df.iterrows():
        cell_id = int(row["cell_id"])
        group = row["primary_group"]
        color = PROCESS_COLORS.get(group, PROCESS_COLORS["inactive"])

        i, j = cell_id // grid_size, cell_id % grid_size
        cell_grid[i, j] = hex_to_rgb(color)

    ax.imshow(cell_grid, aspect="equal")

    # Grid lines
    for i in range(grid_size + 1):
        ax.axhline(i - 0.5, color="white", linewidth=0.5)
        ax.axvline(i - 0.5, color="white", linewidth=0.5)

    ax.set_xticks(range(0, grid_size, 2))
    ax.set_yticks(range(0, grid_size, 2))
    ax.set_xlabel("Cell Column", fontsize=12)
    ax.set_ylabel("Cell Row", fontsize=12)
    ax.set_title(f"LSTM Cell Process Groups{title_suffix}", fontsize=14)

    # Legend
    legend_elements = [
        Patch(facecolor=color, label=group.replace("_", " ").title())
        for group, color in PROCESS_COLORS.items()
        if group in cell_assignment_df["primary_group"].values
    ]
    ax.legend(
        handles=legend_elements,
        loc="center left",
        bbox_to_anchor=(1.02, 0.5),
        fontsize=10,
    )

    plt.tight_layout()
    return fig


# Plot h_n cell grid
fig = plot_cell_grid(h_cell_assignment, " (h_n - Hidden State)")
fig.savefig(image_dir / "lstm_v2_hn_cell_grid.png", dpi=150, bbox_inches="tight")
plt.show()

## 6. Per-Gauge Dominant Process

For each gauge, determine which process has the strongest representation.

**Method**: Use Z-score normalized correlations to find which process dominates
for each gauge, ensuring fair comparison across processes with different
inherent correlation magnitudes.

In [None]:
# =============================================================================
# Per-Gauge Dominant Process (with Z-score normalization)
# =============================================================================


def compute_gauge_dominant_process(correlation_matrices):
    """
    For each gauge, find which signed process has the strongest representation.
    Uses Z-score normalization for fair comparison across processes.
    """
    # Get common gauges across all params
    common_gauges = set(state_gauge_ids)
    for param in correlation_matrices.values():
        common_gauges &= set(param.index)
    common_gauges = list(common_gauges)

    # Compute max absolute correlation per gauge per process
    gauge_max_corr = pd.DataFrame(index=common_gauges)

    for param_name, corr_df in correlation_matrices.items():
        # For each gauge, get max absolute correlation across all cells
        gauge_max_corr[f"{param_name}_max"] = corr_df.loc[common_gauges].abs().max(axis=1)
        # Also track the sign of the best cell
        best_cell_idx = corr_df.loc[common_gauges].abs().idxmax(axis=1)
        gauge_max_corr[f"{param_name}_sign"] = [
            np.sign(corr_df.loc[g, idx]) for g, idx in zip(common_gauges, best_cell_idx)
        ]

    # Z-score normalize within each process
    gauge_zscores = pd.DataFrame(index=common_gauges)
    for param_name in METEO_PARAMS.keys():
        col = gauge_max_corr[f"{param_name}_max"]
        gauge_zscores[param_name] = (col - col.mean()) / (col.std() + 1e-8)

    # Find dominant process for each gauge
    gauge_dominant = []

    for gauge_id in common_gauges:
        # Best process by Z-score
        best_param = gauge_zscores.loc[gauge_id].idxmax()
        best_r = gauge_max_corr.loc[gauge_id, f"{best_param}_max"]
        best_sign = gauge_max_corr.loc[gauge_id, f"{best_param}_sign"]

        if best_r >= MIN_CORRELATION:
            group = f"{best_param}_pos" if best_sign > 0 else f"{best_param}_neg"
        else:
            group = "inactive"
            best_r = 0

        gauge_dominant.append(
            {
                "gauge_id": gauge_id,
                "dominant_group": group,
                "max_r": best_r * best_sign,
                "zscore": gauge_zscores.loc[gauge_id, best_param]
                if group != "inactive"
                else 0,
            }
        )

    return pd.DataFrame(gauge_dominant).set_index("gauge_id")


# Compute for h_n
h_gauge_dominant = compute_gauge_dominant_process(h_correlation_matrices)

print("h_n Gauge Dominant Process Distribution:")
print(h_gauge_dominant["dominant_group"].value_counts())

## 7. Spatial Visualization

Map the dominant process for each gauge to see spatial patterns.

Expected patterns:
- Snow processes should dominate in northern/mountain regions
- Evaporation should dominate in southern/agricultural regions
- Subsurface should appear in regions with significant groundwater contribution

In [None]:
# =============================================================================
# Spatial Map of Dominant Processes
# =============================================================================

# Merge with gauge geometries
gauges_with_process = gauges.copy()
gauges_with_process = gauges_with_process.join(h_gauge_dominant, how="inner")

# Create map
fig, ax = plt.subplots(
    figsize=(14, 10),
    subplot_kw={
        "projection": ccrs.AlbersEqualArea(central_longitude=100, central_latitude=60)
    },
)

# Plot basemap
basemap_data.to_crs(ax.projection.proj4_init).plot(
    ax=ax, color="lightgray", edgecolor="white", linewidth=0.5
)

# Plot gauges by dominant process
for group in gauges_with_process["dominant_group"].unique():
    subset = gauges_with_process[gauges_with_process["dominant_group"] == group]
    color = PROCESS_COLORS.get(group, PROCESS_COLORS["inactive"])
    subset.to_crs(ax.projection.proj4_init).plot(
        ax=ax,
        color=color,
        markersize=15,
        alpha=0.8,
        label=group.replace("_", " ").title(),
    )

ax.set_title(
    "Dominant LSTM Process (h_n) by Gauge\n(Z-score normalized for fair comparison)",
    fontsize=14,
)
ax.legend(loc="lower left", fontsize=10)

plt.tight_layout()
fig.savefig(
    image_dir / "lstm_v2_hn_dominant_process_map.png", dpi=150, bbox_inches="tight"
)
plt.show()

print(f"Watersheds with valid data: {len(gauges_with_process)}")

## 8. h_n vs c_n Comparison

Compare hidden state (h_n) and cell state (c_n) to test the hypothesis:

- **h_n**: Should specialize in quick/variable processes (evaporation, subsurface)
- **c_n**: Should specialize in seasonal/persistent patterns (SWE, snow_depth)

We compare:
1. Cell group distributions
2. Overlap between h_n and c_n assignments
3. Mean correlation strengths per process

In [None]:
# =============================================================================
# Apply same analysis to c_n (cell state)
# =============================================================================

print("=" * 70)
print("Assigning c_n cells to SIGNED process groups")
print("=" * 70)

c_cell_assignment, c_cell_groups, c_cell_corr_matrix = assign_cells_to_signed_groups(
    c_correlation_matrices, "c_n"
)

print("\nc_n Cell Group Distribution:")
print(c_cell_assignment["primary_group"].value_counts())

In [None]:
# =============================================================================
# Side-by-side Grid Comparison: h_n vs c_n
# =============================================================================

fig, axes = plt.subplots(1, 2, figsize=(24, 10))

grid_size = int(np.sqrt(hidden_size))


def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip("#")
    return tuple(int(hex_color[i : i + 2], 16) / 255 for i in (0, 2, 4))


for ax_idx, (cell_assign, title) in enumerate(
    [
        (h_cell_assignment, "Hidden State (h_n)\nShort-term memory"),
        (c_cell_assignment, "Cell State (c_n)\nLong-term memory"),
    ]
):
    ax = axes[ax_idx]
    cell_grid = np.zeros((grid_size, grid_size, 3))

    for _, row in cell_assign.iterrows():
        cell_id = int(row["cell_id"])
        group = row["primary_group"]
        color = PROCESS_COLORS.get(group, PROCESS_COLORS["inactive"])
        i, j = cell_id // grid_size, cell_id % grid_size
        cell_grid[i, j] = hex_to_rgb(color)

    ax.imshow(cell_grid, aspect="equal")

    for i in range(grid_size + 1):
        ax.axhline(i - 0.5, color="white", linewidth=0.5)
        ax.axvline(i - 0.5, color="white", linewidth=0.5)

    ax.set_title(title, fontsize=14)
    ax.set_xlabel("Cell Column")
    ax.set_ylabel("Cell Row")

# Shared legend
legend_elements = [
    Patch(facecolor=color, label=group.replace("_", " ").title())
    for group, color in PROCESS_COLORS.items()
]
fig.legend(
    handles=legend_elements, loc="center right", bbox_to_anchor=(1.08, 0.5), fontsize=10
)

plt.tight_layout()
fig.savefig(image_dir / "lstm_v2_hn_cn_comparison_grid.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# =============================================================================
# Comparison Statistics: h_n vs c_n
# =============================================================================

comparison_data = []

all_groups = set(h_cell_groups.keys()) | set(c_cell_groups.keys())

for group in sorted(all_groups):
    h_cells = set(h_cell_groups.get(group, []))
    c_cells = set(c_cell_groups.get(group, []))

    overlap = h_cells & c_cells
    h_only = h_cells - c_cells
    c_only = c_cells - h_cells

    comparison_data.append(
        {
            "Group": group,
            "h_n": len(h_cells),
            "c_n": len(c_cells),
            "Overlap": len(overlap),
            "h_n only": len(h_only),
            "c_n only": len(c_only),
        }
    )

comparison_df = pd.DataFrame(comparison_data)
print("=" * 70)
print("h_n vs c_n COMPARISON")
print("=" * 70)
print(comparison_df.to_string(index=False))

# Overall agreement
same_assignment = sum(
    1
    for i in range(hidden_size)
    if h_cell_assignment.loc[i, "primary_group"]
    == c_cell_assignment.loc[i, "primary_group"]
)
print(
    f"\nCells with same assignment: {same_assignment}/{hidden_size} ({100 * same_assignment / hidden_size:.1f}%)"
)

## 9. Hypothesis Test: Temporal Scale Specialization

Test whether h_n and c_n show different preferences for temporal scales:

| Process | Expected Temporal Scale | Expected State |
|---------|------------------------|----------------|
| Evaporation | Short-term (daily fluctuations) | h_n |
| Subsurface | Short-term (event-driven) | h_n |
| SWE | Seasonal (slow accumulation/melt) | c_n |
| Snow Depth | Seasonal (winter pattern) | c_n |

In [None]:
# =============================================================================
# Hypothesis Test: Which state better captures which process?
# =============================================================================

# Compute mean absolute correlation per process for h_n and c_n
process_comparison = []

for param_name, info in METEO_PARAMS.items():
    h_mean_r = h_cell_corr_matrix[param_name].abs().mean()
    c_mean_r = c_cell_corr_matrix[param_name].abs().mean()

    # Which state has stronger correlation?
    better_state = "h_n" if h_mean_r > c_mean_r else "c_n"
    expected_state = "h_n" if info["temporal_scale"] == "short-term" else "c_n"
    matches_hypothesis = better_state == expected_state

    process_comparison.append(
        {
            "Process": param_name,
            "Temporal Scale": info["temporal_scale"],
            "h_n Mean |r|": h_mean_r,
            "c_n Mean |r|": c_mean_r,
            "Better State": better_state,
            "Expected": expected_state,
            "Matches Hypothesis": matches_hypothesis,
        }
    )

hypothesis_df = pd.DataFrame(process_comparison)

print("=" * 70)
print("HYPOTHESIS TEST: h_n (short-term) vs c_n (long-term)")
print("=" * 70)
print(hypothesis_df.to_string(index=False))

matches = hypothesis_df["Matches Hypothesis"].sum()
total = len(hypothesis_df)
print(f"\nHypothesis matches: {matches}/{total} processes")

## 10. Cluster Generalization Analysis

Compare LSTM cell-process mappings with hybrid clusters to test:
- Do gauges in the same hybrid cluster have similar dominant processes?
- Does the LSTM's internal organization align with hydrological regimes?

In [None]:
# =============================================================================
# Cluster vs Dominant Process Cross-tabulation
# =============================================================================

# Merge hybrid clusters with gauge dominant process
cluster_process = h_gauge_dominant.copy()
cluster_process["hybrid_cluster"] = hybrid_clusters.reindex(cluster_process.index)
cluster_process = cluster_process.dropna(subset=["hybrid_cluster"])

# Cross-tabulation
crosstab = pd.crosstab(
    cluster_process["hybrid_cluster"],
    cluster_process["dominant_group"],
    margins=True,
)

print("=" * 70)
print("HYBRID CLUSTER vs DOMINANT PROCESS CROSS-TABULATION")
print("=" * 70)
print(crosstab.to_string())

# Save
crosstab.to_csv(table_dir / "lstm_v2_cluster_process_crosstab.csv")

In [None]:
# =============================================================================
# Heatmap of Cluster-Process Proportions
# =============================================================================

# Compute proportions within each cluster
crosstab_norm = crosstab.drop("All", axis=0).drop("All", axis=1)
crosstab_pct = crosstab_norm.div(crosstab_norm.sum(axis=1), axis=0) * 100

fig, ax = plt.subplots(figsize=(14, 10))
sns.heatmap(
    crosstab_pct,
    annot=True,
    fmt=".0f",
    cmap="YlOrRd",
    ax=ax,
    cbar_kws={"label": "% of cluster"},
)
ax.set_xlabel("Dominant Process", fontsize=12)
ax.set_ylabel("Hybrid Cluster", fontsize=12)
ax.set_title(
    "Dominant LSTM Process by Hybrid Cluster (h_n)\nPercentage within each cluster",
    fontsize=14,
)

plt.tight_layout()
fig.savefig(
    image_dir / "lstm_v2_cluster_process_heatmap.png", dpi=150, bbox_inches="tight"
)
plt.show()

## 11. Save Results

In [None]:
# =============================================================================
# Save all results
# =============================================================================

# Cell assignments
h_cell_assignment.to_csv(table_dir / "lstm_v2_hn_cell_assignment.csv", index=False)
c_cell_assignment.to_csv(table_dir / "lstm_v2_cn_cell_assignment.csv", index=False)

# Comparison
comparison_df.to_csv(table_dir / "lstm_v2_hn_cn_comparison.csv", index=False)

# Gauge dominant process
h_gauge_dominant.to_csv(table_dir / "lstm_v2_gauge_dominant.csv")

# Hypothesis test
hypothesis_df.to_csv(table_dir / "lstm_v2_hypothesis_test.csv", index=False)

print("All results saved!")
print(f"\nOutput directory: {table_dir}")

## Summary

### Key Findings

1. **Per-gauge state analysis**: Each gauge's LSTM states were correlated with its own meteo data
2. **Signed correlations**: Preserved positive/negative relationships
3. **Z-score ranking**: Ensured balanced representation across processes
4. **h_n vs c_n comparison**: Tested hypothesis about temporal scale specialization

### Output Files

**Tables**:
- `lstm_v2_hn_cell_assignment.csv` - h_n cell assignments
- `lstm_v2_cn_cell_assignment.csv` - c_n cell assignments
- `lstm_v2_hn_cn_comparison.csv` - h_n vs c_n comparison statistics
- `lstm_v2_gauge_dominant.csv` - Per-gauge dominant process
- `lstm_v2_hypothesis_test.csv` - Temporal scale hypothesis results
- `lstm_v2_cluster_process_crosstab.csv` - Cluster x process cross-tabulation

**Figures**:
- `lstm_v2_hn_cell_grid.png` - 16x16 grid of h_n cell assignments
- `lstm_v2_hn_dominant_process_map.png` - Spatial map of dominant processes
- `lstm_v2_hn_cn_comparison_grid.png` - Side-by-side h_n vs c_n grids
- `lstm_v2_cluster_process_heatmap.png` - Cluster x process heatmap

### Interpretation Notes

- **h_n (hidden state)**: Should theoretically capture short-term dynamics
- **c_n (cell state)**: Should theoretically capture long-term memory/seasonal patterns
- Check `hypothesis_test.csv` to see if this pattern holds in the data