In [1]:
import os

import itertools
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math 

import sys
sys.path.append('../')
import utils
import wiki_utils
%matplotlib inline

In [2]:
from torchtext.datasets import language_modeling, WikiText2
from torchtext.data import BPTTIterator, Field

In [3]:
tokenizer = lambda x: list(x)

In [4]:
TEXT = Field(sequential=True, 
             tokenize=tokenizer, 
             include_lengths=False, 
             use_vocab=True)

In [5]:
train, val, test = WikiText2.splits(text_field=TEXT,
                                    path='./wikitext/', train='train.txt',
                                    validation='valid.txt', test='test.txt')

In [6]:
sequence_length = 30
batch_size = eval_batch_size = 128

grad_clip = 0.1
lr = 4.
best_val_loss = None
log_interval = 100

In [7]:
TEXT.build_vocab(train)

In [8]:
train_iter, val_iter, test_iter = BPTTIterator.splits(
    (train, val, test),
    batch_size=batch_size,
    bptt_len=sequence_length,
    repeat=False, device='cuda')

In [9]:
b = next(iter(train_iter)); vars(b).keys()

dict_keys(['batch_size', 'dataset', 'fields', 'text', 'target'])

In [10]:
b.text[:5, :3]

tensor([[ 2,  9, 16],
        [30, 21, 21],
        [ 2,  2,  4],
        [34, 20, 11],
        [ 2,  5, 10]], device='cuda:0')

In [11]:
b.target[:5, :3]

tensor([[30, 21, 21],
        [ 2,  2,  4],
        [34, 20, 11],
        [ 2,  5, 10],
        [72, 10,  2]], device='cuda:0')

In [12]:
b = next(iter(val_iter)); vars(b).keys()

dict_keys(['batch_size', 'dataset', 'fields', 'text', 'target'])

In [13]:
class RNNModel(nn.Module):

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, hidden=None):
        emb = self.drop(self.encoder(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
                    weight.new(self.nlayers, bsz, self.nhid).zero_())
        else:
            return weight.new(self.nlayers, bsz, self.nhid).zero_()

In [14]:
ntokens = len(TEXT.vocab.itos)
model = RNNModel('LSTM', ntokens, 128, 128, 2, 0.3)
criterion = nn.CrossEntropyLoss()

In [15]:
def evaluate(data_loader):
    model.cuda()
    model.eval()
    total_loss = 0
    ntokens = len(TEXT.vocab.itos)
    hidden = model.init_hidden(eval_batch_size)
    for i, data in enumerate(data_loader):
        output, hidden = model(data.text.cuda())
        output_flat = output.view(-1, ntokens)
        total_loss += criterion(output_flat, data.target.view(-1)).item()
    return total_loss / len(data_loader)

In [16]:
def train():
    model.cuda()
    model.train()
    total_loss = 0
    ntokens = len(TEXT.vocab.itos)
    for batch, data in enumerate(train_iter):
        model.zero_grad()
        output, hidden = model(data.text.cuda())
        #print(output.shape)
        #print(data.target.shape)
        loss = criterion(output.view(-1, ntokens), data.target.view(-1))
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            print(total_loss)
            cur_loss = total_loss / log_interval
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_iter) // sequence_length, lr, cur_loss, math.exp(cur_loss)))
            total_loss = 0

In [17]:
def generate(n=50, temp=1.):
    model.cuda()
    model.eval()
    x = torch.rand(1, 1).mul(ntokens).long().cuda()
    hidden = None
    out = []
    for i in range(n):
        output, hidden = model(x, hidden)
        s_weights = output.squeeze().data.div(temp).exp()
        s_idx = torch.multinomial(s_weights, 1)[0]
        x.data.fill_(s_idx)
        s = TEXT.vocab.itos[s_idx]
        out.append(s)
    return ''.join(out)

In [18]:
len(val_iter)

292

In [19]:
72584/292

248.57534246575344

In [None]:
with torch.no_grad():
    print('sample:\n', generate(50), '\n')

