# RNN Transfer Demonstration with Steady States

The goal of this notebook is to analyze time-rescaling of a pre-trained RNN when predicting constant data. When the RNN is "well behaved" and approaches an equilibrium at some characteristic rate, scaling the bias terms of the forget and input gate can change the rate at which the system equilibrates

The RNN was pre-trained on 10h FMC sensors, with atmospheric inputs standard scaled. 100 replications of the RNN varied train/val split and initial weights. 

The replications are analyzed to see if they exhibit steady-state behavior and if time-warp is consistent with changing a characteristic rate

## Setup

In [None]:
import numpy as np
import h5py
import os
import tensorflow as tf
import pandas as pd
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from src.models import moisture_rnn as mrnn
from src.utils import read_yml, time_intp, plot_styles

In [None]:
# Read Trained model
params = read_yml("models/params.yaml")
rnn = mrnn.RNN_Flexible(params=params)
scaler = joblib.load("models/scaler.joblib")

In [None]:
rnn.load_weights('models/rnn.keras')

In [None]:
ml_data = pd.read_pickle("models/ml_data.pkl")

In [None]:
results = pd.read_pickle("outputs/steady_reps/results_reps.pkl")

In [None]:
nsteps=168
spinup=12

In [None]:
# document-safe plot defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

### Constant Mean Case

### Real Observation Cases

In [None]:
# Station in Rocky Mountain Arsenal
st1 = "AENC2"
dat1 = ml_data[st1]["data"].iloc[-1] # Last available time
print(f"Station: {st1}")
print(f"Time (UTC): {ml_data[st1]['times'][-1]}")
print(dat1[params["features_list"]])

In [None]:
# Station SW of denver
st2 = "BAWC2"
dat2 = ml_data[st2]["data"].iloc[12345] 
print(f"Station: {st2}")
print(f"Time (UTC): {ml_data[st2]['times'][12345]}")
print(dat2[params["features_list"]])

## Constant Mean - Behavior Across Reps

In [None]:
base_mean_case = np.stack([results[seed]["base"]["p0"] for seed in results], axis=0)

# Find median of last time step to center plot:
col = base_mean_case[:, -1]
median_val = np.median(col)
row_idx = np.argmin(np.abs(col - median_val))

### Viz

In [None]:
pstyle = plot_styles["model"] # 10h plotting colors, since these weights are un-warped
pstyle["label"] = "Median Curve"

In [None]:
spinup = 12

Y = base_mean_case[:, spinup:]          # (N, T)
t = np.arange(Y.shape[1])

# Center line: the selected "median-final-value" seed
center_curve = base_mean_case[row_idx, spinup:]

# Percentile band across seeds at each timestep (IQR)
lower = np.percentile(Y, 25, axis=0)
upper = np.percentile(Y, 75, axis=0)

fig, ax = plt.subplots(dpi=DPI)

ax.plot(t, center_curve, linewidth=3, **pstyle)

ax.fill_between(
    t, lower, upper,
    alpha=0.3,
    color=pstyle["color"],
    label="25–75 percentile (across seeds)"
)

ax.set_xlabel("Timestep (Post-Spinup)", fontsize=LABEL_SIZE)
ax.set_ylabel("FM10 (%)", fontsize=LABEL_SIZE)
ax.tick_params(axis="both", labelsize=TICK_SIZE)

ax.legend(fontsize=TICK_SIZE)

plt.tight_layout()
plt.show()

In [None]:
eq = base_mean_case[:, -1]

mu = np.mean(eq)
sigma = np.std(eq)

fig, ax = plt.subplots(dpi=DPI)

# histogram
ax.hist(eq,
        bins=10,
        color=pstyle["color"],
        edgecolor="white",
        alpha=0.3)

# mean line
ax.axvline(mu,
           linestyle="--",
           linewidth=3,
           color=pstyle["color"],
           label="Mean")

# ±1 std lines
ax.axvline(mu + sigma,
           linestyle="--",
           linewidth=2,
           color=pstyle["color"],
           alpha=0.7,
           label=r"$\pm$1 SD")

ax.axvline(mu - sigma,
           linestyle="--",
           linewidth=2,
           color=pstyle["color"],
           alpha=0.7)

ax.set_xlabel("Final FM10 Value (%)", fontsize=LABEL_SIZE)
ax.set_ylabel("Number of Replications", fontsize=LABEL_SIZE)

ax.tick_params(axis="both", labelsize=TICK_SIZE)

ax.legend(fontsize=TICK_SIZE)

plt.tight_layout()
plt.show()

In [None]:
# Tail-end behavior metrics
K = 24
spinup = 12

Y = base_mean_case[:, spinup:]
tail = Y[:, -K:]

slope_end = (tail[:, -1] - tail[:, 0]) / (K - 1)

dy = np.diff(tail, axis=1)
slope_mean = dy.mean(axis=1)
wiggle = dy.std(axis=1)

# Side-by-side histograms
fig, axes = plt.subplots(1, 2, figsize=(12, 4), dpi=DPI)

color = pstyle["color"]

# --- Histogram 1: Mean tail slope (drift) ---
axes[0].hist(slope_mean,
             bins="auto",
             color=color,
             alpha=0.4,
             edgecolor="white",
             linewidth=1.2)

axes[0].axvline(np.mean(slope_mean),
                linestyle="--",
                linewidth=2,
                color=color)

axes[0].set_title(f"Tail Mean Slope Across Seeds (Last {K} Steps)",
                  fontsize=LABEL_SIZE)

axes[0].set_xlabel("Mean ΔFMC per Timestep\n(percentage points per timestep)",
                   fontsize=LABEL_SIZE)

axes[0].set_ylabel("Number of Replications",
                   fontsize=LABEL_SIZE)

