Revised and fixed code from https://github.com/JayParks/transformer (MT) for LM

Stripped the code from the JayParks repo for MT Transformer. Introduced a few updates and changes for speed, but it's still frustratingly slow. Possible improvement - speed it up.

Another issue - hyperparameter search for language modelling (number of heads, number of self-attention layers, etc). Does not work well from the box. This might be of help https://arxiv.org/pdf/1804.00247.pdf.

Also consider parallelizing.

# TODO
* Clean up
* Add MoS

# Random sequence length batching

This version of Transformer LM usesrandom sequence length batching.

**NB** Make sure the src code does not assuem the existence of PAD.

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from showprogress import showprogress

import torch
torch.cuda.device(0)
import torch.nn as nn
import torch.optim as optim

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils import clip_grad_norm_ as clip
from torch.optim.lr_scheduler import StepLR

import const
from data import *
from transformer import *
from utils import logging

# FUNCTIONS

In [None]:
def train(model, opt, criterion, ):
    pass

def evaluate(data_source, model, ntokens, seq_len):
    model.eval()
    total_loss = 0
    batch = 0
    for i in range(0, data_source.size(0) - 1, seq_len):
        data, targets = get_batch(data_source, i, seq_len=seq_len)
        seq_len = data.shape[1]
        lengths = torch.ones(data.shape[0], device=device, dtype=torch.long) * seq_len

        log_prob, self_attn = model(data, lengths)
        loss = criterion(log_prob, targets.view(-1))

        total_loss += loss.item()
        batch += 1
    return total_loss / batch

# DATASET

In [None]:
eval_batch_size = 10
test_batch_size = 1
batch_size = 128

In [None]:
data_path = 'data/penn'  # 'wikitext-2' 

# Random length sequence batching
corpus = Corpus(data_path)
train_data = batchify(corpus.train, batch_size, )
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, test_batch_size)

# Sentence-wise batching
# train_data = DataSet(data_path + '/train.txt', batch_size, display_freq=0, max_len=90, trunc_len=90)
# val_data = DataSet(data_path + '/valid.txt', batch_size, display_freq=0, max_len=90, trunc_len=90)
# test_data = DataSet(data_path + '/test.txt', batch_size, display_freq=0, max_len=90, trunc_len=90)
# train_data.build_dict()
# valid_data.change_dict(train_data.dictionary)
# test_data.change_dict(train_data.dictionary)

# MODEL PARAMS

In [None]:
voc_size = len(corpus.dictionary)  # ptb_train.num_vocb
emb_dim = 512
d_k = 64
d_v = 64
n_layers = 2
n_heads = 4
d_ff = 1024
max_tgt_seq_len = 90
dropout = 0.1
weighted_model = False
share_proj_weight = True

# MODEL

In [None]:
model = LMTransformer(n_layers, d_k, d_v, emb_dim, d_ff,
                      n_heads, max_tgt_seq_len, voc_size,
                      dropout, weighted_model, share_proj_weight)
criterion = nn.CrossEntropyLoss(ignore_index=const.PAD)
opt = optim.Adam(model.trainable_params(), betas=(0.9, 0.98), eps=1e-09, lr=lr)

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

#opt = optim.Adam(model.trainable_params(), lr=lr)
# lr_lambda = lambda epoch: 0.99 ** epoch
#lrsched = StepLR(opt, step_size=10, gamma=0.5)

# TRAIN PARAMS

In [None]:
# batching
bptt0 = 70
max_seq_len_delta = 40

#general
n_epochs = 1000

# optimization
lr = 1e-6
clip_grad = 5
warmup_steps = 2000

# LOGGING PARAMS

In [None]:
# save path
exp_dir = '{}-{}'.format("PTB", time.strftime("%Y%m%d-%H%M%S"))

# logging
log_interval = 100
log_file = 'log.txt'

# RUN

In [16]:
i=0
best_val_loss = []

In [None]:
try:
    for epoch in range(n_epochs):
        epoch_start_time = time.time()
        total_loss = 0
        print('Start epoch %d, learning rate %f '%(epoch + 1, opt.state_dict()['param_groups'][0]['lr']))
        start_time = time.time()
        model.train()
        batch, i = 0, 0
        while i < train_data.size(0) - 2:
            bptt = bptt0 if np.random.random() < 0.95 else bptt0 / 2.
            # Prevent excessively small or negative sequence lengths
            seq_len = max(5, int(np.random.normal(bptt, 5))) # loc 70, scale 5
            # There's a very small chance that it could select a very long sequence length resulting in OOM
            seq_len = min(seq_len, bptt + max_seq_len_delta)

            data, targets = get_batch(train_data, i, seq_len=seq_len)
            seq_len = data.shape[1]
            lengths = torch.ones(data.shape[0], device=device, dtype=torch.long) * seq_len

            opt.zero_grad()
            output, self_attn = model.forward(data, lengths)
            loss = criterion(output, targets.view(-1))
            
            loss.backward()
            opt.step()
            
            batch += 1
            i += seq_len
            
            new_lr = np.power(emb_dim, -0.5) * np.min([
                np.power((batch), -0.5),
                np.power(warmup_steps, -1.5) * (batch)])
            for param_group in opt.param_groups:
                param_group['lr'] = new_lr

            if batch % log_interval == 0 and batch > 0:
                cur_loss = loss.item()
                elapsed = time.time() - start_time
                logging('| epoch {:3d} | {}/{} batches | lr {:02.4f} | ms/batch {:5.2f} | '
                        'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt0, opt.param_groups[0]['lr'],
                    elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()

        val_loss = evaluate(val_data, model, voc_size, bptt0)
        logging('-' * 89)
        logging('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                           val_loss, math.exp(val_loss)))
        logging('-' * 89)

        best_val_loss.append(val_loss)

except KeyboardInterrupt:
    logging('-' * 89)
    logging('Exiting from training early')