In [1]:
%matplotlib inline
import torch
from d2l import torch as d2l
from utils.trainer import training_loop
from models.rnns import RNNScratch, RNNLMScratch

## Hyperparameters

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [20]:
model_name = "rnnlm-scratch"
n_epochs = 1500
early_stopping_params = {
    'metric': 'f1',      # Monitor F1 score
    'mode': 'max',       # We want to maximize F1
    'patience': 1000,       # Wait for 5 epochs before stopping
    'min_delta': 0   # Minimum change to qualify as improvement
}

## Data

In [21]:
data = d2l.TimeMachine(batch_size=1024, num_steps=32, num_train=10240, num_val=5120)

## Model

In [22]:
rnn = RNNScratch(num_inputs=len(data.vocab), num_hiddens=512).to(device)
model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=1e-3).to(device)

In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

In [24]:
training_loop(
    n_epochs=n_epochs,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=data.get_dataloader(train=True),
    val_loader=data.get_dataloader(train=False),
    device=device,
    class_names=data.vocab.idx_to_token,
    model_name=model_name,
    early_stopping_params=early_stopping_params,
)

Training on device cuda
Early stopping enabled: monitoring f1, mode=max, patience=1000
TensorBoard logs will be saved to runs/rnnlm-scratch_RNNLMScratch_20250415-000130


  for X in inputs:  # Shape of inputs: (num_steps, batch_size, num_inputs)


2025-04-15 00:01:31.905152 Epoch 1
  Training:   Loss: 3.1266, Accuracy: 0.2068
  Validation: Loss: 2.8873, Accuracy: 0.1956, F1: 0.0761
2025-04-15 00:01:34.585877 Epoch 5
  Training:   Loss: 2.8342, Accuracy: 0.1876
  Validation: Loss: 2.7939, Accuracy: 0.1945, F1: 0.0750
2025-04-15 00:01:38.081152 Epoch 10
  Training:   Loss: 2.5739, Accuracy: 0.2751
  Validation: Loss: 2.5237, Accuracy: 0.2855, F1: 0.1940
2025-04-15 00:01:41.403699 Epoch 15
  Training:   Loss: 2.3523, Accuracy: 0.3228
  Validation: Loss: 2.3194, Accuracy: 0.3354, F1: 0.2536
2025-04-15 00:01:44.706088 Epoch 20
  Training:   Loss: 2.2124, Accuracy: 0.3463
  Validation: Loss: 2.2225, Accuracy: 0.3465, F1: 0.2741
2025-04-15 00:01:47.757472 Epoch 25
  Training:   Loss: 2.0388, Accuracy: 0.3959
  Validation: Loss: 2.1135, Accuracy: 0.3772, F1: 0.3269
2025-04-15 00:01:51.130064 Epoch 30
  Training:   Loss: 1.8762, Accuracy: 0.4474
  Validation: Loss: 2.0640, Accuracy: 0.4032, F1: 0.3710
2025-04-15 00:01:54.378442 Epoch 35


KeyboardInterrupt: 