In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
# import argparse
import time
import math
import os
import torch
import torch.nn as nn
import torch.onnx

from utils import data
from models.RNNLM import RNNModel

In [2]:
device = torch.device("cpu")

In [3]:
corpus = data.Corpus('./data/wikitext')

In [4]:
print(corpus.valid.shape)
print(corpus.valid.size())

torch.Size([217646])
torch.Size([217646])


In [5]:
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

In [6]:
batch_size = 6
eval_batch_size = 10
bptt = 30
lr = 0.1
clip = 0.25
log_interval = 10
epochs = 1

In [7]:
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

In [8]:
def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [9]:
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

In [10]:
a = get_batch(train_data,1)

In [13]:
print(a[0].size())

torch.Size([30, 6])


In [14]:
print(a[0])

tensor([[    1,    17,  1007,   676, 60633,  4592],
        [    2,  1640,    39,   527,   119,    15],
        [    3, 12373,    55,   135,    17,  2059],
        [    4, 28130,   147, 21079,   279,    43],
        [    1,  1655,  4775,    62, 19760,    17],
        [    0, 28294,    43,    17,    81,   932],
        [    0,   186,  4054,  7716,   117,    13],
        [    5,    16,    22,    62, 60621,    17],
        [    6,    17,    17,    39,    15,   928],
        [    2, 28295,  1188,    27,   179,   128],
        [    7, 28296, 11616,  9084,   527,  5540],
        [    8,    13,    48,  4338,  2484, 17175],
        [    9,    27,    13,    15,    22,  6921],
        [    3,  3103,  6087,     0,  2436,  2419],
        [   10,    16,   253,     0, 60628,    17],
        [   11, 28297,   183,     1,   119,  2307],
        [    8, 20755,    22,     1,    17, 11191],
        [   12,    10,  1345,     1,   279,  2348],
        [   13,   101, 33091, 51495, 19760,    22],
        [   

In [13]:
def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model(data, hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
            hidden = repackage_hidden(hidden)
    return total_loss / (len(data_source) - 1)

In [14]:
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // bptt, lr,
                elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
    start_time = time.time()

In [21]:
ntokens = len(corpus.dictionary)
model = RNNModel(ntokens).to(device)

criterion = nn.CrossEntropyLoss()

best_val_loss = None

  "num_layers={}".format(dropout, num_layers))


In [22]:
try:
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train()
        val_loss = evaluate(val_data)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                           val_loss, math.exp(val_loss)))
        print('-' * 89)
        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            with open(args.save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            lr /= 4.0
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

| epoch   1 |    10/ 1087 batches | lr 0.10 | ms/batch 2256.37 | loss 12.48 | ppl 263284.83
| epoch   1 |    20/ 1087 batches | lr 0.10 | ms/batch 4305.87 | loss 11.33 | ppl 83400.38
| epoch   1 |    30/ 1087 batches | lr 0.10 | ms/batch 6366.69 | loss 11.32 | ppl 82282.31
| epoch   1 |    40/ 1087 batches | lr 0.10 | ms/batch 8444.80 | loss 11.30 | ppl 81143.83
| epoch   1 |    50/ 1087 batches | lr 0.10 | ms/batch 10523.16 | loss 11.29 | ppl 79959.88
| epoch   1 |    60/ 1087 batches | lr 0.10 | ms/batch 12590.56 | loss 11.27 | ppl 78825.20
| epoch   1 |    70/ 1087 batches | lr 0.10 | ms/batch 14657.28 | loss 11.26 | ppl 77454.02
| epoch   1 |    80/ 1087 batches | lr 0.10 | ms/batch 16744.95 | loss 11.25 | ppl 76514.76
| epoch   1 |    90/ 1087 batches | lr 0.10 | ms/batch 18827.69 | loss 11.23 | ppl 75292.04
| epoch   1 |   100/ 1087 batches | lr 0.10 | ms/batch 20901.19 | loss 11.22 | ppl 74368.92
-----------------------------------------------------------------------------------