## RNN Language Model

Sources

[1] 
[2] https://mlexplained.com/2018/02/15/language-modeling-tutorial-in-torchtext-practical-torchtext-part-2/



We have the following datasets available for this task:

- Penn Trebank (originally created for POS tagging)
- WikiText

Before loading our dataset, define how it will be tokenized and preprocessed. To do this, `torchtext` uses `data.Field`. By default, it uses [`spaCy`](https://spacy.io/api/tokenizer) tokenization.

Also, we set an `init_token` and `eos_token` for the begin and end of sentence characters.

In [1]:
import torch
import torchtext
from torchtext import data

TEXT = data.Field(
    tokenizer_language='en',
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    batch_first=True,
)

Now, we can load our dataset

In [2]:
from torchtext.datasets import WikiText2
 
train, valid, test = WikiText2.splits(TEXT) 

TEXT.build_vocab(train, vectors="glove.6B.300d")

print(f"We have {len(TEXT.vocab)} tokens in our vocabulary")

We have 28914 tokens in our vocabulary


## Iterator


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
BPTT_LEN = 35

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=BPTT_LEN, # this is where we specify the sequence length
    device=device,
    repeat=False)

In [4]:
import warnings
warnings.filterwarnings('ignore')

import torch
from torchqrnn import QRNN

seq_len, batch_size, hidden_size = 7, 20, 256
size = (seq_len, batch_size, hidden_size)
X = torch.randn(size).to(device)
                
qrnn = QRNN(hidden_size, hidden_size, num_layers=2, dropout=0.4)
qrnn = qrnn.to(device)
output, hidden = qrnn(X)

print(output.size(), hidden.size())


torch.Size([7, 20, 256]) torch.Size([2, 20, 256])


In [5]:
import torch.nn as nn
from torchqrnn import QRNN


class QRNNLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, pad_idx, hidden_size,
                 cell_class=nn.GRU, dropout=0.20, zoneout=.0):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=PAD_IDX)
        
        self.qrnn = QRNN(embedding_dim, hidden_size, num_layers=2, window=2, dropout=dropout, zoneout=zoneout)
        #self.rnn = cell_class(embedding_dim, hidden_size, batch_first=True)
        
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, inp, hidden=None):
        """
        Inputs are supposed to be just one step (i.e. one letter)
        """
        
        # inputs = [batch_size, seqlen]
        emb = self.embedding(inp)
        # emb = [batch, seqlen, embedding_dim]
        emb = emb.permute(1, 0, 2)
        # emb = [seqlen, batch, embedding_dim]
        outputs, hidden = self.qrnn(emb, hidden)
        # outputs = [seqlen, batch, hidden_size]
        outputs = outputs.permute(1, 0, 2)
        # outputs = [batch, seqlen, hidden_size]
        
        # hidden = [batch, hidden_dim]
        
        out = self.fc(outputs)
        # out = [batch, vocab size]

        return out, hidden

Create the Language Model

In [6]:


PAD_IDX = TEXT.vocab.stoi["<pad>"]
UNK_IDX = TEXT.vocab.stoi["<unk>"]
EOS_IDX = TEXT.vocab.stoi["<eos>"]
SOS_IDX = TEXT.vocab.stoi["<sos>"]


## Training 

In [7]:
import torch.optim as optim

HIDDEN_DIM = 640
vocab_size = TEXT.vocab.vectors.shape[0]
embedding_dim = TEXT.vocab.vectors.shape[1]


model = QRNNLanguageModel(
    vocab_size, embedding_dim, dropout=0.1, zoneout=0.1,
    hidden_size=HIDDEN_DIM, pad_idx=PAD_IDX)

# Set weight for UNK to a random normal
model.embedding.weight.data.copy_(TEXT.vocab.vectors)
model.embedding.weight.data[UNK_IDX] = torch.randn(embedding_dim)


optimizer = optim.SGD(model.parameters(), lr=1)
criterion = nn.CrossEntropyLoss()

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, threshold=0.0005)

