# Compare Batch Resetting Schedules


In [None]:
import numpy as np
from utils import print_dict_summary, print_first, str2time, logging_setup
import pickle
import logging
import os.path as osp
from moisture_rnn_pkl import pkl2train
from moisture_rnn import RNNParams, RNNData, RNN 
from utils import hash2, read_yml, read_pkl, retrieve_url
from moisture_rnn import RNN
import reproducibility
from data_funcs import rmse, to_json, combine_nested, process_train_dict
from moisture_models import run_augmented_kf
import copy
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import time
import reproducibility
import tensorflow as tf

In [None]:
logging_setup()

In [None]:
retrieve_url(
    url = "https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl", 
    dest_path = "fmda_nw_202401-05_f05.pkl")

In [None]:
repro_file = "data/reproducibility_dict_v2_TEST.pkl"
file_names=['fmda_nw_202401-05_f05.pkl']
file_dir='data'
file_paths = [osp.join(file_dir,file_name) for file_name in file_names]

In [None]:
params_all = read_yml("params.yaml")
params_data = read_yml("params_data.yaml")

In [None]:
data_params = read_yml("params_data.yaml")
data_params.update({
    'hours': 3000,
    'max_intp_time': 24,
    'zero_lag_threshold': 24
})
train = process_train_dict(["data/fmda_nw_202401-05_f05.pkl"], params_data=params_data, verbose=True)

In [None]:
from itertools import islice
train = {k: train[k] for k in islice(train, 250)}

In [None]:
print(f"Number of Training Cases: {len(train_cases)}")

In [None]:
## params = RNNParams(read_yml("params.yaml", subkey="rnn"))
params.update({'epochs': 200, 
               'learning_rate': 0.001,
               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.
               'recurrent_layers': 1, 'recurrent_units': 30, 
               'dense_layers': 1, 'dense_units': 30,
               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping
               'batch_schedule_type': 'exp', # Hidden state batch reset schedule
               'bmin': 20, # Lower bound of hidden state batch reset, 
               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours
               'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],
               'timesteps': 24
              })

## Handle Data

In [None]:
train_sp = combine_nested(train)
rnn_dat = RNNData(
    train_sp, # input dictionary
    scaler="standard",  # data scaling type
    features_list = params['features_list'] # features for predicting outcome
)


rnn_dat.train_test_split(   
    time_fracs = [.9, .05, .05], # Percent of total time steps used for train/val/test
    space_fracs = [.40, .30, .30] # Percent of total timeseries used for train/val/test
)
rnn_dat.scale_data()

rnn_dat.batch_reshape(
    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. 
    batch_size = params['batch_size'], # Number of samples of length timesteps for a single round of grad. descent
    start_times = np.zeros(len(rnn_dat.loc['train_locs']))
)

params.update({
    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch
})

## Non-Stateful

In [None]:
params.update({
    'stateful': False,
    'batch_schedule_type': None
})

In [None]:
reproducibility.set_seed(123)
rnn = RNN(params)
m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)

In [None]:
errs0.mean()

## Constant Batch Schedule (Stateful)

In [None]:
params.update({
    'stateful': True, 
    'batch_schedule_type':'constant', 
    'bmin': 20})

In [None]:
reproducibility.set_seed(123)
rnn = RNN(params)
m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)

## Exp Batch Schedule (Stateful)

In [None]:
params.update({
    'stateful': True, 
    'batch_schedule_type':'exp', 
    'bmin': 20,
    'bmax': rnn_dat.hours
})

In [None]:
reproducibility.set_seed(123)
rnn = RNN(params)
m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)

In [None]:
errs3.mean()