In [1]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from mxnet.io import DataIter

import operator
import re
import warnings

  chunks = self.iterencode(o, _one_shot=True)


In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
Args = namedtuple(
    'Args', 
    ('test load_epoch num_layers num_hidden num_embed bidirectional gpus '
     'kv_store num_epochs optimizer mom wd lr batch_size disp_batches '
     'stack_rnn dropout model_prefix'))

args = Args(
    test=          False,
    load_epoch=    0,
    num_layers=    2,
    num_hidden=    200,
    num_embed=     200,
    bidirectional= False,
    gpus=          '0,1',
    kv_store=      'device',
    num_epochs=    1,
    optimizer=    'adam',
    mom=           0.9,
    wd=           0.00001,
    lr = 0.001,
    batch_size=    32,
    disp_batches=  50,
    stack_rnn=     False,
    dropout=       0.5,
    model_prefix= 'foo'
)

In [4]:
# Do a word count to get the number of words 

start_label = 1
invalid_label = 0

# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

def get_data(src_path, targ_path, start_label=1, invalid_label=0):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    new_start = max(src_vocab.values())
    
    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=new_start+1,
                                          invalid_label=invalid_label)

    return src_sent, src_vocab, targ_sent, targ_vocab

In [5]:
src_sent, src_vocab, targ_sent, targ_vocab = \
    get_data(
        src_path='./data/europarl-v7.pl-en.small.en', # pl-en.small.en
        targ_path='./data/europarl-v7.pl-en.small.pl', # pl-en.small.pl
        start_label=1,
        invalid_label=0
    )

In [6]:
def gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=None):
    length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
    counts = list(Counter(length_pairs).items())
    c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
    buckets = [i for i in c_sorted if i[1] >= filter_smaller_counts_than]
    return buckets

In [7]:
all_buckets = gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=None)

In [8]:
bucket_filter = 32
batched_buckets = gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=bucket_filter)

In [9]:
num_sent = lambda sent: reduce(lambda a, b: a + b, map(lambda x: x[1], sent))
num_all = num_sent(all_buckets)
num_batched = num_sent(batched_buckets)
print("# of all buckets: %d (# sent: %d)" % (len(all_buckets), num_all))
print("# of buckets with counts < %d filtered out: %d (num sent: %d)" % (bucket_filter, len(batched_buckets), num_batched))

# of all buckets: 4858 (# sent: 49773)
# of buckets with counts < 32 filtered out: 440 (num sent: 18984)


In [10]:
def print_sorted_buckets(buckets):
    b = sorted(buckets, key=operator.itemgetter(0, 1))
    for i in b:
        print(i)

print_sorted_buckets(batched_buckets)

((2, 2), 411)
((3, 2), 41)
((3, 3), 39)
((4, 2), 48)
((4, 4), 93)
((5, 5), 79)
((6, 5), 48)
((6, 6), 39)
((6, 7), 34)
((7, 6), 41)
((7, 7), 81)
((7, 8), 33)
((8, 7), 54)
((8, 9), 32)
((8, 10), 37)
((8, 11), 33)
((8, 12), 42)
((8, 13), 33)
((8, 16), 35)
((8, 20), 37)
((9, 7), 33)
((9, 8), 78)
((9, 9), 69)
((9, 10), 37)
((9, 11), 63)
((9, 13), 37)
((9, 14), 33)
((9, 16), 42)
((9, 28), 32)
((10, 8), 48)
((10, 9), 67)
((10, 10), 42)
((10, 11), 37)
((10, 12), 45)
((10, 13), 34)
((10, 15), 39)
((10, 17), 37)
((10, 18), 34)
((10, 19), 39)
((10, 22), 32)
((10, 24), 35)
((10, 25), 39)
((10, 28), 34)
((11, 9), 39)
((11, 10), 43)
((11, 11), 38)
((11, 12), 47)
((11, 13), 49)
((11, 14), 35)
((11, 15), 32)
((11, 16), 39)
((11, 17), 35)
((11, 18), 36)
((11, 19), 41)
((11, 20), 37)
((11, 21), 35)
((11, 22), 33)
((11, 24), 37)
((11, 25), 36)
((12, 9), 40)
((12, 10), 52)
((12, 11), 47)
((12, 12), 50)
((12, 13), 41)
((12, 14), 54)
((12, 15), 57)
((12, 16), 39)
((12, 17), 44)
((12, 18), 38)
((12, 19), 49)

In [11]:
class TwoDBisect:
    def __init__(self, buckets):
        self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
        self.x, self.y = zip(*buckets)
        self.x, self.y = np.array(list(self.x)), np.array(list(self.y))
        
    def twod_bisect(self, source, target):    
        offset1 = np.searchsorted(self.x, len(source), side='left')
        offset2 = np.where(self.y[offset1:] >= len(target))[0]        
        return self.buckets[offset2[0]]   