model = model.to(device)
criterion = criterion.to(device)


## An example of calculating the loss
batch = next(iter(train_iter))

preds, _ = model(batch.text)
preds = preds.view(-1, preds.shape[-1])


trg = batch.target.view(-1)
criterion(preds, trg)

tensor(10.2708, device='cuda:0', grad_fn=<NllLossBackward>)

In [9]:
from tqdm.notebook import tqdm
import time

N_EPOCHS = 400

best_valid_loss = float('inf')

early_stopping_tolerance = 39
epochs_without_improvement = 0

model_path = "/tmp/qrnn_lang_model.pt"

pbar = tqdm(range(N_EPOCHS), ncols=1000)
for epoch in pbar:
    print(f"Epoch {epoch+1}")
    train_loss, train_perplexity = train(model, train_iter, optimizer, criterion, clip_norm=10)
    valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)

    lr_scheduler.step(valid_loss)
    desc = f' Train Loss: {train_loss:.3f} Perp: {train_perplexity:.2f}'
    desc += f' Val. Loss: {valid_loss:.3f} Perp: {valid_perplexity:.2f}'
    pbar.set_description(desc)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), model_path)
        print(f"Best model so far (Loss {best_valid_loss:.5f} Perp {valid_perplexity:.2f}) saved at {model_path}")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= early_stopping_tolerance:
            print("Early stopping")
            break

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=400.0), HTML(value='')), layout=Layout(di…

Epoch 1


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 6.202 Perp 493.56) saved at /tmp/qrnn_lang_model.pt
Epoch 2


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.870 Perp 354.14) saved at /tmp/qrnn_lang_model.pt
Epoch 3


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.711 Perp 302.23) saved at /tmp/qrnn_lang_model.pt
Epoch 4


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.601 Perp 270.60) saved at /tmp/qrnn_lang_model.pt
Epoch 5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.485 Perp 240.93) saved at /tmp/qrnn_lang_model.pt
Epoch 6


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.411 Perp 223.76) saved at /tmp/qrnn_lang_model.pt
Epoch 7


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.352 Perp 211.06) saved at /tmp/qrnn_lang_model.pt
Epoch 8


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.302 Perp 200.69) saved at /tmp/qrnn_lang_model.pt
Epoch 9


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.268 Perp 194.05) saved at /tmp/qrnn_lang_model.pt
Epoch 10


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.237 Perp 188.02) saved at /tmp/qrnn_lang_model.pt
Epoch 11


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.196 Perp 180.64) saved at /tmp/qrnn_lang_model.pt
Epoch 12


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.178 Perp 177.26) saved at /tmp/qrnn_lang_model.pt
Epoch 13


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.148 Perp 172.06) saved at /tmp/qrnn_lang_model.pt
Epoch 14


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.128 Perp 168.63) saved at /tmp/qrnn_lang_model.pt
Epoch 15


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.104 Perp 164.70) saved at /tmp/qrnn_lang_model.pt
Epoch 16


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.082 Perp 161.08) saved at /tmp/qrnn_lang_model.pt
Epoch 17


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.068 Perp 158.85) saved at /tmp/qrnn_lang_model.pt
Epoch 18


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.052 Perp 156.27) saved at /tmp/qrnn_lang_model.pt
Epoch 19


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.031 Perp 153.09) saved at /tmp/qrnn_lang_model.pt
Epoch 20


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.030 Perp 152.92) saved at /tmp/qrnn_lang_model.pt
Epoch 21


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.012 Perp 150.15) saved at /tmp/qrnn_lang_model.pt
Epoch 22


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 5.007 Perp 149.52) saved at /tmp/qrnn_lang_model.pt
Epoch 23


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.984 Perp 146.05) saved at /tmp/qrnn_lang_model.pt
Epoch 24


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.979 Perp 145.36) saved at /tmp/qrnn_lang_model.pt
Epoch 25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 26


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 27


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.958 Perp 142.25) saved at /tmp/qrnn_lang_model.pt
Epoch 28


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.954 Perp 141.80) saved at /tmp/qrnn_lang_model.pt
Epoch 29


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.951 Perp 141.26) saved at /tmp/qrnn_lang_model.pt
Epoch 30


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.950 Perp 141.21) saved at /tmp/qrnn_lang_model.pt
Epoch 31


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.937 Perp 139.35) saved at /tmp/qrnn_lang_model.pt
Epoch 32


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 33


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.935 Perp 139.04) saved at /tmp/qrnn_lang_model.pt
Epoch 34


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 35


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 36


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.890 Perp 132.89) saved at /tmp/qrnn_lang_model.pt
Epoch 37


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.888 Perp 132.63) saved at /tmp/qrnn_lang_model.pt
Epoch 38


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.886 Perp 132.43) saved at /tmp/qrnn_lang_model.pt
Epoch 39


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.886 Perp 132.37) saved at /tmp/qrnn_lang_model.pt
Epoch 40


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.885 Perp 132.35) saved at /tmp/qrnn_lang_model.pt
Epoch 41


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.885 Perp 132.23) saved at /tmp/qrnn_lang_model.pt
Epoch 42


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.884 Perp 132.19) saved at /tmp/qrnn_lang_model.pt
Epoch 43


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.858 Perp 128.78) saved at /tmp/qrnn_lang_model.pt
Epoch 44


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.855 Perp 128.44) saved at /tmp/qrnn_lang_model.pt
Epoch 45


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.855 Perp 128.40) saved at /tmp/qrnn_lang_model.pt
Epoch 46


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.855 Perp 128.39) saved at /tmp/qrnn_lang_model.pt
Epoch 47


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.855 Perp 128.35) saved at /tmp/qrnn_lang_model.pt
Epoch 48


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.855 Perp 128.34) saved at /tmp/qrnn_lang_model.pt
Epoch 49


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.847 Perp 127.41) saved at /tmp/qrnn_lang_model.pt
Epoch 50


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.843 Perp 126.88) saved at /tmp/qrnn_lang_model.pt
Epoch 51


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.841 Perp 126.65) saved at /tmp/qrnn_lang_model.pt
Epoch 52


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.841 Perp 126.56) saved at /tmp/qrnn_lang_model.pt
Epoch 53


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.52) saved at /tmp/qrnn_lang_model.pt
Epoch 54


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.50) saved at /tmp/qrnn_lang_model.pt
Epoch 55


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.48) saved at /tmp/qrnn_lang_model.pt
Epoch 56


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.48) saved at /tmp/qrnn_lang_model.pt
Epoch 57


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.46) saved at /tmp/qrnn_lang_model.pt
Epoch 58


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.44) saved at /tmp/qrnn_lang_model.pt
Epoch 59


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.42) saved at /tmp/qrnn_lang_model.pt
Epoch 60


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.41) saved at /tmp/qrnn_lang_model.pt
Epoch 61


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.41) saved at /tmp/qrnn_lang_model.pt
Epoch 62


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.840 Perp 126.41) saved at /tmp/qrnn_lang_model.pt
Epoch 63


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.41) saved at /tmp/qrnn_lang_model.pt
Epoch 64


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 65


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 66


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 67


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 68


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 69


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 70


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 71


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 72


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 73


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 74


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 75


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 76


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 77


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 78


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 79


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 80


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 81


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 82


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 83


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 84


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 85


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 86


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 87


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Best model so far (Loss 4.839 Perp 126.40) saved at /tmp/qrnn_lang_model.pt
Epoch 88


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 89


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 90


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 91


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 92


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 93


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 94


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 95


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 96


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 97


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 98


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 99


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…