axes[0].tick_params(axis="both", labelsize=TICK_SIZE)


# --- Histogram 2: Tail wiggle (std of increments) ---
axes[1].hist(wiggle,
             bins="auto",
             color=color,
             alpha=0.4,
             edgecolor="white",
             linewidth=1.2)

axes[1].axvline(np.mean(wiggle),
                linestyle="--",
                linewidth=2,
                color=color)

axes[1].set_title(f"Tail Increment Variability Across Seeds (Last {K} Steps)",
                  fontsize=LABEL_SIZE)

axes[1].set_xlabel("Std of ΔFMC per Timestep\n(percentage points per timestep)",
                   fontsize=LABEL_SIZE)

axes[1].set_ylabel("Number of Replications",
                   fontsize=LABEL_SIZE)

axes[1].tick_params(axis="both", labelsize=TICK_SIZE)

plt.tight_layout()
plt.show()

## Constant Mean - Time Warped Cases

In [None]:
fast_mean_case = np.stack([results[seed]["fast"]["p0"] for seed in results], axis=0)
slow_mean_case = np.stack([results[seed]["slow"]["p0"] for seed in results], axis=0)

In [None]:
import matplotlib
cmap = matplotlib.colormaps.get_cmap("viridis")
n = 3
colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0.2, 0.9, n)]

In [None]:
xgrid = np.arange(nsteps)

fig, ax = plt.subplots(dpi=DPI)
ax.plot(xgrid[spinup:], fast_mean_case[row_idx, spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], base_mean_case[row_idx, spinup:], c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], slow_mean_case[row_idx, spinup:], c=colors[2], label="Slow")

ax.set_xlabel("Timestep (Post-Spinup)", fontsize=LABEL_SIZE)
ax.set_ylabel("FM10 (%)", fontsize=LABEL_SIZE)
ax.tick_params(axis="both", labelsize=TICK_SIZE)

ax.legend(fontsize=TICK_SIZE)

plt.tight_layout()
plt.show()

## Sine Wave Case

## Viz

In [None]:
import matplotlib.colors as mcolors
c0 = plot_styles["fm"]["color"]  # '#468a29'

# convert to RGB
rgb0 = np.array(mcolors.to_rgb(c0))

# define lighter and darker variants
light = rgb0 + 0.35 * (1.0 - rgb0)
dark  = rgb0 * 0.65

# clamp to [0, 1]
colors = [
    mcolors.to_hex(np.clip(light, 0, 1)),
    mcolors.to_hex(np.clip(rgb0,  0, 1)),
    mcolors.to_hex(np.clip(dark,  0, 1)),
]

In [None]:
import matplotlib
cmap = matplotlib.colormaps.get_cmap("viridis")
n = 3
colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0.2, 0.9, n)]

In [None]:
# document-safe defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

# Plotting predictions
xgrid = np.arange(nsteps)
fig, ax = plt.subplots(dpi=DPI)

ax.plot(xgrid[spinup:], x0[spinup:, 0], **plot_styles["Ed"])
ax.plot(xgrid[spinup:], x0[spinup:, 1], **plot_styles["Ew"])

# ax.set_title("Constant Mean", fontsize=LABEL_SIZE)
ax.plot(xgrid[spinup:], p0_fast[spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], p0[spinup:],      c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], p0_slow[spinup:], c=colors[2], label="Slow")
ax.set_ylabel("FMC (%)", fontsize = LABEL_SIZE)
ax.set_xlabel("Hour", fontsize = LABEL_SIZE)
ax.tick_params(labelsize=TICK_SIZE)
ax.legend(
    loc="upper left",
    bbox_to_anchor = (1.02, .7),
    fontsize=TICK_SIZE
)
ax.grid()
ax.set_title(f"Constant Artificial Input - Constant Mean")

In [None]:
# document-safe defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

# Plotting predictions
xgrid = np.arange(nsteps)
fig, ax = plt.subplots(dpi=DPI)

plot_styles["fm"]["linestyle"] = "dashed"
ax.plot(xgrid[spinup:], np.repeat(dat1.fm, len(xgrid[spinup:])), **plot_styles["fm"])
ax.plot(xgrid[spinup:], p1_fast[spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], p1[spinup:],      c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], p1_slow[spinup:], c=colors[2], label="Slow")
ax.set_ylabel("FMC (%)", fontsize = LABEL_SIZE)
ax.set_xlabel("Hour", fontsize = LABEL_SIZE)
ax.tick_params(labelsize=TICK_SIZE)
ax.legend(
    loc="upper left",
    bbox_to_anchor = (1.02, .7),
    fontsize=TICK_SIZE
)
ax.grid()
ax.set_title(f"Constant Input from {st1} at {dat1.date_time.strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# document-safe defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

# Plotting predictions
xgrid = np.arange(nsteps)
fig, ax = plt.subplots(dpi=DPI)

plot_styles["fm"]["linestyle"] = "dashed"
ax.plot(xgrid[spinup:], np.repeat(dat2.fm, len(xgrid[spinup:])), **plot_styles["fm"])
ax.plot(xgrid[spinup:], p2_fast[spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], p2[spinup:],      c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], p2_slow[spinup:], c=colors[2], label="Slow")
ax.set_ylabel("FMC (%)", fontsize = LABEL_SIZE)
ax.set_xlabel("Hour", fontsize = LABEL_SIZE)
ax.tick_params(labelsize=TICK_SIZE)
ax.legend(
    loc="upper left",
    bbox_to_anchor = (1.02, .7),
    fontsize=TICK_SIZE
)
ax.grid()
ax.set_title(f"Constant Input from {st2} at {dat2.date_time.strftime('%Y-%m-%d %H:%M:%S')}")