# LSTM (Long Short-Term Memory)

In this notebook, we train LSTM models to predict dry (0) or wet (1) for a given HJ Andrews River Site on a given date.

We use the following advanced techniques:
- **Optuna** for hyperparameter optimization
- **ADASYN** for handling class imbalance
- **Permutation importance** for feature analysis
- **Early stopping** to prevent overfitting

At the bottom of the notebook, we provide a function for practitioners to run inference with our trained models for an inputted site-date combination.

## Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score,
    recall_score, f1_score, roc_auc_score
)

from imblearn.over_sampling import ADASYN

import optuna
from optuna.visualization import plot_param_importances
from optuna.visualization.matplotlib import plot_contour

import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(42)

## Data Preprocessing

In [None]:
#Load static variables and observations
static_vars_df = pd.read_parquet('/Users/michaelmurphy/Desktop/usgs_data/static_vars.parquet')
obs_df = pd.read_parquet('/Users/michaelmurphy/Desktop/usgs_data/obs.parquet')

print('static_vars_df head:')
print(static_vars_df.head())
print('\nobs_df head:')
print(obs_df.head())

In [None]:
#Pivot static vars (each variable as a column)
static_wide = static_vars_df.pivot(index='NHDPlusID', columns='variable', values='value').reset_index()

#Pivot obs_df (Date as rows, variable as columns)
obs_wide = obs_df.pivot_table(index=['NHDPlusID', 'Date'], columns='variable', values='value').reset_index()

#Merge static features into each site's obs
merged_df = obs_wide.merge(static_wide, on='NHDPlusID', how='left')

#Sort and forward-fill
merged_df = merged_df.sort_values(['NHDPlusID', 'Date']).ffill()
merged_df.head()

In [None]:
#Discretization: Convert continuous discharge into wet/dry
DRY_THRESHOLD = 0.00014   # threshold for discharge

#Create a new binary column from discharge
merged_df["wetdry_discharge"] = (merged_df["Discharge_CMS"] >= DRY_THRESHOLD).astype(int)

#Combine HOBO + discharge discretization
merged_df["wetdry_final"] = merged_df["HoboWetDry0.05"]
merged_df["wetdry_final"] = merged_df["wetdry_final"].fillna(merged_df["wetdry_discharge"])

print("Final combined wet/dry distribution:")
print(merged_df["wetdry_final"].value_counts())

In [None]:
#Pick sites that have both classes (wet and dry)
valid_obs = merged_df.dropna(subset=["wetdry_final"])
site_variation = valid_obs.groupby("NHDPlusID")["wetdry_final"].nunique()
sites_with_both = site_variation[site_variation > 1].index.tolist()

#Pick site with most valid observations
site_counts = merged_df[merged_df["NHDPlusID"].isin(sites_with_both)].groupby("NHDPlusID")["wetdry_final"].apply(lambda x: x.dropna().shape[0])
best_site = site_counts.idxmax()
print(f"Using site {best_site} with {site_counts[best_site]} valid samples")

#Subset data for that site
site_df = merged_df[merged_df["NHDPlusID"] == best_site].dropna(subset=["wetdry_final"]).sort_values("Date")

In [None]:
#Fill all NaNs
df = site_df.copy()
df = df.fillna(method='ffill').fillna(method='bfill').fillna(0)

target_col = "wetdry_final"
y_original = df[target_col].astype(int).values

#Select numeric features only (remove date, ID, target)
df_numeric = df.select_dtypes(include=[np.number]).drop(columns=[target_col], errors='ignore')

#Scale features
scaler = StandardScaler()
scaled = scaler.fit_transform(df_numeric)
df_scaled = pd.DataFrame(scaled, columns=df_numeric.columns)

#Add label back
df_scaled_with_label = df_scaled.copy()
df_scaled_with_label[target_col] = y_original

print("df_scaled_with_label shape:", df_scaled_with_label.shape)

