# LSTM with Static and Dynamic Features
This notebook demonstrates training the `HydroLSTM` model using dynamic meteorological inputs and static physiographic attributes.

In [None]:
import json
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from src.models.lstm.data_builders import build_dataset_from_folder
from src.models.lstm.losses import HuberLoss
from src.models.lstm.model import HydroLSTM
from src.models.lstm.splits import split_by_ranges


In [None]:
gauge_ids = ['0001', '0002']
meteo_dir = Path('../data/MeteoData/ProcessedGauges/era5_land/res')
hydro_dir = Path('../data/HydroFiles')
static_df = ...  # load DataFrame with static attributes indexed by gauge_id
dyn_cols = ['prcp', 't_mean', 't_min', 't_max']
dataset, dates = build_dataset_from_folder(
    gauge_ids,
    meteo_dir=meteo_dir,
    hydro_dir=hydro_dir,
    temp_dir=None,
    df_static=static_df,
    dyn_feature_cols=dyn_cols,
    seq_len=90
)
train_ds, val_ds, test_ds = split_by_ranges(
    dataset,
    dates,
    train_range=('2008-01-01', '2018-12-31'),
    val_range=('2019-01-01', '2020-12-31'),
    test_range=('2021-01-01', '2022-12-31')
)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)
test_loader = DataLoader(test_ds, batch_size=64)


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HydroLSTM(
    n_dyn=dataset.dyn.shape[-1],
    n_static=dataset.static.shape[-1],
    static_mode='init'
).to(device)
loss_fn = HuberLoss(delta=1.0)
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

for epoch in range(1, 6):
    model.train()
    for seq, stat, y in train_loader:
        seq, stat, y = seq.to(device), stat.to(device), y.to(device)
        optim.zero_grad()
        pred = model(seq, stat)
        loss = loss_fn(pred, y)
        loss.backward()
        optim.step()
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for seq, stat, y in val_loader:
            pred = model(seq.to(device), stat.to(device))
            val_loss += loss_fn(pred, y.to(device)).item() * seq.size(0)
    print(f'Epoch {epoch}: val loss {val_loss / len(val_ds):.4f}')
torch.save(model.state_dict(), 'hydrolstm_static.pt')
