In [1]:
import numpy as np

# Description
In this tutorial, we'll see how we can train a language model using the built-in datasets in torchtext.

We'll also take a look at some more practical features of torchtext that you might want to use when training your own practical models.
This tutorial assumes that you have access to a GPU for the sake of training speed. If you don't have a GPU, you can change the following variable `USE_GPU` to False. Be warned though, since the training will be very slow.

In [2]:
USE_GPU = False
BATCH_SIZE = 32

# 1. What is Language Modeling?
Language modeling is a task where we build a model that can take a sequence of words as input and determine how likely that sequence is to be actual human language. For instance, we would want our model to predict "This is a sentence" to be a likely sequence and "cold his book her" to be unlikely.

The way we generally train language models is by training them to predict the next word given all previous words in a sentence or multiple sentences. Therefore, all we need to do language modeling is a large amount of language data (called a corpus).

In this tutorial, we'll be using the famous WikiText2 dataset.

# 2. Preparing the Data

In [3]:
import torchtext
from torchtext import data

In the last tutorial we tokenized on spaces. This time, we'll use a slightly more sophisticated tokenizer: the spacy tokenizer.

[Spacy](https://spacy.io/) is a framework that handles many natural language processing tasks, and torchtext is designed to work closely with it.

Using the tokenizer is easy with torchtext: all we have to do is pass in the tokenizer function!

In [4]:
import spacy

from spacy.symbols import ORTH
my_tok = spacy.load('en')
my_tok.tokenizer.add_special_case('<eos>', [{ORTH: '<eos>'}])
my_tok.tokenizer.add_special_case('<bos>', [{ORTH: '<bos>'}])
my_tok.tokenizer.add_special_case('<unk>', [{ORTH: '<unk>'}])
def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]

`add_special_case` simply tells the tokenizer to parse a certain string in a certain way. The list after the special case string represents how we want the string to be tokenized. 

If we wanted to tokenize "don't" into "do" and "'nt", then we would write

`my_tok.tokenizer.add_special_case("don't", [{ORTH: "do"}, {ORTH: "n't"}])`

We need to initialize the text field by ourselves.

In [5]:
TEXT = data.Field(lower=True, tokenize=spacy_tok)

Now we'll load the built-in datasets.
There are two effective ways of using these datasets: one is loading as a Dataset split into the train, validation, and test sets, and the other is loading as an Iterator. The dataset offers more flexibility, so we'll use that approach here.

There is currently one built-in dataset for language modeling: the WikiText2 dataset. (I've sent a pull request for the also commonly used and slightly smaller dataset called the Penn Treebank dataset. If you install the version on [my fork](https://github.com/keitakurita/text@penn_treebank), you can use it in place and have the code run faster!)

In [6]:
from torchtext.datasets import WikiText2

In [7]:
train, valid, test = WikiText2.splits(TEXT) # loading custom datasets requires passing in the field, but nothing else.

downloading wikitext-2-v1.zip


wikitext-2-v1.zip: 100%|██████████| 4.48M/4.48M [00:02<00:00, 1.92MB/s]


extracting


Let's take a quick look inside. Remember, datasets behave largely like normal lists, so we can measure the length using the `len` function.

In [11]:
len(train)

1

Only one training example?! Did we do something wrong?

Turns out not. It's just that the entire corpus of the dataset is contained within a single example. We'll see how this example gets batched and processed later.

Now that we have our data, let's build the vocabulary. This time, let's try using precomputed word embeddings.

We'll use GloVe vectors with 200 dimensions this time. There are various other precomputed word embeddings in torchtext (including GloVe vectors with 100 and 300 dimensions) as well which can be loaded in mostly the same way.

In [12]:
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

In [13]:
TEXT.build_vocab(train, vectors="glove.6B.200d")

.vector_cache/glove.6B.zip: 862MB [07:35, 1.89MB/s]                              
100%|█████████▉| 399506/400000 [00:40<00:00, 14151.35it/s]

Now we can build our iterator. This is the climax of this tutorial!
It turns out that torchtext has a very handy iterator that does most of the heavy lifting for us. It's called the `BPTTIterator`.
The `BPTTIterator` does the following for us:
- Divide the corpus into batches of sequence length `bptt`

For instance, suppose we have the following corpus: 

*"Machine learning is a field of computer science that gives computers the ability to learn without being explicitly programmed."*

Though this sentence is short, the actual corpus is thousands of words long, so we can't possibly feed it in all at once. We'll want to divide the corpus into sequences of a shorter length. In the above example, if we wanted to divide the corpus into batches of sequence length 5, we would get the following sequences:

["*Machine*", "*learning*", "*is*", "*a*", "*field*"],

["*of*", "*computer*", "*science*", "*that*", "*gives*"],

["*computers*", "*the*", "*ability*", "*to*", "*learn*"],

["*without*", "*being*", "*explicitly*", "*programmed*", EOS]


- Generate batches that are the input sequences offset by one

In language modeling, the supervision data is the next word in a sequence of words. We, therefore, want to generate the sequences that are the input sequences offset by one. In the above example, we would get the following sequence that we train the model to predict:

["*learning*", "*is*", "*a*", "*field*", "*of*"],

["*computer*", "*science*", "*that*", "*gives*", "*computers*"],

["*the*", "*ability*", "*to*", "*learn*", "*without*"],

["*being*", "*explicitly*", "*programmed*", EOS, EOS]

In [14]:
train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=BATCH_SIZE,
    bptt_len=30, # this is where we specify the sequence length
    device=(0 if USE_GPU else -1),
    repeat=False)

