In [1]:
from collections import defaultdict
from pathlib import Path

from neuralhydrology.datasetzoo import get_dataset
from neuralhydrology.datautils.utils import load_scaler
from neuralhydrology.modelzoo.cudalstm import CudaLSTM
from neuralhydrology.modelzoo.customlstm import CustomLSTM
from neuralhydrology.utils.config import Config
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4080 SUPER


# LSTM State Extraction for All Gauges

This notebook properly extracts LSTM hidden states (h_n) and cell states (c_n) for **each gauge** separately.

**Key insight**: The previous approach only extracted states from the first batch (â‰ˆ1 gauge).
This notebook iterates through ALL samples and organizes states by gauge ID.

**Requirements**: This notebook requires PyTorch and CUDA. Run on a machine with GPU support.

Reference: https://neuralhydrology.readthedocs.io/en/latest/tutorials/inspect-lstm.html

In [2]:
def best_epoch_finder(validation_dir: Path) -> int:
    """Find the best epoch based on median NSE across validation metrics.

    Args:
        validation_dir: Path to the validation directory containing epoch subdirectories

    Returns:
        The epoch number with the highest median NSE
    """
    epoch_median_nse = {}

    # Iterate through all epoch directories
    for epoch_dir in sorted(validation_dir.glob("model_epoch*")):
        metrics_file = epoch_dir / "validation_metrics.csv"

        if not metrics_file.exists():
            continue

        # Read validation metrics CSV
        df = pd.read_csv(metrics_file)

        # Extract epoch number from directory name (e.g., "model_epoch030" -> 30)
        epoch_num = int(epoch_dir.name.split("model_epoch")[1])

        # Calculate median NSE across all basins
        median_nse = df["NSE"].median()
        epoch_median_nse[epoch_num] = median_nse

    if not epoch_median_nse:
        raise ValueError(f"No validation metrics found in {validation_dir}")

    # Return epoch with highest median NSE
    best_epoch = max(epoch_median_nse, key=lambda x: epoch_median_nse[x])

    return best_epoch


In [3]:
# =============================================================================
# Configuration - Update these paths for your setup
# =============================================================================

# Path to your trained model config
CONFIG_FILE = Path(
    "../data/lstm_configs/model_runs/FULL_cudalstm_q_mm_day_256_365_e5l_1201_193434/config.yml"
)

# Model epoch to use (set to None to auto-detect best epoch)
MODEL_EPOCH = best_epoch_finder(
    Path(
        "../data/lstm_configs/model_runs/FULL_cudalstm_q_mm_day_256_365_e5l_1201_193434/validation"
    )
)

# Output directory for extracted states
OUTPUT_DIR = Path("../data/optimization/lstm_states_per_gauge")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Batch size for processing (adjust based on GPU memory)
# Larger = faster but more memory
BATCH_SIZE = 256

# Device selection
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"  # For Apple Silicon
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")
print(f"Config file: {CONFIG_FILE}")
print(f"Output directory: {OUTPUT_DIR}")

Using device: cuda
Config file: ../data/lstm_configs/model_runs/FULL_cudalstm_q_mm_day_256_365_e5l_1201_193434/config.yml
Output directory: ../data/optimization/lstm_states_per_gauge


In [4]:
# =============================================================================
# Load trained model
# =============================================================================

# Load config
cudalstm_config = Config(CONFIG_FILE)
cudalstm_config.update_config(
    {
        "data_dir": Path(f"../{cudalstm_config.data_dir}"),
        "run_dir": Path(f"../{cudalstm_config.run_dir}"),
        "test_basin_file": Path(f"../{cudalstm_config.test_basin_file}"),
        "validation_basin_file": Path(f"../{cudalstm_config.validation_basin_file}"),
        "train_basin_file": Path(f"../{cudalstm_config.train_basin_file}"),
    }
)
print(f"Model hidden size: {cudalstm_config.hidden_size}")
print(f"Sequence length: {cudalstm_config.seq_length}")

