# v2.1 exploration trying to make it work better

In [None]:
# Environment
import os
import os.path as osp
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import sys
# Local modules
sys.path.append('..')
import reproducibility
import pandas as pd
from utils import print_dict_summary
from data_funcs import rmse
from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, create_rnn_data2
from moisture_rnn_pkl import pkl2train
from tensorflow.keras.callbacks import Callback
from utils import hash2
import copy
import logging
import pickle
from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights
import yaml
import copy

In [None]:
logging_setup()

## Test Batch Reset

In [None]:
train = read_pkl('train.pkl')
train.keys()

In [None]:
import importlib
import moisture_rnn
importlib.reload(moisture_rnn)
from moisture_rnn import RNN, RNNData

In [None]:
params = read_yml("params.yaml", subkey="rnn")
params = RNNParams(params)
params.update({'batch_size': 7})
rnn_dat = RNNData(train['PLFI1_202401'], scaler=params['scaler'], features_list = params['features_list'])
rnn_dat.train_test_split(
    train_frac = .9,
    val_frac = .05
)
rnn_dat.scale_data()
rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])

In [None]:
import importlib
import moisture_rnn
importlib.reload(moisture_rnn)
from moisture_rnn import RNN, ResetStatesCallback

In [None]:
params.update({'epochs': 2, 'verbose_fit': True, 'batch_size': 32, 
        'rnn_layers': 2, 'activation':['relu', 'relu']})
rnn_dat = RNNData(train['PLFI1_202401'], scaler=params['scaler'], features_list = params['features_list'])
rnn_dat.train_test_split(
    train_frac = .9,
    val_frac = .05
)
rnn_dat.scale_data()
rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])
reproducibility.set_seed()
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat, plot_period="predict")

In [None]:
class ResetStatesCallback(Callback):
    """
    Custom callback to reset the states of RNN layers at the end of each epoch and optionally after a specified number of batches.

    Parameters:
    -----------
    batch_reset : int, optional
        If provided, resets the states of RNN layers after every `batch_reset` batches. Default is None.
    """    
    def __init__(self, batch_reset=None, loc_batch_reset=None):
        """
        Initializes the ResetStatesCallback with an optional batch reset interval.

        Parameters:
        -----------
        batch_reset : int, optional
            The interval of batches after which to reset the states of RNN layers. Default is None.
        loc_batch_reset : int, optional
            The interval of batches after which the location changes for a given batch number, then reset the states of RNN layers. Default is None.
        """        
        print("*************DEBUG************")
        super(ResetStatesCallback, self).__init__()
        self.batch_reset = batch_reset 
        self.loc_batch_reset = loc_batch_reset 
    def on_epoch_end(self, epoch, logs=None):
        """
        Resets the states of RNN layers at the end of each epoch.

        Parameters:
        -----------
        epoch : int
            The index of the current epoch.
        logs : dict, optional
            A dictionary containing metrics from the epoch. Default is None.
        """        
        # print(f"Resetting hidden state after epoch: {epoch+1}", flush=True)
        # Iterate over each layer in the model
        for layer in self.model.layers:
            # Check if the layer has a reset_states method
            if hasattr(layer, 'reset_states'):
                layer.reset_states()
    def on_train_batch_end(self, batch, logs=None):
        """
        Resets the states of RNN layers during training after a specified number of batches, if `batch_reset` or `loc_batch_reset` are provided. The `batch_reset` is used for stability and to avoid exploding gradients at the beginning of training when a hidden state is being passed with weights that haven't learned yet. The `loc_batch_reset` is used to reset the states when a particular batch is from a new location and thus the hidden state should be passed.

        Parameters:
        -----------
        batch : int
            The index of the current batch.
        logs : dict, optional
            A dictionary containing metrics from the batch. Default is None.
        """        
        batch_reset = self.batch_reset
        loc_batch_reset = self.loc_batch_reset
        if (batch_reset is not None and batch % batch_reset == 0) or (loc_batch_reset is not None and batch % loc_batch_reset == 0):
            print(f"Resetting states after batch {batch + 1}")
            # Iterate over each layer in the model
            for layer in self.model.layers:
                # Check if the layer has a reset_states method
                if hasattr(layer, 'reset_states'):
                    layer.reset_states()
    def on_test_batch_end(self, batch, logs=None):
        """
        Resets the states of RNN layers during validation if `loc_batch_reset` is provided to demarcate a new location and thus avoid passing a hidden state to a wrong location.

        Parameters:
        -----------
        batch : int
            The index of the current batch.
        logs : dict, optional
            A dictionary containing metrics from the batch. Default is None.
        """          
        loc_batch_reset = self.loc_batch_reset
        if (loc_batch_reset is not None and batch % loc_batch_reset == 0):
            # print(f"Resetting in TEST batch states after batch {batch + 1}")
            # Iterate over each layer in the model
            for layer in self.model.layers:
                # Check if the layer has a reset_states method
                if hasattr(layer, 'reset_states'):
                    layer.reset_states()       

In [None]:
params.update({'epochs': 2, 'verbose_fit': True, 'batch_size': 32, 
        'rnn_layers': 2, 'activation':['relu', 'relu']})
