# Simplified Jupyter notebook

In [6]:
import sys
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # to avoid warnings in transformers
from pathlib import Path
# Get the current directory
current_dir = os.getcwd()

# Append the relative path to the utils folder
sys.path.append(os.path.join(current_dir, "utils"))

from importlib import reload

import utils.model
reload(utils.model)
from utils.model import (
    NRMSModel
)

import utils.helper
reload(utils.helper)
from utils.helper import (
    HParams,
    prepare_training_data,
    torch,
    train_and_evaluate,
    evaluate_model,
)


In [7]:
# Setting hyperparameters
hparams = HParams()
hparams.data_fraction = 0.01
hparams.batch_size = 32

# Preprocessing and Loading Data
# Data loading
PATH = Path(os.path.join(current_dir, "data"))
DATASPLIT = "ebnerd_small"

print("Loading data from ", PATH)
train_loader, val_loader, word_embeddings = prepare_training_data(
    hparams, PATH, DATASPLIT
)

# Initialize and train model
print("Training model with ", hparams)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NRMSModel(hparams, word_embeddings)

model = train_and_evaluate(
    device,
    model,
    train_loader,
    val_loader,
    hparams,
    patience=3,
)

# Evaluate model
metrics = evaluate_model(model, val_loader, device)
print("\nValidation Metrics:")
for metric_name, value in metrics.items():
    print(f"{metric_name}: {value:.4f}")

Loading data from  /Users/kevinmoore/Git Repositories/dtu-02456-deep-learning-ebnerd/our_implementation/data
Training model with  
 title_size: 30
 head_num: 20
 head_dim: 20
 attention_hidden_dim: 200
 dropout: 0.2
 batch_size: 32
 verbose: False
 data_fraction: 0.01
 sampling_nratio: 4
 history_size: 20
 epochs: 1
 learning_rate: 0.001
 transformer_model_name: facebookai/xlm-roberta-base


Epoch 1/1: 100%|██████████| 63/63 [00:52<00:00,  1.20batch/s]
Validation: 100%|██████████| 11/11 [00:01<00:00,  6.06batch/s]


Epoch 1/1, Train Loss: 1.6022, Val Loss: 1.6025, Val AUC: 0.5396, Improvement from Previous Epoch: 0.5396
Checkpoint saved to: checkpoints/nrms_checkpoint_1.pth

Validation Metrics:
auc: 0.5396