# Create CudaLSTM and load weights
cuda_lstm = CudaLSTM(cfg=cudalstm_config)
model_path = cudalstm_config.run_dir / f"model_epoch{MODEL_EPOCH:03d}.pt"
print(f"Loading model from: {model_path}")

model_weights = torch.load(str(model_path), map_location=DEVICE)
cuda_lstm.load_state_dict(model_weights)

# Create CustomLSTM and copy weights (for full state access)
custom_lstm = CustomLSTM(cfg=cudalstm_config)
custom_lstm.copy_weights(cuda_lstm)
custom_lstm.to(DEVICE)
custom_lstm.eval()

print("Model loaded successfully!")
print(f"Hidden size: {cudalstm_config.hidden_size}")

Model hidden size: 256


Sequence length: 365
Loading model from: ../data/lstm_configs/model_runs/FULL_cudalstm_q_mm_day_256_365_e5l_1201_193434/model_epoch011.pt
Model loaded successfully!
Hidden size: 256


In [5]:
# =============================================================================
# Load test dataset
# =============================================================================

scaler = load_scaler(cudalstm_config.run_dir)
dataset = get_dataset(cudalstm_config, is_train=False, period="test", scaler=scaler)

# Get list of basins (gauges) in the dataset
# NeuralHydrology stores basin IDs in the dataset
all_basins = (
    dataset.basins if hasattr(dataset, "basins") else list(dataset.basin_id_map.keys())
)
n_basins = len(all_basins)

print(f"Number of gauges in dataset: {n_basins}")
print(f"Total samples in dataset: {len(dataset)}")
print(f"Samples per gauge (approx): {len(dataset) // n_basins}")
print(f"First 5 gauge IDs: {all_basins[:5]}")

Number of gauges in dataset: 996
Total samples in dataset: 727080
Samples per gauge (approx): 730
First 5 gauge IDs: ['10042', '10044', '10048', '10058', '10059']


In [11]:
# =============================================================================
# Extract states for ALL gauges
# =============================================================================


def move_to_device(obj, device):
    """Recursively move tensors to device, handling nested structures."""
    if isinstance(obj, torch.Tensor):
        return obj.to(device)
    elif isinstance(obj, dict):
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        moved = [move_to_device(item, device) for item in obj]
        return type(obj)(moved)
    else:
        return obj


# Debug: inspect dataset structure to find basin mapping
print("Inspecting dataset structure...")
print(f"  lookup_table type: {type(dataset.lookup_table)}")
print(f"  lookup_table length: {len(dataset.lookup_table)}")

# Inspect dict structure
if isinstance(dataset.lookup_table, dict):
    keys = list(dataset.lookup_table.keys())[:3]
    print(f"  First 3 keys: {keys}")
    print(f"  Key types: {[type(k) for k in keys]}")
    for k in keys[:2]:
        v = dataset.lookup_table[k]
        print(f"    [{k}] -> {type(v)}: {v}")

# Check for other attributes that might contain basin info
if hasattr(dataset, "basin_id_to_sample_index"):
    print(
        f"  basin_id_to_sample_index: {list(dataset.basin_id_to_sample_index.items())[:3]}"
    )
if hasattr(dataset, "sample_to_basin"):
    print("  sample_to_basin available")
if hasattr(dataset, "_sample_index_to_basin"):
    print("  _sample_index_to_basin available")

# Print all relevant attributes
attrs = [a for a in dir(dataset) if not a.startswith("_") and "basin" in a.lower()]
print(f"  Basin-related attributes: {attrs}")

# Build sample-to-basin mapping based on dataset structure
# lookup_table is {sample_index: (basin_id, date_info)}
print("\nBuilding sample-to-basin mapping...")
n_samples = len(dataset.lookup_table)
sample_to_basin = [""] * n_samples

for sample_idx, (basin_id, date_info) in dataset.lookup_table.items():
    sample_to_basin[sample_idx] = str(basin_id)