## Sequence Creation

In [None]:
#Create sequences for LSTM (30-day window)
def create_sequences(df, seq_len=30, target_col="wetdry_final"):
    X, y = [], []
    values = df.drop(target_col, axis=1).values
    labels = df[target_col].values
    
    for i in range(len(df) - seq_len):
        seq = values[i:i+seq_len]
        target = labels[i+seq_len]
        
        X.append(seq)
        y.append(target)
    
    return np.array(X), np.array(y)

seq_len = 30
X, y = create_sequences(df_scaled_with_label, seq_len=seq_len, target_col=target_col)

print("NaNs in X:", np.isnan(X).sum())
print("NaNs in y:", np.isnan(y).sum())
print(f"Generated {len(X)} sequences with seq_len={seq_len}")
print("Shapes:", X.shape, y.shape)

## ADASYN for Class Imbalance

In [None]:
#Apply ADASYN to fix class imbalance
n, T, d = X.shape

#Flatten sequences: (n, T, d) -> (n, T*d)
X_flat = X.reshape(n, T * d)

adasyn = ADASYN(random_state=42)
X_flat_res, y_res = adasyn.fit_resample(X_flat, y.astype(int))

#Reshape back to (n_resampled, T, d)
X_res = X_flat_res.reshape(-1, T, d)

print("Original class counts:", np.bincount(y.astype(int)))
print("After ADASYN:", np.bincount(y_res.astype(int)))
print("Original shape:", X.shape, "Resampled shape:", X_res.shape)

In [None]:
#Convert to tensors (use ADASYN output)
X_tensor = torch.tensor(X_res, dtype=torch.float32)
y_tensor = torch.tensor(y_res.reshape(-1, 1), dtype=torch.float32)

print("X_tensor:", X_tensor.shape)
print("y_tensor:", y_tensor.shape)

## Train/Validation Split

In [None]:
#Train/val split
X_train, X_val, y_train, y_val = train_test_split(
    X_tensor, y_tensor, test_size=0.2, shuffle=True, random_state=42
)

print("Train:", X_train.shape, "Val:", X_val.shape)

## Model Architecture and Hyperparameter Optimization with Optuna

In [None]:
#Define LSTM model class
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return self.sigmoid(out)

In [None]:
#Define Optuna objective
def objective(trial):
    hidden_size = trial.suggest_int("hidden_size", 32, 128)
    num_layers  = trial.suggest_int("num_layers", 1, 3)
    dropout     = trial.suggest_float("dropout", 0.0, 0.4)
    lr          = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    batch_size  = trial.suggest_categorical("batch_size", [16, 32, 64])
    epochs      = 7 # prev 7  # short tuning run
    
    model = LSTMModel(
        input_size=X_tensor.shape[2],
        hidden_size=hidden_size,
        num_layers=num_layers,
        dropout=dropout
    )
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False)
    
    # Training
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_loader:
            optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
    
    # Validation
    model.eval()
    preds, actuals = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            p = model(xb).squeeze().cpu().numpy()
            preds.append(p)
            actuals.append(yb.squeeze().cpu().numpy())
    
    preds = np.concatenate(preds)
    actuals = np.concatenate(actuals)
    
    if len(np.unique(actuals)) < 2:
        return 1.0  # meaningless trial
    
    auc = roc_auc_score(actuals, preds)
    return 1 - auc  # minimize (1 - AUC)

#Run Optuna study
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=7) # prev 7

print("\nâœ… Best Hyperparameters:")
print(study.best_params)
print(f"Best Validation AUC: {1 - study.best_value:.4f}")

## Optuna Visualizations

In [None]:
#Contour plot: lr vs dropout
fig = plot_contour(
    study,
    params=["lr", "dropout"]
)
plt.title("Contour Plot: lr vs dropout")
plt.tight_layout()
plt.show()

In [None]:
#Parameter importances
fig = plot_param_importances(study)
fig.show()