Epoch 101


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1865.0), HTML(value='')), layout=Layout(d…





KeyboardInterrupt: 

In [10]:
#model.load_state_dict(torch.load(model_path))
model.eval()

valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)
test_loss, test_perplexity = evaluate(model, test_iter, criterion)


print(f"Valid loss      : {valid_loss:.2f}")
print(f"Valid perplexity: {valid_perplexity:.2f}\n")

print(f"Test loss      : {test_loss:.2f}")
print(f"Test perplexity: {test_perplexity:.2f}")

Valid loss      : 4.84
Valid perplexity: 126.40

Test loss      : 4.78
Test perplexity: 118.56


We can check perplexities for other models in [this blogpost](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)

A more complex recurrent network (using a cache of hidden states) achieves a perplexity of 100. So this very basic model (without any hyperparameter optimization) seems fairly ok

## Sampling

In [11]:
import torch.nn.functional as F

def sample_sentence(init_token="<eos>", temperature=1):

    seq = [TEXT.vocab.stoi[init_token]]

    while len(seq) == 1 or seq[-1] != EOS_IDX:
        inp = torch.LongTensor([[seq[-1]]]).to(device)
        out, _ = model(inp)

        """
        Sample from probabilities
        """
        probs = F.softmax(out.view(-1) / temperature, dim=0)
        next_tok_idx = torch.multinomial(probs, num_samples=1)
        
        seq.append(next_tok_idx)
        
    return [TEXT.vocab.itos[t] for t in seq]

