## RNN with Adaptive Softmax

[Efficient softmax approximation for GPUs](https://arxiv.org/pdf/1609.04309v3.pdf)


In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torchtext
from torchtext import data

TEXT = data.Field(
    tokenizer_language='en',
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    batch_first=True,
)

Now, we can load our dataset

In [2]:
from torchtext.datasets import WikiText2
 
train, valid, test = WikiText2.splits(TEXT) 

TEXT.build_vocab(train, vectors="glove.6B.300d", min_freq=5)

print(f"We have {len(TEXT.vocab)} tokens in our vocabulary")

We have 20490 tokens in our vocabulary


## Iterator


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
BPTT_LEN = 35

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=BPTT_LEN, # this is where we specify the sequence length
    device=device,
    repeat=False)

In [4]:
import warnings
warnings.filterwarnings('ignore')
from torch import nn
import torch

seq_len, batch_size, hidden_size = 7, 20, 256
size = (seq_len, batch_size, hidden_size)
vocab_size = len(TEXT.vocab)

X = torch.randn(seq_len, batch_size, hidden_size)

trg = torch.randint(0, vocab_size, (seq_len * batch_size,))

out = nn.AdaptiveLogSoftmaxWithLoss(
    hidden_size, vocab_size, cutoffs=[10, 100, 1000]
)

X = X.view(-1, X.shape[-1])

print(X.shape, trg.shape)
asm_out = out(X, trg)

asm_out.loss

torch.Size([140, 256]) torch.Size([140])


In [6]:
# %load ../pytorch_lm/models/qrnn.py
import torch.nn as nn
from torchqrnn import QRNN


class RNNAdapLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, pad_idx, hidden_size,
                 cell_class=nn.GRU, dropout=0.20, zoneout=.0):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=PAD_IDX)

        #self.qrnn = QRNN(embedding_dim, hidden_size, num_layers=2, window=2, dropout=dropout, zoneout=zoneout)
        self.rnn = cell_class(embedding_dim, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.AdaptiveLogSoftmaxWithLoss(
            hidden_size, vocab_size, 
            cutoffs=[500, 2000, 10000],
        )

    def forward(self, inp, target=None, hidden=None):
        """
        Inputs are supposed to be just one step (i.e. one letter)
        """

        # inputs = [batch_size, seqlen]
        emb = self.embedding(inp)
        # emb = [batch, seqlen, embedding_dim]
        outputs, hidden = self.rnn(emb, hidden)
        # outputs = [batch, seqlen, hidden_size]
        outputs = outputs.contiguous().view(-1, outputs.shape[-1])
        # outputs = [seqlen * batch, hidden_size]
        outputs = self.dropout(outputs)
        asm_out = self.out(outputs, target)

        return asm_out, hidden


Create the Language Model

In [12]:
# %load ../pytorch_lm/training.py
from tqdm.auto import tqdm
import torch
import math
import numpy as np


def display_lr(lr):
    pow = math.floor(math.log10(lr))
    return f"{lr*(10**(-pow)):.2f}e{pow}"

def train(model, iterator, optimizer, criterion, clip_norm=None, ncols=None):
    """
    Trains the model for one full epoch
    """
    epoch_loss = 0
    epoch_perplexity = 0

    model.train()

    epoch_bar = tqdm(iterator, total=len(iterator), ncols=ncols)

    i = 0
    for batch in epoch_bar:
        i += 1
        optimizer.zero_grad()
        text = batch.text
        trg = batch.target.view(-1)

        out, _ = model(text, target=trg)
        if type(out) is torch.Tensor:
            # Not adaptive softmax
            # Use criterion
            preds = out.view(-1, out.shape[-1])
            loss = criterion(preds, trg)
        else:
            # Adaptive softmax
            loss = out.loss
        
        loss.backward()

        total_norm = 0

        for p in model.parameters():
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2

        total_norm = total_norm ** (1. / 2)

        if clip_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)

        optimizer.step()

        epoch_loss += loss.item()
        epoch_perplexity += np.exp(loss.item())

        lr = optimizer.param_groups[0]["lr"]

        epoch_bar.set_description(f"norm = {total_norm:.5f} loss = {epoch_loss / i:.4f} LR = {display_lr(lr)}")

    return epoch_loss / len(iterator), epoch_perplexity / len(iterator)


