Revised and fixed code from https://github.com/JayParks/transformer (MT) for LM

Stripped the code from the JayParks repo for MT Transformer. Introduced a few updates and changes for speed, but it's still frustratingly slow. Possible improvement - speed it up.

Another issue - hyperparameter search for language modelling (number of heads, number of self-attention layers, etc). Does not work well from the box. This might be of help https://arxiv.org/pdf/1804.00247.pdf.

Also consider parallelizing.

# TODO

* Speed up
* Tune hyperparams (now it's diverging)
* Add MoS

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="4"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

import torch
torch.cuda.device(0)
import torch.nn as nn
import torch.optim as optim

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils import clip_grad_norm_ as clip
from torch.optim.lr_scheduler import StepLR

import const
from data import *
from transformer import *

In [None]:
ptb_datapath_train = 'data/penn/train.txt'
ptb_datapath_valid = 'data/penn/valid.txt'
ptb_datapath_test = 'data/penn/test.txt'

batch_size = 128

ptb_train = DataSet(ptb_datapath_train, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_valid = DataSet(ptb_datapath_valid, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_test = DataSet(ptb_datapath_test, batch_size, display_freq=0, max_len=90, trunc_len=90)

In [None]:
ptb_train.build_dict()
ptb_valid.change_dict(ptb_train.dictionary)
ptb_test.change_dict(ptb_train.dictionary)

In [None]:
############ Optional: get data by tokens ###############
corpus = Corpus('data/penn')
eval_batch_size = 10
test_batch_size = 1
batch_size = 64
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, test_batch_size)

#### how to take a batch ####
# the data is already splitten into batch_size(now we need to decide about seq length)
batch_num = 2
batch = get_batch(train_data, batch_num, seq_len=35)


#### TODO (if needed) ###
# 1) repackage hiddens for learning by tokens
# 2) learn not every step (depends on 1st point)
# 3) add grad clipping

In [None]:
voc_size = ptb_train.num_vocb
emb_dim = 1024
d_k = 64
d_v = 64
n_layers = 2
n_heads = 4
d_ff = 2048
max_tgt_seq_len = 90
dropout = 0.1
weighted_model = False
share_proj_weight = True
lr = 1e-6
n_epochs = 30
clip_grad = 5
warmup_steps = 3000

In [None]:
model = LMTransformer(n_layers, d_k, d_v, emb_dim, d_ff,
                      n_heads, max_tgt_seq_len, voc_size,
                      dropout, weighted_model, share_proj_weight)
criterion = nn.CrossEntropyLoss(ignore_index=const.PAD)

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

#opt = optim.Adam(model.trainable_params(), lr=lr)
# lr_lambda = lambda epoch: 0.99 ** epoch
#lrsched = StepLR(opt, step_size=10, gamma=0.5)

In [None]:
opt = optim.Adam(model.trainable_params(),betas=(0.9, 0.98), eps=1e-09, lr=lr)
i=0
for epoch in range(n_epochs):   
    acc_loss = 0
    print('Start epoch %d, learning rate %f '%(epoch + 1, opt.state_dict()['param_groups'][0]['lr']))
    start_time = time.time()
    model.train()
    ptb_train.shuffle_epoch()
    for group_id in range(len(ptb_train.groups)):
        for batch_idx in tqdm(range(ptb_train.num_batches[group_id]), unit='batches'):
            data, lengths, target = ptb_train.get_batch_new(batch_idx, group_id)

            opt.zero_grad()
            output, self_attn = model.forward(data, lengths)
            loss = criterion(output, target.view(-1))
            loss.backward()
            opt.step()
            acc_loss += loss.item()
            i+=1
            new_lr = np.power(emb_dim, -0.5) * np.min([
                np.power((i), -0.5),
                np.power(warmup_steps, -1.5) * (i)])
            for param_group in opt.param_groups:
                param_group['lr'] = new_lr
        
    avg_loss = acc_loss / ptb_train.num_batch
    print('Epoch : %d, Batch : %d / %d, Loss : %f, Perplexity : %f, Time : %f' 
          % (epoch + 1, batch_idx, ptb_train.num_batch,
             avg_loss, math.exp(avg_loss),
             time.time() - start_time))

    acc_loss = 0
    model.eval()
    for batch_idx in tqdm(range(ptb_test.num_batch), unit='batches'):
        data, lengths, target = ptb_test[batch_idx]
        output, self_attn = model.forward(data, lengths)
        loss = criterion(output, target.view(-1))
        acc_loss += loss.item()

    val_loss = acc_loss / ptb_test.num_batch
    print('Validation Loss : %f' % val_loss)
    print('Validation Perplexity : %f' % math.exp(val_loss))