# RNN Exploration & Analysis
**Author: Jibran**<br>
**Date: 2023-11-01**

## Imports & Setup

In [None]:
import logging
from pathlib import Path

import numpy as np
import torch

from src.data_preprocess.rnn_data_prep import RNNDataPrep
from src.models import rnn_model
from src.models.train_eval import train_eval_model
from src.models.helpers_rnn import plot_predicted_probabilities
from src.utils.utilities import create_config_dict

# Set up logging
# The logging level can be set to one of the following:
# DEBUG - Detailed information, typically of interest only when diagnosing problems.
# INFO - Confirmation that things are working as expected.
# WARNING - An indication that something unexpected happened, or indicative of some problem in the near future
# (e.g. ‘disk space low’). The software is still working as expected.
# ERROR - Due to a more serious problem, the software has not been able to perform some function.
# CRITICAL - A serious error, indicating that the program itself may be unable to continue running.
logging.basicConfig(level=logging.INFO,
                    format='%(name)s - %(levelname)s - %(message)s\n')

## Process Data and/or Get Train/Test Splits

In [None]:
save_data = False
save_train_test = False

rnn_data_prep = RNNDataPrep(pickle_path="../data/interim/ff-mw.pkl",
                            train_test_data_par_dir="../data/processed/rnn_input/")

X_train, Y_train, X_test, Y_test = rnn_data_prep.get_rnn_data(
    load_train_test=False, sequence_length=5, split_ratio=2/3, save_data=save_data)

In [None]:
# Print shapes
print("X_train shape: ", X_train.shape)
print("Y_train shape: ", Y_train.shape)
print("X_test shape: ", X_test.shape)
print("Y_test shape: ", Y_test.shape)

## Running some checks...

### Check data imbalance

In [None]:
# Check for data imbalance in Y_train and Y_test
# Note that the single feature in Y data is a binary classification
# 0: no walk
# 1: walk
logging.info("Checking for data imbalance...")
logging.info(f"Y_train: {np.unique(Y_train, return_counts=True)}")
logging.info(f"Y_test: {np.unique(Y_test, return_counts=True)}")

### Check for invalid values (NaNs and Infs)

In [None]:
# num_nan_values_X_train = np.isnan(X_train).sum()
# num_nan_values_X_test = np.isnan(X_test).sum()
# num_nan_values_Y_train = np.isnan(Y_train).sum()
# num_nan_values_Y_test = np.isnan(Y_test).sum()
# print(f"Number of NaN values in train X set: {num_nan_values_X_train}")
# print(f"Number of NaN values in test X set: {num_nan_values_X_test}")
# print(f"Number of NaN values in train Y set: {num_nan_values_Y_train}")
# print(f"Number of NaN values in test Y set: {num_nan_values_Y_test}\n")

# num_inf_values_X_train = np.isinf(X_train).sum()
# num_inf_values_X_test = np.isinf(X_test).sum()
# num_inf_values_Y_train = np.isinf(Y_train).sum()
# num_inf_values_Y_test = np.isinf(Y_test).sum()
# print(f"Number of inf values in train X set: {num_inf_values_X_train}")
# print(f"Number of inf values in test X set: {num_inf_values_X_test}")
# print(f"Number of inf values in train Y set: {num_inf_values_Y_train}")
# print(f"Number of inf values in test Y set: {num_inf_values_Y_test}")

#### Replace NaNs and Infs with 0

In [None]:
# X_train[np.isnan(X_train)] = 0
# X_test[np.isnan(X_test)] = 0
# Y_train[np.isnan(Y_train)] = 0
# Y_test[np.isnan(Y_test)] = 0

# X_train[np.isinf(X_train)] = 0
# X_test[np.isinf(X_test)] = 0
# Y_train[np.isinf(Y_train)] = 0
# Y_test[np.isinf(Y_test)] = 0

# # handle_inf_na(X_train, X_test)

## RNN Model

### Train model

In [None]:
# Train the RNN model
print(f"Training RNN Model...\n===============================\n")
# - 1  ### -1 because we drop the target column????
input_size = X_train.shape[2]

