# Batch Reset Hyperparameter Tutorial

When training data are very long, a stateful model is prone to instability since at the early iterations of training, an unreasonable hidden state is generated and propogated through many batches of training.

We introduce the hyperparameter `batch_reset`, which resets the hidden state after a fixed number of batches. Future work will make this a schedule where the number of batches before reset is increased as the network learns and will be less subject to exploding/vanishing gradients.

We demonstrate with linear activation since the benefit of the scheduler is much more pronounced. With more typical nonlinear activations like ReLu or tanh, the batch resetting schedule iss empirically useful for speeding up and improving training. But it is much more demonstrable with linear activation, so we use linear activation here for educational purposes.

## Environment and Data Setup

In [None]:
# Environment
import os
import os.path as osp
import matplotlib.pyplot as plt
import sys
import numpy as np
import pandas as pd
# Local modules
sys.path.append('..')
import reproducibility
from utils import print_dict_summary
from data_funcs import rmse
from moisture_rnn import RNNParams, RNNData, RNN
from moisture_rnn_pkl import pkl2train
from utils import read_yml, read_pkl
import yaml
import pickle

In [None]:
dat = read_pkl("batch_reset_tutorial_case.pkl")

In [None]:
params = read_yml("../params.yaml", subkey="rnn")
params = RNNParams(params)
params.update({'epochs': 10, 'timesteps': 5, 'activation': ['linear', 'linear']})

In [None]:
rnn_dat = RNNData(dat, scaler = params['scaler'], features_list = params['features_list'])
rnn_dat.train_test_split(
    time_fracs = [.9, .05, .05]
)
rnn_dat.scale_data()
rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])

## Train without Stateful

In [None]:
params.update({'verbose_fit': True, 'stateful': False, 
               'batch_schedule_type': None})
reproducibility.set_seed(123)
rnn = RNN(params)
try:
    m, errs = rnn.run_model(rnn_dat)
except Exception as e:
    print("*"*50)
    print(f"Caught Error {e}")
    print("*"*50)

In [None]:
X = rnn_dat.scale_all_X()
X = X.reshape(1, X.shape[0], X.shape[1])
rnn.predict(X)[0:150]

## Train with Stateful, without Batch Reset


In [None]:
params.update({'verbose_fit': True, 'stateful': True, 'batch_schedule_type':None})
params.update({'epochs': 30})
reproducibility.set_seed(123)
rnn = RNN(params)

In [None]:
try:
    m, errs = rnn.run_model(rnn_dat)
except Exception as e:
    print("*"*50)
    print(f"Caught Error {e}")
    print("*"*50)

## Train with Stateful, with Periodic Batch Reset

In [None]:
params.update({'verbose_fit': True, 'stateful': True, 'batch_schedule_type':'constant', 'bmin': 20})
params.update({'epochs': 30})
reproducibility.set_seed(123)
rnn = RNN(params)

In [None]:
try:
    m, errs = rnn.run_model(rnn_dat, plot_period="predict")
except Exception as e:
    print("*"*50)
    print(f"Caught Error {e}")
    print("*"*50)

## Batch Reset Schedules

In [None]:
from moisture_rnn import calc_exp_intervals, calc_log_intervals, calc_step_intervals

In [None]:
epochs = 50
bmin = 10
bmax = 200

egrid = np.arange(epochs)

In [None]:
plt.plot(egrid, np.linspace(bmin, bmax, epochs), label='Linear Increase')
plt.plot(egrid, calc_exp_intervals(bmin, bmax, epochs), label='Exponential Increase')
plt.plot(egrid, calc_log_intervals(bmin, bmax, epochs), label='Logarithmic Increase')
plt.plot(egrid, calc_step_intervals(bmin, bmax, epochs, estep=25), label='Step Increase')
plt.xlabel('Epoch')
plt.ylabel('Batch Reset Value')
plt.legend()
plt.title('Batch Reset Value vs Epoch')
plt.show()

### Linear Schedule

In [None]:
params.update({'verbose_fit': False, 'stateful': True, 
               'batch_schedule_type':'linear', 'bmin': 20, 'bmax': rnn_dat.hours})
params.update({'epochs': 40})
reproducibility.set_seed(123)
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat, plot_period = "predict")

### Exponential Increase

In [None]:
params.update({'verbose_fit': True, 'stateful': True, 
               'batch_schedule_type':'exp', 'bmin': 20, 'bmax': rnn_dat.hours,
               'early_stopping_patience': 10})
params.update({'epochs': 40})
reproducibility.set_seed(123)
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat, plot_period = "predict")

### Log Increase

In [None]:
params.update({'verbose_fit': False, 'stateful': True, 
               'batch_schedule_type':'log', 'bmin': 20, 'bmax': rnn_dat.hours})
params.update({'epochs': 40})
reproducibility.set_seed(123)
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat, plot_period = "predict")