The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


As always, it's a good idea to take a look into what is actually happening behind the scenes

In [15]:
b = next(iter(train_iter))

In [16]:
vars(b).keys()

dict_keys(['batch_size', 'dataset', 'fields', 'text', 'target'])

We never specified a target field, so it must have been automatically generated. Hopefully, it's the original text offset by one. Let's see...

In [17]:
b.text[:, :3]

tensor([[    9,   953,     0],
        [   10,   324,  5909],
        [    9,    11, 20014],
        [   12,  5906,    27],
        [ 3872, 10434,     2],
        [ 3892,     3, 10780],
        [  886,    11,  3273],
        [   12,  9357,     0],
        [   10,  8826, 23499],
        [    9,  1228,     4],
        [   10,     7,   569],
        [    9,     2,   235],
        [20059,  2592,  5909],
        [   90,     3,    20],
        [ 3872,   141,     2],
        [   95,     8,  1450],
        [   49,  6794,   369],
        [    0,  9046,     5],
        [ 3892,  1497,     2],
        [   24,    13,  2168],
        [  786,     4,   488],
        [   49,    26,  5967],
        [28867,    25,   656],
        [    3, 18430,    14],
        [ 6213,    58,    48],
        [    4,  4886,  4364],
        [ 3872,   217,     4],
        [    5,     5,    22],
        [    2,     2,  1936],
        [ 5050,   593,    59]])

In [18]:
b.target[:, :3]

tensor([[   10,   324,  5909],
        [    9,    11, 20014],
        [   12,  5906,    27],
        [ 3872, 10434,     2],
        [ 3892,     3, 10780],
        [  886,    11,  3273],
        [   12,  9357,     0],
        [   10,  8826, 23499],
        [    9,  1228,     4],
        [   10,     7,   569],
        [    9,     2,   235],
        [20059,  2592,  5909],
        [   90,     3,    20],
        [ 3872,   141,     2],
        [   95,     8,  1450],
        [   49,  6794,   369],
        [    0,  9046,     5],
        [ 3892,  1497,     2],
        [   24,    13,  2168],
        [  786,     4,   488],
        [   49,    26,  5967],
        [28867,    25,   656],
        [    3, 18430,    14],
        [ 6213,    58,    48],
        [    4,  4886,  4364],
        [ 3872,   217,     4],
        [    5,     5,    22],
        [    2,     2,  1936],
        [ 5050,   593,    59],
        [   95,     7,    14]])

Be careful, the first dimension of the text and target is the sequence, and the next is the batch.
We see that the target is indeed the original text offset by 1 (shifted downwards by 1). Which means we have all the we need to start training a language model!

In [20]:
import pickle
split_fname = 'split_iter.p'
split_f = open(split_fname, 'wb')
pickle.dump((train, valid, test), split_f)

TypeError: 'generator' object is not callable

# 3. Training the Language Model

With the above iterators, training the language model is easy. 

First, we need to prepare the model. We'll be borrowing and customizing the model from the [examples](https://github.com/pytorch/examples/tree/master/word_language_model) in pytorch.

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as V

