# Simple LSTM/GRU language modelling

Revised and fixed code from https://github.com/SudaMonster/lstm-language-model.

NB we pack examples sentence wise unlike many other models that do not account for periods.

Possible improvements: pack similar length sentences into batches to reduce the number of padding tokens.

Also note that we tie the input and output weights (https://arxiv.org/abs/1608.05859)

Essentially this is a playground to familiarize oneself with language modelling.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from torch.nn.utils import clip_grad_norm_ as clip
from torch.optim.lr_scheduler import StepLR

import const
from data import *
from rnn import *

In [2]:
ptb_datapath_train = 'data/penn/train.txt'
ptb_datapath_valid = 'data/penn/valid.txt'
ptb_datapath_test = 'data/penn/test.txt'

batch_size = 128

ptb_train = DataSet(ptb_datapath_train, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_valid = DataSet(ptb_datapath_valid, batch_size, display_freq=0, max_len=90, trunc_len=90)
ptb_test = DataSet(ptb_datapath_test, batch_size, display_freq=0, max_len=90, trunc_len=90)

Loading data from data/penn/train.txt ...
Loading data from data/penn/valid.txt ...
Loading data from data/penn/test.txt ...


In [3]:
ptb_train.build_dict()
ptb_valid.change_dict(ptb_train.dictionary)
ptb_test.change_dict(ptb_train.dictionary)

Building dictionary...
Done.
Save dictionary at data/penn/train.txt.dict
Index tokens ...
42068 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/train.txt
Number of sentence : 42068
Number of tokens : 887521
Vocabulary size : 10000
Number of batches : 328
Batch size : 128
Done.
Index tokens ...
3370 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/valid.txt
Number of sentence : 3370
Number of tokens : 70390
Vocabulary size : 10000
Number of batches : 26
Batch size : 128
Done.
Index tokens ...
3761 sentences were processed, 0 longer than maximum length,0 were ignored because zero length
Data discription:
Data name : data/penn/test.txt
Number of sentence : 3761
Number of tokens : 78669
Vocabulary size : 10000
Number of batches : 29
Batch size : 128
Done.


In [4]:
# Plot sentence length histogram
# ss = []
# for s in ptb_train.sentence:
#     ss.append(len(s))
# ss = np.array(ss)

# hist = np.histogram(ss)
# plt.hist(ss, bins=100)

In [5]:
rnn_type = 'LSTM'
voc_size = ptb_train.num_vocb
emb_dim = 512
hid_dim = 512
n_layers = 3

lr = 1e-3
n_epochs = 30
clip_grad = 5

In [6]:
model = rnn_model(rnn_type, voc_size, emb_dim, hid_dim, n_layers, tie_weights=True)
criterion = nn.CrossEntropyLoss(ignore_index=const.PAD)

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

opt = optim.Adam(model.parameters(), lr=lr)
# lr_lambda = lambda epoch: 0.99 ** epoch
lrsched = StepLR(opt, step_size=10, gamma=0.5)

In [7]:
for epoch in range(n_epochs):
    lrsched.step()
    acc_loss = 0
    print('Start epoch %d, learning rate %f '%(epoch + 1, opt.state_dict()['param_groups'][0]['lr']))
    start_time = time.time()
    model.train()
    ptb_train.shuffle()
    for batch_idx in tqdm(range(ptb_train.num_batch), unit='batches'):
        data, lengths, target = ptb_train.get_batch(batch_idx)
        
        opt.zero_grad()
        output_flat = model.forward(data, lengths)
        loss = criterion(output_flat, target.view(-1))
        loss.backward()
        clip(model.parameters(), clip_grad)
        opt.step()
        acc_loss += loss.item()

    avg_loss = acc_loss / ptb_train.num_batch
    print('Epoch : %d, Batch : %d / %d, Loss : %f, Perplexity : %f, Time : %f' 
          % (epoch + 1, batch_idx, ptb_train.num_batch,
             avg_loss, math.exp(avg_loss),
             time.time() - start_time))

    acc_loss = 0
    model.eval()
    for batch_idx in tqdm(range(ptb_test.num_batch), unit='batches'):
        data, lengths, target = ptb_test[batch_idx]
        output = model.forward(data, lengths)
        loss = criterion(output, target.view(-1))
        acc_loss += loss.item()

    val_loss = acc_loss / ptb_test.num_batch
    print('Validation Loss : %f' % val_loss)
    print('Validation Perplexity : %f' % math.exp(val_loss))

  0%|          | 1/328 [00:00<00:45,  7.20batches/s]

Start epoch 1, learning rate 0.001000 
2


100%|██████████| 328/328 [00:30<00:00, 10.68batches/s]
 10%|█         | 3/29 [00:00<00:00, 29.04batches/s]

Epoch : 1, Batch : 327 / 328, Loss : 6.395267, Perplexity : 599.003170, Time : 30.777261


100%|██████████| 29/29 [00:00<00:00, 30.68batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.789270
Validation Perplexity : 326.774518
Start epoch 2, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.55batches/s]
 10%|█         | 3/29 [00:00<00:00, 29.70batches/s]

Epoch : 2, Batch : 327 / 328, Loss : 5.642224, Perplexity : 282.089434, Time : 31.139646


100%|██████████| 29/29 [00:00<00:00, 31.86batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 5.365528
Validation Perplexity : 213.904081
Start epoch 3, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.52batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.71batches/s]

Epoch : 3, Batch : 327 / 328, Loss : 5.306388, Perplexity : 201.620592, Time : 31.251887


100%|██████████| 29/29 [00:00<00:00, 32.21batches/s]
  1%|          | 2/328 [00:00<00:29, 10.93batches/s]

Validation Loss : 5.144202
Validation Perplexity : 171.434651
Start epoch 4, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.38batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 30.66batches/s]

Epoch : 4, Batch : 327 / 328, Loss : 5.092792, Perplexity : 162.843856, Time : 31.640079


100%|██████████| 29/29 [00:00<00:00, 30.99batches/s]
  1%|          | 2/328 [00:00<00:29, 10.96batches/s]

Validation Loss : 5.006068
Validation Perplexity : 149.316473
Start epoch 5, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.53batches/s]
 10%|█         | 3/29 [00:00<00:00, 29.11batches/s]

