Revised and fixed code from https://github.com/JayParks/transformer (MT) for LM

Stripped the code from the JayParks repo for MT Transformer. Introduced a few updates and changes for speed, but it's still frustratingly slow. Possible improvement - speed it up.

Another issue - hyperparameter search for language modelling (number of heads, number of self-attention layers, etc). Does not work well from the box. This might be of help https://arxiv.org/pdf/1804.00247.pdf.

Also consider parallelizing.

# TODO

* Speed up
* Tune hyperparams (now it's diverging)
* Add MoS

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="4"

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

import torch
torch.cuda.device(0)
import torch.nn as nn
import torch.optim as optim

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils import clip_grad_norm_ as clip
from torch.optim.lr_scheduler import StepLR

import const
from data import *
from transformer import *

In [3]:
ptb_datapath_train = 'data/penn/train.txt'
ptb_datapath_valid = 'data/penn/valid.txt'
ptb_datapath_test = 'data/penn/test.txt'

batch_size = 128

ptb_train = DataSet(ptb_datapath_train, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_valid = DataSet(ptb_datapath_valid, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_test = DataSet(ptb_datapath_test, batch_size, display_freq=0, max_len=90, trunc_len=90)

Loading data from data/penn/train.txt ...
Loading data from data/penn/valid.txt ...
Loading data from data/penn/test.txt ...


In [4]:
ptb_train.build_dict()
ptb_valid.change_dict(ptb_train.dictionary)
ptb_test.change_dict(ptb_train.dictionary)

Building dictionary...
Done.
Save dictionary at data/penn/train.txt.dict
Index tokens ...
42068 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/train.txt
Number of sentence : 42068
Number of tokens : 887521
Vocabulary size : 10000
Number of batches : 328
Batch size : 128
Done.
Index tokens ...
3370 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/valid.txt
Number of sentence : 3370
Number of tokens : 70390
Vocabulary size : 10000
Number of batches : 26
Batch size : 128
Done.
Index tokens ...
3761 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/test.txt
Number of sentence : 3761
Number of tokens : 78669
Vocabulary size : 10000
Number of batches : 29
Batch size : 128
Done.


In [5]:
############ Optional: get data by tokens ###############
corpus = Corpus('data/penn')
eval_batch_size = 10
test_batch_size = 1
batch_size = 128
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, test_batch_size)

#### how to take a batch ####
# the data is already splitten into batch_size(now we need to decide about seq length)
batch_num = 2
batch = get_batch(train_data, batch_num, seq_len=35)


#### TODO (if needed) ###
# 1) repackage hiddens for learning by tokens
# 2) learn not every step (depends on 1st point)
# 3) add grad clipping

torch.Size([7262, 128])
torch.Size([7376, 10])
torch.Size([82430, 1])


In [6]:
voc_size = ptb_train.num_vocb
emb_dim = 512
d_k = 64
d_v = 64
n_layers = 2
n_heads = 4
d_ff = 2048
max_tgt_seq_len = 90
dropout = 0.1
weighted_model = False
share_proj_weight = True
lr = 1e-6
n_epochs = 10
clip_grad = 5
warmup_steps = 2000

In [7]:
model = LMTransformer(n_layers, d_k, d_v, emb_dim, d_ff,
                      n_heads, max_tgt_seq_len, voc_size,
                      dropout, weighted_model, share_proj_weight)
criterion = nn.CrossEntropyLoss(ignore_index=const.PAD)

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

#opt = optim.Adam(model.trainable_params(), lr=lr)
# lr_lambda = lambda epoch: 0.99 ** epoch
#lrsched = StepLR(opt, step_size=10, gamma=0.5)

Sharing target embedding and projection..


In [8]:
torch.cuda.is_available()

True

In [9]:
opt = optim.Adam(model.trainable_params(),betas=(0.9, 0.98), eps=1e-09, lr=lr)
i=0
for epoch in range(n_epochs):
    #lrsched.step()
    acc_loss = 0
    print('Start epoch %d, learning rate %f '%(epoch + 1, opt.state_dict()['param_groups'][0]['lr']))
    start_time = time.time()
    model.train()
    ptb_train.shuffle()
    for batch_idx in tqdm(range(ptb_train.num_batch), unit='batches'):
        data, lengths, target = ptb_train.get_batch(batch_idx)
        
        opt.zero_grad()
        output, self_attn = model.forward(data, lengths)
        loss = criterion(output, target.view(-1))
        loss.backward()
        opt.step()
        acc_loss += loss.item()
        i+=1
        new_lr = np.power(emb_dim, -0.5) * np.min([
            np.power((i), -0.5),
            np.power(warmup_steps, -1.5) * (i)])
        for param_group in opt.param_groups:
            param_group['lr'] = new_lr
        
    avg_loss = acc_loss / ptb_train.num_batch
    print('Epoch : %d, Batch : %d / %d, Loss : %f, Perplexity : %f, Time : %f' 
          % (epoch + 1, batch_idx, ptb_train.num_batch,
             avg_loss, math.exp(avg_loss),
             time.time() - start_time))

    acc_loss = 0
    model.eval()
    for batch_idx in tqdm(range(ptb_test.num_batch), unit='batches'):
        data, lengths, target = ptb_test[batch_idx]
        output, self_attn = model.forward(data, lengths)
        loss = criterion(output, target.view(-1))
        acc_loss += loss.item()

    val_loss = acc_loss / ptb_test.num_batch
    print('Validation Loss : %f' % val_loss)
    print('Validation Perplexity : %f' % math.exp(val_loss))

  0%|          | 0/328 [00:00<?, ?batches/s]

Start epoch 1, learning rate 0.000001 
2


100%|██████████| 328/328 [00:40<00:00,  8.12batches/s]
  7%|▋         | 2/29 [00:00<00:01, 16.72batches/s]

Epoch : 1, Batch : 327 / 328, Loss : 7.083226, Perplexity : 1191.807269, Time : 40.480996


100%|██████████| 29/29 [00:01<00:00, 18.50batches/s]
  0%|          | 1/328 [00:00<00:37,  8.80batches/s]

Validation Loss : 5.930350
Validation Perplexity : 376.286262
Start epoch 2, learning rate 0.000162 
2


100%|██████████| 328/328 [00:40<00:00,  8.16batches/s]
  7%|▋         | 2/29 [00:00<00:01, 17.07batches/s]

Epoch : 2, Batch : 327 / 328, Loss : 5.664921, Perplexity : 288.565326, Time : 40.266509


100%|██████████| 29/29 [00:01<00:00, 17.71batches/s]
  0%|          | 1/328 [00:00<00:39,  8.18batches/s]

Validation Loss : 5.358341
Validation Perplexity : 212.372231
Start epoch 3, learning rate 0.000324 
2


100%|██████████| 328/328 [00:40<00:00,  8.19batches/s]
  7%|▋         | 2/29 [00:00<00:01, 16.49batches/s]

Epoch : 3, Batch : 327 / 328, Loss : 5.199912, Perplexity : 181.256210, Time : 40.095078


100%|██████████| 29/29 [00:01<00:00, 19.23batches/s]
  0%|          | 1/328 [00:00<00:38,  8.52batches/s]

Validation Loss : 5.111100
Validation Perplexity : 165.852742
Start epoch 4, learning rate 0.000486 
2


100%|██████████| 328/328 [00:39<00:00,  8.24batches/s]
  7%|▋         | 2/29 [00:00<00:01, 15.20batches/s]

Epoch : 4, Batch : 327 / 328, Loss : 4.897650, Perplexity : 133.974540, Time : 39.823247


100%|██████████| 29/29 [00:01<00:00, 18.52batches/s]
  0%|          | 1/328 [00:00<00:36,  8.90batches/s]

Validation Loss : 4.953933
Validation Perplexity : 141.731328
Start epoch 5, learning rate 0.000648 
2


100%|██████████| 328/328 [00:39<00:00,  8.24batches/s]
  7%|▋         | 2/29 [00:00<00:01, 17.19batches/s]

Epoch : 5, Batch : 327 / 328, Loss : 4.667583, Perplexity : 106.440218, Time : 39.824336


100%|██████████| 29/29 [00:01<00:00, 20.39batches/s]
  0%|          | 1/328 [00:00<00:42,  7.72batches/s]

Validation Loss : 4.876707
Validation Perplexity : 131.197856
Start epoch 6, learning rate 0.000810 
2


100%|██████████| 328/328 [00:40<00:00,  8.13batches/s]
  7%|▋         | 2/29 [00:00<00:01, 17.11batches/s]

Epoch : 6, Batch : 327 / 328, Loss : 4.482169, Perplexity : 88.426283, Time : 40.380551


100%|██████████| 29/29 [00:01<00:00, 20.25batches/s]
  0%|          | 1/328 [00:00<00:34,  9.56batches/s]

Validation Loss : 4.867549
Validation Perplexity : 130.001846
Start epoch 7, learning rate 0.000972 
2


100%|██████████| 328/328 [00:41<00:00,  7.98batches/s]
  7%|▋         | 2/29 [00:00<00:01, 17.03batches/s]

Epoch : 7, Batch : 327 / 328, Loss : 4.308748, Perplexity : 74.347347, Time : 41.156543


100%|██████████| 29/29 [00:01<00:00, 17.82batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.832151
Validation Perplexity : 125.480522
Start epoch 8, learning rate 0.000922 
2


100%|██████████| 328/328 [00:40<00:00,  8.15batches/s]
  7%|▋         | 2/29 [00:00<00:01, 16.72batches/s]

Epoch : 8, Batch : 327 / 328, Loss : 4.106600, Perplexity : 60.739875, Time : 40.293201


100%|██████████| 29/29 [00:01<00:00, 20.27batches/s]
  1%|          | 2/328 [00:00<00:32, 10.12batches/s]

Validation Loss : 4.827914
Validation Perplexity : 124.950029
Start epoch 9, learning rate 0.000863 
2


100%|██████████| 328/328 [00:40<00:00,  8.11batches/s]
  7%|▋         | 2/29 [00:00<00:01, 17.29batches/s]

Epoch : 9, Batch : 327 / 328, Loss : 3.927658, Perplexity : 50.787889, Time : 40.461197


100%|██████████| 29/29 [00:01<00:00, 19.58batches/s]
  0%|          | 1/328 [00:00<00:34,  9.40batches/s]

Validation Loss : 4.859546
Validation Perplexity : 128.965611
Start epoch 10, learning rate 0.000813 
2


100%|██████████| 328/328 [00:40<00:00,  8.11batches/s]
  7%|▋         | 2/29 [00:00<00:01, 18.19batches/s]

Epoch : 10, Batch : 327 / 328, Loss : 3.767748, Perplexity : 43.282469, Time : 40.479215


100%|██████████| 29/29 [00:01<00:00, 20.35batches/s]

Validation Loss : 4.905952
Validation Perplexity : 135.091419



