# Zero-Shot Transfer Learning with pre-trained RNN via Time-Warp

RNN trained on 10h fuels, hourly resolution data. Using time-warp technique of modifying forget/input gates to generate 1h and 100h predictions with no retrain.

The temporal scaling for the biases and the forget and input gate are being set to 1 here. That number results in a clear separation of the prediction curves. This is just for demonstration of the time-scaling. This is not intended to be the most accurate fit for the scaling factors.

## Setup

In [None]:
import numpy as np
import tensorflow as tf
import pandas as pd
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from src.models import moisture_rnn as mrnn
from src.utils import read_yml, time_intp, plot_styles

In [None]:
# Read Trained model
params = read_yml("models/params.yaml")
rnn = mrnn.RNN_Flexible(params=params)
scaler = joblib.load("models/scaler.joblib")

In [None]:
rnn.load_weights('models/rnn.keras')

In [None]:
# Extract Info from RNN
lstm = rnn.get_layer("lstm")
lstm_units = lstm.units
weights10 = lstm.get_weights()

## Read Data

In [None]:
weather = pd.read_excel("data/processed_data/dvdk_weather.xlsx")
fm1 = pd.read_excel("data/processed_data/ok_1h.xlsx")
fm10 = pd.read_excel("data/processed_data/ok_10h.xlsx")
fm100 = pd.read_excel("data/processed_data/ok_100h.xlsx")
fm1000 = pd.read_excel("data/processed_data/ok_1000h.xlsx")

## Get Time Period

Find stretch of time where all fuel classes have observations

In [None]:
t0, t1 = fm10.utc_rounded.min(), fm10.utc_rounded.max()

print(f"Start Date: {t0}")
print(f"End Date: {t1}")

print("~"*50)
print("1h FMC Observations")
print(f"   N Observations: {fm1[(fm1.utc_rounded>=t0) & (fm1.utc_rounded<=t1)].shape}")

print("~"*50)
print("10h FMC Observations")
print(f"   N Observations: {fm10[(fm10.utc_rounded>=t0) & (fm10.utc_rounded<=t1)].shape}")


print("~"*50)
print("100h FMC Observations")
print(f"   N Observations: {fm100[(fm100.utc_rounded>=t0) & (fm100.utc_rounded<=t1)].shape}")


## Setup Data

Input data same for all fuel classes

In [None]:
w2 = weather[(weather.utc>=t0) & (weather.utc<=t1)]

# Geographic Variables from Slapout station
X = pd.DataFrame({
    "Ed": w2.Ed,
    "Ew": w2.Ew,
    "solar": w2["solar"],
    "wind": w2["wind"],
    "elev": 774,
    "lon": -100.261920,
    "lat": 36.597490,
    "rain": w2["rain"],
    "hod": w2.hod_utc,
    "doy": w2.doy_utc
})

In [None]:
# Get scaler from RNN Data
XX = scaler.transform(X)
XX = XX.reshape(1, *XX.shape)

## RNN Predictions and Warp

### 10h Predictions

Normal pre-trained weights. Extract weights for reuse

In [None]:
preds10 = rnn.predict(XX).flatten()

### 1h Predictions

Forget gate bias UP, Input gate bias DOWN

In [None]:
# Modify biases for speed-up 
weights1 = [w.copy() for w in weights10]
b = weights1[2]

# Input gate biases (i)
b[0:lstm_units] += 1

# Forget gate biases (f)
b[lstm_units:2*lstm_units] -= 1

# Update the bias in the weights list
weights1[2] = b

# Now set these weights into the same layer in rnn2
rnn.get_layer("lstm").set_weights(weights1)

In [None]:
preds1 = rnn.predict(XX).flatten()

### 100h Predictions

Forget gate bias DOWN, Input gate bias UP

In [None]:
# Modify biases for slow-down 
weights100 = [w.copy() for w in weights10]
b = weights100[2]

# Input gate biases (i)
b[0:lstm_units] -= 1

# Forget gate biases (f)
b[lstm_units:2*lstm_units] += 1

# Update the bias in the weights list
weights100[2] = b

# Now set these weights into the same layer in rnn2
rnn.get_layer("lstm").set_weights(weights100)

In [None]:
preds100 = rnn.predict(XX).flatten()

### 1000h Predictions

In [None]:
# Modify biases for slow-down 
weights1000 = [w.copy() for w in weights10]
b = weights1000[2]

# Input gate biases (i)
b[0:lstm_units] -= 2

# Forget gate biases (f)
b[lstm_units:2*lstm_units] += 2

# Update the bias in the weights list
weights1000[2] = b

# Now set these weights into the same layer in rnn2
rnn.get_layer("lstm").set_weights(weights1000)

preds1000 = rnn.predict(XX).flatten()

### Viz

In [None]:
# document-safe defaults
FIGSIZE = (10, 6)
DPI = 300
LABEL_SIZE = 14
TICK_SIZE = 12
CBAR_LABEL_SIZE = 13

In [None]:
import matplotlib.dates as mdates

inds = np.arange(0, 72)
dates = w2.utc.iloc[inds]

fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)

# Plot fuel class predictions
ax.plot(dates, preds1[inds], **plot_styles["model1"])
ax.plot(dates, preds10[inds], **plot_styles["model"])
ax.plot(dates, preds100[inds], **plot_styles["model100"])
ax.plot(dates, preds1000[inds], **plot_styles["model1000"])

ax.xaxis.set_major_locator(mdates.HourLocator(interval=6))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d %H:%M"))

ax.tick_params(axis="x", rotation=90, labelsize=TICK_SIZE)
ax.tick_params(axis="y", labelsize=TICK_SIZE)

ax.set_ylabel("FMC (%)", fontsize=LABEL_SIZE)
ax.set_title("Predicted FMC Classes - Zero-Shot Time-Warp", fontsize=LABEL_SIZE)

ax.legend(
    loc="center left",
    bbox_to_anchor=(1.02, 0.5),
    fontsize=TICK_SIZE
)

fig.tight_layout()
plt.savefig("outputs/transfer_zeroshot_example.png", dpi=DPI, bbox_inches='tight')