Revised and fixed code from https://github.com/JayParks/transformer (MT) for LM

Stripped the code from the JayParks repo for MT Transformer. Introduced a few updates and changes for speed, but it's still frustratingly slow. Possible improvement - speed it up.

Another issue - hyperparameter search for language modelling (number of heads, number of self-attention layers, etc). Does not work well from the box.

# TODO

* Speed up
* Tune hyperparams (now it's diverging)
* Add MoS

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

import torch
torch.cuda.device(2)
import torch.nn as nn
import torch.optim as optim

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils import clip_grad_norm_ as clip
from torch.optim.lr_scheduler import StepLR

import const
from data import *
from transformer import *

In [3]:
ptb_datapath_train = 'data/penn/train.txt'
ptb_datapath_valid = 'data/penn/valid.txt'
ptb_datapath_test = 'data/penn/test.txt'

batch_size = 128

ptb_train = DataSet(ptb_datapath_train, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_valid = DataSet(ptb_datapath_valid, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_test = DataSet(ptb_datapath_test, batch_size, display_freq=0, max_len=90, trunc_len=90)

Loading data from data/penn/train.txt ...
Loading data from data/penn/valid.txt ...
Loading data from data/penn/test.txt ...


In [4]:
ptb_train.build_dict()
ptb_valid.change_dict(ptb_train.dictionary)
ptb_test.change_dict(ptb_train.dictionary)

Building dictionary...
Done.
Save dictionary at data/penn/train.txt.dict
Index tokens ...
42068 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/train.txt
Number of sentence : 42068
Number of tokens : 887521
Vocabulary size : 10000
Number of batches : 328
Batch size : 128
Done.
Index tokens ...
3370 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/valid.txt
Number of sentence : 3370
Number of tokens : 70390
Vocabulary size : 10000
Number of batches : 26
Batch size : 128
Done.
Index tokens ...
3761 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/test.txt
Number of sentence : 3761
Number of tokens : 78669
Vocabulary size : 10000
Number of batches : 29
Batch size : 128
Done.


In [14]:
voc_size = ptb_train.num_vocb
emb_dim = 512
d_k = 64
d_v = 64
n_layers = 2
n_heads = 4
d_ff = 2048
max_tgt_seq_len = 90
dropout = 0.1
weighted_model = False
share_proj_weight = True
lr = 1e-3
n_epochs = 30
clip_grad = 5

In [15]:
model = LMTransformer(n_layers, d_k, d_v, emb_dim, d_ff,
                      n_heads, max_tgt_seq_len, voc_size,
                      dropout, weighted_model, share_proj_weight)
criterion = nn.CrossEntropyLoss(ignore_index=const.PAD)

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

opt = optim.Adam(model.trainable_params(), lr=lr)
# lr_lambda = lambda epoch: 0.99 ** epoch
lrsched = StepLR(opt, step_size=10, gamma=0.5)

Sharing target embedding and projection..


In [None]:
for epoch in range(n_epochs):
    lrsched.step()
    acc_loss = 0
    print('Start epoch %d, learning rate %f '%(epoch + 1, opt.state_dict()['param_groups'][0]['lr']))
    start_time = time.time()
    model.train()
    ptb_train.shuffle()
    for batch_idx in tqdm(range(ptb_train.num_batch), unit='batches'):
        data, lengths, target = ptb_train.get_batch(batch_idx)
        
        opt.zero_grad()
        output, self_attn = model.forward(data, lengths)
        loss = criterion(output, target.view(-1))
        loss.backward()
        opt.step()
        acc_loss += loss.item()

    avg_loss = acc_loss / ptb_train.num_batch
    print('Epoch : %d, Batch : %d / %d, Loss : %f, Perplexity : %f, Time : %f' 
          % (epoch + 1, batch_idx, ptb_train.num_batch,
             avg_loss, math.exp(avg_loss),
             time.time() - start_time))

    acc_loss = 0
    model.eval()
    for batch_idx in tqdm(range(ptb_test.num_batch), unit='batches'):
        data, lengths, target = ptb_test[batch_idx]
        output, self_attn = model.forward(data, lengths)
        loss = criterion(output, target.view(-1))
        acc_loss += loss.item()

    val_loss = acc_loss / ptb_test.num_batch
    print('Validation Loss : %f' % val_loss)
    print('Validation Perplexity : %f' % math.exp(val_loss))

  0%|          | 0/328 [00:00<?, ?batches/s]

Start epoch 1, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.75batches/s]

Epoch : 1, Batch : 327 / 328, Loss : 5.702489, Perplexity : 299.612289, Time : 121.237305


100%|██████████| 29/29 [00:03<00:00,  7.73batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.133783
Validation Perplexity : 169.657661
Start epoch 2, learning rate 0.001000 
2


100%|██████████| 328/328 [02:00<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.92batches/s]

Epoch : 2, Batch : 327 / 328, Loss : 4.874813, Perplexity : 130.949660, Time : 120.998704


100%|██████████| 29/29 [00:03<00:00,  7.52batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.907077
Validation Perplexity : 135.243477
Start epoch 3, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.89batches/s]

Epoch : 3, Batch : 327 / 328, Loss : 4.534898, Perplexity : 93.213973, Time : 121.122315


100%|██████████| 29/29 [00:03<00:00,  7.61batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.829746
Validation Perplexity : 125.179180
Start epoch 4, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.69batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.85batches/s]

Epoch : 4, Batch : 327 / 328, Loss : 4.289405, Perplexity : 72.923101, Time : 121.794339


100%|██████████| 29/29 [00:03<00:00,  7.51batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.818284
Validation Perplexity : 123.752517
Start epoch 5, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.62batches/s]

Epoch : 5, Batch : 327 / 328, Loss : 4.089971, Perplexity : 59.738153, Time : 121.294400


100%|██████████| 29/29 [00:03<00:00,  7.72batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.846339
Validation Perplexity : 127.273587
Start epoch 6, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.83batches/s]

Epoch : 6, Batch : 327 / 328, Loss : 3.920071, Perplexity : 50.404037, Time : 121.199066


100%|██████████| 29/29 [00:03<00:00,  7.72batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.865344
Validation Perplexity : 129.715576
Start epoch 7, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.70batches/s]
  3%|▎         | 1/29 [00:00<00:03,  8.28batches/s]

Epoch : 7, Batch : 327 / 328, Loss : 3.767682, Perplexity : 43.279606, Time : 121.382798


100%|██████████| 29/29 [00:03<00:00,  7.76batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.924917
Validation Perplexity : 137.677976
Start epoch 8, learning rate 0.001000 
2


100%|██████████| 328/328 [02:00<00:00,  2.72batches/s]
  3%|▎         | 1/29 [00:00<00:03,  8.21batches/s]

Epoch : 8, Batch : 327 / 328, Loss : 3.631104, Perplexity : 37.754487, Time : 120.860556


100%|██████████| 29/29 [00:03<00:00,  7.73batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.981597
Validation Perplexity : 145.706839
Start epoch 9, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.70batches/s]
  3%|▎         | 1/29 [00:00<00:03,  8.16batches/s]

Epoch : 9, Batch : 327 / 328, Loss : 3.504425, Perplexity : 33.262320, Time : 121.500200


100%|██████████| 29/29 [00:03<00:00,  7.79batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.043568
Validation Perplexity : 155.022185
Start epoch 10, learning rate 0.001000 
2


100%|██████████| 328/328 [02:01<00:00,  2.70batches/s]
  3%|▎         | 1/29 [00:00<00:03,  8.19batches/s]

Epoch : 10, Batch : 327 / 328, Loss : 3.388829, Perplexity : 29.631234, Time : 121.482503


100%|██████████| 29/29 [00:03<00:00,  7.84batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.107677
Validation Perplexity : 165.286014
Start epoch 11, learning rate 0.000500 
2


100%|██████████| 328/328 [02:01<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  7.90batches/s]

Epoch : 11, Batch : 327 / 328, Loss : 3.041171, Perplexity : 20.929747, Time : 121.131451


100%|██████████| 29/29 [00:03<00:00,  7.67batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.212470
Validation Perplexity : 183.546788
Start epoch 12, learning rate 0.000500 
2


100%|██████████| 328/328 [02:00<00:00,  2.73batches/s]
  3%|▎         | 1/29 [00:00<00:03,  8.37batches/s]

Epoch : 12, Batch : 327 / 328, Loss : 2.912203, Perplexity : 18.397279, Time : 120.404881


100%|██████████| 29/29 [00:03<00:00,  7.83batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.307598
Validation Perplexity : 201.864733
Start epoch 13, learning rate 0.000500 
2


100%|██████████| 328/328 [02:01<00:00,  2.71batches/s]
  3%|▎         | 1/29 [00:00<00:03,  8.02batches/s]

Epoch : 13, Batch : 327 / 328, Loss : 2.829475, Perplexity : 16.936575, Time : 121.257781


100%|██████████| 29/29 [00:03<00:00,  7.47batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.365556
Validation Perplexity : 213.910039
Start epoch 14, learning rate 0.000500 
2


 52%|█████▏    | 170/328 [01:02<00:58,  2.72batches/s]