## AWD LSTM Language Model

Sources

[1] 
[2] https://mlexplained.com/2018/02/15/language-modeling-tutorial-in-torchtext-practical-torchtext-part-2/



We have the following datasets available for this task:

- Penn Trebank (originally created for POS tagging)
- WikiText

Before loading our dataset, define how it will be tokenized and preprocessed. To do this, `torchtext` uses `data.Field`. By default, it uses [`spaCy`](https://spacy.io/api/tokenizer) tokenization.

Also, we set an `init_token` and `eos_token` for the begin and end of sentence characters.

In [1]:
import torch
import torchtext
from torchtext import data

TEXT = data.Field(
    tokenizer_language='en',
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    batch_first=True,
)

Now, we can load our dataset

In [2]:
torch.__version__

'1.4.0'

In [3]:
torch.cuda.is_available()

True

In [4]:
from torchtext.datasets import WikiText2
 
device = "cuda" if torch.cuda.is_available() else "cpu"

train, valid, test = WikiText2.splits(TEXT) 

TEXT.build_vocab(train)

print(f"We have {len(TEXT.vocab)} tokens in our vocabulary")

We have 28914 tokens in our vocabulary


## Iterator


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
BPTT_LEN = 30

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=BPTT_LEN, # this is where we specify the sequence length
    device=device,
    repeat=False)

In [6]:
from fastai.text.models import AWD_LSTM, LinearDecoder
import torch.nn as nn

class AWDLanguageModel(nn.Sequential):
    def __init__(self, vocab_size, embedding_dim, pad_idx, hidden_size, dropout=0.20):
        encoder = AWD_LSTM(
            vocab_sz=vocab_size, emb_sz = embedding_dim, n_hid=hidden, n_layers=n_layers)

        decoder = LinearDecoder(n_out=vocab_size, n_hid=embedding_dim, output_p=out_dropout)
        super().__init__(encoder, decoder)


Create the Language Model

In [7]:


PAD_IDX = TEXT.vocab.stoi["<pad>"]
UNK_IDX = TEXT.vocab.stoi["<unk>"]
EOS_IDX = TEXT.vocab.stoi["<eos>"]
SOS_IDX = TEXT.vocab.stoi["<sos>"]


## Training 

In [8]:
import torch.optim as optim
from fastai.text.models import LinearDecoder

hidden = 1150
vocab_size = len(TEXT.vocab.stoi)
embedding_dim = 400
out_dropout = 0.1

n_layers = 3

model = AWDLanguageModel(
    vocab_size=vocab_size, embedding_dim=embedding_dim, 
    pad_idx=PAD_IDX, hidden_size=hidden)


optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

model = model.to(device)
criterion = criterion.to(device)

In [9]:
## An example of calculating the loss
batch = next(iter(train_iter))


preds, _, _= model(batch.text)

In [10]:

preds = preds.view(-1, preds.shape[-1])


trg = batch.target.view(-1)
criterion(preds, trg)

tensor(10.2758, device='cuda:0', grad_fn=<NllLossBackward>)

In [11]:
device

'cuda'

In [13]:
from pytorch_lm.training import training_cycle

N_EPOCHS = 10

model_path = "/tmp/awd_lstm_lang_model.pt"

training_cycle(
    epochs=20,
    model=model, train_iter=train_iter, valid_iter=valid_iter, 
    optimizer=optimizer, criterion=criterion, scheduler=lr_scheduler,
    model_path=model_path, early_stopping_tolerance=3
)


Epoch 1


HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))


 Train Loss: 6.95985 Perp: 1059.537 Val. Loss: 6.78227 Perp: 882.068
Best model so far (Loss 6.78227 Perp 882.07) saved at /tmp/awd_lstm_lang_model.pt
Epoch 2


HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




KeyboardInterrupt: 

In [None]:
from pytorch_lm.training import evaluate
model.load_state_dict(torch.load(model_path))
model.eval()

valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)
test_loss, test_perplexity = evaluate(model, test_iter, criterion)


print(f"Valid loss      : {valid_loss:.2f}")
print(f"Valid perplexity: {valid_perplexity:.2f}\n")

print(f"Test loss      : {test_loss:.2f}")
print(f"Test perplexity: {test_perplexity:.2f}")

In [None]:
from pytorch_lm.saving import save_model, load_model

save_model(model, TEXT, "../models/rnn.pt")

We can check perplexities for other models in [this blogpost](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)

