# Notebook for test

# Initialization

In [None]:
# CHANGE THIS PART TO TEST DIFFERENT MODELS!!!
# The only parameters to change are the directory path of the model to test
# and the Activation Function type
# ****---------------------------------------****
actual_dir = "./results_tmp/CkTIM_TAB6_60epochs"
AF_TYPE = 'sin'
HIDDEN_SCALE = 1
N_TAB = 6
# ****---------------------------------------****
# These other are automatic based on the name of the directory!!!
# Ex. CkkTIM_aug, CkTIM (no augmentation) ...
GEN_TYPE = 'convKan' if 'Ckk' in actual_dir else 'conv'
CK = 'Ck' in actual_dir 
AUG = 'aug' in actual_dir

In [None]:
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn

import dataloader.utils as utils
import dataloader.dataset as dataset
from dataloader.preprocessing import Preprocessing
from model.ctim import CTIM
from testing.test import test as test

from torch.utils.data import Dataset, DataLoader

# Check if Hydra is initialized
if GlobalHydra().is_initialized():
    # Clear the Hydra instance if it is initialized
    GlobalHydra.instance().clear()
    print("Hydra instance was initialized and has been cleared.")
else:
    # Initialize
    print("Hydra now initialized!")

# Initialization and Load configuration
initialize(config_path="./conf", job_name="notebook_nn_exam", version_base=None)
cfg = compose(config_name="config")

# Set the devide mode on GPU (if available CUDA for Nvidia and  MPS for Apple Silicon) or CPU
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Your device is:", device)

# Load the hyperparameters of the model
DROPOUT_RATE = cfg.model.dropout_rate
KERNEL_SIZE = cfg.model.kernel_size
N_FILTER = cfg.model.n_filter
NUM_FEATURES = cfg.model.num_features
OMEGA_0 = cfg.model.omega_0


model = CTIM(
    kernel_size=KERNEL_SIZE,
    dropout_rate=DROPOUT_RATE,
    n_temporal_aware_block=N_TAB,
    n_filter=N_FILTER,
    in_channels=39,
    ck=CK,
    generator_type= GEN_TYPE,
    num_features=NUM_FEATURES,
    num_classes = 7,
    omega_0=OMEGA_0,
    af_type=AF_TYPE,
    hidden_scale=HIDDEN_SCALE,
    device=device
)

model_path = actual_dir + '/model.pt'

# Model Summary

In [None]:
parameters = utils.count_parameters(model)

# TEST 

In [None]:
if AUG == True:
    test_path = './dataset/EMOVO_aug/test'
    test_dataset = dataset.EMOVO_Dataset(test_path, feature_extract=True, mfcc_np=False, device=device)

    # Initialization of the dataloader
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

    # Loss Function
    loss_fn = nn.CrossEntropyLoss()

    # Initialization of the test metrics dictionary
    test_metrics_dict = {
        "model" : [model.model_name],
        "epoch": [],
        "loss": [],
        "accuracy": [],
        "recall": [],
        "precision": [],
        "f1_score": [],
        "inference_time_one_sample": [],
        "inference_time_tot": [],
        "trainable_parameters": [parameters]
    }

    cm, inference_time_list = test(
        model,
        model_path,
        test_dataloader,
        test_metrics_dict,
        loss_fn
    )

elif AUG == False:

    test_path = './dataset/EMOVO_split_MFCC/test'
    test_dataset = dataset.EMOVO_Dataset(test_path, feature_extract=False, mfcc_np=True, device=device)

    # Initialization of the dataloader
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

    # Loss Function
    loss_fn = nn.CrossEntropyLoss()

    # Initialization of the test metrics dictionary
    test_metrics_dict = {
        "model" : [model.model_name],
        "epoch": [],
        "loss": [],
        "accuracy": [],
        "recall": [],
        "precision": [],
        "f1_score": [],
        "inference_time_one_sample": [],
        "inference_time_tot": [],
        "trainable_parameters": [parameters]
    }

    cm, inference_time_list = test(
        model,
        model_path,
        test_dataloader,
        test_metrics_dict,
        loss_fn
    )
    
inference_time_one_sample = round(torch.tensor(inference_time_list).mean().item(),3)
inference_time_tot = round(torch.tensor(inference_time_list).sum().item(),3)

print("Inference Time for 1 sample (average value):")
print(inference_time_one_sample)
print("Inference Time for all samples of test set:")
print(inference_time_tot)

# Loss and Accuracy (Training and Validation)

In [None]:
path_training_metrics = actual_dir + '/training_metrics.json'
path_validation_metrics = actual_dir+ '/validation_metrics.json'

# load the dictionary {'model': [model_name], 'epoch': [1, 2 ...], 'loss': [1.9, 1.8 ...], 'accuracy': [0.6, 0.7, ...], 'recall': [0.2, 0.3, ...], 'precision': [0.3, 0.4, ...], 'f1-score': [0.5, 0.6, ...]}
data_training = utils.load_metrics(path_training_metrics)
data_validation = utils.load_metrics(path_validation_metrics)

# Load the epochs, loss and accuracy in training and validation for the plots
epochs = data_training['epoch']
training_loss = data_training['loss']
training_accuracy = data_training['accuracy']

validation_loss = data_validation['loss']
validation_accuracy = data_validation['accuracy']

# Plot section
# 1st Subplot => Loss in Training and Validation
# 2nd Subplot => Accuracy in Training and Validation
utils.plot_loss_acc(epochs, training_loss, validation_loss, training_accuracy, validation_accuracy)

# Confusion Matrix

In [None]:
# Plot the confusion Matrix
class_names = ['dis', 'gio', 'neu', 'pau', 'rab', 'sor', 'tri']
utils.plot_confusion_matrix(cm, class_names, cmap='rocket')