In [None]:
!pip install transformers

## Import librairies

In [None]:
import pickle
import numpy as np
import torch
import fine_tuning_functions
import test_models_functions

# specify GPU
device = torch.device("cuda")

## Load data

In [None]:
with open("data_nl_english_french_23300", 'rb') as f:
    data_english_french_nl = pickle.load(f)

## Define model and hyperparameters

In [None]:
model_name = "mBERT_all"
lr = 0.0005
dropout = 0.4
epochs = 60
folder_name = "TPE_" + model_name

## Split dataset into train, validation and test sets

In [None]:
train_text, train_labels, val_text, val_labels, test_text, test_labels = fine_tuning_functions.split_dataset(data_english_french_nl, model_name)

## Plot losses

In [None]:
all_epochs, train_losses, valid_losses = test_models_functions.get_epochs_train_val_losses(lr, dropout, epochs, folder_name)

In [None]:
test_models_functions.plot_train_val_losses(all_epochs, train_losses, valid_losses)

In [None]:
# Print the best epoch with the lowest validation loss
print(np.argmin(valid_losses))

## Metrics on test set

In [None]:
model = test_models_functions.get_model(model_name, folder_name, lr, dropout, epochs)

In [None]:
# If the model is multilingual
if "test" in model_name or "all" in model_name:
    # Get test tensors
    test_seq_nl, test_mask_nl, test_seq_en, test_mask_en, test_seq_fr, test_mask_fr, test_y = test_models_functions.get_test_tensors(model_name, test_text, test_labels)
    
    # Print metrics
    test_models_functions.print_accuracy_test_multilingual(model, test_seq_nl, test_mask_nl, test_seq_en, test_mask_en, test_seq_fr, test_mask_fr, test_y, device, model_name)
    test_models_functions.print_precision_recall_f1_test_multilingual(model, test_seq_nl, test_mask_nl, test_seq_en, test_mask_en, test_seq_fr, test_mask_fr, test_y, device, model_name)
    
else:
    # Get test tensors
    test_seq, test_mask, test_y = test_models_functions.get_test_tensors(model_name, test_text, test_labels)
    
    # Print metrics
    test_models_functions.print_accuracy_test_unilingual(model, test_seq, test_mask, test_y, device, model_name)
    test_models_functions.print_precision_recall_f1_test_unilingual(model, test_seq, test_mask, test_y, device, model_name)