In [1]:
#! git clone https://github.com/kaiweic/Poor-man-s-GPT-3.git

In [2]:
import os,sys,time,math,textwrap

import numpy as np

import torch
import torch.nn as nn

sys.path.append('Poor-man-s-GPT-3')
import dataset, datasetxl, transformer, transformerxl, transformerxl3

from data_utils import get_lm_corpus

root = 'Poor-man-s-GPT-3/data/wikitext-2'

In [3]:
lr = .00035
context = 150
batch_size = 32
log_interval = 50
tied_weights = False
shuffle = False

stage = 'd'

if stage == 'c':
    heads = 10
    depth = 16
    epochs = 10
    dropout = 0
    dropoutio = 0
    k = 40
    d = 400
    m = 900
elif stage == 'b':
    heads = 2
    depth = 2
    epochs = 10
    dropout = 0
    dropoutio = 0
    k = 40
    d = 400
    m = 900
elif stage == 'd':
    heads = 8
    depth = 12
    epochs = 80  
    dropout = 0.2
    dropoutio = 0.6
    k = 64
    d = 400
    m = 900

torch.manual_seed(0)
device = torch.device("cuda")
corpus = get_lm_corpus(root, "wt2")

Loading cached dataset...


In [4]:
# Our dataset
train_data = datasetxl.WikiText2(root, context, batch_size, dataset.DatasetSplit.train)
valid_data = datasetxl.WikiText2(root, context, batch_size, dataset.DatasetSplit.valid)
test_data = datasetxl.WikiText2(root, context, batch_size, dataset.DatasetSplit.test)

#train_iter = corpus.get_iterator('train', batch_size, context, device=device, ext_len=0)
#validate_iter = corpus.get_iterator('valid', batch_size, context, device=device, ext_len=0)

# John's dataset
#train_data = dataset.WikiText2(root, context, dataset.DatasetSplit.train)
#valid_data = dataset.WikiText2(root, context, dataset.DatasetSplit.valid)
#test_data = dataset.WikiText2(root, context, dataset.DatasetSplit.test)

In [5]:
def evaluate(data):
    model.eval()
    #model.reset_memory()
    with torch.no_grad():
        count = 0
        loss = 0.
        loader = torch.utils.data.DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle)
        #test_iter = corpus.get_iterator('test', batch_size, context, device=device, ext_len=0)
        for i, (x,y) in enumerate(loader):
        #for i, (x,y, seq_len) in enumerate(validate_iter):
            #if (x.shape[1] == 0): continue
            x, y = x.permute(1,0).to(device), y.permute(1,0).to(device)
            #print('x, y', x.shape, y.shape)
            model_x = model(x)
            # print('model(x), word_count', model_x.shape, train_data.word_count())
            yhat = model_x.contiguous().view(-1, train_data.word_count())
            loss += criterion(yhat, y.contiguous().view(-1))
            count +=1
            #print(loss)
            #print(x.shape, y.shape)
            #print(x)
            #print(y)
            #print("------------")

    print()
    model.train()
    return loss / len(loader)
    #return loss / count

In [6]:
model = transformerxl.Transformer(context, train_data.word_count(), d, k, m, heads, depth, tied_weights=tied_weights, dropout=dropout, dropoutio=dropoutio).to(device)
count = sum([np.prod(parm.shape) for parm in model.parameters() if parm.requires_grad])
print('Initialized graph with {} parameters'.format(count))

Initialized graph with 47653582 parameters


In [7]:
criterion = nn.NLLLoss()
curr_lr = .0001
clip = .25
best_val_loss = None
save = 'model.pt'

train_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=shuffle)
print('Initiating training, {} iterations/epoch.'.format(len(train_loader)))

try:
    optimizer = torch.optim.Adam(model.parameters(), lr=curr_lr)
    for epoch in range(epochs):
        t0 = time.time()
        print('epoch', epoch)
        val_loss = evaluate(valid_data)
        print('-' * 100)
        print('| checkpoint | epoch {:3d} | time: {:5.2f}s | validation loss {:5.2f} | '
                'validation perplexity {:8.2f}'.format(epoch, (time.time() - t0),
                                                       val_loss, math.exp(val_loss)))
        print('-' * 100)
        print('epoch\t\tms/batch\tlr\tloss\tperplexity')

        if not best_val_loss or val_loss < best_val_loss:
            with open(save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss

        model.train()
        #model.reset_memory()
        total_loss = 0.
        t0 = time.time()
        if epoch == 1: optimizer.param_groups[0]['lr'] = curr_lr = lr # finished warmup
        for i, (x,y) in enumerate(train_loader):
        # for i, (x,y, seq_len) in enumerate(train_iter):
            #if i == 0 or i == 1:
            # print(x.shape, y.shape)
            #  print(x)
            #  print(y)
            #  "------------------"
            #if (x.shape[1] == 0): continue
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - t0
                print('{:3d} ({:2.1f}%)\t{:5.2f}\t\t{:1.3}\t{:5.2f}\t{:8.2f}'.format(
                    epoch, 100*i/float(len(train_loader)),
                    elapsed * 1000 / log_interval, curr_lr, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                t0 = time.time()

            x, y = x.permute(1,0).to(device), y.permute(1,0).to(device)
            model.zero_grad()
            yhat = model(x).contiguous().view(-1, train_data.word_count())
            loss = criterion(yhat, y.contiguous().view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            total_loss += loss.item()

except KeyboardInterrupt:
    print('Graceful Exit')

Initiating training, 435 iterations/epoch.
epoch 0

----------------------------------------------------------------------------------------------------
| checkpoint | epoch   0 | time:  9.80s | validation loss 10.50 | validation perplexity 36172.16
----------------------------------------------------------------------------------------------------
epoch		ms/batch	lr	loss	perplexity
  0 (11.5%)	667.78		0.0001	 8.97	 7892.94
  0 (23.0%)	646.23		0.0001	 7.37	 1587.01
  0 (34.5%)	647.34		0.0001	 6.98	 1077.22
  0 (46.0%)	654.41		0.0001	 6.89	  986.55
  0 (57.5%)	648.76		0.0001	 6.82	  917.10
  0 (69.0%)	649.11		0.0001	 6.78	  878.47
  0 (80.5%)	650.76		0.0001	 6.72	  829.48
  0 (92.0%)	654.22		0.0001	 6.67	  785.39
epoch 1

----------------------------------------------------------------------------------------------------
| checkpoint | epoch   1 | time: 10.29s | validation loss  6.52 | validation perplexity   676.04
-----------------------------------------------------------------------

In [8]:
print('Restoring best checkpointed model...')
with open(save, 'rb') as f:
    model = torch.load(f)

test_loss = evaluate(test_data)
print('=' * 89)
print('| end of training | test loss {:5.2f} | test perplexity {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

Restoring best checkpointed model...

| end of training | test loss  6.26 | test perplexity   525.54


In [9]:
print('\nUncurated samples')
print('-' * 89)

# this sampling code doesn't work with transformerxl because of input shape
def sample():
    words = []
    model.eval()
    history = torch.randint(train_data.word_count(), (1, 1), dtype=torch.long).cuda()
    for i in range(context):
        output = model(history)
        word_weights = output[-1].squeeze().exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[0]
        word_tensor = torch.Tensor([[word_idx]]).long().cuda()
        history = torch.cat([history, word_tensor], 0)

        words.append(train_data.idx2word[word_idx])

    return '\n'.join(textwrap.wrap(' '.join(words),80))

for i in range(5):
    print('({})'.format(i), sample())


Uncurated samples
-----------------------------------------------------------------------------------------


RuntimeError: ignored