## Imports

In [None]:
from src.model_training import RNNLanguageModel, LSTMLanguageModel, train, plot_and_save_training_metrics
import numpy as np
import os
import pickle
import random
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Setup

In [2]:
# Device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seeds
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
random.seed(seed)
np.random.seed(seed)

## Load Data

In [None]:
with open(os.path.normpath(os.path.join("data", "word_tokenisation_reuters_data.pkl")), "rb") as f:
    word_tokenised_numericalised_docs = pickle.load(f)

with open(os.path.normpath(os.path.join("data", "word_tokenisation_reuters_train_vocab.pkl")), "rb") as f:
    word_tokenisation_train_vocab = pickle.load(f)

with open(os.path.normpath(os.path.join("data", "subword_tokenisation_reuters_data.pkl")), "rb") as f:
    subword_tokenised_numericalised_docs = pickle.load(f)

with open(os.path.normpath(os.path.join("data", "subword_tokenisation_reuters_train_vocab.pkl")), "rb") as f:
    subword_tokenisation_train_vocab = pickle.load(f)

## Set Model Hyperparameters

In [4]:
EMBEDDING_SIZE = 128
HIDDEN_SIZE = 256
NUM_LAYERS = 2
DROPOUT = 0.0
USE_WORD_TOKENISATION = True    # Choose either word tokenisation or subword tokenisation

In [5]:
if USE_WORD_TOKENISATION:
    converted_tokenised_docs = word_tokenised_numericalised_docs
    train_vocab = word_tokenisation_train_vocab
else:
    converted_tokenised_docs = subword_tokenised_numericalised_docs
    train_vocab = subword_tokenisation_train_vocab

## Set Training Hyperparameters

In [None]:
SEQUENCE_LENGTH = 32
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
GRADIENT_CLIPPING_MAX_NORM = 1.0
PATIENCE = 5       # For early stopping. Represents max number of consecutive epochs where val loss does not improve, before early stopping is triggered

## RNN

In [7]:
# Initialise model
rnn = RNNLanguageModel(
    vocab_size=len(train_vocab),
    embed_size=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    pad_idx=train_vocab["<pad>"]
).to(device)
rnn_save_name = f"rnn_word_tokens_{USE_WORD_TOKENISATION}_context_{SEQUENCE_LENGTH}"

In [None]:
# Train model
rnn_trained, rnn_train_loss_history, rnn_val_loss_history = train(
    model=rnn,
    converted_tokenised_docs=converted_tokenised_docs,
    train_vocab=train_vocab,
    seq_len=SEQUENCE_LENGTH,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    grad_clipping_max_norm=GRADIENT_CLIPPING_MAX_NORM,
    patience=PATIENCE,
    device=device,
    save_name=rnn_save_name
)

Epoch 0/9
----------


100%|██████████| 7603/7603 [36:50<00:00,  3.44it/s]  


Train Loss: 3.6771


100%|██████████| 824/824 [01:36<00:00,  8.51it/s]


Val Loss: 8.1882
Best val loss has improved. Counter: 0 | Best val loss: 8.188167695570918

Epoch 1/9
----------


100%|██████████| 7603/7603 [1:00:01<00:00,  2.11it/s]   


Train Loss: 1.5684


100%|██████████| 824/824 [01:12<00:00, 11.30it/s]


Val Loss: 9.6689
Best val loss did not improve. Counter: 1 | Best val loss: 8.188167695570918

Epoch 2/9
----------


 62%|██████▏   | 4724/7603 [42:32<16:17,  2.94it/s]     

In [None]:
# View and save loss curves
plot_and_save_training_metrics(train_loss_history=rnn_train_loss_history, val_loss_history=rnn_val_loss_history, save_name=rnn_save_name)

## LSTM

In [None]:
# Initialise model
lstm = LSTMLanguageModel(
    vocab_size=len(train_vocab),
    embed_size=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    pad_idx=train_vocab["<pad>"]
).to(device)
lstm_save_name = f"lstm_word_tokens_{USE_WORD_TOKENISATION}_context_{SEQUENCE_LENGTH}"

In [None]:
# Train model
lstm_trained, lstm_train_loss_history, lstm_val_loss_history = train(
    model=lstm,
    converted_tokenised_docs=converted_tokenised_docs,
    train_vocab=train_vocab,
    seq_len=SEQUENCE_LENGTH,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    grad_clipping_max_norm=GRADIENT_CLIPPING_MAX_NORM,
    patience=PATIENCE,
    device=device,
    save_name=lstm_save_name
)

In [None]:
# View and save loss curves
plot_and_save_training_metrics(train_loss_history=lstm_train_loss_history, val_loss_history=lstm_val_loss_history, save_name=lstm_save_name)