# Training a language model with RNNs

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import random
import os

## Data
The data for this notebook can be downloaded from here
`https://github.com/pytorch/examples/tree/master/word_language_model/data/wikitext-2`

In [12]:
PATH=Path("/data/yinterian/wikitext-2")
list(PATH.iterdir())

[PosixPath('/data/yinterian/wikitext-2/valid.txt'),
 PosixPath('/data/yinterian/wikitext-2/train.txt'),
 PosixPath('/data/yinterian/wikitext-2/test.txt')]

In [13]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = torch.LongTensor(tokens)
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1
        return ids

In [14]:
corpus = Corpus(PATH)
corpus.test[100:110]


 44
 45
 46
 47
 48
 49
 45
 50
 51
 42
[torch.LongTensor of size 10]

In [15]:
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.cuda()

In [16]:
batch_size = 20
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, batch_size)
test_data = batchify(corpus.test, batch_size)

In [17]:
test_data


    0     0    69  ...    973     0  1058
    0     8    63  ...    204   281     0
    0    85   149  ...    974   981   219
       ...          ⋱          ...       
   84   147   241  ...     45  1029  1105
   21    63   242  ...   1009  1056     0
    0   148   243  ...   1010  1057     0
[torch.cuda.LongTensor of size 160x20 (GPU 0)]

## Model
Based on the model [here](https://github.com/pytorch/examples/tree/master/word_language_model)

In [25]:
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.init_weights()
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        #self.decoder.bias.data.zero_()
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())

## Training

In [19]:
criterion = nn.CrossEntropyLoss()

In [47]:
def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, bptt):
        data, targets = get_batch(data_source, i, bptt, evaluation=True)
        output, hidden = model(data, hidden)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).data
        hidden = Variable(hidden.data) #.detach()
    return total_loss[0] / len(data_source)

In [39]:
def repackage_hidden(h):
    """Wraps hidden states in new Variables, to detach them from their history."""
    return Variable(h.data)

In [49]:
import time
import math
bptt = 35
clip = 0.25
log_interval = 200 

def get_batch(source, i, bptt, evaluation=False):
    seq_len = min(bptt, len(source) - 1 - i)
    data = Variable(source[i:i+seq_len], volatile=evaluation)
    target = Variable(source[i+1:i+1+seq_len].view(-1))
    return data, target
    
def train(model, lr):
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i, bptt)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = Variable(hidden.data) #.detach()
        model.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.data

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss[0] / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // bptt, lr,
                elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

In [51]:
lr = 20.0
best_val_loss = None
epochs = 40
nemb = 200
nhid = 200
nlayers = 2
ntokens = len(corpus.dictionary)
model = RNNModel(ntokens, nemb, nhid, nlayers).cuda()

for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    train(model, lr)
    val_loss = evaluate(val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
        val_loss, math.exp(val_loss)))
    print('-' * 89)
    # Save the model if the validation loss is the best we've seen so far.
    if not best_val_loss or val_loss < best_val_loss:
        #with open(args.save, 'wb') as f:
        #torch.save(model, f)
        print("best val loss", best_val_loss)
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0

-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  0.06s | valid loss  8.21 | valid ppl  3679.27
-----------------------------------------------------------------------------------------
best val loss None
-----------------------------------------------------------------------------------------
| end of epoch   2 | time:  0.06s | valid loss  8.88 | valid ppl  7210.23
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   3 | time:  0.06s | valid loss  5.69 | valid ppl   296.80
-----------------------------------------------------------------------------------------
best val loss 8.210468292236328
-----------------------------------------------------------------------------------------
| end of epoch   4 | time:  0.06s | valid loss  5.59 | valid ppl   268.96
-----------------------

-----------------------------------------------------------------------------------------
| end of epoch  30 | time:  0.05s | valid loss  4.07 | valid ppl    58.71
-----------------------------------------------------------------------------------------
best val loss 4.1049541473388675
-----------------------------------------------------------------------------------------
| end of epoch  31 | time:  0.04s | valid loss  4.04 | valid ppl    56.90
-----------------------------------------------------------------------------------------
best val loss 4.072663879394531
-----------------------------------------------------------------------------------------
| end of epoch  32 | time:  0.04s | valid loss  4.00 | valid ppl    54.86
-----------------------------------------------------------------------------------------
best val loss 4.041250228881836
-----------------------------------------------------------------------------------------
| end of epoch  33 | time:  0.04s | valid loss  3.9

## References
* https://github.com/pytorch/examples/tree/master/word_language_model
* http://colah.github.io/posts/2015-08-Understanding-LSTMs/