In [25]:
class RNNModel(nn.Module):
    def __init__(self, ntoken, ninp,
                 nhid, nlayers, bsz,
                 dropout=0.5, tie_weights=True):
        super(RNNModel, self).__init__()
        self.nhid, self.nlayers, self.bsz = nhid, nlayers, bsz
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.init_weights()
        self.hidden = self.init_hidden(bsz) # the input is a batched consecutive corpus
                                            # therefore, we retain the hidden state across batches

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input):
        emb = self.drop(self.encoder(input))
        output, self.hidden = self.rnn(emb, self.hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1))

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        return (V(weight.new(self.nlayers, bsz, self.nhid).zero_()),
                V(weight.new(self.nlayers, bsz, self.nhid).zero_()))
#         return (V(weight.new(self.nlayers, bsz, self.nhid).zero_().cuda()),
#                 V(weight.new(self.nlayers, bsz, self.nhid).zero_()).cuda())

    
    def reset_history(self):
        """Wraps hidden states in new Variables, to detach them from their history."""
        self.hidden = tuple(V(v.data) for v in self.hidden)

We need to explicitly pass the initial weights of the embedding matrix that are initialize with the GloVe vectors

In [26]:
weight_matrix = TEXT.vocab.vectors

In [27]:
model = RNNModel(weight_matrix.size(0),
                 weight_matrix.size(1), 200, 1, BATCH_SIZE)

In [28]:
model.encoder.weight.data.copy_(weight_matrix);

In [20]:
if USE_GPU:
    model.cuda()

Now we can begin training the language model. We'll use the Adam optimizer here.

For the loss, we'll use the `nn.CrossEntropyLoss` function. This loss takes the index of the correct class as the ground truth instead of a one-hot vector. Unfortunately, it only takes tensors of dimension 2 or 4, so we'll need to do a bit of reshaping.

In [29]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.7, 0.99))

In [30]:
n_epochs = 2

In [31]:
n_tokens = weight_matrix.size(0)

In [32]:
from tqdm import tqdm

Now we can start the training loop.

