## QRNN Language Model



In [1]:
import torch
import torchtext
from torchtext import data

TEXT = data.Field(
    tokenizer_language='en',
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    batch_first=True,
)

Now, we can load our dataset

In [2]:
from torchtext.datasets import WikiText2
 
train, valid, test = WikiText2.splits(TEXT) 

TEXT.build_vocab(train, vectors="glove.6B.300d", min_freq=5)

print(f"We have {len(TEXT.vocab)} tokens in our vocabulary")

We have 20490 tokens in our vocabulary


## Iterator


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
BPTT_LEN = 35

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=BPTT_LEN, # this is where we specify the sequence length
    device=device,
    repeat=False)

In [4]:
import warnings
warnings.filterwarnings('ignore')

import torch
from torchqrnn import QRNN

seq_len, batch_size, hidden_size = 7, 20, 256
size = (seq_len, batch_size, hidden_size)
X = torch.randn(size).to(device)
                
qrnn = QRNN(hidden_size, hidden_size, num_layers=2, dropout=0.4)
qrnn = qrnn.to(device)
output, hidden = qrnn(X)

print(output.size(), hidden.size())


torch.Size([7, 20, 256]) torch.Size([2, 20, 256])


In [9]:
# %load ../pytorch_lm/models/qrnn.py
import torch.nn as nn
from torchqrnn import QRNN


class QRNNLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, pad_idx, hidden_size,
                 cell_class=nn.GRU, dropout=0.20, zoneout=.0):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=PAD_IDX)

        self.qrnn = QRNN(embedding_dim, hidden_size, num_layers=2, window=2, dropout=dropout, zoneout=zoneout)
        #self.rnn = cell_class(embedding_dim, hidden_size, batch_first=True)

        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inp, hidden=None):
        """
        Inputs are supposed to be just one step (i.e. one letter)
        """

        # inputs = [batch_size, seqlen]
        emb = self.embedding(inp)
        # emb = [batch, seqlen, embedding_dim]
        emb = emb.permute(1, 0, 2)
        # emb = [seqlen, batch, embedding_dim]
        outputs, hidden = self.qrnn(emb, hidden)
        # outputs = [seqlen, batch, hidden_size]
        outputs = outputs.permute(1, 0, 2)
        # outputs = [batch, seqlen, hidden_size]

        # hidden = [batch, hidden_dim]

        out = self.fc(outputs)
        # out = [batch, vocab size]

        return out, hidden


Create the Language Model

In [10]:


PAD_IDX = TEXT.vocab.stoi["<pad>"]
UNK_IDX = TEXT.vocab.stoi["<unk>"]
EOS_IDX = TEXT.vocab.stoi["<eos>"]
SOS_IDX = TEXT.vocab.stoi["<sos>"]


## Training 

In [11]:
import torch.optim as optim

HIDDEN_DIM = 640
vocab_size = TEXT.vocab.vectors.shape[0]
embedding_dim = TEXT.vocab.vectors.shape[1]

model_path = "/tmp/qrnn_lang_model.pt"

model = QRNNLanguageModel(
    vocab_size, embedding_dim, dropout=0.1, zoneout=0.1,
    hidden_size=HIDDEN_DIM, pad_idx=PAD_IDX)

# Set weight for UNK to a random normal
model.embedding.weight.data.copy_(TEXT.vocab.vectors)
model.embedding.weight.data[UNK_IDX] = torch.randn(embedding_dim)


optimizer = optim.SGD(model.parameters(), lr=1)
criterion = nn.CrossEntropyLoss()

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, threshold=0.0001)

model = model.to(device)
criterion = criterion.to(device)


## An example of calculating the loss
batch = next(iter(train_iter))

preds, _ = model(batch.text)
preds = preds.view(-1, preds.shape[-1])


trg = batch.target.view(-1)
criterion(preds, trg)

tensor(9.9291, device='cuda:0', grad_fn=<NllLossBackward>)

In [None]:
from pytorch_lm.training import training_cycle

N_EPOCHS = 400

best_valid_loss = float('inf')

early_stopping_tolerance = 39



