# LSTM-based Language Models

In [1]:
import warnings
warnings.filterwarnings('ignore')

import glob
import time
import math

import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon.utils import download

import gluonnlp as nlp

In [2]:
context = [mx.gpu(0)]
log_interval = 200

In [3]:
batch_size = 20
lr = 20
epochs = 3
bptt = 35
grad_clip = 0.25

### Loading the dataset

In [4]:
import nltk
moses_tokenizer = nlp.data.SacreMosesTokenizer()

train_datasets = [
    nlp.data.CorpusDataset(
        'data/{}.txt'.format(name),
        sample_splitter=nltk.tokenize.sent_tokenize,
        tokenizer=moses_tokenizer,
        flatten=True,
        eos='<eos>') for name in ['train', 'val', 'test']
]

all_datasets = nlp.data.CorpusDataset(
        'data/all.txt',
        sample_splitter=nltk.tokenize.sent_tokenize,
        tokenizer=moses_tokenizer,
        flatten=True,
        eos='<eos>')

In [5]:
# Extract the vocabulary and numericalize with "Counter"
vocab = nlp.Vocab(
    nlp.data.Counter(all_datasets), padding_token=None, bos_token=None)

# Batchify for BPTT
bptt_batchify = nlp.data.batchify.CorpusBPTTBatchify(
    vocab, bptt, batch_size, last_batch='discard')

train_data, val_data, test_data = [
    bptt_batchify(dataset) for dataset in train_datasets
]

In [6]:
model_name = 'standard_lstm_lm_200'
model, vocab = nlp.model.get_model(model_name, vocab=vocab, dataset_name=None)
print(model)
print(vocab)

# Initialize the model
model.initialize(mx.init.Xavier(), ctx=context)

# Initialize the trainer and optimizer and specify some hyperparameters
trainer = gluon.Trainer(model.collect_params(), 'sgd', {
    'learning_rate': lr,
    'momentum': 0,
    'wd': 0
})

# Specify the loss function, in this case, cross-entropy with softmax.
loss = gluon.loss.SoftmaxCrossEntropyLoss()

StandardRNN(
  (embedding): HybridSequential(
    (0): Embedding(30896 -> 200, float32)
    (1): Dropout(p = 0.2, axes=())
  )
  (encoder): LSTM(200 -> 200, TNC, num_layers=2, dropout=0.2)
  (decoder): HybridSequential(
    (0): Dense(200 -> 30896, linear)
  )
)
Vocab(size=30896, unk="<unk>", reserved="['<eos>']")


### Training the LM

In [7]:
def detach(hidden):
    if isinstance(hidden, (tuple, list)):
        hidden = [detach(i) for i in hidden]
    else:
        hidden = hidden.detach()
    return hidden

In [8]:
# Note that ctx is short for context
def evaluate(model, data_source, batch_size, ctx):
    total_L = 0.0
    ntotal = 0
    hidden = model.begin_state(
        batch_size=batch_size, func=mx.nd.zeros, ctx=ctx)
    for i, (data, target) in enumerate(data_source):
        data = data.as_in_context(ctx)
        target = target.as_in_context(ctx)
        output, hidden = model(data, hidden)
        hidden = detach(hidden)
        L = loss(output.reshape(-3, -1), target.reshape(-1))
        total_L += mx.nd.sum(L).asscalar()
        ntotal += L.size
    return total_L / ntotal

In [12]:
# Function for actually training the model
def train(model, train_data, val_data, test_data, epochs, lr):
    best_val = float("Inf")
    start_train_time = time.time()
    parameters = model.collect_params().values()

    for epoch in range(7,epochs):
        total_L = 0.0
        start_epoch_time = time.time()
        start_log_interval_time = time.time()
        hiddens = [model.begin_state(batch_size//len(context), func=mx.nd.zeros, ctx=ctx)
                   for ctx in context]

        for i, (data, target) in enumerate(train_data):
            data_list = gluon.utils.split_and_load(data, context,
                                                   batch_axis=1, even_split=True)
            target_list = gluon.utils.split_and_load(target, context,
                                                     batch_axis=1, even_split=True)
            hiddens = detach(hiddens)
            L = 0
            Ls = []

            with autograd.record():
                for j, (X, y, h) in enumerate(zip(data_list, target_list, hiddens)):
                    output, h = model(X, h)
                    batch_L = loss(output.reshape(-3, -1), y.reshape(-1,))
                    L = L + batch_L.as_in_context(context[0]) / (len(context) * X.size)
                    Ls.append(batch_L / (len(context) * X.size))
                    hiddens[j] = h
            L.backward()
            grads = [p.grad(x.context) for p in parameters for x in data_list]
            gluon.utils.clip_global_norm(grads, grad_clip)

            trainer.step(1)

            total_L += sum([mx.nd.sum(l).asscalar() for l in Ls])

            if i % log_interval == 0 and i > 0:
                cur_L = total_L / log_interval
                print('[Epoch %d Batch %d/%d] loss %.2f, ppl %.2f, '
                      'throughput %.2f samples/s'%(
                    epoch, i, len(train_data), cur_L, math.exp(cur_L),
                    batch_size * log_interval / (time.time() - start_log_interval_time)))
                total_L = 0.0
                start_log_interval_time = time.time()

        mx.nd.waitall()

        print('[Epoch %d] throughput %.2f samples/s'%(
                    epoch, len(train_data)*batch_size / (time.time() - start_epoch_time)))

        val_L = evaluate(model, val_data, batch_size, context[0])
        print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
            epoch, time.time()-start_epoch_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            test_L = evaluate(model, test_data, batch_size, context[0])
            model.save_parameters('{}-{}.params'.format(model_name, epoch))
            print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
        else:
            lr = lr*0.25
            print('Learning rate now %f'%(lr))
            trainer.set_learning_rate(lr)

    print('Total training throughput %.2f samples/s'%(
                            (batch_size * len(train_data) * epochs) /
                            (time.time() - start_train_time)))

In [None]:
train(
    model,
    train_data, # This is your input training data, we leave batchifying and tokenizing as an exercise for the reader
    val_data,
    test_data, # This would be your test data, again left as an exercise for the reader
    epochs=12,
    lr=20)

[Epoch 7 Batch 200/1259] loss 4.76, ppl 116.25, throughput 890.04 samples/s
[Epoch 7 Batch 400/1259] loss 4.81, ppl 122.41, throughput 889.08 samples/s
[Epoch 7 Batch 600/1259] loss 4.78, ppl 119.37, throughput 871.87 samples/s
[Epoch 7 Batch 800/1259] loss 4.78, ppl 119.01, throughput 888.32 samples/s
[Epoch 7 Batch 1000/1259] loss 4.77, ppl 117.82, throughput 896.38 samples/s