def evaluate(model, iterator, criterion=None):
    """
    Evaluates the model on the given iterator
    """
    epoch_loss = .0
    model.eval()
    with torch.no_grad():
        for batch in iterator:
            text = batch.text
            trg = batch.target.view(-1)

            out, _ = model(text, target=trg)
            if type(out) is torch.Tensor:
                # Not adaptive softmax
                # Use criterion
                preds = out.view(-1, out.shape[-1])
                loss = criterion(preds, trg)
            else:
                # Adaptive softmax
                loss = out.loss
            
            epoch_loss += loss.item()

        loss = epoch_loss / len(iterator)

        perplexity = np.exp(loss)

    return loss, perplexity

def training_cycle(model, train_iter, valid_iter, epochs,
                   optimizer, criterion, scheduler, model_path,
                   early_stopping_tolerance=None, ncols=None):

    best_valid_loss = float('inf')
    epochs_without_improvement = 0

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}")

        train_loss, train_perplexity = train(model, train_iter, optimizer, criterion, ncols=ncols)
        valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)

        scheduler.step(valid_loss)

        desc = f' Train Loss: {train_loss:.5f} Perp: {train_perplexity:.3f}'
        desc += f' Val. Loss: {valid_loss:.5f} Perp: {valid_perplexity:.3f}'

        print(desc)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            epochs_without_improvement = 0
            torch.save(model.state_dict(), model_path)
            print(f"Best model so far (Loss {best_valid_loss:.5f} Perp {valid_perplexity:.2f}) saved at {model_path}")
        else:
            epochs_without_improvement += 1
            if early_stopping_tolerance and epochs_without_improvement >= early_stopping_tolerance:
                print("Early stopping")
                break


## Training 

In [8]:
import torch.optim as optim


PAD_IDX = TEXT.vocab.stoi["<pad>"]
UNK_IDX = TEXT.vocab.stoi["<unk>"]


HIDDEN_DIM = 512
vocab_size = TEXT.vocab.vectors.shape[0]
embedding_dim = TEXT.vocab.vectors.shape[1]

model_path = "/tmp/qrnn_lang_model.pt"

model = QRNNAdaSoftmaxLanguageModel(
    vocab_size, embedding_dim, dropout=0.1, zoneout=0.1,
    hidden_size=HIDDEN_DIM, pad_idx=PAD_IDX)

# Set weight for UNK to a random normal
model.embedding.weight.data.copy_(TEXT.vocab.vectors)
model.embedding.weight.data[UNK_IDX] = torch.randn(embedding_dim)


optimizer = optim.SGD(model.parameters(), lr=1)
criterion = nn.CrossEntropyLoss()

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, threshold=0.0001)

model = model.to(device)
criterion = criterion.to(device)



In [9]:

N_EPOCHS = 400

best_valid_loss = float('inf')

early_stopping_tolerance = 39



training_cycle(
    epochs=200, ncols=800,
    model=model, train_iter=train_iter, valid_iter=valid_iter, 
    optimizer=optimizer, criterion=criterion, scheduler=lr_scheduler,
    model_path=model_path, early_stopping_tolerance=7
)

