# LSTM State Extraction for All Gauges

This notebook properly extracts LSTM hidden states (h_n) and cell states (c_n) for **each gauge** separately.

**Key insight**: The previous approach only extracted states from the first batch (â‰ˆ1 gauge).
This notebook iterates through ALL samples and organizes states by gauge ID.

**Requirements**: This notebook requires PyTorch and CUDA. Run on a machine with GPU support.

Reference: https://neuralhydrology.readthedocs.io/en/latest/tutorials/inspect-lstm.html

In [None]:
from pathlib import Path
from collections import defaultdict
import sys

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from neuralhydrology.datasetzoo import get_dataset
from neuralhydrology.datautils.utils import load_scaler
from neuralhydrology.modelzoo.cudalstm import CudaLSTM
from neuralhydrology.modelzoo.customlstm import CustomLSTM
from neuralhydrology.utils.config import Config

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# =============================================================================
# Configuration - Update these paths for your setup
# =============================================================================

# Path to your trained model config
CONFIG_FILE = Path(
    "../data/lstm_configs/model_runs/cudalstm_q_mm_day_mswep_no_autocorr_static_1103_191754/config.yml"
)

# Model epoch to use (set to None to auto-detect best epoch)
MODEL_EPOCH = 24

# Output directory for extracted states
OUTPUT_DIR = Path("../data/optimization/lstm_states_per_gauge")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Batch size for processing (adjust based on GPU memory)
# Larger = faster but more memory
BATCH_SIZE = 256

# Device selection
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"  # For Apple Silicon
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")
print(f"Config file: {CONFIG_FILE}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# =============================================================================
# Load trained model
# =============================================================================

# Load config
cudalstm_config = Config(CONFIG_FILE)
print(f"Model hidden size: {cudalstm_config.hidden_size}")
print(f"Sequence length: {cudalstm_config.seq_length}")

# Create CudaLSTM and load weights
cuda_lstm = CudaLSTM(cfg=cudalstm_config)
model_path = cudalstm_config.run_dir / f"model_epoch{MODEL_EPOCH:03d}.pt"
print(f"Loading model from: {model_path}")

model_weights = torch.load(str(model_path), map_location=DEVICE)
cuda_lstm.load_state_dict(model_weights)

# Create CustomLSTM and copy weights (for full state access)
custom_lstm = CustomLSTM(cfg=cudalstm_config)
custom_lstm.copy_weights(cuda_lstm)
custom_lstm.to(DEVICE)
custom_lstm.eval()

print("Model loaded successfully!")
print(f"Hidden size: {cudalstm_config.hidden_size}")

In [None]:
# =============================================================================
# Load test dataset
# =============================================================================

scaler = load_scaler(cudalstm_config.run_dir)
dataset = get_dataset(cudalstm_config, is_train=False, period="test", scaler=scaler)

# Get list of basins (gauges) in the dataset
# NeuralHydrology stores basin IDs in the dataset
all_basins = (
    dataset.basins if hasattr(dataset, "basins") else list(dataset.basin_id_map.keys())
)
n_basins = len(all_basins)

print(f"Number of gauges in dataset: {n_basins}")
print(f"Total samples in dataset: {len(dataset)}")
print(f"Samples per gauge (approx): {len(dataset) // n_basins}")
print(f"First 5 gauge IDs: {all_basins[:5]}")

In [None]:
# =============================================================================
# Extract states for ALL gauges
# =============================================================================

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # Keep order to track gauge IDs
    collate_fn=dataset.collate_fn,
    num_workers=0,  # Avoid multiprocessing issues
)

# Storage for per-gauge states
# Key: gauge_id, Value: list of state arrays
gauge_h_states = defaultdict(list)  # Hidden states
gauge_c_states = defaultdict(list)  # Cell states
gauge_dates = defaultdict(list)  # Corresponding dates

print(f"Processing {len(dataloader)} batches...")

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Extracting states")):
        # Move batch to device
        batch_device = {
            k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v
            for k, v in batch.items()
        }

        # Forward pass through CustomLSTM
        output = custom_lstm(batch_device)

        # Extract final timestep states: (batch, seq_len, hidden) -> (batch, hidden)
        h_n = output["h_n"][:, -1, :].cpu().numpy()  # Hidden state
        c_n = output["c_n"][:, -1, :].cpu().numpy()  # Cell state

        # Get basin IDs and dates for this batch
        # NeuralHydrology includes metadata in the batch
        basin_ids = batch.get("basin", batch.get("basin_id", None))
        dates = batch.get("date", batch.get("end_date", None))

        if basin_ids is None:
            # Fallback: infer from batch index and dataset structure
            start_idx = batch_idx * BATCH_SIZE
            end_idx = min(start_idx + len(h_n), len(dataset))
            # This requires knowing the dataset's internal ordering
            print(f"Warning: Basin IDs not in batch. Using index-based assignment.")
            # You may need to adjust this based on your dataset structure
            continue

        # Store states by gauge
        for i in range(len(h_n)):
            gauge_id = (
                str(basin_ids[i]) if not isinstance(basin_ids[i], str) else basin_ids[i]
            )
            gauge_h_states[gauge_id].append(h_n[i])
            gauge_c_states[gauge_id].append(c_n[i])
            if dates is not None:
                gauge_dates[gauge_id].append(dates[i])