print(f"Total samples mapped: {len(sample_to_basin)}")

# Create dataloader with batch_sampler to track indices, or use sequential access
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # Keep order to track gauge IDs
    collate_fn=dataset.collate_fn,
    num_workers=0,  # Avoid multiprocessing issues
)

# Storage for per-gauge states
# Key: gauge_id, Value: list of state arrays
gauge_h_states = defaultdict(list)  # Hidden states
gauge_c_states = defaultdict(list)  # Cell states

print(f"Processing {len(dataloader)} batches...")

sample_idx = 0  # Track global sample index

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Extracting states")):
        # Move batch to device (handles nested tensors)
        batch_device = move_to_device(batch, DEVICE)

        # Forward pass through CustomLSTM
        output = custom_lstm(batch_device)

        # Extract final timestep states: (batch, seq_len, hidden) -> (batch, hidden)
        h_n = output["h_n"][:, -1, :].cpu().numpy()  # Hidden state
        c_n = output["c_n"][:, -1, :].cpu().numpy()  # Cell state

        batch_size_actual = len(h_n)

        # Store states by gauge using the pre-built mapping
        for i in range(batch_size_actual):
            gauge_id = sample_to_basin[sample_idx]
            gauge_h_states[gauge_id].append(h_n[i])
            gauge_c_states[gauge_id].append(c_n[i])
            sample_idx += 1

print(f"\nExtracted states for {len(gauge_h_states)} gauges")

Inspecting dataset structure...
  lookup_table type: <class 'dict'>
  lookup_table length: 727080
  First 3 keys: [0, 1, 2]
  Key types: [<class 'int'>, <class 'int'>, <class 'int'>]
    [0] -> <class 'tuple'>: ('10042', [364])
    [1] -> <class 'tuple'>: ('10042', [365])
  Basin-related attributes: ['basins']

Building sample-to-basin mapping...
Total samples mapped: 727080
Processing 2841 batches...


Extracting states:   0%|          | 0/2841 [00:00<?, ?it/s]


Extracted states for 996 gauges


In [12]:
# =============================================================================
# Verify extraction
# =============================================================================

print("Per-gauge sample counts:")
sample_counts = {g: len(states) for g, states in gauge_h_states.items()}

print(f"  Min samples per gauge: {min(sample_counts.values())}")
print(f"  Max samples per gauge: {max(sample_counts.values())}")
print(f"  Mean samples per gauge: {np.mean(list(sample_counts.values())):.1f}")

# Show first few gauges
for gauge_id in list(gauge_h_states.keys())[:5]:
    print(f"  {gauge_id}: {len(gauge_h_states[gauge_id])} samples")

Per-gauge sample counts:
  Min samples per gauge: 730
  Max samples per gauge: 730
  Mean samples per gauge: 730.0
  10042: 730 samples
  10044: 730 samples
  10048: 730 samples
  10058: 730 samples
  10059: 730 samples


In [13]:
# =============================================================================
# Save states - Option 1: One file per gauge
# =============================================================================

print("Saving per-gauge state files...")

for gauge_id in tqdm(gauge_h_states.keys(), desc="Saving"):
    h_states = np.array(gauge_h_states[gauge_id])  # Shape: (n_days, hidden_size)
    c_states = np.array(gauge_c_states[gauge_id])  # Shape: (n_days, hidden_size)

    # Save as npz
    output_file = OUTPUT_DIR / f"{gauge_id}_states.npz"
    np.savez_compressed(output_file, h_n=h_states, c_n=c_states, gauge_id=gauge_id)

print(f"\nSaved {len(gauge_h_states)} state files to {OUTPUT_DIR}")

Saving per-gauge state files...


Saving:   0%|          | 0/996 [00:00<?, ?it/s]


Saved 996 state files to ../data/optimization/lstm_states_per_gauge


In [14]:
# =============================================================================
# Save states - Option 2: Combined file with all gauges
# =============================================================================