rnn_dat = RNNData(train['PLFI1_202401'], scaler=params['scaler'], features_list = params['features_list'])
rnn_dat.train_test_split(
    train_frac = .9,
    val_frac = .05
)
rnn_dat.scale_data()
rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])
reproducibility.set_seed()
rnn = RNN(params)
rnn.fit(
    rnn_dat.X_train, rnn_dat.y_train, validation_data = (rnn_dat.X_val, rnn_dat.y_val),
    callbacks = [ResetStatesCallback(batch_reset=10)]
)

In [None]:
def calc_exp_intervals(bmin, bmax, n_epochs, force_bmax = True):
    # Calculate the exponential intervals for each epoch
    epochs = np.arange(n_epochs)
    factors = epochs / n_epochs
    intervals = bmin * (bmax / bmin) ** factors
    if force_bmax:
        intervals[-1] = bmax  # Ensure the last value is exactly bmax
    return intervals.astype(int)

def calc_log_intervals(bmin, bmax, n_epochs, force_bmax = True):
    # Calculate the logarithmic intervals for each epoch
    epochs = np.arange(n_epochs)
    factors = np.log(1 + epochs) / np.log(1 + n_epochs)
    intervals = bmin + (bmax - bmin) * factors
    if force_bmax:
        intervals[-1] = bmax  # Ensure the last value is exactly bmax
    return intervals.astype(int)

In [None]:
ep = 15
bmin = 10
bmax = 500
xgrid = np.arange(0, ep)
plt.plot(xgrid, calc_exp_intervals(bmin, bmax, ep))
plt.plot(xgrid, calc_log_intervals(bmin, bmax, ep))

In [None]:
class ResetStatesCallback(Callback):
    """
    Custom callback to reset the states of RNN layers at the end of each epoch and optionally after a specified number of batches.

    Parameters:
    -----------
    batch_reset : int, optional
        If provided, resets the states of RNN layers after every `batch_reset` batches. Default is None.
    """    
    def __init__(self, bmin, bmax, epochs, batch_schedule_type='linear', verbose=True):
        """
        Initializes the ResetStatesCallback with an optional batch reset interval.

        Parameters:
        -----------
        batch_reset : int, optional
            The interval of batches after which to reset the states of RNN layers. Default is None.
        loc_batch_reset : int, optional
            The interval of batches after which the location changes for a given batch number, then reset the states of RNN layers. Default is None.
        """        
        print("*************DEBUG SCHEDULER************")
        super(ResetStatesCallback, self).__init__()
        self.bmin = bmin
        self.bmax = bmax
        self.epochs = epochs
        self.verbose = verbose
        print(f"epochs: {epochs}")
        # Calculate the reset intervals for each epoch during initialization
        # self.batch_reset_intervals = np.linspace(self.bmin, self.bmax, self.epochs).astype(int)
        self.batch_reset_intervals = self._calc_reset_intervals(batch_schedule_type)
        if self.verbose:
            print(f"Using ResetStatesCallback with Batch Reset Schedule: {batch_schedule_type}")
            print(f"batch_reset_intervals: {self.batch_reset_intervals}")
    def on_epoch_end(self, epoch, logs=None):
        """
        Resets the states of RNN layers at the end of each epoch.

        Parameters:
        -----------
        epoch : int
            The index of the current epoch.
        logs : dict, optional
            A dictionary containing metrics from the epoch. Default is None.
        """        
        print(f" Resetting hidden state after epoch: {epoch+1}", flush=True)
        # Iterate over each layer in the model
        for layer in self.model.layers:
            # Check if the layer has a reset_states method
            if hasattr(layer, 'reset_states'):
                layer.reset_states()
    def _calc_reset_intervals(self,batch_schedule_type):
        methods = ['linear', 'exp', 'log']
        if batch_schedule_type not in methods:
            raise ValueError(f"Batch schedule method {batch_schedule_type} not recognized. \n Available methods: {methods}")
        if batch_schedule_type == "linear":
            return np.linspace(self.bmin, self.bmax, self.epochs).astype(int)
        elif batch_schedule_type == "exp":
            return calc_exp_intervals(self.bmin, self.bmax, self.epochs)
        elif batch_schedule_type == "log":
            return calc_log_intervals(self.bmin, self.bmax, self.epochs)
    def on_epoch_begin(self, epoch, logs=None):
        # Set the reset interval for the current epoch
        self.current_batch_reset = self.batch_reset_intervals[epoch]
    def on_train_batch_end(self, batch, logs=None):
        """
        Resets the states of RNN layers during training after a specified number of batches, if `batch_reset` or `loc_batch_reset` are provided. The `batch_reset` is used for stability and to avoid exploding gradients at the beginning of training when a hidden state is being passed with weights that haven't learned yet. The `loc_batch_reset` is used to reset the states when a particular batch is from a new location and thus the hidden state should be passed.

        Parameters:
        -----------
        batch : int
            The index of the current batch.
        logs : dict, optional
            A dictionary containing metrics from the batch. Default is None.
        """        
        batch_reset = self.current_batch_reset
        if (batch_reset is not None and batch % batch_reset == 0):
            print(f" Resetting states after batch {batch + 1}")
            # Iterate over each layer in the model
            for layer in self.model.layers:
                # Check if the layer has a reset_states method
                if hasattr(layer, 'reset_states'):
                    layer.reset_states()  

