# Timelag Differential Equation Tutorial

The purpose of this notebook is to demonstrate the use of the physics-based fuel moisture model used with `wrfxpy`. This model is used as a benchmark for the machine learning methods. The model is a time-lag ODE plus a Kalman filter for assimilating fuel moisture observations.

## Model Background

The physics-based model used within WRF-SFIRE is a timelag ODE. Data assimilation is done through Augmented Kalman filter. Model state is extended to include equilibrium bias correction term.

* **Inputs**: wetting/drying equilibrium moisture content and hourly rainfall, and optional FMC data for data assimilation
* **Spinup**: model is run with data assimilation for a number of spinup hours for equilibrium  bias to stabilize, this is analogous to training an ML model
* **Forecast**: model is run with no data assimilation after set number of spinup hours
* **How Model is Applied**: ODE+KF applied pointwise, or independently at some set of grid nodes. In this project, the ODE+KF will be run at the location of RAWS sites, using the observed RAWS data for spinup data assimilation. NOTE: this is "best case" scenario for the model, since in production spatially interpoalted FMC used for spinup data assimilation

For more info, see ___

## ODE+KF in this Project

**Workflow:**
- Retrieve fmda data: gets data from API or stash, interpolates missing observations to regular hourly intervals
- Build fmda ML data: merges data sources and applies filters
- Define a cross validation test period and test locations (RAWS STIDS)
- Based on CV above, get needed data from built ML data

**ODE Modeling:**
* Run on 72 hour stretches (24 spinup, 48 val)
* Get test station list used by other models
* For those test stations, use `get_sts_and_times` accounting for the spinup period
    * So adjust test times by subtracting 24 hours to account for spinup

## Setup

In [None]:
import os.path as osp
import json
import sys
import numpy as np
import pandas as pd
sys.path.append('../src')
from utils import Dict, read_yml, read_pkl, str2time, print_dict_summary, time_range, rename_dict
import data_funcs
from models.moisture_ode import ODE_FMC
import matplotlib.pyplot as plt
import reproducibility

## Create Data

In [None]:
ml_data = read_pkl("../data/test_data/test_ml_dat.pkl")
# ml_data = read_pkl("../outputs/report_materials/ml_data.pkl")

In [None]:
ft = str2time("2023-06-03T00:00:00Z")

In [None]:
reproducibility.set_seed(123)
train, val, test = data_funcs.cv_data_wrap(ml_data, ft, 
                                           train_hours=720,
                                           forecast_hours=48)

In [None]:
test.keys()

In [None]:
te_sts = [*test.keys()]
test_times = test[te_sts[0]]["times"]
ode_data = data_funcs.get_ode_data(ml_data, te_sts, test_times)

In [None]:
print(ode_data.keys())

## Run Model

Model object creator defined in `models/moisture_models`. Has hyperparameters associated with model, such as fixed covariance matrices

In [None]:
ode = ODE_FMC()

In [None]:
ode.params

### Run Single Case

In [None]:
u = ode.run_model_single(ode_data[te_sts[0]], hours=72, h2=24)

In [None]:
print(u.shape)

In [None]:
plt.plot(u[0,:])

In [None]:
# Print RMSE for Period
ode.eval(u[0,:], ode_data[te_sts[0]]["data"].fm.to_numpy())

## Run Whole Dictionary

In [None]:
m, errs = ode.run_model(ode_data, hours=72, h2=24)

In [None]:
# Should be shape (n_locations, forecast_hours, 1)
print(m.shape)

In [None]:
print(errs)

In [None]:
from utils import hash_ndarray
hash_ndarray(m)

### Analyze Error Over Time

As we go from forecast hour 1 to 48, does error accumulate?

In [None]:
fstart = str2time("2023-06-03T00:00:00Z")
fend = str2time("2023-06-29T23:00:00Z")

# Handle Forecast Periods
# Define Forecast start times, 48hr spacing
forecast_periods = time_range(
    start = fstart,
    end = fend,
    freq = "2d"
)

In [None]:
m = []
y_test = []
reproducibility.set_seed(123)
for ft in forecast_periods:
    train, val, test = data_funcs.cv_data_wrap(ml_data, ft, 
                                               train_hours=720,
                                               forecast_hours=48)
    te_sts = [*test.keys()]
    test_times = test[te_sts[0]]["times"]
    ode_data = data_funcs.get_ode_data(ml_data, te_sts, test_times)    
    mi, errsi = ode.run_model(ode_data, hours=72, h2=24)

    y_list = []
    for loc in ode_data:
        y = ode_data[loc]["data"]["fm"][(72-48):72]
        y = np.array(y).reshape(48, 1)  # Ensure shape is (48, 1)
        y_list.append(y)
    yi = np.stack(y_list) 
    
    m.append(mi)
    y_test.append(yi)

In [None]:
preds = np.concatenate(m, axis=0)
y = np.concatenate(y_test, axis=0)
print(preds.shape)
print(y.shape)

In [None]:
err2 = ((preds - y) ** 2).squeeze()
err2.shape

In [None]:
err48 = np.mean(err2, axis=0)  # shape (48,)
se48 = np.std(err2, axis=0, ddof=1) / np.sqrt(err2.shape[0])
print(err48.shape)
print(se48.shape)

In [None]:
time_steps = np.arange(err48.shape[0])

plt.figure(figsize=(10, 4))
plt.plot(time_steps, err48, label='Mean Squared Error', color='blue')
plt.fill_between(
    time_steps,
    err48 - se48,
    err48 + se48,
    color='blue',
    alpha=0.3,
    label='±1 SD'
)

plt.xlabel('Time Step (Hour)')
plt.ylabel('Average Squared Error')
plt.title('ODE - Mean Squared Error Over Time')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
#y = np.concatenate(y_test, axis=0)
y = y.squeeze()
ymean = np.mean(y, axis=0)  # shape (48,)
yse = np.std(y, axis=0, ddof=1) / np.sqrt(y.shape[0])
print(ymean.shape)
print(yse.shape)

preds = preds.squeeze()
predmean = np.mean(preds, axis=0)  # shape (48,)
predse = np.std(preds, axis=0, ddof=1) / np.sqrt(preds.shape[0])


In [None]:
plt.figure(figsize=(10, 4))
plt.plot(time_steps, ymean, label='FMC Mean', color='green')
# plt.fill_between(
#     time_steps,
#     ymean - yse,
#     ymean + yse,
#     color='green',
#     alpha=0.3,
#     label='±1 SD'
# )

plt.xlabel('Time Step (Hour)')
plt.ylabel('FMC')
plt.title('Mean FMC over 48 Hours')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

## Test single runs

In [None]:
from viz import plot_one

In [None]:
st = "MRLS2"
d = ml_data[st]
u = ode.run_model_single(d, hours=720, h2=720-48)

plot_one(ml_data, st, m=u[0,:], start_time="2023-06-01", 
         end_time='2023-06-30 23:00:00+0000')