# RNN Transfer Demonstration with Steady States

The goal of this notebook is to analyze time-rescaling of a pre-trained RNN when predicting constant data. When the RNN is "well behaved" and approaches an equilibrium at some characteristic rate, scaling the bias terms of the forget and input gate can change the rate at which the system equilibrates

The RNN was pre-trained on 10h FMC sensors, with atmospheric inputs standard scaled.

## Setup

In [1]:
import numpy as np
import h5py
import os
import tensorflow as tf
import pandas as pd
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from src.models import moisture_rnn as mrnn
from src.utils import read_yml, time_intp, plot_styles

In [2]:
# Read Trained model
params = read_yml("models/params.yaml")
rnn = mrnn.RNN_Flexible(params=params)
scaler = joblib.load("models/scaler.joblib")

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [3]:
rnn.load_weights('models/rnn.keras')

  saveable.load_own_variables(weights_store.get(inner_path))


In [4]:
# Extract Info from RNN
lstm = rnn.get_layer("lstm")
lstm_units = lstm.units
weights10 = lstm.get_weights()

In [5]:
ml_data = pd.read_pickle("models/ml_data.pkl")

## Construct Predictors

In [6]:
print(params["features_list"])

['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'rain', 'hod', 'doy']


### Constant Mean Case

In [None]:
spinup = 12
psteps = int(24*2)
nsteps = spinup+psteps
# All zero inputs after transform, so mean of training set features
x0 = np.repeat(scaler.mean_[None, :], nsteps, axis=0)

### Sine Wave Case based on Constant Mean and Var

In [None]:
mu = scaler.mean_
var = scaler.var_

t = np.linspace(0, 2*np.pi, nsteps, endpoint=False)

# start from constant mean sequence
xs = np.repeat(mu[None, :], nsteps, axis=0)

# amplitudes for first two features
A = np.sqrt(2 * var[:2])

# apply sine only to first two features
xs[:, :2] = mu[:2] + A[None, :] * np.sin(t)[:, None]

### Real Observation Cases

In [None]:
# Station in Rocky Mountain Arsenal
st1 = "AENC2"
dat1 = ml_data[st1]["data"].iloc[-1] # Last available time
x1 = dat1[params["features_list"]].to_numpy(dtype=np.float32)
x1 = np.repeat(x1[None, :], nsteps, axis=0)
print(f"Station: {st1}")
print(f"Time (UTC): {ml_data[st1]['times'][-1]}")
print(dat1[params["features_list"]])

In [None]:
# Station SW of denver
st2 = "BAWC2"
dat2 = ml_data[st2]["data"].iloc[12345] 
x2 = dat2[params["features_list"]].to_numpy(dtype=np.float32)
x2 = np.repeat(x2[None, :], nsteps, axis=0)
print(f"Station: {st2}")
print(f"Time (UTC): {ml_data[st2]['times'][12345]}")
print(dat2[params["features_list"]])

In [None]:
# Scale
X0 = scaler.transform(x0.copy())
Xs = scaler.transform(xs.copy())

X1 = scaler.transform(x1.copy())
X2 = scaler.transform(x2.copy())


# Reshape
X0 = X0[None, :, :]
Xs = Xs[None, :, :]

X1 = X1[None, :, :]
X2 = X2[None, :, :]

## Generate RNN Predictions

Make grid of all zeros, corresponding to mean of all inputs. Then increase or decrease equilibria while holding others to zeros.

### 10h-Weight Predictions

Normal pre-trained weights. Extract weights for reuse

In [None]:
p0 = rnn.predict(X0).flatten()
ps = rnn.predict(Xs).flatten()

p1 = rnn.predict(X1).flatten()
p2 = rnn.predict(X2).flatten()

In [None]:
units = rnn.get_layer("lstm").units
weights0 = [w.copy() for w in rnn.get_layer("lstm").get_weights()]

### Speed up

In [None]:
weights_fast = rnn.get_layer("lstm").get_weights()
weights_fast[2][0:units]       = weights0[2][0:units] + 0.5
weights_fast[2][units:2*units] = weights0[2][units:2*units] - 0.5

### Slow Down

In [None]:
weights_slow = rnn.get_layer("lstm").get_weights()
weights_slow[2][0:units]       = weights0[2][0:units] - 0.5
weights_slow[2][units:2*units] = weights0[2][units:2*units] + 0.5

### Predictions with modified weights

In [None]:
rnn.get_layer("lstm").set_weights(weights_fast)
p0_fast = rnn.predict(X0).flatten()
ps_fast = rnn.predict(Xs).flatten()
p1_fast = rnn.predict(X1).flatten()
p2_fast = rnn.predict(X2).flatten()


rnn.get_layer("lstm").set_weights(weights_slow)
p0_slow = rnn.predict(X0).flatten()
ps_slow = rnn.predict(Xs).flatten()
p1_slow = rnn.predict(X1).flatten()
p2_slow = rnn.predict(X2).flatten()

## Viz

In [None]:
import matplotlib.colors as mcolors
c0 = plot_styles["fm"]["color"]  # '#468a29'

# convert to RGB
rgb0 = np.array(mcolors.to_rgb(c0))

# define lighter and darker variants
light = rgb0 + 0.35 * (1.0 - rgb0)
dark  = rgb0 * 0.65

# clamp to [0, 1]
colors = [
    mcolors.to_hex(np.clip(light, 0, 1)),
    mcolors.to_hex(np.clip(rgb0,  0, 1)),
    mcolors.to_hex(np.clip(dark,  0, 1)),
]

In [None]:
import matplotlib
cmap = matplotlib.colormaps.get_cmap("viridis")
n = 3
colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0.2, 0.9, n)]

