In [1]:
import mxnet as mx
from mxnet import gluon, autograd
from mxnet import ndarray as nd
from mxnet.gluon import nn, rnn
import math

import time
from datetime import timedelta

from preprocessing_zh import Corpus, LMDataset

In [2]:
class LMConfig(object):
    rnn_type = 'LSTM'
    embedding_dim = 64
    hidden_dim = 128
    num_layers = 2
    dropout = 0.5
    
    batch_size = 20
    seq_len = 30
    learning_rate = 1.
    optimizer = 'sgd'
    grad_clip = 0.2
    
    num_epochs = 10
    print_per_batch = 50

In [3]:
class RNNModel(nn.Block):
    def __init__(self, config, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        
        vocab_size = config.vocab_size
        embedding_dim = config.embedding_dim
        hidden_dim = config.hidden_dim
        dropout = config.dropout
        num_layers = config.num_layers
        rnn_type = config.rnn_type
        
        self.hidden_dim = hidden_dim
        
        with self.name_scope():
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
            
            if rnn_type in ['RNN', 'LSTM', 'GRU']:
                self.rnn = getattr(rnn, rnn_type)(hidden_dim, num_layers, dropout=dropout)
            else:
                raise ValueError("Invalid rnn_type %s. Options are RNN, LSTM, GRU" % rnn_type)
            
            self.decoder = nn.Dense(vocab_size)
            
    def forward(self, inputs, hidden):
        embedded = self.embedding(inputs)
        output, hidden = self.rnn(embedded, hidden)
        decoded = self.decoder(output.reshape((-1, self.hidden_dim)))
        return decoded, hidden
    
    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)

In [4]:
train_dir = 'data/weicheng.txt'
vocab_dir = 'data/weicheng.vocab.txt'

corpus = Corpus(train_dir, vocab_dir)
print(corpus)

Corpus length: 242052, Vocabulary size: 3423.


In [5]:
config = LMConfig()
config.vocab_size = len(corpus.words)
train_data = LMDataset(corpus.data, config.batch_size, config.seq_len)
print(train_data)

Num of batches: 403, Batch Shape: (30, 20)


In [6]:
model = RNNModel(config)
model.collect_params().initialize(mx.init.Xavier())

In [7]:
trainer = gluon.Trainer(model.collect_params(), config.optimizer, {'learning_rate': config.learning_rate})
loss_func = gluon.loss.SoftmaxCrossEntropyLoss()

In [8]:
def detach(hidden):
    if isinstance(hidden, (tuple, list)):
        hidden = [i.detach() for i in hidden]
    else:
        hidden = hidden.detach()
    return hidden

In [9]:
def get_time_dif(start_time):
    """
    Return the time used since start_time.
    """
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

In [10]:
context = mx.cpu()

In [12]:
grad_clip = config.grad_clip
seq_len = config.seq_len
batch_size = config.batch_size
start_time = time.time()

for epoch in range(config.num_epochs):
    total_loss = 0.0
    hidden = model.begin_state(func=nd.zeros, batch_size=batch_size)
    for ibatch, (data, label) in enumerate(train_data):
        data = nd.array(data).as_in_context(context)
        label = nd.array(label).as_in_context(context)
        hidden = detach(hidden)
        
        with autograd.record():
            output, hidden = model(data, hidden)
            loss = loss_func(output, label)
            
        loss.backward()
        
        grads = [x.grad(mx.cpu()) for x in model.collect_params().values()]
        gluon.utils.clip_global_norm(grads, grad_clip * seq_len * batch_size)
        
        trainer.step(config.batch_size)
        total_loss += nd.sum(loss).asscalar()
        
        if ibatch % config.print_per_batch == 0 and ibatch > 0:
            cur_loss = total_loss / seq_len / batch_size / config.print_per_batch
            print('[Epoch %d Batch %d] loss %.2f, perplexity %.2f' % (
                    epoch + 1, ibatch, cur_loss, math.exp(cur_loss)))
            total_loss = 0.0

[Epoch 1 Batch 50] loss 6.24, perplexity 515.12


KeyboardInterrupt: 