Epoch : 5, Batch : 327 / 328, Loss : 4.932322, Perplexity : 138.701269, Time : 31.199519


100%|██████████| 29/29 [00:00<00:00, 31.46batches/s]
  0%|          | 1/328 [00:00<00:32,  9.92batches/s]

Validation Loss : 4.908475
Validation Perplexity : 135.432778
Start epoch 6, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.52batches/s]
 10%|█         | 3/29 [00:00<00:00, 26.47batches/s]

Epoch : 6, Batch : 327 / 328, Loss : 4.803298, Perplexity : 121.911761, Time : 31.213892


100%|██████████| 29/29 [00:00<00:00, 31.66batches/s]
  1%|          | 2/328 [00:00<00:28, 11.58batches/s]

Validation Loss : 4.833810
Validation Perplexity : 125.688974
Start epoch 7, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.52batches/s]
 10%|█         | 3/29 [00:00<00:00, 29.48batches/s]

Epoch : 7, Batch : 327 / 328, Loss : 4.695808, Perplexity : 109.487201, Time : 31.232909


100%|██████████| 29/29 [00:00<00:00, 30.33batches/s]
  0%|          | 1/328 [00:00<00:36,  8.91batches/s]

Validation Loss : 4.783526
Validation Perplexity : 119.525005
Start epoch 8, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.46batches/s]
 10%|█         | 3/29 [00:00<00:00, 29.47batches/s]

Epoch : 8, Batch : 327 / 328, Loss : 4.598695, Perplexity : 99.354541, Time : 31.406680


100%|██████████| 29/29 [00:00<00:00, 29.92batches/s]
  1%|          | 2/328 [00:00<00:28, 11.53batches/s]

Validation Loss : 4.737699
Validation Perplexity : 114.171167
Start epoch 9, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.50batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 31.66batches/s]

Epoch : 9, Batch : 327 / 328, Loss : 4.514423, Perplexity : 91.324857, Time : 31.300902


100%|██████████| 29/29 [00:00<00:00, 30.89batches/s]
  1%|          | 2/328 [00:00<00:26, 12.33batches/s]

Validation Loss : 4.710345
Validation Perplexity : 111.090494
Start epoch 10, learning rate 0.001000 
2


100%|██████████| 328/328 [00:31<00:00, 10.54batches/s]
 10%|█         | 3/29 [00:00<00:00, 26.58batches/s]

Epoch : 10, Batch : 327 / 328, Loss : 4.438671, Perplexity : 84.662379, Time : 31.170156


100%|██████████| 29/29 [00:00<00:00, 30.81batches/s]
  1%|          | 2/328 [00:00<00:29, 11.20batches/s]

Validation Loss : 4.682787
Validation Perplexity : 108.070852
Start epoch 11, learning rate 0.000500 
2


100%|██████████| 328/328 [00:31<00:00, 10.53batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 31.85batches/s]

Epoch : 11, Batch : 327 / 328, Loss : 4.327469, Perplexity : 75.752315, Time : 31.194883


