# v2.1 run RNN with Spatial Training

This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.


## Environment Setup

In [None]:
import numpy as np
from utils import print_dict_summary, print_first, str2time, logging_setup
import pickle
import logging
import os.path as osp
from moisture_rnn_pkl import pkl2train
from moisture_rnn import RNNParams, RNNData, RNN 
from utils import hash2, read_yml, read_pkl, retrieve_url, Dict
from moisture_rnn import RNN
import reproducibility
from data_funcs import rmse, to_json, combine_nested
from moisture_models import run_augmented_kf
import copy
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import time

In [None]:
logging_setup()

In [None]:
retrieve_url(
    url = "https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl", 
    dest_path = "fmda_nw_202401-05_f05.pkl")

In [None]:
repro_file = "data/reproducibility_dict_v2_TEST.pkl"
file_names=['fmda_nw_202401-05_f05.pkl']
file_dir='data'
file_paths = [osp.join(file_dir,file_name) for file_name in file_names]

In [None]:
# read/write control
train_file='train.pkl'
train_create=False   # if false, read
train_write=False
train_read=True

In [None]:
repro = read_pkl(repro_file)

if train_create:
    logging.info('creating the training cases from files %s',file_paths)
    # osp.join works on windows too, joins paths using \ or /
    train = pkl2train(file_paths)
if train_write:
    with open(train_file, 'wb') as file:
        logging.info('Writing the rain cases into file %s',train_file)
        pickle.dump(train, file)
if train_read:
    logging.info('Reading the train cases from file %s',train_file)
    train = read_pkl(train_file)

In [None]:
params = read_yml("params.yaml", subkey='rnn')
params

In [None]:
# from itertools import islice
# train = {k: train[k] for k in islice(train, 100)}
dat = Dict(combine_nested(train))

In [None]:
# Set up output dictionaries
outputs_kf = {}
outputs_rnn_serial = {}
outputs_rnn_spatial = {}

## Spatial Data Traing

In [None]:
params = RNNParams(params)

In [None]:
# Start timer
start_time = time.time()

In [None]:
rnn_dat = RNNData(dat, scaler="standard", 
                  features_list = ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat',
                                  'solar', 'wind'])

rnn_dat.train_test_split(   
    time_fracs = [.9, .05, .05],
    space_fracs = [.6, .2, .2]
)
rnn_dat.scale_data()

rnn_dat.batch_reshape(
    timesteps = params['timesteps'], 
    batch_size = params['batch_size']
)

In [None]:
params.update({'batch_schedule_type': 'exp', 'bmin': 20, 'bmax': rnn_dat.hours,
               'loc_batch_reset': rnn_dat.n_seqs, 
               'epochs': 100, 'learning_rate': 0.0001,
               'recurrent_layers': 2, 'recurrent_units': 40, 'dense_layers': 2, 'dense_units': 20,
              'features_list': rnn_dat.features_list})

In [None]:
reproducibility.set_seed(123)
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat)

In [None]:
errs.mean()

In [None]:
# End Timer
end_time = time.time()

# Calculate Code Runtime
elapsed_time = end_time - start_time
print(f"Spatial Training Elapsed time: {elapsed_time:.4f} seconds")

## Run ODE + KF and Compare

In [None]:
outputs_kf = {}
for case in rnn_dat.loc['test_locs']:
    print("~"*50)
    print(case)
    # Run Augmented KF
    print('Running Augmented KF')
    train[case]['h2'] = train[case]['hours'] // 2
    train[case]['scale_fm'] = 1
    m, Ec = run_augmented_kf(train[case])
    y = train[case]['y']        
    train[case]['m'] = m
    print(f"KF RMSE: {rmse(m,y)}")
    outputs_kf[case] = {'case':case, 'errs': rmse(m,y)}

In [None]:
df2 = pd.DataFrame.from_dict(outputs_kf).transpose()
df2.head()

## Compare

In [None]:
df2.errs.mean()

In [None]:
df2.shape

In [None]:
errs.shape

In [None]:
errs.mean()