A more complex recurrent network (using a cache of hidden states) achieves a perplexity of 100. So this very basic model (without any hyperparameter optimization) seems fairly ok

In [9]:
from pytorch_lm import load_model
from torchtext.datasets import WikiText2
from pytorch_lm.training import evaluate
import torch.nn as nn
 
device = "cuda" if torch.cuda.is_available() else "cpu"

model, TEXT = load_model("../models/rnn.pt", device)


train, valid, test = WikiText2.splits(TEXT) 


BATCH_SIZE = 32
BPTT_LEN = 30

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=BPTT_LEN, # this is where we specify the sequence length
    device=device,
    repeat=False)

criterion = nn.CrossEntropyLoss()

model.eval()

valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)
test_loss, test_perplexity = evaluate(model, test_iter, criterion)


print(f"Valid loss      : {valid_loss:.2f}")
print(f"Valid perplexity: {valid_perplexity:.2f}\n")

print(f"Test loss      : {test_loss:.2f}")
print(f"Test perplexity: {test_perplexity:.2f}")

Valid loss      : 4.86
Valid perplexity: 128.92

Test loss      : 4.79
Test perplexity: 120.56


## Sampling

In [16]:
import torch.nn.functional as F
import numpy as np

def sample_sentence(init_token="<eos>", temperature=1):

    seq = [TEXT.vocab.stoi[init_token]]

    while len(seq) == 1 or seq[-1] != EOS_IDX:
        inp = torch.LongTensor([[seq[-1]]]).to(device)
        out, _ = model(inp)

        """
        Sample from probabilities
        """
        probs = F.softmax(out.view(-1) / temperature, dim=0)
        next_tok_idx = torch.multinomial(probs, num_samples=1)
        
        seq.append(next_tok_idx)
        
    return [TEXT.vocab.itos[t] for t in seq]

In [17]:
for temperature in np.arange(0.5, 1.5, 0.15):
    print("="*80, f"\nSampling with temperature = {temperature:.2f}")
    
    print(" ".join(sample_sentence("the", temperature=temperature)))

Sampling with temperature = 0.50
the <unk> , with the <unk> , the <unk> <unk> , and <unk> , <unk> and <unk> <unk> of the start of the <unk> , and the <unk> to the <unk> <unk> <unk> , it was <unk> of the majority of the <unk> <unk> of the <unk> , and the <unk> , a few years after the 766th regiment . the <unk> <unk> . <eos>
Sampling with temperature = 0.65
the city , the <unk> , and <unk> . <eos>
Sampling with temperature = 0.80
the alien 3ds version of the top of the <unk> @-@ drawn up costumes of the first to give a large combatants , only a less than one of the kids , although three days . on it was raised in an independent , the amount of the sixth and the a 1 @,@ 000 , with " <unk> then re @-@ sensitive north korean leaders in the jin song " <unk> <unk> and the majority soldiers to be used to the <unk> . " was true territory . <eos>
Sampling with temperature = 0.95
the taking time , with the waving a very shape dark blue wolf ( <unk> freely reduced greater manchester , also flat by

As we rise temperature, we have more variety at the cost of meaningless stuff..

### Hidden State

There is a problem here! We are missing the hidden state

In [18]:
import torch.nn.functional as F

def sample_sentence(init_token="<eos>", temperature=1):

    seq = [TEXT.vocab.stoi[init_token]]
    hidden = None
    while len(seq) == 1 or seq[-1] != EOS_IDX:
        inp = torch.LongTensor([[seq[-1]]]).to(device)
        out, hidden = model(inp, hidden=hidden)

        """
        Sample from probabilities
        """
        probs = F.softmax(out.view(-1) / temperature, dim=0)
        next_tok_idx = torch.multinomial(probs, num_samples=1)
        
        seq.append(next_tok_idx)
        
    return [TEXT.vocab.itos[t] for t in seq]

In [19]:
for temperature in np.arange(0.5, 1.5, 0.15):
    print("="*80, f"\nSampling with temperature = {temperature:.2f}")
    
    print(" ".join(sample_sentence("the", temperature=temperature)))

Sampling with temperature = 0.50
the <unk> and <unk> <unk> , the <unk> of the central <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> <unk> , <unk> <unk> , <unk> <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> <unk> , <unk> and <unk> <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> , <unk> 

We can observe that:

- with hidden states there are more "meaningful" stuff
- quotation marks are closed when using the hidden state