Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add fused rnn cell #5004

Merged
merged 4 commits into from
Feb 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions example/rnn/cudnn_lstm_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import numpy as np
import mxnet as mx
import argparse

parser = argparse.ArgumentParser(description="Train RNN on Penn Tree Bank",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--test', default=False, action='store_true',
help='whether to do testing instead of training')
parser.add_argument('--model-prefix', type=str, default=None,
help='path to save/load model')
parser.add_argument('--load-epoch', type=int, default=0,
help='load from epoch')
parser.add_argument('--num-layers', type=int, default=2,
help='number of stacked RNN layers')
parser.add_argument('--num-hidden', type=int, default=200,
help='hidden layer size')
parser.add_argument('--num-embed', type=int, default=200,
help='embedding layer size')
parser.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ' \
'Increase batch size when using multiple gpus for best performance.')
parser.add_argument('--kv-store', type=str, default='device',
help='key-value store type')
parser.add_argument('--num-epochs', type=int, default=25,
help='max num of epochs')
parser.add_argument('--lr', type=float, default=0.01,
help='initial learning rate')
parser.add_argument('--optimizer', type=str, default='sgd',
help='the optimizer type')
parser.add_argument('--mom', type=float, default=0.0,
help='momentum for sgd')
parser.add_argument('--wd', type=float, default=0.00001,
help='weight decay for sgd')
parser.add_argument('--batch-size', type=int, default=32,
help='the batch size.')
parser.add_argument('--disp-batches', type=int, default=50,
help='show progress for every n batches')


#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]

start_label = 1
invalid_label = 0

def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
lines = open(fname).readlines()
lines = [filter(None, i.split(' ')) for i in lines]
sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
return sentences, vocab

def get_data(layout):
train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label,
invalid_label=invalid_label)
val_sent, _ = tokenize_text("./data/ptb.test.txt", vocab=vocab, start_label=start_label,
invalid_label=invalid_label)

data_train = mx.rnn.BucketSentenceIter(train_sent, args.batch_size, buckets=buckets,
invalid_label=invalid_label, layout=layout)
data_val = mx.rnn.BucketSentenceIter(val_sent, args.batch_size, buckets=buckets,
invalid_label=invalid_label, layout=layout)
return data_train, data_val, vocab


def train(args):
data_train, data_val, vocab = get_data('TN')

cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, mode='lstm')

def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')

output, _ = cell.unroll(seq_len, inputs=embed, merge_outputs=True, layout='TNC')

pred = mx.sym.Reshape(output, shape=(-1, args.num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

return pred, ('data',), ('softmax_label',)

if args.gpus:
contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
else:
contexts = mx.cpu(0)

model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = data_train.default_bucket_key,
context = contexts)

if args.load_epoch:
_, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
cell, args.model_prefix, args.load_epoch)
else:
arg_params = None
aux_params = None

model.fit(
train_data = data_train,
eval_data = data_val,
eval_metric = mx.metric.Perplexity(invalid_label),
kvstore = args.kv_store,
optimizer = args.optimizer,
optimizer_params = { 'learning_rate': args.lr,
'momentum': args.mom,
'wd': args.wd },
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params = arg_params,
aux_params = aux_params,
begin_epoch = args.load_epoch,
num_epoch = args.num_epochs,
batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches),
epoch_end_callback = mx.rnn.do_rnn_checkpoint(cell, args.model_prefix, 1)
if args.model_prefix else None)

def test(args):
assert args.model_prefix, "Must specifiy path to load from"
_, data_val, vocab = get_data('NT')

stack = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_'%i))

def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
output_dim=args.num_embed, name='embed')

outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

return pred, ('data',), ('softmax_label',)

if args.gpus:
contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
else:
contexts = mx.cpu(0)

model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = data_val.default_bucket_key,
context = contexts)
model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

# note here we load using SequentialRNNCell instead of FusedRNNCell.
_, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
model.set_params(arg_params, aux_params)

model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

if __name__ == '__main__':
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

args = parser.parse_args()
if args.test:
# Demonstrates how to load a model trained with CuDNN RNN and predict
# with non-fused MXNet symbol
test(args)
else:
train(args)
9 changes: 4 additions & 5 deletions example/rnn/lstm_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,15 @@ def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
output_dim=args.num_embed, name='embed')

stack = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_'%i))
outputs, states = mx.rnn.rnn_unroll(stack, seq_len, inputs=embed)
outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

outputs = [mx.sym.expand_dims(x, axis=1) for x in outputs]
pred = mx.sym.Concat(*outputs, dim=1)
pred = mx.sym.Reshape(pred, shape=(-1, args.num_hidden))
pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