100%|██████████| 29/29 [00:00<00:00, 32.17batches/s]
  1%|          | 2/328 [00:00<00:31, 10.31batches/s]

Validation Loss : 4.660024
Validation Perplexity : 105.638585
Start epoch 12, learning rate 0.000500 
2


100%|██████████| 328/328 [00:31<00:00, 10.44batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 31.76batches/s]

Epoch : 12, Batch : 327 / 328, Loss : 4.280667, Perplexity : 72.288642, Time : 31.461761


100%|██████████| 29/29 [00:00<00:00, 31.20batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.653208
Validation Perplexity : 104.921055
Start epoch 13, learning rate 0.000500 
2


100%|██████████| 328/328 [00:31<00:00, 10.56batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 31.05batches/s]

Epoch : 13, Batch : 327 / 328, Loss : 4.243794, Perplexity : 69.671656, Time : 31.156370


100%|██████████| 29/29 [00:00<00:00, 30.48batches/s]
  0%|          | 1/328 [00:00<00:34,  9.52batches/s]

Validation Loss : 4.644167
Validation Perplexity : 103.976703
Start epoch 14, learning rate 0.000500 
2


100%|██████████| 328/328 [00:31<00:00, 10.57batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.21batches/s]

Epoch : 14, Batch : 327 / 328, Loss : 4.209471, Perplexity : 67.320885, Time : 31.079974


100%|██████████| 29/29 [00:00<00:00, 31.37batches/s]
  0%|          | 1/328 [00:00<00:36,  9.01batches/s]

Validation Loss : 4.641936
Validation Perplexity : 103.744957
Start epoch 15, learning rate 0.000500 
2


100%|██████████| 328/328 [00:31<00:00, 10.57batches/s]
 10%|█         | 3/29 [00:00<00:01, 25.35batches/s]

Epoch : 15, Batch : 327 / 328, Loss : 4.175868, Perplexity : 65.096320, Time : 31.090578


100%|██████████| 29/29 [00:00<00:00, 31.31batches/s]
  1%|          | 2/328 [00:00<00:28, 11.43batches/s]

Validation Loss : 4.636537
Validation Perplexity : 103.186351
Start epoch 16, learning rate 0.000500 
2


100%|██████████| 328/328 [00:30<00:00, 10.59batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 31.51batches/s]

Epoch : 16, Batch : 327 / 328, Loss : 4.146063, Perplexity : 63.184722, Time : 31.006106


100%|██████████| 29/29 [00:00<00:00, 31.04batches/s]
  1%|          | 2/328 [00:00<00:27, 11.83batches/s]

Validation Loss : 4.634082
Validation Perplexity : 102.933414
Start epoch 17, learning rate 0.000500 
2


100%|██████████| 328/328 [00:30<00:00, 10.61batches/s]
 10%|█         | 3/29 [00:00<00:00, 28.20batches/s]

Epoch : 17, Batch : 327 / 328, Loss : 4.115440, Perplexity : 61.279180, Time : 30.950186


100%|██████████| 29/29 [00:00<00:00, 30.47batches/s]
  1%|          | 2/328 [00:00<00:30, 10.84batches/s]

Validation Loss : 4.631405
Validation Perplexity : 102.658238
Start epoch 18, learning rate 0.000500 
2


100%|██████████| 328/328 [00:31<00:00, 10.55batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.64batches/s]

Epoch : 18, Batch : 327 / 328, Loss : 4.086684, Perplexity : 59.542101, Time : 31.140421


100%|██████████| 29/29 [00:00<00:00, 31.88batches/s]
  1%|          | 2/328 [00:00<00:29, 11.10batches/s]

Validation Loss : 4.630307
Validation Perplexity : 102.545569
Start epoch 19, learning rate 0.000500 
2


100%|██████████| 328/328 [00:30<00:00, 10.62batches/s]
 10%|█         | 3/29 [00:00<00:00, 27.12batches/s]

Epoch : 19, Batch : 327 / 328, Loss : 4.060201, Perplexity : 57.985989, Time : 30.939090


100%|██████████| 29/29 [00:00<00:00, 30.83batches/s]
  1%|          | 2/328 [00:00<00:30, 10.69batches/s]

Validation Loss : 4.629208
Validation Perplexity : 102.432876
Start epoch 20, learning rate 0.000500 
2


100%|██████████| 328/328 [00:30<00:00, 10.62batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 33.13batches/s]

Epoch : 20, Batch : 327 / 328, Loss : 4.033714, Perplexity : 56.470247, Time : 30.933452


100%|██████████| 29/29 [00:00<00:00, 31.14batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.627532
Validation Perplexity : 102.261342
Start epoch 21, learning rate 0.000250 
2


100%|██████████| 328/328 [00:31<00:00, 10.56batches/s]
 10%|█         | 3/29 [00:00<00:00, 26.04batches/s]

Epoch : 21, Batch : 327 / 328, Loss : 3.980738, Perplexity : 53.556534, Time : 31.103255


100%|██████████| 29/29 [00:00<00:00, 31.02batches/s]
  1%|          | 2/328 [00:00<00:31, 10.51batches/s]

Validation Loss : 4.629087
Validation Perplexity : 102.420475
Start epoch 22, learning rate 0.000250 
2


100%|██████████| 328/328 [00:31<00:00, 10.48batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 31.16batches/s]

Epoch : 22, Batch : 327 / 328, Loss : 3.961635, Perplexity : 52.543165, Time : 31.324435


100%|██████████| 29/29 [00:00<00:00, 32.00batches/s]
  1%|          | 2/328 [00:00<00:29, 10.89batches/s]

Validation Loss : 4.629796
Validation Perplexity : 102.493177
Start epoch 23, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.66batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.20batches/s]

Epoch : 23, Batch : 327 / 328, Loss : 3.947563, Perplexity : 51.808978, Time : 30.803166


100%|██████████| 29/29 [00:00<00:00, 31.86batches/s]
  1%|          | 2/328 [00:00<00:29, 11.08batches/s]

Validation Loss : 4.629463
Validation Perplexity : 102.459003
Start epoch 24, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.60batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 30.69batches/s]