## Model Training with Best Hyperparameters

In [None]:
#Use best params from Optuna (update these with your results)
best_params = study.best_params

#Compute pos_weight for class balance
pos = int(y_train.sum().item())
neg = int((1 - y_train).sum().item())
safe_pos_weight = 1.0  # mild weight
pos_weight_tensor = torch.tensor([safe_pos_weight], dtype=torch.float32)

print("safe_pos_weight =", safe_pos_weight)

In [None]:
#Define final model with best hyperparameters
model = LSTMModel(
    input_size=X_tensor.shape[2],
    hidden_size=best_params['hidden_size'],
    num_layers=best_params['num_layers'],
    dropout=best_params['dropout']
)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
optimizer = optim.Adam(model.parameters(), lr=best_params['lr'])

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=best_params['batch_size'], shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val, y_val), batch_size=best_params['batch_size'], shuffle=False)

In [None]:
#Train with early stopping
best_val_loss = float("inf")
best_state = None
patience = 3
wait = 0
epochs = 15 # prev 15

for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for xb, yb in train_loader:
        optimizer.zero_grad()
        
        logits = model(xb)
        logits = torch.nan_to_num(logits, nan=0.0, posinf=5.0, neginf=-5.0)
        
        loss = criterion(logits, yb)
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    # Validation loss
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            logits = model(xb)
            logits = torch.nan_to_num(logits, nan=0.0, posinf=5.0, neginf=-5.0)
            val_loss += criterion(logits, yb).item()
    
    print(f"Epoch {epoch+1}/{epochs} | Train Loss={total_loss/len(train_loader):.4f} | Val Loss={val_loss/len(val_loader):.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state = model.state_dict().copy()
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping triggered.")
            break

model.load_state_dict(best_state)
print("Loaded best model state.")

## Metrics and Evaluation

In [None]:
#Get predictions safely
def get_probs(loader):
    model.eval()
    preds, actuals = [], []
    with torch.no_grad():
        for xb, yb in loader:
            logits = model(xb)
            logits = torch.nan_to_num(logits, nan=0.0)
            
            probs = torch.sigmoid(logits).squeeze().cpu().numpy()
            probs = np.nan_to_num(probs, nan=0.0)
            
            preds.append(probs)
            actuals.append(yb.squeeze().cpu().numpy())
    
    return np.concatenate(preds), np.concatenate(actuals)

train_probs, train_true = get_probs(train_loader)
val_probs, val_true = get_probs(val_loader)

#Sanitize all values
train_probs = np.nan_to_num(train_probs, nan=0.0, posinf=1.0, neginf=0.0)
val_probs   = np.nan_to_num(val_probs, nan=0.0, posinf=1.0, neginf=0.0)

train_true = np.nan_to_num(train_true, nan=0.0)
val_true   = np.nan_to_num(val_true, nan=0.0)

In [None]:
#Apply threshold
threshold = 0.5
train_pred = (train_probs >= threshold).astype(int)
val_pred   = (val_probs >= threshold).astype(int)

#Compute metrics
def compute_metrics(y_true, y_prob, y_pred):
    cm   = confusion_matrix(y_true, y_pred)
    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    f1   = f1_score(y_true, y_pred)
    auc  = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else float('nan')
    return cm, acc, prec, rec, f1, auc

train_cm, train_acc, train_prec, train_rec, train_f1, train_auc = compute_metrics(train_true, train_probs, train_pred)
val_cm, val_acc, val_prec, val_rec, val_f1, val_auc = compute_metrics(val_true, val_probs, val_pred)

print("\n================ TRAIN METRICS ================")
print(f"Accuracy:  {train_acc:.4f}")
print(f"Precision: {train_prec:.4f}")
print(f"Recall:    {train_rec:.4f}")
print(f"F1 Score:  {train_f1:.4f}")
print(f"AUC:       {train_auc:.4f}")
print("Confusion Matrix:\n", train_cm)

print("\n================ VAL METRICS ================")
print(f"Accuracy:  {val_acc:.4f}")
print(f"Precision: {val_prec:.4f}")
print(f"Recall:    {val_rec:.4f}")
print(f"F1 Score:  {val_f1:.4f}")
print(f"AUC:       {val_auc:.4f}")
print("Confusion Matrix:\n", val_cm)

In [None]:
#Plot confusion matrices
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

sns.heatmap(train_cm, annot=True, fmt='d', cmap='Blues', ax=ax[0])
ax[0].set_title("Train Confusion Matrix")
ax[0].set_xlabel('Predicted')
ax[0].set_ylabel('Actual')

sns.heatmap(val_cm, annot=True, fmt='d', cmap='Blues', ax=ax[1])
ax[1].set_title("Validation Confusion Matrix")
ax[1].set_xlabel('Predicted')
ax[1].set_ylabel('Actual')

plt.tight_layout()
plt.show()

## Feature Analysis

In [None]:
#Correlation heatmap
numeric_df = df_scaled_with_label.select_dtypes(include=[np.number])

plt.figure(figsize=(12, 10))
sns.heatmap(numeric_df.corr(), cmap='coolwarm', center=0, annot=False)
plt.title("Correlation Heatmap of All Numeric Features", fontsize=16)
plt.show()

In [None]:
#Correlation with target
corr_target = numeric_df.corr()[target_col].sort_values(ascending=False)

plt.figure(figsize=(6, 10))
sns.barplot(y=corr_target.index, x=corr_target.values, palette="viridis")
plt.title("Correlation of Features with Wet/Dry Target")
plt.xlabel("Correlation")
plt.ylabel("Feature")
plt.tight_layout()
plt.show()

In [None]:
#Permutation importance
def perm_importance(model, X_val, y_val, n_repeats=5):
    base_preds = torch.sigmoid(model(X_val)).detach().cpu().numpy()
    base_f1 = f1_score(y_val, (base_preds > 0.5))
    
    importances = []
    
    for col in range(X_val.shape[2]):
        f1_scores = []
        for _ in range(n_repeats):
            X_permuted = X_val.clone()
            X_permuted[:, :, col] = X_val[:, :, col][torch.randperm(X_val.shape[0])]
            perm_preds = torch.sigmoid(model(X_permuted)).detach().cpu().numpy()
            f1_scores.append(f1_score(y_val, (perm_preds > 0.5)))
        importances.append(base_f1 - np.mean(f1_scores))
    
    return np.array(importances)

imps = perm_importance(model, X_val, val_true)
plt.figure(figsize=(8, 6))
sns.barplot(x=np.arange(len(imps)), y=imps)
plt.title("Permutation Importance per Feature Index")
plt.xlabel("Feature Index")
plt.ylabel("Importance (F1 Drop)")
plt.show()

## Additional Visualizations

In [None]:
#Discharge lagged features
df_vis = site_df.copy()
df_vis = df_vis.sort_values("Date")

df_vis["discharge_lag1"] = df_vis["Discharge_CMS"].shift(1)
df_vis["discharge_lag7"] = df_vis["Discharge_CMS"].shift(7)
df_vis["discharge_lag30"] = df_vis["Discharge_CMS"].shift(30)

plt.figure(figsize=(14, 4))
plt.subplot(1, 3, 1)
sns.scatterplot(x=df_vis["discharge_lag1"], y=df_vis["Discharge_CMS"], s=10)
plt.title("Discharge vs Lag-1")
plt.xlabel("Lag 1")
plt.ylabel("Current")

plt.subplot(1, 3, 2)
sns.scatterplot(x=df_vis["discharge_lag7"], y=df_vis["Discharge_CMS"], s=10)
plt.title("Discharge vs Lag-7")
plt.xlabel("Lag 7")

plt.subplot(1, 3, 3)
sns.scatterplot(x=df_vis["discharge_lag30"], y=df_vis["Discharge_CMS"], s=10)
plt.title("Discharge vs Lag-30")
plt.xlabel("Lag 30")

plt.tight_layout()
plt.show()

In [None]:
#Rolling wet/dry trend
df_vis["rolling_wet"] = df_vis[target_col].rolling(100).mean()

plt.figure(figsize=(12, 4))
plt.plot(df_vis["Date"], df_vis["rolling_wet"])
plt.title("Rolling % of Wet Days (Window=100)")
plt.ylabel("Pct Wet")
plt.xlabel("Date")
plt.grid(True)
plt.show()

## Inference

In [None]:
def predict_site_date_lstm(model, scaler, site_df, seq_len, site_id, date, target_col="wetdry_final"):
    """
    Predict wet/dry status for a given site and date using LSTM.
    
    Parameters:
    -----------
    model : LSTMModel
        Trained LSTM model
    scaler : StandardScaler
        Fitted scaler for feature normalization
    site_df : pd.DataFrame
        Complete dataset for the site with all features
    seq_len : int
        Sequence length used in training (e.g., 30)
    site_id : str
        Site identifier (e.g., "55000900061097")
    date : str
        Date in format "YYYY-MM-DD"
    target_col : str
        Target column name
    
    Returns:
    --------
    str : Prediction result with probability
    """
    date = pd.to_datetime(date).date()
    site_id = str(site_id)
    
    # Get site data
    site_data = site_df[site_df["NHDPlusID"] == site_id].sort_values("Date")
    
    if site_data.empty:
        return f"No data found for Site {site_id}"
    
    # Find the date index
    date_idx = site_data[site_data["Date"] == date].index
    
    if len(date_idx) == 0:
        return f"No data found for Site {site_id} on {date.date()}"
    
    date_idx = date_idx[0]
    date_pos = site_data.index.get_loc(date_idx)
    
    # Check if we have enough history
    if date_pos < seq_len:
        return f"Not enough historical data (need {seq_len} days) for {date.date()}"
    
    # Extract sequence
    seq_data = site_data.iloc[date_pos - seq_len:date_pos]
    
    # Prepare features
    df_numeric = seq_data.select_dtypes(include=[np.number]).drop(columns=[target_col], errors='ignore')
    seq_scaled = scaler.transform(df_numeric)
    
    # Convert to tensor
    seq_tensor = torch.tensor(seq_scaled, dtype=torch.float32).unsqueeze(0)  # (1, seq_len, features)
    
    # Predict
    model.eval()
    with torch.no_grad():
        logits = model(seq_tensor)
        prob = torch.sigmoid(logits).item()
    
    pred_class = 1 if prob >= 0.5 else 0
    
    return f"Site {site_id} on {date}: {'DRY' if pred_class == 0 else 'WET'}, (P(wet)={prob:.4f})"

In [None]:
#Example inference
# predict_site_date_lstm(
#     model=model,
#     scaler=scaler,
#     site_df=merged_df,
#     seq_len=30,
#     site_id=best_site,
#     date="2020-10-22",
#     target_col="wetdry_final"
# )

predict_site_date_lstm(
    model=model,
    scaler=scaler,
    site_df=merged_df,
    seq_len=30,
    site_id="55000900272714",
    date="2020-10-18",
    target_col="wetdry_final"
)

In [None]:
site_id="55000900272714"
# date="2020-10-18"
date="2020-06-24"

site_df=merged_df

date = pd.to_datetime(date).date()
site_id = str(site_id)

# Get site data
site_data = site_df[site_df["NHDPlusID"] == site_id].sort_values("Date")

# Find the date index
date_idx = site_data[site_data["Date"] == date].index

print(date)

print(site_data["Date"].iloc[0])

# date_idx



In [None]:
best_site

In [None]:
merged_df