### Import libraries and classes

In [None]:
import os
import sys
import random
import itertools

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from tqdm import tqdm

# Adding 'src' directory to the system path
project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(os.path.join(project_root, 'src'))

from my_classes import EmbeddingDataset, EmbeddingClassifier

### Set the task

In [None]:
task = "all" # has to be one of 'all', 'Sheldon_Penny', 'Sheldon_Leonard'

### Create dataset

In [None]:
# Load embeddings from the pickle file
df = pd.read_pickle("../data/processed/sbert_mini_embeddings.pkl")
if task == "Sheldon_Leonard":
    df =df[df['Person'].isin(['Sheldon','Leonard'])]
elif task == "Sheldon_Penny":
    df = df[df['Person'].isin(['Sheldon','Penny'])]
elif task != "all":
    print("Task not recognized, using all data.")

# Convert embeddings to a tensor
X = np.stack(df["Embedding"].values)
X = torch.tensor(X, dtype=torch.float32)

# Label encoding for the 'Person' column
label_encoder = LabelEncoder()
y = torch.tensor(label_encoder.fit_transform(df["Person"]), dtype=torch.long)
y_np = y.numpy()
all_classes = np.unique(y_np)
# print(len(all_classes))

# Split the dataset into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

train_dataset = EmbeddingDataset(X_train, y_train)
val_dataset = EmbeddingDataset(X_val, y_val)

# create loaders for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

### Set seeds for reproducibility

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)

torch.backends.cudnn.deterministic = True  # disable non-deterministic optimizations
torch.backends.cudnn.benchmark = False     # disable benchmarking for reproducibility

### Grid search

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define hyperparameters for the grid search
learning_rates = [1e-5, 5e-5, 1e-4, 5e-4]
dropout_rates = [0, 0.1, 0.2]
weight_decays = [0, 1e-4, 5e-4, 1e-3]
losses = ["cross_entropy","weighted_cross_entropy"]

for lr, dropout_rate, weight_decay, loss_name in itertools.product(learning_rates, dropout_rates, weight_decays, losses):

  model = EmbeddingClassifier(input_dim=384, num_classes=len(all_classes), dropout_rate=dropout_rate).to(device)
  optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

  # Define the loss function
  if loss_name == "weighted_cross_entropy":
    class_weights = compute_class_weight(class_weight='balanced', classes=all_classes, y=y_np)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)
  elif loss_name == "cross_entropy":
    loss_fn = nn.CrossEntropyLoss()

  # Function to compute accuracy
  def compute_accuracy(logits, labels):
      preds = torch.argmax(logits, dim=1)
      return (preds == labels).float().mean().item()

  # To monitor training and validation losses and accuracies
  train_losses = []
  val_losses = []
  train_accuracies = []
  val_accuracies = []
  best_val_loss = 10.0    # An arbitrary high value for initial best validation loss
  best_model_path = f"../models/classifier_only/{task}/{loss_name}/best_classifier_{lr}_{dropout_rate}_{weight_decay}.pt"
  # print(best_model_path)

  # Early stopping settings
  patience = 10
  no_improvement = 0
  num_epochs = 1000

  for epoch in range(num_epochs):
    # ---- TRAIN ----
    model.train()
    total_train_loss = 0
    total_train_acc = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
        inputs = batch["embedding"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        total_train_acc += compute_accuracy(outputs, labels)

    avg_train_loss = total_train_loss / len(train_loader)
    avg_train_acc = total_train_acc / len(train_loader)

    # ---- VAL ----
    model.eval()
    total_val_loss = 0
    total_val_acc = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
            inputs = batch["embedding"].to(device)
            labels = batch["label"].to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            total_val_loss += loss.item()
            total_val_acc += compute_accuracy(outputs, labels)

    avg_val_loss = total_val_loss / len(val_loader)
    avg_val_acc = total_val_acc / len(val_loader)

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accuracies.append(avg_train_acc)
    val_accuracies.append(avg_val_acc)

    # ---- Early Stopping ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"💾 Best model saved (val loss: {best_val_loss:.4f})")
        no_improvement = 0
    else:
        no_improvement += 1
        print(f"⚠️ No improvement for {no_improvement} epoch(s)")

    # ---- Print Summary for the epoch ----
    print(f"\n📊 Epoch {epoch+1} Summary:")
    print(f"  🔹 Train Loss: {avg_train_loss:.4f} | Accuracy: {avg_train_acc:.4f}")
    print(f"  🔸 Val   Loss: {avg_val_loss:.4f} | Accuracy: {avg_val_acc:.4f}\n")

    # ---- Early Stopping Check ----
    if no_improvement >= patience:
        print(f"⏹️ Early stopping triggered after {epoch+1} epochs.")
        break
    
  # ---- Save Metrics ----
  metrics_df = pd.DataFrame({
    "epoch": list(range(1, len(train_losses) + 1)),
    "train_loss": train_losses,
    "val_loss": val_losses,
    "train_accuracy": train_accuracies,
    "val_accuracy": val_accuracies,
    "learning_rate": lr,
    "dropout_rate": dropout_rate,
    "weight_decay": weight_decay,
  })

  # Save the metrics DataFrame to a CSV file
  csv_path = f"../metrics/classifier_only/{task}/{loss_name}/metrics_{lr}_{dropout_rate}_{weight_decay}.csv"
  metrics_df.to_csv(csv_path, index=False)