Epoch 1


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 6.32825 Perp: 6919978294719.197 Val. Loss: 5.60865 Perp: 272.775
Best model so far (Loss 5.60865 Perp 272.78) saved at /tmp/qrnn_lang_model.pt
Epoch 2


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 5.56768 Perp: 264.571 Val. Loss: 5.35646 Perp: 211.972
Best model so far (Loss 5.35646 Perp 211.97) saved at /tmp/qrnn_lang_model.pt
Epoch 3


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 5.31514 Perp: 205.277 Val. Loss: 5.20139 Perp: 181.525
Best model so far (Loss 5.20139 Perp 181.52) saved at /tmp/qrnn_lang_model.pt
Epoch 4


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 5.13625 Perp: 171.551 Val. Loss: 5.09949 Perp: 163.938
Best model so far (Loss 5.09949 Perp 163.94) saved at /tmp/qrnn_lang_model.pt
Epoch 5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 5.00521 Perp: 150.454 Val. Loss: 5.02668 Perp: 152.427
Best model so far (Loss 5.02668 Perp 152.43) saved at /tmp/qrnn_lang_model.pt
Epoch 6


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.90317 Perp: 135.851 Val. Loss: 4.97945 Perp: 145.394
Best model so far (Loss 4.97945 Perp 145.39) saved at /tmp/qrnn_lang_model.pt
Epoch 7


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.81962 Perp: 124.962 Val. Loss: 4.94721 Perp: 140.782
Best model so far (Loss 4.94721 Perp 140.78) saved at /tmp/qrnn_lang_model.pt
Epoch 8


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.74713 Perp: 116.227 Val. Loss: 4.92355 Perp: 137.490
Best model so far (Loss 4.92355 Perp 137.49) saved at /tmp/qrnn_lang_model.pt
Epoch 9


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.68325 Perp: 109.031 Val. Loss: 4.90530 Perp: 135.004
Best model so far (Loss 4.90530 Perp 135.00) saved at /tmp/qrnn_lang_model.pt
Epoch 10


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.62666 Perp: 103.046 Val. Loss: 4.89835 Perp: 134.068
Best model so far (Loss 4.89835 Perp 134.07) saved at /tmp/qrnn_lang_model.pt
Epoch 11


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.57566 Perp: 97.924 Val. Loss: 4.89012 Perp: 132.969
Best model so far (Loss 4.89012 Perp 132.97) saved at /tmp/qrnn_lang_model.pt
Epoch 12


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.52963 Perp: 93.520 Val. Loss: 4.88810 Perp: 132.701
Best model so far (Loss 4.88810 Perp 132.70) saved at /tmp/qrnn_lang_model.pt
Epoch 13


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.48627 Perp: 89.552 Val. Loss: 4.88872 Perp: 132.784
Epoch 14


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.44684 Perp: 86.087 Val. Loss: 4.89181 Perp: 133.195
Epoch 15


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.40941 Perp: 82.931 Val. Loss: 4.89844 Perp: 134.081
Epoch 16


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.37495 Perp: 80.115 Val. Loss: 4.90873 Perp: 135.468
Epoch 17


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.34154 Perp: 77.487 Val. Loss: 4.91917 Perp: 136.889
Epoch 18


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.28343 Perp: 73.290 Val. Loss: 4.86071 Perp: 129.116
Best model so far (Loss 4.86071 Perp 129.12) saved at /tmp/qrnn_lang_model.pt
Epoch 19


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.25527 Perp: 71.163 Val. Loss: 4.86197 Perp: 129.278
Epoch 20


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.24347 Perp: 70.303 Val. Loss: 4.86389 Perp: 129.527
Epoch 21


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23450 Perp: 69.660 Val. Loss: 4.86601 Perp: 129.802
Epoch 22


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22834 Perp: 69.226 Val. Loss: 4.86787 Perp: 130.044
Epoch 23


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22212 Perp: 68.786 Val. Loss: 4.87052 Perp: 130.388
Epoch 24


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.24400 Perp: 70.388 Val. Loss: 4.84563 Perp: 127.184
Best model so far (Loss 4.84563 Perp 127.18) saved at /tmp/qrnn_lang_model.pt
Epoch 25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23552 Perp: 69.764 Val. Loss: 4.84405 Perp: 126.983
Best model so far (Loss 4.84405 Perp 126.98) saved at /tmp/qrnn_lang_model.pt
Epoch 26


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23256 Perp: 69.547 Val. Loss: 4.84363 Perp: 126.929
Best model so far (Loss 4.84363 Perp 126.93) saved at /tmp/qrnn_lang_model.pt
Epoch 27


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23136 Perp: 69.459 Val. Loss: 4.84349 Perp: 126.911
Best model so far (Loss 4.84349 Perp 126.91) saved at /tmp/qrnn_lang_model.pt
Epoch 28


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23021 Perp: 69.372 Val. Loss: 4.84357 Perp: 126.921
Epoch 29


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22893 Perp: 69.283 Val. Loss: 4.84353 Perp: 126.917
Epoch 30


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22839 Perp: 69.246 Val. Loss: 4.84380 Perp: 126.951
Epoch 31


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22759 Perp: 69.186 Val. Loss: 4.84396 Perp: 126.971
Epoch 32


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22694 Perp: 69.135 Val. Loss: 4.84408 Perp: 126.987
Epoch 33


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23285 Perp: 69.575 Val. Loss: 4.84074 Perp: 126.563
Best model so far (Loss 4.84074 Perp 126.56) saved at /tmp/qrnn_lang_model.pt
Epoch 34


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23125 Perp: 69.458 Val. Loss: 4.83987 Perp: 126.453
Best model so far (Loss 4.83987 Perp 126.45) saved at /tmp/qrnn_lang_model.pt
Epoch 35


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.23024 Perp: 69.384 Val. Loss: 4.83948 Perp: 126.404
Best model so far (Loss 4.83948 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 36


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22984 Perp: 69.354 Val. Loss: 4.83923 Perp: 126.373
Best model so far (Loss 4.83923 Perp 126.37) saved at /tmp/qrnn_lang_model.pt
Epoch 37


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22938 Perp: 69.318 Val. Loss: 4.83911 Perp: 126.356
Best model so far (Loss 4.83911 Perp 126.36) saved at /tmp/qrnn_lang_model.pt
Epoch 38


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22954 Perp: 69.331 Val. Loss: 4.83899 Perp: 126.342
Best model so far (Loss 4.83899 Perp 126.34) saved at /tmp/qrnn_lang_model.pt
Epoch 39


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22878 Perp: 69.279 Val. Loss: 4.83891 Perp: 126.332
Best model so far (Loss 4.83891 Perp 126.33) saved at /tmp/qrnn_lang_model.pt
Epoch 40


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22928 Perp: 69.309 Val. Loss: 4.83885 Perp: 126.324
Best model so far (Loss 4.83885 Perp 126.32) saved at /tmp/qrnn_lang_model.pt
Epoch 41


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22919 Perp: 69.307 Val. Loss: 4.83877 Perp: 126.314
Best model so far (Loss 4.83877 Perp 126.31) saved at /tmp/qrnn_lang_model.pt
Epoch 42


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22886 Perp: 69.281 Val. Loss: 4.83873 Perp: 126.309
Best model so far (Loss 4.83873 Perp 126.31) saved at /tmp/qrnn_lang_model.pt
Epoch 43


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22871 Perp: 69.272 Val. Loss: 4.83869 Perp: 126.304
Best model so far (Loss 4.83869 Perp 126.30) saved at /tmp/qrnn_lang_model.pt
Epoch 44


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22844 Perp: 69.257 Val. Loss: 4.83866 Perp: 126.300
Best model so far (Loss 4.83866 Perp 126.30) saved at /tmp/qrnn_lang_model.pt
Epoch 45


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22830 Perp: 69.247 Val. Loss: 4.83864 Perp: 126.298
Best model so far (Loss 4.83864 Perp 126.30) saved at /tmp/qrnn_lang_model.pt
Epoch 46


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22827 Perp: 69.239 Val. Loss: 4.83862 Perp: 126.295
Best model so far (Loss 4.83862 Perp 126.30) saved at /tmp/qrnn_lang_model.pt
Epoch 47


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22885 Perp: 69.283 Val. Loss: 4.83861 Perp: 126.293
Best model so far (Loss 4.83861 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 48


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22903 Perp: 69.292 Val. Loss: 4.83861 Perp: 126.293
Best model so far (Loss 4.83861 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 49


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22880 Perp: 69.276 Val. Loss: 4.83860 Perp: 126.293
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 50


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22842 Perp: 69.256 Val. Loss: 4.83860 Perp: 126.293
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 51


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22897 Perp: 69.285 Val. Loss: 4.83860 Perp: 126.293
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 52


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22865 Perp: 69.269 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 53


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22913 Perp: 69.298 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 54


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22865 Perp: 69.270 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 55


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22876 Perp: 69.271 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 56


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22868 Perp: 69.266 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 57


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22869 Perp: 69.270 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 58


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22905 Perp: 69.288 Val. Loss: 4.83860 Perp: 126.292
Epoch 59


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22870 Perp: 69.270 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 60


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22856 Perp: 69.263 Val. Loss: 4.83860 Perp: 126.292
Epoch 61


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22849 Perp: 69.258 Val. Loss: 4.83860 Perp: 126.292
Best model so far (Loss 4.83860 Perp 126.29) saved at /tmp/qrnn_lang_model.pt
Epoch 62


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22874 Perp: 69.273 Val. Loss: 4.83860 Perp: 126.292
Epoch 63


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22902 Perp: 69.291 Val. Loss: 4.83860 Perp: 126.292
Epoch 64


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22853 Perp: 69.260 Val. Loss: 4.83860 Perp: 126.292
Epoch 65


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22849 Perp: 69.252 Val. Loss: 4.83860 Perp: 126.292
Epoch 66


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22904 Perp: 69.294 Val. Loss: 4.83860 Perp: 126.292
Epoch 67


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22859 Perp: 69.261 Val. Loss: 4.83860 Perp: 126.292
Epoch 68


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


 Train Loss: 4.22896 Perp: 69.295 Val. Loss: 4.83860 Perp: 126.292
Early stopping


In [13]:


model.load_state_dict(torch.load(model_path))
model.eval()

valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)
test_loss, test_perplexity = evaluate(model, test_iter, criterion)


print(f"Valid loss      : {valid_loss:.2f}")
print(f"Valid perplexity: {valid_perplexity:.2f}\n")

print(f"Test loss      : {test_loss:.2f}")
print(f"Test perplexity: {test_perplexity:.2f}")

Valid loss      : 4.84
Valid perplexity: 126.29

Test loss      : 4.78
Test perplexity: 119.49


We can check perplexities for other models in [this blogpost](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)

A more complex recurrent network (using a cache of hidden states) achieves a perplexity of 100. So this very basic model (without any hyperparameter optimization) seems fairly ok

## Sampling

In [None]:
import numpy as np
import torch.nn.functional as F

def sample_sentence(init_token="<eos>", temperature=1):

    seq = [TEXT.vocab.stoi[init_token]]
    hidden = None
    while len(seq) == 1 or seq[-1] != EOS_IDX:
        inp = torch.LongTensor([[seq[-1]]]).to(device)
        out, hidden = model(inp, hidden=hidden)

        """
        Sample from probabilities
        """
        probs = F.softmax(out.view(-1) / temperature, dim=0)
        next_tok_idx = torch.multinomial(probs, num_samples=1)
        
        seq.append(next_tok_idx)
        
    return [TEXT.vocab.itos[t] for t in seq]

In [None]:
for temperature in np.arange(0.5, 2.5, 0.10):
    print("="*80, f"\nSampling with temperature = {temperature:.2f}")
    
    print(" ".join(sample_sentence("the", temperature=temperature)))

We can observe that:

- with hidden states there are more "meaningful" stuff
- quotation marks are closed when using the hidden state