In [35]:
def train_epoch(epoch):
    """One epoch of a training loop"""
    epoch_loss = 0
    for batch in tqdm(train_iter):
        # reset the hidden state or else the model will try to backpropagate to the
        # beginning of the dataset, requiring lots of time and a lot of memory
        model.reset_history()
        
        optimizer.zero_grad()
        
        text, targets = batch.text, batch.target
        prediction = model(text)
        # pytorch currently only supports cross entropy loss for inputs of 2 or 4 dimensions.
        # we therefore flatten the predictions out across the batch axis so that it becomes
        # shape (batch_size * sequence_length, n_tokens)
        # in accordance to this, we reshape the targets to be
        # shape (batch_size * sequence_length)
        loss = criterion(prediction.view(-1, n_tokens), targets.view(-1))
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item() * prediction.size(0) * prediction.size(1)

    epoch_loss /= len(train.examples[0].text)

    # monitor the loss
    val_loss = 0
    model.eval()
    for batch in valid_iter:
        model.reset_history()
        text, targets = batch.text, batch.target
        prediction = model(text)
        loss = criterion(prediction.view(-1, n_tokens), targets.view(-1))
        val_loss += loss.item() * text.size(0)
    val_loss /= len(valid.examples[0].text)
    
    print('Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, epoch_loss, val_loss))

In [None]:
for epoch in range(1, n_epochs + 1):
    train_epoch(epoch)



  0%|          | 0/2217 [00:00<?, ?it/s][A[A

  0%|          | 1/2217 [00:01<46:43,  1.27s/it][A[A

  0%|          | 2/2217 [00:01<40:06,  1.09s/it][A[A

  0%|          | 3/2217 [00:02<35:12,  1.05it/s][A[A

  0%|          | 4/2217 [00:03<31:44,  1.16it/s][A[A

  0%|          | 5/2217 [00:03<29:17,  1.26it/s][A[A

  0%|          | 6/2217 [00:04<27:31,  1.34it/s][A[A

  0%|          | 7/2217 [00:05<26:19,  1.40it/s][A[A

  0%|          | 8/2217 [00:05<25:29,  1.44it/s][A[A

  0%|          | 9/2217 [00:06<25:23,  1.45it/s][A[A

  0%|          | 10/2217 [00:07<24:48,  1.48it/s][A[A

  0%|          | 11/2217 [00:07<24:27,  1.50it/s][A[A

  1%|          | 12/2217 [00:08<24:18,  1.51it/s][A[A

  1%|          | 13/2217 [00:09<23:50,  1.54it/s][A[A

  1%|          | 14/2217 [00:09<23:36,  1.56it/s][A[A

  1%|          | 15/2217 [00:10<24:23,  1.50it/s][A[A

  1%|          | 16/2217 [00:11<24:47,  1.48it/s][A[A

  1%|          | 17/2217 [00:11<24:27,  1.50it/

  6%|▋         | 143/2217 [01:32<21:49,  1.58it/s][A[A

  6%|▋         | 144/2217 [01:33<21:46,  1.59it/s][A[A

  7%|▋         | 145/2217 [01:34<21:39,  1.59it/s][A[A

  7%|▋         | 146/2217 [01:34<21:46,  1.59it/s][A[A

  7%|▋         | 147/2217 [01:35<22:11,  1.56it/s][A[A

  7%|▋         | 148/2217 [01:36<22:01,  1.57it/s][A[A

  7%|▋         | 149/2217 [01:36<21:58,  1.57it/s][A[A

  7%|▋         | 150/2217 [01:37<21:57,  1.57it/s][A[A

  7%|▋         | 151/2217 [01:37<21:55,  1.57it/s][A[A

  7%|▋         | 152/2217 [01:38<21:55,  1.57it/s][A[A

  7%|▋         | 153/2217 [01:39<21:55,  1.57it/s][A[A

  7%|▋         | 154/2217 [01:39<21:53,  1.57it/s][A[A

  7%|▋         | 155/2217 [01:40<21:49,  1.57it/s][A[A

  7%|▋         | 156/2217 [01:41<22:19,  1.54it/s][A[A

  7%|▋         | 157/2217 [01:41<22:24,  1.53it/s][A[A

  7%|▋         | 158/2217 [01:42<22:18,  1.54it/s][A[A

  7%|▋         | 159/2217 [01:43<22:08,  1.55it/s][A[A

  7%|▋        

 13%|█▎        | 284/2217 [03:07<21:53,  1.47it/s][A[A

 13%|█▎        | 285/2217 [03:08<21:48,  1.48it/s][A[A

 13%|█▎        | 286/2217 [03:08<21:47,  1.48it/s][A[A

 13%|█▎        | 287/2217 [03:09<21:43,  1.48it/s][A[A

 13%|█▎        | 288/2217 [03:10<21:38,  1.49it/s][A[A

 13%|█▎        | 289/2217 [03:10<21:42,  1.48it/s][A[A

 13%|█▎        | 290/2217 [03:11<21:41,  1.48it/s][A[A

 13%|█▎        | 291/2217 [03:12<21:32,  1.49it/s][A[A

 13%|█▎        | 292/2217 [03:12<21:37,  1.48it/s][A[A

 13%|█▎        | 293/2217 [03:13<21:41,  1.48it/s][A[A

 13%|█▎        | 294/2217 [03:14<21:34,  1.48it/s][A[A

 13%|█▎        | 295/2217 [03:14<21:27,  1.49it/s][A[A

 13%|█▎        | 296/2217 [03:15<21:37,  1.48it/s][A[A

 13%|█▎        | 297/2217 [03:16<21:57,  1.46it/s][A[A

 13%|█▎        | 298/2217 [03:17<22:17,  1.43it/s][A[A

 13%|█▎        | 299/2217 [03:17<22:46,  1.40it/s][A[A

 14%|█▎        | 300/2217 [03:18<22:45,  1.40it/s][A[A

 14%|█▎       

 19%|█▉        | 425/2217 [04:48<20:50,  1.43it/s][A[A

 19%|█▉        | 426/2217 [04:49<20:47,  1.44it/s][A[A

 19%|█▉        | 427/2217 [04:49<20:43,  1.44it/s][A[A

 19%|█▉        | 428/2217 [04:50<20:39,  1.44it/s][A[A

 19%|█▉        | 429/2217 [04:51<21:00,  1.42it/s][A[A

 19%|█▉        | 430/2217 [04:51<21:43,  1.37it/s][A[A

 19%|█▉        | 431/2217 [04:52<21:30,  1.38it/s][A[A

 19%|█▉        | 432/2217 [04:53<21:07,  1.41it/s][A[A

 20%|█▉        | 433/2217 [04:54<21:01,  1.41it/s][A[A

 20%|█▉        | 434/2217 [04:54<21:03,  1.41it/s][A[A

 20%|█▉        | 435/2217 [04:55<20:54,  1.42it/s][A[A

 20%|█▉        | 436/2217 [04:56<20:47,  1.43it/s][A[A

 20%|█▉        | 437/2217 [04:56<20:38,  1.44it/s][A[A

 20%|█▉        | 438/2217 [04:57<20:34,  1.44it/s][A[A

 20%|█▉        | 439/2217 [04:58<20:34,  1.44it/s][A[A

 20%|█▉        | 440/2217 [04:58<20:33,  1.44it/s][A[A

 20%|█▉        | 441/2217 [04:59<20:29,  1.45it/s][A[A

 20%|█▉       

 26%|██▌       | 566/2217 [06:33<20:23,  1.35it/s][A[A

 26%|██▌       | 567/2217 [06:34<20:48,  1.32it/s][A[A

 26%|██▌       | 568/2217 [06:35<21:09,  1.30it/s][A[A

 26%|██▌       | 569/2217 [06:35<21:23,  1.28it/s][A[A

 26%|██▌       | 570/2217 [06:36<21:14,  1.29it/s][A[A

 26%|██▌       | 571/2217 [06:37<21:35,  1.27it/s][A[A

 26%|██▌       | 572/2217 [06:38<22:31,  1.22it/s][A[A

 26%|██▌       | 573/2217 [06:39<22:03,  1.24it/s][A[A

 26%|██▌       | 574/2217 [06:39<22:00,  1.24it/s][A[A

 26%|██▌       | 575/2217 [06:40<22:07,  1.24it/s][A[A

 26%|██▌       | 576/2217 [06:41<21:57,  1.25it/s][A[A

 26%|██▌       | 577/2217 [06:42<22:04,  1.24it/s][A[A

 26%|██▌       | 578/2217 [06:43<21:39,  1.26it/s][A[A

 26%|██▌       | 579/2217 [06:43<21:31,  1.27it/s][A[A

 26%|██▌       | 580/2217 [06:44<21:03,  1.30it/s][A[A

 26%|██▌       | 581/2217 [06:45<21:18,  1.28it/s][A[A

 26%|██▋       | 582/2217 [06:46<21:28,  1.27it/s][A[A

 26%|██▋      

 32%|███▏      | 707/2217 [08:22<18:40,  1.35it/s][A[A

 32%|███▏      | 708/2217 [08:23<18:32,  1.36it/s][A[A

 32%|███▏      | 709/2217 [08:24<18:24,  1.37it/s][A[A

 32%|███▏      | 710/2217 [08:25<18:33,  1.35it/s][A[A

 32%|███▏      | 711/2217 [08:25<18:31,  1.35it/s][A[A

 32%|███▏      | 712/2217 [08:26<18:15,  1.37it/s][A[A

 32%|███▏      | 713/2217 [08:27<18:03,  1.39it/s][A[A

 32%|███▏      | 714/2217 [08:27<18:35,  1.35it/s][A[A

 32%|███▏      | 715/2217 [08:28<18:36,  1.35it/s][A[A

 32%|███▏      | 716/2217 [08:29<18:42,  1.34it/s][A[A

 32%|███▏      | 717/2217 [08:30<18:32,  1.35it/s][A[A

 32%|███▏      | 718/2217 [08:30<18:41,  1.34it/s][A[A

 32%|███▏      | 719/2217 [08:31<18:44,  1.33it/s][A[A

 32%|███▏      | 720/2217 [08:32<18:46,  1.33it/s][A[A

 33%|███▎      | 721/2217 [08:33<18:51,  1.32it/s][A[A

 33%|███▎      | 722/2217 [08:34<18:59,  1.31it/s][A[A

 33%|███▎      | 723/2217 [08:34<18:56,  1.31it/s][A[A

 33%|███▎     

 38%|███▊      | 848/2217 [10:06<16:56,  1.35it/s][A[A

 38%|███▊      | 849/2217 [10:07<17:05,  1.33it/s][A[A

 38%|███▊      | 850/2217 [10:07<17:38,  1.29it/s][A[A

 38%|███▊      | 851/2217 [10:08<17:13,  1.32it/s][A[A

 38%|███▊      | 852/2217 [10:09<17:00,  1.34it/s][A[A

 38%|███▊      | 853/2217 [10:10<16:56,  1.34it/s][A[A

 39%|███▊      | 854/2217 [10:10<16:46,  1.35it/s][A[A

 39%|███▊      | 855/2217 [10:11<16:43,  1.36it/s][A[A

 39%|███▊      | 856/2217 [10:12<16:41,  1.36it/s][A[A

 39%|███▊      | 857/2217 [10:13<16:40,  1.36it/s][A[A

 39%|███▊      | 858/2217 [10:13<16:39,  1.36it/s][A[A

 39%|███▊      | 859/2217 [10:14<16:31,  1.37it/s][A[A

 39%|███▉      | 860/2217 [10:15<16:29,  1.37it/s][A[A

 39%|███▉      | 861/2217 [10:15<16:33,  1.36it/s][A[A

 39%|███▉      | 862/2217 [10:16<16:39,  1.36it/s][A[A

 39%|███▉      | 863/2217 [10:17<16:40,  1.35it/s][A[A

 39%|███▉      | 864/2217 [10:18<17:19,  1.30it/s][A[A

 39%|███▉     

 45%|████▍     | 989/2217 [11:56<16:10,  1.27it/s][A[A

 45%|████▍     | 990/2217 [11:57<16:12,  1.26it/s][A[A

 45%|████▍     | 991/2217 [11:57<16:10,  1.26it/s][A[A

 45%|████▍     | 992/2217 [11:58<15:59,  1.28it/s][A[A

 45%|████▍     | 993/2217 [11:59<16:00,  1.27it/s][A[A

 45%|████▍     | 994/2217 [12:00<16:08,  1.26it/s][A[A

 45%|████▍     | 995/2217 [12:01<16:07,  1.26it/s][A[A

 45%|████▍     | 996/2217 [12:01<15:52,  1.28it/s][A[A

 45%|████▍     | 997/2217 [12:02<15:39,  1.30it/s][A[A

 45%|████▌     | 998/2217 [12:03<15:30,  1.31it/s][A[A

 45%|████▌     | 999/2217 [12:04<15:28,  1.31it/s][A[A

 45%|████▌     | 1000/2217 [12:04<15:31,  1.31it/s][A[A

 45%|████▌     | 1001/2217 [12:05<15:28,  1.31it/s][A[A

 45%|████▌     | 1002/2217 [12:06<15:28,  1.31it/s][A[A

 45%|████▌     | 1003/2217 [12:07<15:21,  1.32it/s][A[A

 45%|████▌     | 1004/2217 [12:07<15:30,  1.30it/s][A[A

 45%|████▌     | 1005/2217 [12:08<15:25,  1.31it/s][A[A

 45%|███

 51%|█████     | 1128/2217 [13:43<13:39,  1.33it/s][A[A

 51%|█████     | 1129/2217 [13:44<13:35,  1.33it/s][A[A

 51%|█████     | 1130/2217 [13:45<13:36,  1.33it/s][A[A

 51%|█████     | 1131/2217 [13:46<13:31,  1.34it/s][A[A

 51%|█████     | 1132/2217 [13:46<13:33,  1.33it/s][A[A

 51%|█████     | 1133/2217 [13:47<13:31,  1.34it/s][A[A

 51%|█████     | 1134/2217 [13:48<13:28,  1.34it/s][A[A

 51%|█████     | 1135/2217 [13:49<13:28,  1.34it/s][A[A

 51%|█████     | 1136/2217 [13:49<13:35,  1.33it/s][A[A

 51%|█████▏    | 1137/2217 [13:50<13:32,  1.33it/s][A[A

 51%|█████▏    | 1138/2217 [13:51<13:28,  1.33it/s][A[A

 51%|█████▏    | 1139/2217 [13:52<13:30,  1.33it/s][A[A

 51%|█████▏    | 1140/2217 [13:52<13:28,  1.33it/s][A[A

 51%|█████▏    | 1141/2217 [13:53<13:26,  1.33it/s][A[A

 52%|█████▏    | 1142/2217 [13:54<13:22,  1.34it/s][A[A

 52%|█████▏    | 1143/2217 [13:55<13:22,  1.34it/s][A[A

 52%|█████▏    | 1144/2217 [13:55<13:22,  1.34it/s][A[

 57%|█████▋    | 1266/2217 [15:32<12:12,  1.30it/s][A[A

 57%|█████▋    | 1267/2217 [15:33<12:10,  1.30it/s][A[A

 57%|█████▋    | 1268/2217 [15:34<12:09,  1.30it/s][A[A

 57%|█████▋    | 1269/2217 [15:35<12:11,  1.30it/s][A[A

 57%|█████▋    | 1270/2217 [15:35<12:17,  1.28it/s][A[A

 57%|█████▋    | 1271/2217 [15:36<12:21,  1.28it/s][A[A

 57%|█████▋    | 1272/2217 [15:37<12:23,  1.27it/s][A[A

 57%|█████▋    | 1273/2217 [15:38<12:27,  1.26it/s][A[A

 57%|█████▋    | 1274/2217 [15:39<12:41,  1.24it/s][A[A

 58%|█████▊    | 1275/2217 [15:39<12:34,  1.25it/s][A[A

 58%|█████▊    | 1276/2217 [15:40<12:29,  1.25it/s][A[A

 58%|█████▊    | 1277/2217 [15:41<12:26,  1.26it/s][A[A

 58%|█████▊    | 1278/2217 [15:42<12:34,  1.24it/s][A[A

 58%|█████▊    | 1279/2217 [15:43<12:29,  1.25it/s][A[A

 58%|█████▊    | 1280/2217 [15:43<12:27,  1.25it/s][A[A

 58%|█████▊    | 1281/2217 [15:44<12:27,  1.25it/s][A[A

 58%|█████▊    | 1282/2217 [15:45<12:37,  1.23it/s][A[

 63%|██████▎   | 1404/2217 [17:27<11:11,  1.21it/s][A[A

 63%|██████▎   | 1405/2217 [17:28<11:28,  1.18it/s][A[A

 63%|██████▎   | 1406/2217 [17:29<11:20,  1.19it/s][A[A

 63%|██████▎   | 1407/2217 [17:30<11:40,  1.16it/s][A[A

 64%|██████▎   | 1408/2217 [17:31<11:34,  1.16it/s][A[A

 64%|██████▎   | 1409/2217 [17:32<11:24,  1.18it/s][A[A

 64%|██████▎   | 1410/2217 [17:32<11:07,  1.21it/s][A[A

 64%|██████▎   | 1411/2217 [17:33<10:55,  1.23it/s][A[A

 64%|██████▎   | 1412/2217 [17:34<10:47,  1.24it/s][A[A

 64%|██████▎   | 1413/2217 [17:35<10:42,  1.25it/s][A[A

 64%|██████▍   | 1414/2217 [17:35<10:35,  1.26it/s][A[A

 64%|██████▍   | 1415/2217 [17:36<10:28,  1.28it/s][A[A

 64%|██████▍   | 1416/2217 [17:37<10:28,  1.27it/s][A[A

 64%|██████▍   | 1417/2217 [17:38<10:29,  1.27it/s][A[A

 64%|██████▍   | 1418/2217 [17:39<10:31,  1.27it/s][A[A

 64%|██████▍   | 1419/2217 [17:39<10:27,  1.27it/s][A[A

 64%|██████▍   | 1420/2217 [17:40<10:22,  1.28it/s][A[

 70%|██████▉   | 1542/2217 [19:17<09:10,  1.23it/s][A[A

 70%|██████▉   | 1543/2217 [19:18<09:15,  1.21it/s][A[A

 70%|██████▉   | 1544/2217 [19:19<09:16,  1.21it/s][A[A

 70%|██████▉   | 1545/2217 [19:20<09:07,  1.23it/s][A[A

 70%|██████▉   | 1546/2217 [19:21<09:04,  1.23it/s][A[A

 70%|██████▉   | 1547/2217 [19:21<08:59,  1.24it/s][A[A

 70%|██████▉   | 1548/2217 [19:22<08:58,  1.24it/s][A[A

 70%|██████▉   | 1549/2217 [19:23<08:57,  1.24it/s][A[A

 70%|██████▉   | 1550/2217 [19:24<08:54,  1.25it/s][A[A

 70%|██████▉   | 1551/2217 [19:25<08:59,  1.23it/s][A[A

 70%|███████   | 1552/2217 [19:25<09:01,  1.23it/s][A[A

 70%|███████   | 1553/2217 [19:26<08:57,  1.24it/s][A[A

 70%|███████   | 1554/2217 [19:27<08:54,  1.24it/s][A[A

 70%|███████   | 1555/2217 [19:28<08:47,  1.26it/s][A[A

 70%|███████   | 1556/2217 [19:29<08:46,  1.26it/s][A[A

 70%|███████   | 1557/2217 [19:30<09:04,  1.21it/s][A[A

 70%|███████   | 1558/2217 [19:30<09:18,  1.18it/s][A[

 76%|███████▌  | 1680/2217 [21:09<07:17,  1.23it/s][A[A

 76%|███████▌  | 1681/2217 [21:10<07:18,  1.22it/s][A[A

 76%|███████▌  | 1682/2217 [21:10<07:17,  1.22it/s][A[A

 76%|███████▌  | 1683/2217 [21:11<07:14,  1.23it/s][A[A

 76%|███████▌  | 1684/2217 [21:12<07:14,  1.23it/s][A[A

 76%|███████▌  | 1685/2217 [21:13<07:08,  1.24it/s][A[A

 76%|███████▌  | 1686/2217 [21:14<07:06,  1.25it/s][A[A

 76%|███████▌  | 1687/2217 [21:14<07:10,  1.23it/s][A[A

 76%|███████▌  | 1688/2217 [21:15<07:08,  1.23it/s][A[A

 76%|███████▌  | 1689/2217 [21:16<07:03,  1.25it/s][A[A

 76%|███████▌  | 1690/2217 [21:17<07:00,  1.25it/s][A[A

 76%|███████▋  | 1691/2217 [21:18<06:58,  1.26it/s][A[A

 76%|███████▋  | 1692/2217 [21:18<06:57,  1.26it/s][A[A

 76%|███████▋  | 1693/2217 [21:19<06:54,  1.26it/s][A[A

 76%|███████▋  | 1694/2217 [21:20<06:52,  1.27it/s][A[A

 76%|███████▋  | 1695/2217 [21:21<07:00,  1.24it/s][A[A

 76%|███████▋  | 1696/2217 [21:22<07:11,  1.21it/s][A[

 82%|████████▏ | 1818/2217 [23:01<05:21,  1.24it/s][A[A

Let's examine the output at 2 epochs

In [27]:
b = next(iter(valid_iter))

In [28]:
def word_ids_to_sentence(id_tensor, vocab, join=None):
    """Converts a sequence of word ids to a sentence"""
    if isinstance(id_tensor, torch.LongTensor):
        ids = id_tensor.transpose(0, 1).contiguous().view(-1)
    elif isinstance(id_tensor, np.ndarray):
        ids = id_tensor.transpose().reshape(-1)

    batch = [vocab.itos[ind] for ind in ids]  # denumericalize
    if join is None:
        return batch
    else:
        return join.join(batch)

In [29]:
word_ids_to_sentence(b.text.cpu().data, TEXT.vocab, join=' ')[:210]

'  <eos>   = homarus gammarus = <eos>   <eos>   homarus gammarus , known as the european lobster or common lobster , is a species of <unk> lobster from . <unk> ceo hiroshi <unk> referred to <unk> as one of his f'

In [30]:
arrs = model(b.text).cpu().data.numpy()

In [31]:
word_ids_to_sentence(np.argmax(arrs, axis=2), TEXT.vocab, join=' ')[:210]

'<unk>   <eos> = = ( <eos>   <eos>   = = ( <unk> as the <unk> @-@ ( <unk> species , <unk> a <unk> of the <unk> ( the <eos> was <unk> <unk> <unk> to the the a of the first " , the , <eos>   <eos> reviewers were t'

Hmm.. doesn't seem to be making much sense yet.
Let's train for another 2 epochs and see how the results change

In [32]:
for epoch in range(n_epochs + 1, n_epochs * 4 + 1):
    train_epoch(epoch)

100%|██████████| 2217/2217 [01:59<00:00, 18.56it/s]
  0%|          | 0/2217 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 4.9020, Validation Loss: 0.1568


100%|██████████| 2217/2217 [01:59<00:00, 18.61it/s]


Epoch: 4, Training Loss: 4.6959, Validation Loss: 0.1549


In [33]:
arrs = model(b.text).cpu().data.numpy()
word_ids_to_sentence(np.argmax(arrs, axis=2), TEXT.vocab, join=' ')[:210]

'<unk>   <eos> = = ( <eos>   <eos>   <eos> ( ( is as the <unk> union <unk> <unk> starling <unk> <unk> the <unk> of the <unk> , the <eos> , <unk> <unk> , to the the a of the " " , the , <eos>   <eos> reviewers ha'

Is this getting better? The loss is certainly getting better.
This just goes to show how difficult it is to match a loss value with the quality of the predictions in language modeling.