# Batch Reset Hyperparameter Tutorial

When training data are very long, a stateful model is prone to instability since at the early iterations of training, an unreasonable hidden state is generated and propogated through many batches of training.

We introduce the hyperparameter `batch_reset`, which resets the hidden state after a fixed number of batches. Future work will make this a schedule where the number of batches before reset is increased as the network learns and will be less subject to exploding/vanishing gradients.

## Environment and Data Setup

In [None]:
# Environment
import os
import os.path as osp
import matplotlib.pyplot as plt
import sys
import numpy as np
import pandas as pd
# Local modules
sys.path.append('..')
import reproducibility
from utils import print_dict_summary
from data_funcs import rmse
from moisture_rnn import RNNParams, RNNData, RNN
from moisture_rnn_pkl import pkl2train
from utils import read_yml, read_pkl
import yaml
import pickle

In [None]:
dat = read_pkl("batch_reset_tutorial_case.pkl")

In [None]:
params = read_yml("../params.yaml", subkey="rnn")
params = RNNParams(params)
params.update({'epochs': 10})

In [None]:
rnn_dat = RNNData(dat, scaler = params['scaler'], features_list = params['features_list'])
rnn_dat.train_test_split(
    train_frac = .9,
    val_frac = .05
)
rnn_dat.scale_data()

## Train without Stateful

In [None]:
params.update({'verbose_fit': True, 'stateful': False, 'batch_reset':9999})
reproducibility.set_seed(123)
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat)

In [None]:
rnn.predict(rnn_dat.X_train[0:500])

## Train with Stateful, without Batch Reset

We turn off the parameter by setting it to a huge value.

In [None]:
params.update({'verbose_fit': True, 'stateful': True, 'batch_reset':9999})
params.update({'epochs': 10})
reproducibility.set_seed(123)
rnn = RNN(params)

In [None]:
m, errs = rnn.run_model(rnn_dat)

## Train with Stateful, with Batch Reset

In [None]:
params.update({'verbose_fit': True, 'stateful': True, 'batch_reset':20})
params.update({'epochs': 30})
reproducibility.set_seed(123)
rnn = RNN(params)

In [None]:
m, errs = rnn.run_model(rnn_dat, plot_period="predict")