# Create combined arrays
# This is useful for the correlation analysis

all_gauge_ids = sorted(gauge_h_states.keys())
hidden_size = len(gauge_h_states[all_gauge_ids[0]][0])

# Find common number of timesteps (use minimum across gauges)
min_timesteps = min(len(gauge_h_states[g]) for g in all_gauge_ids)
print(f"Using {min_timesteps} timesteps per gauge (minimum across all gauges)")

# Create arrays: (n_gauges, n_timesteps, hidden_size)
n_gauges = len(all_gauge_ids)
all_h_states = np.zeros((n_gauges, min_timesteps, hidden_size), dtype=np.float32)
all_c_states = np.zeros((n_gauges, min_timesteps, hidden_size), dtype=np.float32)

for i, gauge_id in enumerate(all_gauge_ids):
    all_h_states[i] = np.array(gauge_h_states[gauge_id][:min_timesteps])
    all_c_states[i] = np.array(gauge_c_states[gauge_id][:min_timesteps])

# Save combined file
combined_file = OUTPUT_DIR / "all_gauges_states.npz"
np.savez_compressed(
    combined_file,
    h_n=all_h_states,  # Shape: (n_gauges, n_timesteps, hidden_size)
    c_n=all_c_states,  # Shape: (n_gauges, n_timesteps, hidden_size)
    gauge_ids=np.array(all_gauge_ids),  # Gauge ID ordering
    hidden_size=hidden_size,
    n_timesteps=min_timesteps,
)

print("\nSaved combined states:")
print(f"  File: {combined_file}")
print(f"  h_n shape: {all_h_states.shape}")
print(f"  c_n shape: {all_c_states.shape}")
print(f"  Gauges: {n_gauges}")

Using 730 timesteps per gauge (minimum across all gauges)

Saved combined states:
  File: ../data/optimization/lstm_states_per_gauge/all_gauges_states.npz
  h_n shape: (996, 730, 256)
  c_n shape: (996, 730, 256)
  Gauges: 996


In [15]:
# =============================================================================
# Summary
# =============================================================================

print("=" * 70)
print("EXTRACTION COMPLETE")
print("=" * 70)
print(f"\nExtracted LSTM states for {len(gauge_h_states)} gauges")
print(f"Hidden size: {hidden_size}")
print(f"Timesteps per gauge: {min_timesteps}")
print("\nOutput files:")
print(f"  Individual: {OUTPUT_DIR}/<gauge_id>_states.npz")
print(f"  Combined:   {combined_file}")
print("\nTo use in analysis notebook:")
print("""```python
# Load combined states
states = np.load('data/optimization/lstm_states_per_gauge/all_gauges_states.npz')
h_n = states['h_n']  # (n_gauges, n_timesteps, hidden_size)
c_n = states['c_n']  # (n_gauges, n_timesteps, hidden_size)
gauge_ids = states['gauge_ids']

# Get states for specific gauge
gauge_idx = np.where(gauge_ids == '12345')[0][0]
gauge_h_states = h_n[gauge_idx]  # (n_timesteps, hidden_size)
```""")

EXTRACTION COMPLETE

Extracted LSTM states for 996 gauges
Hidden size: 256
Timesteps per gauge: 730

Output files:
  Individual: ../data/optimization/lstm_states_per_gauge/<gauge_id>_states.npz
  Combined:   ../data/optimization/lstm_states_per_gauge/all_gauges_states.npz

To use in analysis notebook:
```python
# Load combined states
states = np.load('data/optimization/lstm_states_per_gauge/all_gauges_states.npz')
h_n = states['h_n']  # (n_gauges, n_timesteps, hidden_size)
c_n = states['c_n']  # (n_gauges, n_timesteps, hidden_size)
gauge_ids = states['gauge_ids']

# Get states for specific gauge
gauge_idx = np.where(gauge_ids == '12345')[0][0]
gauge_h_states = h_n[gauge_idx]  # (n_timesteps, hidden_size)
```