In [22]:
sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = 5
# max_len = 65
increment = 5

all_pairs = [(i, j) for i in range(
        min_len,max_len+increment,increment
    ) for j in range(
        min_len,max_len+increment,increment
    )]


bisect = TwoDBisect(buckets=all_pairs)

In [24]:
all_pairs

[(1, 1),
 (1, 6),
 (1, 11),
 (1, 16),
 (1, 21),
 (1, 26),
 (1, 31),
 (1, 36),
 (1, 41),
 (1, 46),
 (1, 51),
 (1, 56),
 (1, 61),
 (1, 66),
 (1, 71),
 (1, 76),
 (1, 81),
 (1, 86),
 (1, 91),
 (1, 96),
 (1, 101),
 (1, 106),
 (1, 111),
 (1, 116),
 (1, 121),
 (1, 126),
 (1, 131),
 (1, 136),
 (1, 141),
 (1, 146),
 (1, 151),
 (1, 156),
 (1, 161),
 (1, 166),
 (1, 171),
 (1, 176),
 (1, 181),
 (1, 186),
 (1, 191),
 (1, 196),
 (1, 201),
 (1, 206),
 (1, 211),
 (1, 216),
 (1, 221),
 (1, 226),
 (1, 231),
 (1, 236),
 (1, 241),
 (1, 246),
 (1, 251),
 (1, 256),
 (1, 261),
 (1, 266),
 (1, 271),
 (1, 276),
 (1, 281),
 (1, 286),
 (1, 291),
 (1, 296),
 (1, 301),
 (1, 306),
 (1, 311),
 (1, 316),
 (1, 321),
 (1, 326),
 (1, 331),
 (1, 336),
 (1, 341),
 (6, 1),
 (6, 6),
 (6, 11),
 (6, 16),
 (6, 21),
 (6, 26),
 (6, 31),
 (6, 36),
 (6, 41),
 (6, 46),
 (6, 51),
 (6, 56),
 (6, 61),
 (6, 66),
 (6, 71),
 (6, 76),
 (6, 81),
 (6, 86),
 (6, 91),
 (6, 96),
 (6, 101),
 (6, 106),
 (6, 111),
 (6, 116),
 (6, 121),
 (6, 126),

In [26]:
tuples = []

for src, targ in zip(src_sent, targ_sent):
    try:
        len_tup = bisect.twod_bisect(src, targ)
        tuples.append(len_tup)
    except Exception as e:
        print("src length: %d" % len(src))
        print("targ length: %d" % len(targ))

In [27]:
tuples

[(1, 11),
 (1, 6),
 (1, 11),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 16),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 11),
 (1, 6),
 (1, 21),
 (1, 6),
 (1, 21),
 (1, 6),
 (1, 21),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 21),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 6),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 16),
 (1, 11),
 (1, 26),
 (1, 6),
 (1, 11),
 (1, 11),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 11),
 (1, 6),
 (1, 6),
 (1, 16),
 (1, 11),
 (1, 11),
 (1, 11),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 6),
 (1, 6),
 (1, 11),
 (1, 11),
 (1, 11),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 21),
 (1, 6),
 (1, 21),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 16),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6),
 (1, 11),
 (1, 6

In [28]:
counts = list(Counter(tuples).items())
batch_size = 32
filter_smaller_counts_than = batch_size

c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
buckets = [i for i in c_sorted if i[1] >= filter_smaller_counts_than]
print(buckets)

[((1, 1), 41), ((1, 6), 3627), ((1, 11), 6782), ((1, 16), 8231), ((1, 21), 7856), ((1, 26), 7043), ((1, 31), 5317), ((1, 36), 3692), ((1, 41), 2572), ((1, 46), 1641), ((1, 51), 1142), ((1, 56), 668), ((1, 61), 429), ((1, 66), 269), ((1, 71), 163), ((1, 76), 106), ((1, 81), 64), ((1, 86), 42)]


In [None]:
class BucketSeq2SeqIter(DataIter):
    """Simple bucketing iterator for sequence-to-sequence models.

    Parameters
    ----------
    sentences : list of list of pairs of int
        encoded sentences (source seq dict int, target seq dict int)
    batch_size : int
        batch_size of data    
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of pairs of int (source seq length, target seq length)
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data    print(buckets)
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, input_seq, output_seq, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSeq2SeqIter, self).__init__()
        
        # insert call to gen_buckets here if buckets are not provided
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        # Sorting is kinda pointless because it's first by the first element of the tuple,
        # then the next. So, it could be [(1, 2), (1, 20), (2, 5), (2, 71)]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            
            # this bisect also won't work because it's now based on a tuple of lengths
            buck = bisect.bisect_left(buckets, len(sent))
            # this test is not appropriate because of the tuple of lengths
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

In [None]:
print(batched_buckets)

In [None]:
buckets = list(range(5, 60, 5))
#[10, 20, 30, 40, 50, 60]

                 data_name='data', label_name='softmax_label', dtype='float32',
start_label = 1
invalid_label = 0

def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
    lines = open(fname).readlines()
    lines = [filter(None, i.split(' ')) for i in lines]
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

def get_data(layout):
    train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label,
                                      invalid_label=invalid_label)
    val_sent, _ = tokenize_text("./data/ptb.test.txt", vocab=vocab, start_label=start_label,
                                invalid_label=invalid_label)

    data_train  = mx.rnn.BucketSentenceIter(train_sent, args.batch_size, buckets=buckets,
                                            invalid_label=invalid_label, layout=layout)
    data_val    = mx.rnn.BucketSentenceIter(val_sent, args.batch_size, buckets=buckets,
                                            invalid_label=invalid_label, layout=layout)
    
    print("default: %s" % data_train.default_bucket_key)
    return data_train, data_val, vocab


def train(args):
    data_train, data_val, vocab = get_data('TN')
    if args.stack_rnn:
        stack = mx.rnn.SequentialRNNCell()
        for layer in range(args.num_layers):
            dropout = 0.0
            if layer < (args.num_layers - 1):
                dropout = args.dropout
            stack.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
                    mode='lstm', prefix='lstm_%d'%layer, dropout=dropout,
                    bidirectional=args.bidirectional))
        cell = stack
    else:
        cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
                mode='lstm', bidirectional=args.bidirectional)

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')

        output, _ = cell.unroll(seq_len, inputs=embed, merge_outputs=True, layout='TNC')

        pred = mx.sym.Reshape(output,
                shape=(-1, args.num_hidden*(1+args.bidirectional)))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return pred, ('data',), ('softmax_label',)

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen             = sym_gen,
        default_bucket_key  = data_train.default_bucket_key,
        context             = contexts)

    if args.load_epoch:
        _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
            cell, args.model_prefix, args.load_epoch)
    else:
        arg_params = None
        aux_params = None

    opt_params = {
      'learning_rate': args.lr,
      'wd': args.wd
    }

    if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']:
        opt_params['momentum'] = args.mom

    model.fit(
        train_data          = data_train,
        eval_data           = data_val,
        eval_metric         = mx.metric.Perplexity(invalid_label),
        kvstore             = args.kv_store,
        optimizer           = args.optimizer,
        optimizer_params    = opt_params, 
        initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
        arg_params          = arg_params,
        aux_params          = aux_params,
        begin_epoch         = args.load_epoch,
        num_epoch           = args.num_epochs,
        batch_end_callback  = mx.callback.Speedometer(args.batch_size, args.disp_batches),
        epoch_end_callback  = mx.rnn.do_rnn_checkpoint(cell, args.model_prefix, 1)
                              if args.model_prefix else None)

def test(args):
    assert args.model_prefix, "Must specifiy path to load from"
    _, data_val, vocab = get_data('NT')

    if not args.stack_rnn:
        stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
                mode='lstm', bidirectional=args.bidirectionclass BucketSentenceIter(DataIter):
    """Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.

    Parameters
    ----------
    sentences : list of list of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of int
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])al).unfuse()
    else:
        stack = mx.rnn.SequentialRNNCell()
        for i in range(args.num_layers):
            cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
            if args.bidirectional:
                cell = mx.rnn.BidirectionalCell(
                        cell,
                        mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
                        output_prefix='bi_lstm_%d'%i)
            stack.add(cell)

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
                                 output_dim=args.num_embed, name='embed')

        stack.reset()
        outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

        pred = mx.sym.Reshape(outputs,
                shape=(-1, args.num_hidden*(1+args.bidirectional)))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return pred, ('data',), ('softmax_label',)class BucketSentenceIter(DataIter):
    """Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.

    Parameters
    ----------
    sentences : list of list of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of int
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen             = sym_gen,
        default_bucket_key  = data_val.default_bucket_key,
        context             = contexts)
    model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

    # note here we load using SequentialRNNCell instead of FusedRNNCell.
    _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
    model.set_params(arg_params, aux_params)

    model.score(data_val, mx.metric.Perplexity(invalid_label),
                batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

In [None]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

# args = parser.parse_args()

if args.num_layers >= 4 and len(args.gpus.split(',')) >= 4 and not args.stack_rnn:
    print('WARNING: stack-rnn is recommended to train complex model on multiple GPUs')

if args.test:
    # Demonstrates how to load a model trained with CuDNN RNN and predict
    # with non-fused MXNet symbol
    test(args)
else:
    train(args)

In [None]:
class BucketSeq2SeqIter(DataIter):
    """Simple bucketing iterator for sequence-to-sequence models.

    Parameters
    ----------
    sentences : list of list of pairs of int
        encoded sentences (source seq dict int, target seq dict int)
    batch_size : int
        batch_size of data    
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of pairs of int (source seq length, target seq length)
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data    print(buckets)
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, input_seq, output_seq, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])