# Setup

## Imports

In [None]:
# Import importlib to reload modules and sys and os to add the path for other imports
import importlib
import sys
import os
import torch

# Append the parent directory to the path to import the necessary modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import utilities
from utils import setuputil, trainutil, inferutil
from classes.models import Rnn2d

# Reload the necessary modules to ensure they are up-to-date
importlib.reload(setuputil)
importlib.reload(trainutil)
importlib.reload(inferutil)
importlib.reload(Rnn2d)

# Import the required utils
from utils.setuputil import setup_simple_config, display_simple_config
from utils.trainutil import train_model
from utils.inferutil import infer_one, infer_full

# Import the SimpleGeluEmbedAdd class
from classes.models.Rnn2d import Rnn2dSquare

## Config

In [None]:
# Define the input configuration for the RNN model
setup_config = {
    # Environment and Model Info
    "env": "gcp",                
    "approach": "rnn",         
    "model_name": "Rnn2dSquare",
    
    # System Configuration
    "device": "cuda:0",
    "threads": 12,
    "seed": 42,
    
    # Data Configuration
    "data_dir": "../../data/farzan",
    "data_ds": "manual",
    
    # Model Parameters
    "rows": 100,
    "cols": 100,
    "tokens": 32,
    
    # RNN-Specific Parameters
    "hidden_dim": 100,         # Dimension of the hidden state vector
    "rnn_layers": 2,           # Number of RNN layers
    "dropout_rate": 0.05,      # Dropout rate for regularization
    "nonlinearity": "relu",    # Nonlinearity for the RNN (e.g., relu, tanh)
    
    # Vocabulary Parameters
    "vocab_size": 150000,
    "vocab_space": True,
    "vocab_case": "both",
    
    # Training Parameters
    "batch": 10,
    "lr": 7e-5,
    "mu": 0.25,
    "epochs": 20,
    "patience": 3,
    "save_int": 5,
    "save_dir": '../models/'
}

# Setup the configuration using setuputil and display it
config = setup_simple_config(setup_config)
display_simple_config(config)

## Local Variables

In [None]:
# Define local variables from the generated config dictionary to run file
# System variables
DEVICE = config["DEVICE"]
THREADS = config["THREADS"]

# Data loaders and vocab
train_loader = config["train_loader"]
val_loader = config["val_loader"]
test_loader = config["test_loader"]
spreadsheet_vocab = config["vocab"]
spreadsheet_wvs = config["wvs"]

# RNN Specific Params
hidden_dim = config["hidden_dim"]
rnn_layers = config['rnn_layers']
dropout_rate = config['dropout_rate']
nonlinearity = config['nonlinearity']


# Training parameters
batch_size = config["batch"]
lr = config["lr"]
mu = config["mu"]
epochs = config["epochs"]
patience = config["patience"]
save_int = config["save_int"]
save_dir = config["save_dir"]
save_name = config["save_name"]

# Model moved to device with parameters from the config
untrained_model = Rnn2dSquare(
    hidden_state_dim=hidden_dim,
    rnn_layers=rnn_layers,
    embedding_matrix=spreadsheet_wvs,
    dropout_rate=dropout_rate,
    nonlinearity=nonlinearity
).to(DEVICE)

# Print the model to confirm initialization
print(untrained_model)

## Train the Model

In [None]:
trained_model = train_model(
    model=untrained_model,
    train_data=train_loader, 
    val_data=val_loader, 
    DEVICE=DEVICE, 
    batch_size=batch_size,
    lr=lr,
    mu=mu,
    max_epochs=epochs,
    patience=patience,
    save_int=save_int,
    save_dir=save_dir,
    save_name=save_name,
    config=config
)

# Evaluation

In [None]:
# Define params for evaluation
thresh = 0.5
loc = 0
cond = '>'
disp_max=True

## Single Example

In [None]:
infer_one(trained_model, train_loader, loc=loc, threshold=thresh, condition=cond, disp_max=disp_max, device=DEVICE)

In [None]:
infer_one(trained_model, val_loader, loc=loc, threshold=thresh, condition=cond, disp_max=disp_max, device=DEVICE)

In [None]:
infer_one(trained_model, test_loader, loc=loc, threshold=thresh, condition=cond, disp_max=disp_max, device=DEVICE)

## All Examples

In [None]:
# Evaluate the model on all train files
infer_full(trained_model, train_loader, batch_size=batch_size, threshold=thresh, device=DEVICE)

In [None]:
# All val files
infer_full(trained_model, val_loader, batch_size=batch_size, threshold=thresh, device=DEVICE)

In [None]:
# All test files
infer_full(trained_model, test_loader, batch_size=batch_size, threshold=thresh, device=DEVICE)