print(f"\nExtracted states for {len(gauge_h_states)} gauges")

In [None]:
# =============================================================================
# Verify extraction
# =============================================================================

print("Per-gauge sample counts:")
sample_counts = {g: len(states) for g, states in gauge_h_states.items()}

print(f"  Min samples per gauge: {min(sample_counts.values())}")
print(f"  Max samples per gauge: {max(sample_counts.values())}")
print(f"  Mean samples per gauge: {np.mean(list(sample_counts.values())):.1f}")

# Show first few gauges
for gauge_id in list(gauge_h_states.keys())[:5]:
    print(f"  {gauge_id}: {len(gauge_h_states[gauge_id])} samples")

In [None]:
# =============================================================================
# Save states - Option 1: One file per gauge
# =============================================================================

print("Saving per-gauge state files...")

for gauge_id in tqdm(gauge_h_states.keys(), desc="Saving"):
    h_states = np.array(gauge_h_states[gauge_id])  # Shape: (n_days, hidden_size)
    c_states = np.array(gauge_c_states[gauge_id])  # Shape: (n_days, hidden_size)

    # Save as npz
    output_file = OUTPUT_DIR / f"{gauge_id}_states.npz"
    np.savez_compressed(output_file, h_n=h_states, c_n=c_states, gauge_id=gauge_id)

print(f"\nSaved {len(gauge_h_states)} state files to {OUTPUT_DIR}")

In [None]:
# =============================================================================
# Save states - Option 2: Combined file with all gauges
# =============================================================================

# Create combined arrays
# This is useful for the correlation analysis

all_gauge_ids = sorted(gauge_h_states.keys())
hidden_size = len(gauge_h_states[all_gauge_ids[0]][0])

# Find common number of timesteps (use minimum across gauges)
min_timesteps = min(len(gauge_h_states[g]) for g in all_gauge_ids)
print(f"Using {min_timesteps} timesteps per gauge (minimum across all gauges)")

# Create arrays: (n_gauges, n_timesteps, hidden_size)
n_gauges = len(all_gauge_ids)
all_h_states = np.zeros((n_gauges, min_timesteps, hidden_size), dtype=np.float32)
all_c_states = np.zeros((n_gauges, min_timesteps, hidden_size), dtype=np.float32)

for i, gauge_id in enumerate(all_gauge_ids):
    all_h_states[i] = np.array(gauge_h_states[gauge_id][:min_timesteps])
    all_c_states[i] = np.array(gauge_c_states[gauge_id][:min_timesteps])

# Save combined file
combined_file = OUTPUT_DIR / "all_gauges_states.npz"
np.savez_compressed(
    combined_file,
    h_n=all_h_states,  # Shape: (n_gauges, n_timesteps, hidden_size)
    c_n=all_c_states,  # Shape: (n_gauges, n_timesteps, hidden_size)
    gauge_ids=np.array(all_gauge_ids),  # Gauge ID ordering
    hidden_size=hidden_size,
    n_timesteps=min_timesteps,
)

print(f"\nSaved combined states:")
print(f"  File: {combined_file}")
print(f"  h_n shape: {all_h_states.shape}")
print(f"  c_n shape: {all_c_states.shape}")
print(f"  Gauges: {n_gauges}")

In [None]:
# =============================================================================
# Summary
# =============================================================================

print("=" * 70)
print("EXTRACTION COMPLETE")
print("=" * 70)
print(f"\nExtracted LSTM states for {len(gauge_h_states)} gauges")
print(f"Hidden size: {hidden_size}")
print(f"Timesteps per gauge: {min_timesteps}")
print(f"\nOutput files:")
print(f"  Individual: {OUTPUT_DIR}/<gauge_id>_states.npz")
print(f"  Combined:   {combined_file}")
print(f"\nTo use in analysis notebook:")
print("""```python
# Load combined states
states = np.load('data/optimization/lstm_states_per_gauge/all_gauges_states.npz')
h_n = states['h_n']  # (n_gauges, n_timesteps, hidden_size)
c_n = states['c_n']  # (n_gauges, n_timesteps, hidden_size)
gauge_ids = states['gauge_ids']

# Get states for specific gauge
gauge_idx = np.where(gauge_ids == '12345')[0][0]
gauge_h_states = h_n[gauge_idx]  # (n_timesteps, hidden_size)
```""")

## Next Steps

After running this notebook, you'll have per-gauge LSTM states that can be properly analyzed:

1. **Update `c3_LSTMStates_v2.ipynb`** to load from `all_gauges_states.npz`
2. **Per-gauge correlation**: Correlate each gauge's state trajectory with its own meteo data
3. **Cluster analysis**: Group gauges by which cells dominate, compare to hybrid clusters
4. **h_n vs c_n comparison**: Now valid since states are gauge-specific

The key difference from before:
- **Before**: One state trajectory correlated with 996 different gauges' meteo
- **Now**: Each gauge's state trajectory correlated with its own meteo