In [None]:
# document-safe defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

# Plotting predictions
xgrid = np.arange(nsteps)
fig, ax = plt.subplots(dpi=DPI)

ax.plot(xgrid[spinup:], x0[spinup:, 0], **plot_styles["Ed"])
ax.plot(xgrid[spinup:], x0[spinup:, 1], **plot_styles["Ew"])

# ax.set_title("Constant Mean", fontsize=LABEL_SIZE)
ax.plot(xgrid[spinup:], p0_fast[spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], p0[spinup:],      c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], p0_slow[spinup:], c=colors[2], label="Slow")
ax.set_ylabel("FMC (%)", fontsize = LABEL_SIZE)
ax.set_xlabel("Hour", fontsize = LABEL_SIZE)
ax.tick_params(labelsize=TICK_SIZE)
ax.legend(
    loc="upper left",
    bbox_to_anchor = (1.02, .7),
    fontsize=TICK_SIZE
)
ax.grid()
ax.set_title(f"Constant Artificial Input - Constant Mean")

In [None]:
# document-safe defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

# Plotting predictions
xgrid = np.arange(nsteps)
fig, ax = plt.subplots(dpi=DPI)

plot_styles["fm"]["linestyle"] = "dashed"
ax.plot(xgrid[spinup:], np.repeat(dat1.fm, len(xgrid[spinup:])), **plot_styles["fm"])
ax.plot(xgrid[spinup:], p1_fast[spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], p1[spinup:],      c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], p1_slow[spinup:], c=colors[2], label="Slow")
ax.set_ylabel("FMC (%)", fontsize = LABEL_SIZE)
ax.set_xlabel("Hour", fontsize = LABEL_SIZE)
ax.tick_params(labelsize=TICK_SIZE)
ax.legend(
    loc="upper left",
    bbox_to_anchor = (1.02, .7),
    fontsize=TICK_SIZE
)
ax.grid()
ax.set_title(f"Constant Input from {st1} at {dat1.date_time.strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# document-safe defaults
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

# Plotting predictions
xgrid = np.arange(nsteps)
fig, ax = plt.subplots(dpi=DPI)

plot_styles["fm"]["linestyle"] = "dashed"
ax.plot(xgrid[spinup:], np.repeat(dat2.fm, len(xgrid[spinup:])), **plot_styles["fm"])
ax.plot(xgrid[spinup:], p2_fast[spinup:], c=colors[0], label="Fast")
ax.plot(xgrid[spinup:], p2[spinup:],      c=colors[1], label="Baseline")
ax.plot(xgrid[spinup:], p2_slow[spinup:], c=colors[2], label="Slow")
ax.set_ylabel("FMC (%)", fontsize = LABEL_SIZE)
ax.set_xlabel("Hour", fontsize = LABEL_SIZE)
ax.tick_params(labelsize=TICK_SIZE)
ax.legend(
    loc="upper left",
    bbox_to_anchor = (1.02, .7),
    fontsize=TICK_SIZE
)
ax.grid()
ax.set_title(f"Constant Input from {st2} at {dat2.date_time.strftime('%Y-%m-%d %H:%M:%S')}")