label = mx.sym.Reshape(label, shape=(-1,))
Expand Down
46 changes: 44 additions & 2 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches, too-many-arguments
"""Initialization helper for mxnet"""
from __future__ import absolute_import, print_function

Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(self, factor_type="avg", slope=0.25):

@register
class Bilinear(Initializer):
"""docstring for Bilinear"""
"""Initialize weight for upsampling layer"""
def __init__(self):
super(Bilinear, self).__init__()

Expand All @@ -428,3 +428,45 @@ def _init_weight(self, _, arr):
y = (i / shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)


@register
class FusedRNN(Initializer):
"""Initialze parameters for fused rnn layer

Parameters
----------
init : Initializer
intializer applied to unpacked weights.
num_hidden : int
should be the same with arguments passed to FusedRNNCell.
num_layers : int
should be the same with arguments passed to FusedRNNCell.
mode : str
should be the same with arguments passed to FusedRNNCell.
bidirectional : bool
should be the same with arguments passed to FusedRNNCell.
"""
def __init__(self, init, num_hidden, num_layers, mode, bidirectional=False):
if not isinstance(init, Initializer):
klass, kwargs = json.loads(init)
init = _INITIALIZER_REGISTRY[klass.lower()](**kwargs)
super(FusedRNN, self).__init__(init=init.dumps(), num_hidden=num_hidden,
num_layers=num_layers, mode=mode,
bidirectional=bidirectional)
self._num_hidden = num_hidden
self._num_layers = num_layers
self._bidirectional = bidirectional
self._mode = mode
self._init = init

def _init_weight(self, _, arr):
from .rnn import rnn_cell
cell = rnn_cell.FusedRNNCell(self._num_hidden, self._num_layers,
self._mode, self._bidirectional, prefix='')
args = cell.unpack_weights({'parameters': arr})
for name in args:
desc = InitDesc(name)
self._init(desc, args[name])
arr[:] = cell.pack_weights(args)['parameters']

53 changes: 38 additions & 15 deletions python/mxnet/rnn/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ class BucketSentenceIter(DataIter):
name of data
label_name : str, default 'softmax_label'
name of label
layout : str
format of data and label. 'NT' means (batch_size, length)
and 'TN' means (length, batch_size).
"""
def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32',
buckets=None, data_name='data', label_name='softmax_label'):
def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
data_name='data', label_name='softmax_label', dtype='float32',
layout='NTC'):
super(BucketSentenceIter, self).__init__()
if not buckets:
buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
Expand All @@ -90,7 +94,7 @@ def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32',

ndiscard = 0
self.data = [[] for _ in buckets]
for i in xrange(len(sentences)):
for i in range(len(sentences)):
buck = bisect.bisect_left(buckets, len(sentences[i]))
if buck == len(buckets):
ndiscard += 1
Expand All @@ -103,43 +107,62 @@ def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32',

print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

self.default_bucket_key = max(buckets)

self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]

self.batch_size = batch_size
self.buckets = buckets
self.data_name = data_name
self.label_name = label_name
self.dtype = dtype
self.invalid_label = invalid_label
self.nddata = []
self.ndlabel = []
self.major_axis = layout.find('N')
self.default_bucket_key = max(buckets)

if self.major_axis == 0:
self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
elif self.major_axis == 1:
self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
else:
raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

self.idx = []
for i, buck in enumerate(self.data):
self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
self.curr_idx = 0

self.reset()

def reset(self):
self.curr_idx = 0
random.shuffle(self.idx)
for buck in self.data:
np.random.shuffle(buck)

self.nddata = []
self.ndlabel = []
for buck in self.data:
label = np.empty_like(buck)
label[:, :-1] = buck[:, 1:]
label[:, -1] = self.invalid_label
self.nddata.append(ndarray.array(buck, dtype=self.dtype))
self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

def next(self):
if self.curr_idx == len(self.idx):
raise StopIteration
i, j = self.idx[self.curr_idx]
self.curr_idx += 1

data = self.data[i][j:j+self.batch_size]
label = np.empty_like(data)
label[:, :-1] = data[:, 1:]
label[:, -1] = self.invalid_label

if self.major_axis == 1:
data = self.nddata[i][j:j+self.batch_size].T
label = self.ndlabel[i][j:j+self.batch_size].T
else:
data = self.nddata[i][j:j+self.batch_size]
label = self.ndlabel[i][j:j+self.batch_size]

return DataBatch([ndarray.array(data, dtype=self.dtype)],
[ndarray.array(label, dtype=self.dtype)],
return DataBatch([data], [label],
bucket_key=self.buckets[i],
provide_data=[(self.data_name, data.shape)],
provide_label=[(self.label_name, label.shape)])
Expand Down
Loading