for epoch in range(1, 10):
    train()
    val_loss = evaluate(val_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
        epoch, val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print('sample:\n', generate(50), '\n')

sample:
 qr〈ñãγтยั₹†åS,tÉăCaá’ḥđ3♭@Rử:к戦.yოвEćCxβ†5礮fà’Ú`‘V 

359.86430406570435
| epoch   1 |   100/   93 batches | lr 4.00 | loss  3.60 | ppl    36.55
328.2866168022156
| epoch   1 |   200/   93 batches | lr 4.00 | loss  3.28 | ppl    26.65
325.0367991924286
| epoch   1 |   300/   93 batches | lr 4.00 | loss  3.25 | ppl    25.80
322.5218710899353
| epoch   1 |   400/   93 batches | lr 4.00 | loss  3.23 | ppl    25.16
322.15682220458984
| epoch   1 |   500/   93 batches | lr 4.00 | loss  3.22 | ppl    25.07
312.7357544898987
| epoch   1 |   600/   93 batches | lr 4.00 | loss  3.13 | ppl    22.81
300.65835642814636
| epoch   1 |   700/   93 batches | lr 4.00 | loss  3.01 | ppl    20.22
290.1759808063507
| epoch   1 |   800/   93 batches | lr 4.00 | loss  2.90 | ppl    18.21
282.0542154312134
| epoch   1 |   900/   93 batches | lr 4.00 | loss  2.82 | ppl    16.79
274.0940980911255
| epoch   1 |  1000/   93 batches | lr 4.00 | loss  2.74 | ppl    15.50
264.21088194847107
| epoch   1 |  1

-----------------------------------------------------------------------------------------
| end of epoch   3 | valid loss  1.57 | valid ppl     4.82
-----------------------------------------------------------------------------------------
sample:
  about 1802 into the way general strerther . <eos> In  

179.8796672821045
| epoch   4 |   100/   93 batches | lr 4.00 | loss  1.80 | ppl     6.04
177.39516258239746
| epoch   4 |   200/   93 batches | lr 4.00 | loss  1.77 | ppl     5.89
177.55783939361572
| epoch   4 |   300/   93 batches | lr 4.00 | loss  1.78 | ppl     5.90
177.5480626821518
| epoch   4 |   400/   93 batches | lr 4.00 | loss  1.78 | ppl     5.90
177.18369114398956
| epoch   4 |   500/   93 batches | lr 4.00 | loss  1.77 | ppl     5.88
176.54995107650757
| epoch   4 |   600/   93 batches | lr 4.00 | loss  1.77 | ppl     5.84
176.91158151626587
| epoch   4 |   700/   93 batches | lr 4.00 | loss  1.77 | ppl     5.87
176.44852447509766
| epoch   4 |   800/   93 batches | lr 4.

169.3004608154297
| epoch   6 |  2600/   93 batches | lr 4.00 | loss  1.69 | ppl     5.44
168.80155491828918
| epoch   6 |  2700/   93 batches | lr 4.00 | loss  1.69 | ppl     5.41
167.75915372371674
| epoch   6 |  2800/   93 batches | lr 4.00 | loss  1.68 | ppl     5.35
-----------------------------------------------------------------------------------------
| end of epoch   6 | valid loss  1.47 | valid ppl     4.35
-----------------------------------------------------------------------------------------
sample:
  elected in licked from European , and merized , t 

170.01217257976532
| epoch   7 |   100/   93 batches | lr 4.00 | loss  1.70 | ppl     5.47
167.61849164962769
| epoch   7 |   200/   93 batches | lr 4.00 | loss  1.68 | ppl     5.35
168.1452248096466
| epoch   7 |   300/   93 batches | lr 4.00 | loss  1.68 | ppl     5.37
168.2704395055771
| epoch   7 |   400/   93 batches | lr 4.00 | loss  1.68 | ppl     5.38
167.9475440979004
| epoch   7 |   500/   93 batches | lr 4.00 | l

In [36]:
t1 = generate(10000, 1.)
t15 = generate(10000, 1.5)
t075 = generate(10000, 0.75)
with open('./generated075.txt', 'w') as outf:
    outf.write(t075)
with open('./generated1.txt', 'w') as outf:
    outf.write(t1)
with open('./generated15.txt', 'w') as outf:
    outf.write(t15)