In [1]:
import os,sys,time,math,textwrap

import numpy as np

import torch
import torch.nn as nn

import dataset, transformer

root = 'data/wikitext-2'

In [2]:
lr = .00035
context = 150
batch_size = 32
log_interval = 50

heads = 10
depth = 16

torch.manual_seed(0)
device = torch.device("cuda")

In [3]:
train_data = dataset.WikiText2(root, context, dataset.DatasetSplit.train)
valid_data = dataset.WikiText2(root, context, dataset.DatasetSplit.valid)
test_data = dataset.WikiText2(root, context, dataset.DatasetSplit.test)

In [4]:
def evaluate(data):
    model.eval()
    with torch.no_grad():
        loss = 0.
        loader = torch.utils.data.DataLoader(dataset=data,batch_size=batch_size,shuffle=False)
        for i, (x,y) in enumerate(loader):
            x, y = x.permute(1,0).to(device), y.permute(1,0).to(device)
            yhat = model(x).view(-1, train_data.word_count())
            loss += criterion(yhat, y.contiguous().view(-1))

    print()
    model.train()
    return loss / len(loader)

In [5]:
model = transformer.Transformer(context, train_data.word_count(), 400, 40, 900, heads, depth, tied_weights=True).to(device)
count = sum([np.prod(parm.shape) for parm in model.parameters() if parm.requires_grad])
print('Initialized graph with {} parameters'.format(count))

Initialized graph with 35198479 parameters


In [6]:
criterion = nn.NLLLoss()
curr_lr = .0001
clip = .25
best_val_loss = None
epochs = 10
save = 'model.pt'

train_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
print('Initiating training, {} iterations/epoch.'.format(len(train_loader)))

try:
    optimizer = torch.optim.Adam(model.parameters(), lr=curr_lr)
    for epoch in range(epochs):
        t0 = time.time()
        val_loss = evaluate(valid_data)
        print('-' * 100)
        print('| checkpoint | epoch {:3d} | time: {:5.2f}s | validation loss {:5.2f} | '
                'validation perplexity {:8.2f}'.format(epoch, (time.time() - t0),
                                                       val_loss, math.exp(val_loss)))
        print('-' * 100)
        print('epoch\t\tms/batch\tlr\tloss\tperplexity')

        if not best_val_loss or val_loss < best_val_loss:
            with open(save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss

        model.train()
        total_loss = 0.
        t0 = time.time()
        if epoch == 1: optimizer.param_groups[0]['lr'] = curr_lr = lr # finished warmup
        for i, (x,y) in enumerate(train_loader):
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - t0
                print('{:3d} ({:2.1f}%)\t{:5.2f}\t\t{:1.3}\t{:5.2f}\t{:8.2f}'.format(
                    epoch, 100*i/float(len(train_loader)),
                    elapsed * 1000 / log_interval, curr_lr, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                t0 = time.time()

            x, y = x.permute(1,0).to(device), y.permute(1,0).to(device)
            model.zero_grad()
            yhat = model(x).view(-1, train_data.word_count())
            loss = criterion(yhat, y.contiguous().view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            total_loss += loss.item()

except KeyboardInterrupt:
    print('Graceful Exit')

Initiating training, 436 iterations/epoch.

----------------------------------------------------------------------------------------------------
| checkpoint | epoch   0 | time:  3.40s | validation loss 10.42 | validation perplexity 33440.77
----------------------------------------------------------------------------------------------------
epoch		ms/batch	lr	loss	perplexity
  0 (11.5%)	217.25		0.0001	 8.75	 6289.62
  0 (22.9%)	218.28		0.0001	 7.22	 1364.06
  0 (34.4%)	219.41		0.0001	 6.69	  804.71
  0 (45.9%)	220.80		0.0001	 6.49	  659.89
  0 (57.3%)	221.84		0.0001	 6.38	  589.55
  0 (68.8%)	222.36		0.0001	 6.30	  545.21
  0 (80.3%)	222.58		0.0001	 6.21	  498.07
  0 (91.7%)	222.84		0.0001	 6.17	  476.76

----------------------------------------------------------------------------------------------------
| checkpoint | epoch   1 | time:  3.31s | validation loss  5.95 | validation perplexity   383.55
---------------------------------------------------------------------------------------

In [7]:
print('Restoring best checkpointed model...')
with open(save, 'rb') as f:
    model = torch.load(f)

test_loss = evaluate(test_data)
print('=' * 89)
print('| end of training | test loss {:5.2f} | test perplexity {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

Restoring best checkpointed model...

| end of training | test loss  5.50 | test perplexity   245.04


In [8]:
print('\nUncurated samples')
print('-' * 89)

def sample():
    words = []
    model.eval()
    history = torch.randint(train_data.word_count(), (1, 1), dtype=torch.long).cuda()
    for i in range(context):
        output = model(history)
        word_weights = output[-1].squeeze().exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[0]
        word_tensor = torch.Tensor([[word_idx]]).long().cuda()
        history = torch.cat([history, word_tensor], 0)

        words.append(train_data.idx2word[word_idx])

    return '\n'.join(textwrap.wrap(' '.join(words),80))

for i in range(5):
    print('({})'.format(i), sample())


Uncurated samples
-----------------------------------------------------------------------------------------
(0) accepted van observed each <unk> which <unk> New because <unk> the the was the
the wanting and along in unsuccessfully where in minesweepers the rhymes
obsessive cottages as for Dublin Archangel one to which which also as was and
who including known quickly 2010 including shit to reached because scientific
other graduation to without are in which Glenn a was the advancing on including
and and and the Croatia accurately a stated and behind called is musicians which
when a the and Governor he he most alongside and and a featuring who ahead he
her Lisa including and hearts beads and Gordon 13 and but from another any " 1
who that and which thus mined BMI leaving featured crumbling and <unk> coaching
a the which codenamed nectar Canada had and the " <unk> and one and Owen
including with forbs is this but vocals Honduras and the <unk> all
(1) after 400 Street and noting Jr songwr