In [12]:
for temperature in np.arange(0.5, 1.5, 0.15):
    print("="*80, f"\nSampling with temperature = {temperature:.2f}")
    
    print(" ".join(sample_sentence("the", temperature=temperature)))

Sampling with temperature = 0.50
the first time . the <unk> , which are also <unk> to the four @-@ <unk> . " . <eos>
Sampling with temperature = 0.65
the <unk> , <unk> <unk> to the <unk> , the <unk> <unk> and the original song , with the final game , and the <unk> ( saxon of the <unk> ' and the <unk> <unk> began to be regarded as well . <eos>
Sampling with temperature = 0.80
the region of the next to be seen on the war , and coloring piano week . in the music . <eos>
Sampling with temperature = 0.95
the village of the indicators in straight faced a heavy heavy pens plentiful the new york homes were partially occupied the <unk> <unk> in the final . this would be vigorously desire the <unk> destroy ' items . one species are actually survived the team , refused to create for 18 states , bound false applications for fostered their nbc 's <unk> and <unk> — also performed in 2004 signing for 140 @,@ 000 marks would have a nearby , now integrated , sea and <unk> grip flaws , all god and india

As we rise temperature, we have more variety at the cost of meaningless stuff..

### Hidden State

There is a problem here! We are missing the hidden state

In [13]:
import torch.nn.functional as F

def sample_sentence(init_token="<eos>", temperature=1):

    seq = [TEXT.vocab.stoi[init_token]]
    hidden = None
    while len(seq) == 1 or seq[-1] != EOS_IDX:
        inp = torch.LongTensor([[seq[-1]]]).to(device)
        out, hidden = model(inp, hidden=hidden)

        """
        Sample from probabilities
        """
        probs = F.softmax(out.view(-1) / temperature, dim=0)
        next_tok_idx = torch.multinomial(probs, num_samples=1)
        
        seq.append(next_tok_idx)
        
    return [TEXT.vocab.itos[t] for t in seq]

In [16]:
for temperature in np.arange(0.5, 1.5, 0.10):
    print("="*80, f"\nSampling with temperature = {temperature:.2f}")
    
    print(" ".join(sample_sentence("the", temperature=temperature)))

Sampling with temperature = 0.50
the <unk> <unk> <unk> of the following the second game 2 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 @,@ 000 km ) . <eos>
Sampling with temperature = 0.60
the general treaty of the minnesota government by the city of the town of the following the following the first world war day of the war war . it was at the second world war in the first century in early on august 25 october july 14 in may be destroyed by a new york city in which was brought them in a new york city . <eos>
Sampling with temperature = 0.70
the spike of rolling stone is the mexican 

We can observe that:

- with hidden states there are more "meaningful" stuff
- quotation marks are closed when using the hidden state