training_cycle(
    epochs=200,
    model=model, train_iter=train_iter, valid_iter=valid_iter, 
    optimizer=optimizer, criterion=criterion, scheduler=lr_scheduler,
    model_path=model_path, early_stopping_tolerance=7
)

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=200.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 6.04470 Perp 421.87) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.73090 Perp 308.25) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.57417 Perp 263.53) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.45870 Perp 234.79) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.35891 Perp 212.49) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.29066 Perp 198.47) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.22954 Perp 186.71) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.19511 Perp 180.39) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.17795 Perp 177.32) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.13096 Perp 169.18) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.10032 Perp 164.07) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.05902 Perp 157.44) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.02204 Perp 151.72) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.01720 Perp 150.99) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.00438 Perp 149.07) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.95341 Perp 141.66) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.92158 Perp 137.22) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.90667 Perp 135.19) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.89505 Perp 133.63) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.88546 Perp 132.35) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.87038 Perp 130.37) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.86217 Perp 129.30) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.86178 Perp 129.25) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.85390 Perp 128.24) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.84434 Perp 127.02) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.84073 Perp 126.56) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.83316 Perp 125.61) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.83241 Perp 125.51) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.82954 Perp 125.15) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.82718 Perp 124.86) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.81939 Perp 123.89) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.78076 Perp 119.19) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77826 Perp 118.90) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77692 Perp 118.74) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77524 Perp 118.54) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77523 Perp 118.54) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77429 Perp 118.43) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77409 Perp 118.40) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77340 Perp 118.32) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77244 Perp 118.21) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77206 Perp 118.16) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77187 Perp 118.14) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.77135 Perp 118.08) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.74158 Perp 114.61) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73684 Perp 114.07) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73596 Perp 113.97) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73549 Perp 113.92) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73471 Perp 113.83) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73457 Perp 113.81) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73415 Perp 113.77) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73362 Perp 113.71) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73324 Perp 113.66) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73270 Perp 113.60) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73236 Perp 113.56) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73187 Perp 113.51) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73140 Perp 113.45) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73106 Perp 113.42) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73052 Perp 113.35) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.73020 Perp 113.32) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72986 Perp 113.28) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72943 Perp 113.23) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72917 Perp 113.20) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72875 Perp 113.15) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72836 Perp 113.11) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72802 Perp 113.07) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72768 Perp 113.03) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72754 Perp 113.02) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72707 Perp 112.96) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72685 Perp 112.94) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72658 Perp 112.91) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72606 Perp 112.85) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72593 Perp 112.84) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72559 Perp 112.80) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72506 Perp 112.74) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72494 Perp 112.72) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72478 Perp 112.71) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72444 Perp 112.67) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72403 Perp 112.62) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72381 Perp 112.60) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72378 Perp 112.59) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72351 Perp 112.56) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72309 Perp 112.52) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72298 Perp 112.50) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72273 Perp 112.48) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72263 Perp 112.46) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72235 Perp 112.43) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72207 Perp 112.40) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…




HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72177 Perp 112.37) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.72150 Perp 112.34) saved at /tmp/qrnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…

In [10]:
from pytorch_lm.training import evaluate

model.load_state_dict(torch.load(model_path))
model.eval()

valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)
test_loss, test_perplexity = evaluate(model, test_iter, criterion)


print(f"Valid loss      : {valid_loss:.2f}")
print(f"Valid perplexity: {valid_perplexity:.2f}\n")

print(f"Test loss      : {test_loss:.2f}")
print(f"Test perplexity: {test_perplexity:.2f}")

Valid loss      : 4.71
Valid perplexity: 110.87

Test loss      : 4.65
Test perplexity: 104.42


We can check perplexities for other models in [this blogpost](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)

A more complex recurrent network (using a cache of hidden states) achieves a perplexity of 100. So this very basic model (without any hyperparameter optimization) seems fairly ok

## Sampling

In [15]:
import numpy as np
import torch.nn.functional as F

def sample_sentence(init_token="<eos>", temperature=1):

    seq = [TEXT.vocab.stoi[init_token]]
    hidden = None
    while len(seq) == 1 or seq[-1] != EOS_IDX:
        inp = torch.LongTensor([[seq[-1]]]).to(device)
        out, hidden = model(inp, hidden=hidden)

        """
        Sample from probabilities
        """
        probs = F.softmax(out.view(-1) / temperature, dim=0)
        next_tok_idx = torch.multinomial(probs, num_samples=1)
        
        seq.append(next_tok_idx)
        
    return [TEXT.vocab.itos[t] for t in seq]

In [16]:
for temperature in np.arange(0.5, 2.5, 0.10):
    print("="*80, f"\nSampling with temperature = {temperature:.2f}")
    
    print(" ".join(sample_sentence("the", temperature=temperature)))

Sampling with temperature = 0.50
the <unk> <unk> <unk> <unk> of the city of a number one of the south vietnamese army commander of the 3rd battalion , commanded by the following the following the following the following the following the attack on the following his second season with a " best " . <eos>
Sampling with temperature = 0.60
the <unk> club . <eos>
Sampling with temperature = 0.70
the time . <eos>
Sampling with temperature = 0.80
the shallow form of its name like a <unk> , but he will have been no music is used in a new york city of the county league cup in 1999 . he had 14 – 4 – 1 @,@ 000 @,@ 000 @,@ 000 @,@ 000 acres ( 8 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 500 @,@ 000 37 @,@ 000 @,@ 000 @,@ 000 mi ) was also operated on a time up an extratropical cyclone ( 3 @.@ 5 @.@ 2 – 1 @.@ 2 @.@ 5 @.@ 9 @.@ 6 @.@ 2 @.@ 8 @.@ 9 @.@ 2 @.@ 7 @.@ <unk> @.@ 2 @.@ 7 @.@ 9 @.@ 4 @.@ 6 @.@ 2 @.@ 2 @.@ 0 @.@ 2 @.@ 2 @.@ 0 @.@ 4 @.@ 0 @.@ 2 

We can observe that:

- with hidden states there are more "meaningful" stuff
- quotation marks are closed when using the hidden state