## RNN Language Model

Sources

[1] 
[2] https://mlexplained.com/2018/02/15/language-modeling-tutorial-in-torchtext-practical-torchtext-part-2/



In [1]:
import torch
import torchtext


We have the following datasets available for this task:

- Penn Trebank (originally created for POS tagging)
- WikiText

Before loading our dataset, define how it will be tokenized and preprocessed. To do this, `torchtext` uses `data.Field`. By default, it uses [`spaCy`](https://spacy.io/api/tokenizer) tokenization.

Also, we set an `init_token` and `eos_token` for the begin and end of sentence characters.

In [37]:
from torchtext import data

TEXT = data.Field(
    tokenizer_language='en',
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    batch_first=True,
)

Now, we can load our dataset

In [38]:
from torchtext.datasets import WikiText2
 
train, valid, test = WikiText2.splits(TEXT) 

TEXT.build_vocab(train, vectors="glove.6B.300d")

print(f"We have {len(TEXT.vocab)} tokens in our vocabulary")

We have 28914 tokens in our vocabulary


## Iterator


In [39]:
device = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
BPTT_LEN = 30

train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=BPTT_LEN, # this is where we specify the sequence length
    device=device,
    repeat=False)

In [61]:
import torch.nn as nn

class RNNLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, pad_idx, hidden_size,
                 cell_class=nn.GRU, dropout=0.20):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=PAD_IDX)
        self.rnn = cell_class(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
    def forward(self, inp):
        """
        Inputs are supposed to be just one step (i.e. one letter)
        """
        # inputs = [batch_size, ]
        emb = self.embedding(inp)
        # emb = [batch, embedding_dim]
        # As all my examples are of the same length, there is no use 
        # in packing the input to the RNN
        hidden, _ = self.rnn(emb)
        # hidden = [batch, hidden_dim]
        
        out = self.fc(self.dropout(hidden))
        # out = [batch, vocab size]

        return out

Create the Language Model

In [62]:
HIDDEN_DIM = 256
vocab_size = TEXT.vocab.vectors.shape[0]
embedding_dim = TEXT.vocab.vectors.shape[1]

PAD_IDX = TEXT.vocab.stoi["<pad>"]
UNK_IDX = TEXT.vocab.stoi["<unk>"]
EOS_IDX = TEXT.vocab.stoi["<eos>"]
SOS_IDX = TEXT.vocab.stoi["<sos>"]


model = RNNLanguageModel(vocab_size, embedding_dim, hidden_size=HIDDEN_DIM, pad_idx=PAD_IDX)

# Set weight for UNK to a random normal
model.embedding.weight.data.copy_(TEXT.vocab.vectors)
model.embedding.weight.data[UNK_IDX] = torch.randn(embedding_dim)


## Training 

In [63]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

model = model.to(device)
criterion = criterion.to(device)


## An example of calculating the loss
batch = next(iter(train_iter))

preds = model(batch.text)
preds = preds.view(-1, preds.shape[-1])


trg = batch.target.view(-1)
criterion(preds, trg)

tensor(10.2681, device='cuda:0', grad_fn=<NllLossBackward>)

In [64]:
import torch
import numpy as np

def train(model, iterator, optimizer, criterion):
    """
    Trains the model for one full epoch
    """
    epoch_loss = 0
    epoch_perplexity = 0

    model.train()

    for batch in iterator:
        optimizer.zero_grad()
        text = batch.text
        trg = batch.target.view(-1)
        
        preds = model(text)
        preds = preds.view(-1, preds.shape[-1])
        
        loss = criterion(preds, trg)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_perplexity += 2**(loss.item())
    
    return epoch_loss / len(iterator), epoch_perplexity / len(iterator)


def evaluate(model, iterator, criterion):
    """
    Evaluates the model on the given iterator
    """
    epoch_loss = .0
    model.eval()
    with torch.no_grad():
        for batch in iterator:
            text = batch.text
            trg = batch.target.view(-1)

            preds = model(text)
            preds = preds.view(-1, preds.shape[-1])
            
            loss = criterion(preds, trg)

            epoch_loss += loss.item()
            
        loss = epoch_loss / len(iterator)
        
        perplexity = np.exp(loss)

    return loss, perplexity

In [65]:
from tqdm.notebook import tqdm
import time

N_EPOCHS = 100

best_valid_loss = float('inf')

early_stopping_tolerance = 10
epochs_without_improvement = 0

model_path = "/tmp/rnn_lang_model.pt"

pbar = tqdm(range(N_EPOCHS), ncols=1000)
for epoch in pbar:
    
    epoch_bar = tqdm(train_iter)
    train_loss, train_perplexity = train(model, epoch_bar, optimizer, criterion)
    valid_loss, valid_perplexity = evaluate(model, valid_iter, criterion)

    
    desc = f' Train Loss: {train_loss:.3f} Perp: {train_perplexity:.2f}'
    desc += f' Val. Loss: {valid_loss:.3f} Perp: {valid_perplexity:.2f}'
    pbar.set_description(desc)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), model_path)
        print(f"Best model so far (Loss {best_valid_loss:.3f} Perp {valid_perplexity:.2f}) saved at {model_path}")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= early_stopping_tolerance:
            print("Early stopping")
            break

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2')), HTML(value='')), layout=Layout(display='inli…

HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))


Best model so far (Loss 5.233 Perp 187.31) saved at /tmp/rnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))


Best model so far (Loss 5.039 Perp 154.29) saved at /tmp/rnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))


Best model so far (Loss 4.988 Perp 146.63) saved at /tmp/rnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))


Best model so far (Loss 4.984 Perp 146.08) saved at /tmp/rnn_lang_model.pt


HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2176.0), HTML(value='')))

KeyboardInterrupt: 

In [67]:
model.load_state_dict(torch.load(model_path))
test_loss, test_perplexity = evaluate(model, test_iter, criterion)


print(f"Test loss      : {test_loss:.2f}")
print(f"Test perplexity: {test_perplexity:.2f}")

Test loss      : 4.92
Test perplexity: 137.64


## TODO: Add sampling with temperature