# Setup

## Imports

In [None]:
# Import importlib to reload modules and sys and os to add the path for other imports
import importlib
import sys
import os
import torch

# Append the parent directory to the path to import the necessary modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import utilities
from utils import setuputil, trainutil, inferutil
from classes.models import SimpleGeluEmbed

# Reload the necessary modules to ensure they are up-to-date
importlib.reload(setuputil)
importlib.reload(trainutil)
importlib.reload(inferutil)
importlib.reload(SimpleGeluEmbed)

# Import the required utils
from utils.setuputil import setup_config, display_config
from utils.trainutil import train_model
from utils.inferutil import infer_one, infer_full

# Import the SimpleGeluEmbedAdd class
from classes.models.SimpleGeluEmbed import SimpleGeluEmbedAvg

## Config Setup

In [None]:
# Define the input configuration for the simple model
input_config = {
    # Environment and Model Info
    "env": "gcp",                
    "approach": "simple",         
    "model_name": "SimpleGeluEmbedAvg",
    
    # System Configuration
    "device": "cuda:0",
    "threads": 14,
    "seed": 42,
    
    # Data Configuration
    "data_dir": "../../data/farzan",
    "data_ds": "manual",
    
    # Model Parameters
    "rows": 100,
    "cols": 100,
    "tokens": 32,
    
    # Vocabulary Parameters
    "vocab_size": 150000,
    "vocab_space": True,
    "vocab_case": "both",
    
    # Training Parameters
    "batch": 40,
    "lr": 1e-1,
    "mu": 0.25,
    "epochs": 20,
    "patience": 2,
    "save_int": 0,
    "save_dir": '../models/'
}

# Setup the configuration using setuputil and display it
config = setup_config(input_config)
display_config(config)

# Define local variables from the config dictionary
# System variables
DEVICE = config["DEVICE"]
THREADS = config["THREADS"]

# Data loaders and vocab
train_loader = config["train_loader"]
val_loader = config["val_loader"]
test_loader = config["test_loader"]
spreadsheet_vocab = config["vocab"]
spreadsheet_wvs = config["wvs"]

# Training parameters
batch_size = config["batch"]
lr = config["lr"]
mu = config["mu"]
epochs = config["epochs"]
patience = config["patience"]
save_int = config["save_int"]
save_dir = config["save_dir"]
save_name = config["save_name"]


In [None]:
# Get the first item from train_loader
first_item = train_loader[0]

# Get the components
x_tok = first_item['x_tok']
x_masks = first_item['x_masks']
y_tok = first_item['y_tok']
filepath = first_item['file_paths']

print(f"File: {filepath}\n")

print("Shapes:")
print(f"x_tok: {x_tok.shape}")  # Should be 32-length vector
print(f"y_tok: {y_tok.shape}")  # Should be 32-length vector
print(f"x_masks: {x_masks.shape if isinstance(x_masks, torch.Tensor) else len(x_masks)}\n")

# Extract cell location [10,10]
x_cell = x_tok[10,10,:]  
y_cell = y_tok[10,10,:]

print("Values at position [10,10]:")
print(f"\nx_tok: {x_cell.tolist()}")
print(f"\nx_tok decoded: {[spreadsheet_vocab.decode(idx) for idx in x_cell.tolist()]}")
print(f"\ny_tok: {y_cell.tolist()}")



# Get the first item from train_loader
first_item = val_loader[0]

# Get the components
x_tok = first_item['x_tok']
x_masks = first_item['x_masks']
y_tok = first_item['y_tok']
filepath = first_item['file_paths']

print(f"File: {filepath}\n")

print("Shapes:")
print(f"x_tok: {x_tok.shape}")  # Should be 32-length vector
print(f"y_tok: {y_tok.shape}")  # Should be 32-length vector
print(f"x_masks: {x_masks.shape if isinstance(x_masks, torch.Tensor) else len(x_masks)}\n")

# Extract cell location [10,10]
x_cell = x_tok[10,10,:]  
y_cell = y_tok[10,10,:]

print("Values at position [10,10]:")
print(f"\nx_tok: {x_cell.tolist()}")
print(f"\nx_tok decoded: {[spreadsheet_vocab.decode(idx) for idx in x_cell.tolist()]}")
print(f"\ny_tok: {y_cell.tolist()}")

# Model Training

## Define the model

In [None]:
# Define the untrained model and move it to the device
untrained_model = SimpleGeluEmbedAvg(spreadsheet_wvs).to(DEVICE)
print(untrained_model)

## Train the Model

In [None]:
trained_model = train_model(
    model=untrained_model,
    train_data=train_loader, 
    val_data=val_loader, 
    DEVICE=DEVICE, 
    batch_size=batch_size,
    lr=lr,
    mu=mu,
    max_epochs=epochs,
    patience=patience,
    save_int=save_int,
    save_dir=save_dir,
    save_name=save_name,
    config=config
)

# Evaluation

In [None]:
# Define params for evaluation
thresh = 0.91
loc = 0
cond = '>'
disp_max=True

## Single Example

In [None]:
# Check with single example
infer_one(trained_model, train_loader, loc=loc, threshold=thresh, condition=cond, disp_max=disp_max, device=DEVICE)
infer_one(trained_model, val_loader, loc=loc, threshold=thresh, condition=cond, disp_max=disp_max, device=DEVICE)
infer_one(trained_model, test_loader, loc=loc, threshold=thresh, condition=cond, disp_max=disp_max, device=DEVICE)

## All Examples

In [None]:
# Evaluate the model on all train files
infer_full(trained_model, train_loader, batch_size=batch_size, threshold=thresh, device=DEVICE)

In [None]:
# All val files
infer_full(trained_model, val_loader, batch_size=batch_size, threshold=thresh, device=DEVICE)

In [None]:
# All test files
infer_full(trained_model, test_loader, batch_size=batch_size, threshold=thresh, device=DEVICE)