Epoch : 24, Batch : 327 / 328, Loss : 3.933981, Perplexity : 51.110046, Time : 30.985489


100%|██████████| 29/29 [00:00<00:00, 31.42batches/s]
  0%|          | 0/328 [00:00<?, ?batches/s]

Validation Loss : 4.628575
Validation Perplexity : 102.368073
Start epoch 25, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.84batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.29batches/s]

Epoch : 25, Batch : 327 / 328, Loss : 3.921142, Perplexity : 50.458054, Time : 30.289916


100%|██████████| 29/29 [00:00<00:00, 32.06batches/s]
  1%|          | 2/328 [00:00<00:29, 11.15batches/s]

Validation Loss : 4.630227
Validation Perplexity : 102.537326
Start epoch 26, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.72batches/s]
 10%|█         | 3/29 [00:00<00:01, 25.15batches/s]

Epoch : 26, Batch : 327 / 328, Loss : 3.907506, Perplexity : 49.774652, Time : 30.644748


100%|██████████| 29/29 [00:00<00:00, 30.38batches/s]
  1%|          | 2/328 [00:00<00:30, 10.53batches/s]

Validation Loss : 4.629493
Validation Perplexity : 102.462090
Start epoch 27, learning rate 0.000250 
2


100%|██████████| 328/328 [00:31<00:00, 10.56batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.30batches/s]

Epoch : 27, Batch : 327 / 328, Loss : 3.893612, Perplexity : 49.087895, Time : 31.097881


100%|██████████| 29/29 [00:00<00:00, 31.80batches/s]
  1%|          | 2/328 [00:00<00:29, 11.06batches/s]

Validation Loss : 4.632801
Validation Perplexity : 102.801575
Start epoch 28, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.71batches/s]
 10%|█         | 3/29 [00:00<00:00, 28.71batches/s]

Epoch : 28, Batch : 327 / 328, Loss : 3.881570, Perplexity : 48.500324, Time : 30.676496


100%|██████████| 29/29 [00:00<00:00, 31.38batches/s]
  1%|          | 2/328 [00:00<00:29, 11.09batches/s]

Validation Loss : 4.634040
Validation Perplexity : 102.929061
Start epoch 29, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.62batches/s]
 14%|█▍        | 4/29 [00:00<00:00, 32.34batches/s]

Epoch : 29, Batch : 327 / 328, Loss : 3.870295, Perplexity : 47.956543, Time : 30.929198


100%|██████████| 29/29 [00:00<00:00, 32.32batches/s]
  1%|          | 2/328 [00:00<00:32, 10.08batches/s]

Validation Loss : 4.634754
Validation Perplexity : 103.002540
Start epoch 30, learning rate 0.000250 
2


100%|██████████| 328/328 [00:30<00:00, 10.73batches/s]
 10%|█         | 3/29 [00:00<00:00, 29.85batches/s]

Epoch : 30, Batch : 327 / 328, Loss : 3.858495, Perplexity : 47.393972, Time : 30.623644


100%|██████████| 29/29 [00:00<00:00, 31.04batches/s]

Validation Loss : 4.637422
Validation Perplexity : 103.277793



