# LSTM-based Language Models

In [None]:
import warnings
warnings.filterwarnings('ignore')

import glob
import time
import math

import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon.utils import download
import gluonnlp
import nltk
from gluonnlp.data import batchify
import time
import multiprocessing as mp
from mxnet.gluon import Block, nn, rnn

import gluonnlp as nlp

In [None]:
context = [mx.gpu(0)]
log_interval = 200

In [None]:
batch_size = 20
lr = 20
epochs = 3
bptt = 35
grad_clip = 0.25

### Loading the dataset

In [None]:
moses_tokenizer = nlp.data.SacreMosesTokenizer()

all_datasets = nlp.data.CorpusDataset(
        'data/all.txt',
        sample_splitter=nltk.tokenize.sent_tokenize,
        tokenizer=moses_tokenizer,
        flatten=True,
        eos='<eos>')
vocab = nlp.Vocab(
    nlp.data.Counter(all_datasets), padding_token=None, bos_token=None)

In [None]:
hamlet_train = gluonnlp.data.dataset.TSVDataset('data/hamlet_train.txt')
hamlet_val = gluonnlp.data.dataset.TSVDataset('data/hamlet_val.txt')

In [None]:
tokenizer = nlp.data.SpacyTokenizer('en')
length_clip = nlp.data.ClipSequence(35)
length_pad = nlp.data.PadSequence(35)

In [None]:
def preprocess(x):
    data, label = x
    label = int(label)
    data = length_clip(tokenizer(data))
    data = length_pad(data)
    return data, label

In [None]:
def preprocess_dataset(dataset):
    start = time.time()
    pool = mp.Pool()
    dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))
    end = time.time()
    print('Done! Tokenizing Time={:.2f}s, #Sentences={}'.format(end - start, len(dataset)))
    return dataset

In [None]:
train_dataset = preprocess_dataset(hamlet_train)
val_dataset = preprocess_dataset(hamlet_val)

In [None]:
def token_to_idx(x):
    return vocab[x[0]], x[1]

pool = mp.Pool()
train_dataset = pool.map(token_to_idx, train_dataset)
val_dataset = pool.map(token_to_idx, val_dataset)
pool.close()
print(train_dataset[0][0][:50])

In [None]:
batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(axis=0),nlp.data.batchify.Stack())

In [None]:
train_dataloader = gluon.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, batchify_fn=batchify_fn)
val_dataloader = gluon.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, batchify_fn=batchify_fn)

In [None]:
class RNNModel(gluon.Block):
    def __init__(self, vocab_size, num_embed, num_hidden, num_layers, dropout=0.2, tie_weights=False, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.embedding = nn.HybridSequential()
        with self.embedding.name_scope():
            self.embedding.add(nn.Embedding(vocab_size, num_embed))
            self.embedding.add(nn.Dropout(0.2))
        with self.name_scope():
            self.encoder = rnn.LSTM(num_hidden, num_layers, dropout=dropout,input_size=num_embed)
            
    def forward(self, inputs, hidden):
        output, hidden = self.encoder(self.embedding(inputs), hidden)
        return output, hidden

    def begin_state(self, *args, **kwargs):
        return self.encoder.begin_state(*args, **kwargs)

model = RNNModel(len(vocab), 200, 200, 2)
model.load_parameters("standard_lstm_lm_200-6.params",ignore_extra=True,ctx=context)

In [None]:
class DenseLayer(gluon.Block):
    def __init__(self, num_hidden, **kwargs):
        super(DenseLayer, self).__init__(**kwargs)
        self.decoder = nn.HybridSequential()
        with self.decoder.name_scope():
            self.decoder.add(nn.Dense(units=10,flatten=True))
            
    def forward(self, inputs):
        return self.decoder(inputs)

dense = DenseLayer(200)
dense.collect_params().initialize(mx.init.Xavier(magnitude=2.24),ctx=context)

In [None]:
# Initialize the trainer and optimizer and specify some hyperparameters
trainer = gluon.Trainer(model.collect_params(), 'Adam', {
    'learning_rate': lr,
    'wd': 0.001
})

# Specify the loss function, in this case, cross-entropy with softmax.
loss = gluon.loss.SoftmaxCrossEntropyLoss()

### Training the LM

In [None]:
def detach(hidden):
    if isinstance(hidden, (tuple, list)):
        hidden = [detach(i) for i in hidden]
    else:
        hidden = hidden.detach()
    return hidden

In [None]:
# Note that ctx is short for context
def evaluate(model, data_source, batch_size, ctx):
    total_L = 0.0
    ntotal = 0
    for i, (data, target) in enumerate(train_dataloader):
        data = data.as_in_context(context[0])
        data = mx.nd.transpose(data)
        target = target.as_in_context(context[0])
        hidden = model.begin_state(batch_size=data.shape[1], func=mx.nd.zeros, ctx=context[0])
        output, hidden = model(data, hidden)
        hidden = detach(hidden)
        classes = dense(output[-1,:,:])
        L = loss(classes, target)
        total_L += mx.nd.sum(L).asscalar()
        ntotal += L.size
    return total_L / ntotal

In [None]:
# Function for actually training the model
def train(model, train_data, val_data, epochs, lr):
    best_val = float("Inf")
    start_train_time = time.time()
    parameters = model.collect_params().values()

    for epoch in range(epochs):
        total_L = 0.0
        ntotal = 0
        start_epoch_time = time.time()
        start_log_interval_time = time.time()

        for i, (data, target) in enumerate(train_data):
            data = data.as_in_context(context[0])
            data = mx.nd.transpose(data)
            target = target.as_in_context(context[0])
            hidden = model.begin_state(batch_size=data.shape[1], func=mx.nd.zeros, ctx=context[0])
            L = 0
            with autograd.record():
                output, hidden = model(data, hidden)
                hidden = detach(hidden)
                classes = dense(output[-1,:,:])
                L = loss(classes, target)
            L.backward()
            trainer.step(1)
            
            total_L += mx.nd.sum(L).asscalar()
            ntotal += L.size            

            if i % log_interval == 0 and i > 0:
                cur_L = total_L / log_interval
                print('[Epoch %d Batch %d/%d] loss %.2f, ppl %.2f, '
                      'throughput %.2f samples/s'%(
                    epoch, i, len(train_data), cur_L, math.exp(cur_L),
                    batch_size * log_interval / (time.time() - start_log_interval_time)))
                total_L = 0.0
                start_log_interval_time = time.time()

        mx.nd.waitall()

        print('[Epoch %d] throughput %.2f samples/s'%(
                    epoch, len(train_data)*batch_size / (time.time() - start_epoch_time)))

        val_L = evaluate(model, val_data, batch_size, context[0])
        print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
            epoch, time.time()-start_epoch_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            #model.save_parameters('cr-{}.params'.format(epoch))
            print('Model saved!')
        else:
            lr = lr*0.25
            print('Learning rate now %f'%(lr))
            trainer.set_learning_rate(lr)

    print('Total training throughput %.2f samples/s'%(
                            (batch_size * len(train_data) * epochs) /
                            (time.time() - start_train_time)))

In [None]:
evaluate(model, val_dataloader, batch_size, context[0])

In [None]:
train(
    model,
    train_dataloader,
    val_dataloader,
    epochs=12,
    lr=0.0001)