# print(f"Input size: {input_size}\n\n")
hidden_size = 64
output_size = 2
num_epochs = 2
batch_size = 512
learning_rate = 0.01
batch_first = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, test_labels_and_probs = train_eval_model(X_train, Y_train, X_test, Y_test, input_size,
                         hidden_size, output_size, num_epochs, batch_size, learning_rate, device, batch_first=batch_first)

### Model Evaluation, Visualization, and Analysis

In [None]:
test_indices = rnn_data_prep.test_indices
df = rnn_data_prep.df

plot_df, mean_df = plot_predicted_probabilities(df, test_indices, test_labels_and_probs)


### Save model and config

In [None]:
# Create the model name
model_architecture = "rnn"
# get the raw data id, in this case 'ff-mw'
raw_data_id = rnn_data_prep.raw_data_id
version_number = 1
model_name = f"{model_architecture}_{raw_data_id}_v{version_number}"

# Define/get config details
rnn_timestamp = model.timestamp
interim_data_path = rnn_data_prep.interim_data_path
processed_data_path = rnn_data_prep.processed_data_path

# Create the configuration dictionary
config = create_config_dict(
    model_name=f"{rnn_timestamp}_{model_name}",
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_epochs=num_epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    raw_data_path=None,
    interim_data_path=interim_data_path,
    processed_data_path=processed_data_path,
    logging_level='DEBUG',
    logging_format='%(asctime)s - %(levelname)s - %(module)s - %(message)s'
)

# Save the trained model and configuration settings
model_dir = Path(f"models/{model_name}")
model_dir.mkdir(parents=True, exist_ok=True)
config_dir = Path(f"config/{model_name}")
config_dir.mkdir(parents=True, exist_ok=True)
rnn_model.save_model_and_config(model, model_name, rnn_timestamp,
                                interim_data_path, processed_data_path, config, model_dir, config_dir)

## Old/Extra/Misc. Code Below

In [None]:
# from sklearn.metrics import f1_score
# from torch.utils.data import DataLoader, Dataset
# import torch.nn as nn
# # Initialize a new model
# input_size = X_test.shape[2]  # Make sure this is correct
# hidden_size = 64
# output_size = 2
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# new_model = rnn_model.RNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size, batch_first=True).to(device)

# # Load the model
# model_path = "models/rnn_ff-mw_v1/20231020_1701_model_2c92c2793be07eaf3765665d6287ded4_971fce5d8c82c2d1bf8db68939c8162d.pt"
# state_dict = torch.load(model_path)
# new_model.load_state_dict(state_dict)
# # loaded_model = torch.load(model_path)
# new_model.eval()  # Set the model to evaluation mode

# def evaluate_f1(model, X_test, Y_test, batch_size, device):
#     test_dataset = rnn_model.WalkDataset(X_test, Y_test)
#     test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
#     model.eval()

#     # Initialize running loss & sum of squared gradients and parameters
#     running_loss = 0.0
#     correct = 0
#     total = 0
    
#     y_true = []
#     y_pred = []

#     with torch.no_grad():
#         for i, (inputs, labels) in enumerate(test_loader):
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             # Using CrossEntropyLoss as the loss function
#             criterion = nn.CrossEntropyLoss()
#             loss = criterion(outputs, labels)  # Compute loss
#             running_loss += loss.item()  # Accumulate loss
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)  # Accumulate total number of samples
#             correct += (predicted == labels).sum().item()
#             y_true.extend(labels.cpu().numpy().tolist())
#             y_pred.extend(predicted.cpu().numpy().tolist())
            
#     # Calculate average loss and accuracy over all batches
#     test_loss = running_loss / len(test_loader)
#     test_acc = correct / total

#     print(
#         f"Test Error: \n Accuracy: {(100*test_acc):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
#     f1 = f1_score(y_true, y_pred)  # You can change the "average" parameter to suit your needs

#     return f1

# # Make sure you load your saved model into the variable `loaded_model`
# # Also, ensure X_test, Y_test, batch_size and device are set

# f1 = evaluate_f1(new_model, X_test, Y_test, batch_size, device)
# print(f"F1 Score: {f1}")