In [None]:
params.update({'epochs': 10, 'verbose_fit': True, 'batch_size': 32, 
        'rnn_layers': 2, 'activation':['relu', 'relu'],
        'early_stopping_patience': 9999
})
rnn_dat = RNNData(train['PLFI1_202401'], scaler=params['scaler'], features_list = params['features_list'])
rnn_dat.train_test_split(
    train_frac = .9,
    val_frac = .05
)
rnn_dat.scale_data()
rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])
reproducibility.set_seed()
rnn = RNN(params)
rnn.fit(
    rnn_dat.X_train, rnn_dat.y_train, validation_data = (rnn_dat.X_val, rnn_dat.y_val),
    callbacks = [ResetStatesCallback(
        bmin=10, bmax = 100, epochs=params['epochs'],
        batch_schedule_type = "log"
    )]
)

## Test Spatial Data

In [None]:
train = read_pkl('train.pkl')

In [None]:
params = read_yml("params.yaml", subkey="rnn")
params = RNNParams(params)

In [None]:
len(train.keys())

In [None]:
from itertools import islice
dat = {k: train[k] for k in islice(train, 100)}

In [None]:
dat.keys()

In [None]:
from data_funcs import combine_nested
dd = combine_nested(dat)

In [None]:
import importlib
import utils
importlib.reload(utils)
from utils import Dict

In [None]:
dd = Dict(dd)

In [None]:
import importlib
import moisture_rnn
importlib.reload(moisture_rnn)
from moisture_rnn import RNNData

In [None]:
rnn_dat = RNNData(dd, scaler="standard", features_list = ['Ed', 'Ew', 'rain'])
rnn_dat.train_test_split(   
    train_frac = .9,
    val_frac = .05
)

In [None]:
rnn_dat.scale_data()

In [None]:
rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])

In [None]:
import importlib
import moisture_rnn
importlib.reload(moisture_rnn)
from moisture_rnn import RNN

In [None]:
from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback
params.update({'epochs': 20, 'learning_rate': 0.0001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,
              'activation': ['relu', 'relu'], 'features_list': ['Ed', 'Ew', 'rain']})
reproducibility.set_seed(123)
rnn = RNN(params)

history = rnn.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, 
                    batch_size = params['batch_size'], epochs=params['epochs'], 
                    callbacks = [ResetStatesCallback(batch_reset = params['batch_reset'],
                                                     loc_batch_reset = rnn_dat.n_seqs),
                                EarlyStoppingCallback(patience = params['early_stopping_patience'])],
                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))
              

In [None]:
plt.figure()
plt.semilogy(history.history['loss'], label='Training loss')
if 'val_loss' in history.history:
    plt.semilogy(history.history['val_loss'], label='Validation loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.show()

In [None]:
preds = rnn.predict(rnn_dat.X_test[0])

In [None]:
plt.plot(rnn_dat.y_test[2])
plt.plot(preds)

In [None]:
params.update({'epochs': 10, 'verbose_fit': True, 'rnn_layers': 1, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 20,
              'activation': ['relu', 'relu'], 'stateful': False})
reproducibility.set_seed(123)
rnn = RNN(params)

rnn.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, 
                    batch_size = params['batch_size'], epochs=params['epochs'], 
                    callbacks = [ResetStatesCallback(batch_reset = params['batch_reset'],
                                                     loc_batch_reset = rnn_dat.n_seqs),
                                EarlyStoppingCallback(patience = params['early_stopping_patience'])],
                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))

In [None]:
preds = rnn.predict(rnn_dat.X_test[0])

plt.plot(rnn_dat.y_test[2])
plt.plot(preds)

In [None]:
rmse(rnn_dat.y_test[2], preds)

## LSTM

TODO: FIX BELOW

In [None]:
import importlib 
import moisture_rnn
importlib.reload(moisture_rnn)
from moisture_rnn import RNN_LSTM

In [None]:
with open("params.yaml") as file:
    params = yaml.safe_load(file)["lstm"]
    
rnn_dat2 = create_rnn_data2(train[case],params)

In [None]:
params.update({'epochs': 10})

In [None]:
reproducibility.set_seed()
rnn = RNN(params)
m, errs = rnn.run_model(rnn_dat2)

In [None]:
import importlib
importlib.reload(moisture_rnn)
from moisture_rnn import RNN_LSTM

In [None]:
with open("params.yaml") as file:
    params = yaml.safe_load(file)["lstm"]

rnn_dat2 = create_rnn_data2(train[case],params)
params

In [None]:
params.update({
    'learning_rate': 0.000001,
    'epochs': 10,
    'clipvalue':1.0
})

In [None]:
reproducibility.set_seed()
lstm = RNN_LSTM(params)
m, errs = lstm.run_model(rnn_dat2)