In [21]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from mxnet.io import DataIter

import operator
import pickle
import re
import warnings

In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
Args = namedtuple(
    'Args', 
    ('test load_epoch num_layers num_hidden num_embed bidirectional gpus '
     'kv_store num_epochs optimizer mom wd lr batch_size disp_batches '
     'stack_rnn dropout model_prefix'))

args = Args(
    test=          False,
    load_epoch=    0,
    num_layers=    2,
    num_hidden=    200,
    num_embed=     200,
    bidirectional= False,
    gpus=          '0,1',
    kv_store=      'device',
    num_epochs=    1,
    optimizer=    'adam',
    mom=           0.9,
    wd=           0.00001,
    lr = 0.001,
    batch_size=    32,
    disp_batches=  50,
    stack_rnn=     False,
    dropout=       0.5,
    model_prefix= 'foo'
)

In [29]:
# Do a word count to get the number of words 

# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple('Dataset', ['src_sent', 'src_vocab', 'targ_sent', 'targ_vocab'])

def get_data(src_path, targ_path, start_label=1, invalid_label=0):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)

    return Dataset(src_sent=src_sent, src_vocab=src_vocab, targ_sent=targ_sent, targ_vocab=targ_vocab)

def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [30]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_vsmall',
        targ_path='./data/europarl-v7.es-en.es_vsmall',
        start_label=1,
        invalid_label=0
    )

In [72]:
dataset.src_vocab

{'deferment': 7460,
 'Pronk': 12313,
 'woods': 21082,
 'hanging': 6957,
 'disobeying': 16121,
 'goatmeat': 22477,
 'originality': 11245,
 'Western': 11441,
 'Euro': 7074,
 'Valle': 15683,
 'stipulate': 3125,
 'appropriation': 5000,
 'politician': 5222,
 'bringing': 3401,
 'ENVIRONMENT': 12253,
 'wooded': 21909,
 'Miert': 19260,
 'Multilateral': 18142,
 'stereotypical': 21504,
 'busybody': 14458,
 'immunities': 21967,
 '0059': 21052,
 '0058': 18117,
 'inevitably': 4627,
 '0053': 13295,
 '0052': 12073,
 '0051': 16024,
 '0050': 17159,
 '0057': 12116,
 '0056': 17489,
 '0055': 17163,
 '0054': 17484,
 'feasibility': 3595,
 '272': 20609,
 '275': 21277,
 '276': 19857,
 'sustaining': 7796,
 'consenting': 9696,
 'errors': 16315,
 'cooking': 12231,
 'warmongering': 17673,
 'designing': 10480,
 'College': 471,
 'succumb': 21868,
 'shocks': 12382,
 'Cover': 19861,
 '27o': 15478,
 'china': 9861,
 'affiliated': 7071,
 'doldrums': 22012,
 'kids': 20948,
 'controversy': 3229,
 'Isler': 12063,
 'Isles':

In [73]:
dataset.targ_vocab

{'Alertamos': 20901,
 'bloqueos': 16092,
 'Pronk': 18301,
 'Miembro': 24203,
 'acepcion': 27258,
 'prestadoras': 14446,
 'Deseamos': 13269,
 'impotente': 19448,
 'consensuada': 32601,
 'paternalista': 17478,
 'igual': 1491,
 'ciudades': 2216,
 'arriesgarnos': 36263,
 'Western': 37332,
 'hermana': 35183,
 'hermano': 14806,
 'Euro': 20513,
 'prorrogue': 27992,
 'gastara': 34787,
 'dictada': 19395,
 'especificacion': 36386,
 'Valle': 24183,
 'dictado': 12087,
 'democratacristianos': 13933,
 'compuesta': 8184,
 'compuesto': 6160,
 'Contendra': 32500,
 'declaraban': 31442,
 'ocurridas': 19298,
 'enmarcados': 15155,
 'Miert': 30460,
 'consentimiento': 14484,
 'inhumanidad': 26783,
 'minoristas': 34838,
 'ignoraba': 14760,
 'Multilateral': 28591,
 'fuera': 711,
 'extenderse': 11917,
 'fuere': 18935,
 'impropios': 21492,
 '0053': 19952,
 '0052': 17879,
 '0051': 24831,
 '0050': 26880,
 '0057': 17968,
 'asistir': 4834,
 '0055': 26884,
 '0054': 27398,
 '270': 21422,
 'Tratado': 2669,
 '275': 3405

In [31]:
path = './foo.pickle'
persist_dataset(dataset, path)
dataset2 = load_dataset(path)

In [78]:
def gen_buckets(dataset, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
    src_sent = dataset.src_sent
    targ_sent = dataset.targ_sent
    length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
    counts = list(Counter(length_pairs).items())
    c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
    buckets = [i for i in c_sorted if i[1] >= filter_smaller_counts_than and 
               (max_sent_len is None or i[0][0] <= max_sent_len) and
               (max_sent_len is None or i[0][1] <= max_sent_len) and
               (min_sent_len is None or i[0][0] >= min_sent_len) and
               (min_sent_len is None or i[0][1] >= min_sent_len)]
    return buckets

In [154]:
max_len = 60

all_buckets = gen_buckets(dataset, filter_smaller_counts_than=None, max_sent_len=None, min_sent_len=None)
batch_size = 32
batched_buckets = gen_buckets(dataset, filter_smaller_counts_than=batch_size, max_sent_len=max_len, min_sent_len=1)

In [155]:
all_buckets

[((1, 1), 5),
 ((1, 4), 2),
 ((1, 5), 3),
 ((1, 6), 1),
 ((1, 7), 4),
 ((1, 8), 3),
 ((1, 9), 2),
 ((1, 10), 1),
 ((1, 11), 3),
 ((1, 12), 6),
 ((1, 13), 2),
 ((1, 14), 5),
 ((1, 15), 7),
 ((1, 16), 4),
 ((1, 17), 4),
 ((1, 18), 6),
 ((1, 19), 2),
 ((1, 20), 5),
 ((1, 21), 2),
 ((1, 22), 7),
 ((1, 23), 5),
 ((1, 24), 3),
 ((1, 25), 4),
 ((1, 26), 4),
 ((1, 27), 3),
 ((1, 28), 2),
 ((1, 29), 4),
 ((1, 30), 2),
 ((1, 31), 4),
 ((1, 32), 1),
 ((1, 33), 2),
 ((1, 34), 2),
 ((1, 35), 2),
 ((1, 36), 3),
 ((1, 37), 2),
 ((1, 38), 5),
 ((1, 39), 3),
 ((1, 40), 3),
 ((1, 42), 2),
 ((1, 43), 1),
 ((1, 44), 1),
 ((1, 45), 1),
 ((1, 46), 2),
 ((1, 48), 2),
 ((1, 49), 3),
 ((1, 50), 1),
 ((1, 51), 4),
 ((1, 52), 2),
 ((1, 53), 1),
 ((1, 54), 2),
 ((1, 55), 2),
 ((1, 56), 4),
 ((1, 58), 1),
 ((1, 60), 1),
 ((1, 61), 1),
 ((1, 63), 1),
 ((1, 64), 2),
 ((1, 70), 1),
 ((1, 73), 1),
 ((1, 74), 1),
 ((1, 75), 1),
 ((1, 83), 1),
 ((1, 84), 1),
 ((1, 87), 1),
 ((1, 89), 1),
 ((1, 90), 1),
 ((1, 99), 1),
 (

In [156]:
batched_buckets

[((8, 13), 33),
 ((8, 18), 36),
 ((8, 20), 32),
 ((8, 26), 32),
 ((9, 20), 34),
 ((9, 22), 35),
 ((9, 28), 35),
 ((10, 12), 32),
 ((10, 13), 35),
 ((10, 15), 32),
 ((10, 17), 33),
 ((10, 21), 37),
 ((10, 26), 37),
 ((11, 12), 38),
 ((11, 13), 34),
 ((11, 14), 37),
 ((11, 15), 34),
 ((11, 20), 34),
 ((11, 21), 32),
 ((11, 22), 34),
 ((11, 23), 33),
 ((11, 24), 34),
 ((11, 25), 41),
 ((12, 9), 32),
 ((12, 11), 34),
 ((12, 13), 38),
 ((12, 14), 36),
 ((12, 15), 52),
 ((12, 17), 38),
 ((12, 18), 32),
 ((12, 19), 43),
 ((12, 20), 39),
 ((12, 21), 38),
 ((12, 22), 33),
 ((12, 23), 38),
 ((12, 24), 42),
 ((12, 25), 33),
 ((12, 27), 37),
 ((12, 29), 35),
 ((13, 12), 34),
 ((13, 14), 33),
 ((13, 15), 34),
 ((13, 16), 36),
 ((13, 18), 35),
 ((13, 19), 35),
 ((13, 20), 41),
 ((13, 21), 43),
 ((13, 22), 41),
 ((13, 23), 38),
 ((13, 24), 33),
 ((13, 28), 34),
 ((14, 9), 36),
 ((14, 10), 34),
 ((14, 12), 35),
 ((14, 13), 41),
 ((14, 14), 38),
 ((14, 15), 34),
 ((14, 17), 43),
 ((14, 18), 38),
 ((14,

In [157]:
num_sent = lambda sent: reduce(lambda a, b: a + b, map(lambda x: x[1], sent))
num_all = num_sent(all_buckets)
num_batched = num_sent(batched_buckets)
print("# of all buckets: %d (# sent: %d)" % (len(all_buckets), num_all))
print("# of buckets with counts < %d filtered out: %d (num sent: %d)" % (batch_size, len(batched_buckets), num_batched))
print("percent of examples remaining after filtering: %.2f" % (100.0*num_batched/num_all))

# of all buckets: 6491 (# sent: 49866)
# of buckets with counts < 32 filtered out: 283 (num sent: 10654)
percent of examples remaining after filtering: 21.37


In [158]:
def print_sorted_buckets(buckets):
    b = sorted(buckets, key=operator.itemgetter(0, 1))
    for i in b:
        print(i)

print_sorted_buckets(batched_buckets)

((8, 13), 33)
((8, 18), 36)
((8, 20), 32)
((8, 26), 32)
((9, 20), 34)
((9, 22), 35)
((9, 28), 35)
((10, 12), 32)
((10, 13), 35)
((10, 15), 32)
((10, 17), 33)
((10, 21), 37)
((10, 26), 37)
((11, 12), 38)
((11, 13), 34)
((11, 14), 37)
((11, 15), 34)
((11, 20), 34)
((11, 21), 32)
((11, 22), 34)
((11, 23), 33)
((11, 24), 34)
((11, 25), 41)
((12, 9), 32)
((12, 11), 34)
((12, 13), 38)
((12, 14), 36)
((12, 15), 52)
((12, 17), 38)
((12, 18), 32)
((12, 19), 43)
((12, 20), 39)
((12, 21), 38)
((12, 22), 33)
((12, 23), 38)
((12, 24), 42)
((12, 25), 33)
((12, 27), 37)
((12, 29), 35)
((13, 12), 34)
((13, 14), 33)
((13, 15), 34)
((13, 16), 36)
((13, 18), 35)
((13, 19), 35)
((13, 20), 41)
((13, 21), 43)
((13, 22), 41)
((13, 23), 38)
((13, 24), 33)
((13, 28), 34)
((14, 9), 36)
((14, 10), 34)
((14, 12), 35)
((14, 13), 41)
((14, 14), 38)
((14, 15), 34)
((14, 17), 43)
((14, 18), 38)
((14, 19), 37)
((14, 20), 39)
((14, 21), 46)
((14, 22), 44)
((14, 24), 46)
((14, 25), 39)
((14, 27), 35)
((14, 28), 39)
((15

In [159]:
sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
# max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in range(
        min_len,max_len+increment,increment
    ) for j in range(
        min_len,max_len+increment,increment
    )]

print(all_pairs)

[(1, 1), (1, 6), (1, 11), (1, 16), (1, 21), (1, 26), (1, 31), (1, 36), (1, 41), (1, 46), (1, 51), (1, 56), (1, 61), (1, 66), (6, 1), (6, 6), (6, 11), (6, 16), (6, 21), (6, 26), (6, 31), (6, 36), (6, 41), (6, 46), (6, 51), (6, 56), (6, 61), (6, 66), (11, 1), (11, 6), (11, 11), (11, 16), (11, 21), (11, 26), (11, 31), (11, 36), (11, 41), (11, 46), (11, 51), (11, 56), (11, 61), (11, 66), (16, 1), (16, 6), (16, 11), (16, 16), (16, 21), (16, 26), (16, 31), (16, 36), (16, 41), (16, 46), (16, 51), (16, 56), (16, 61), (16, 66), (21, 1), (21, 6), (21, 11), (21, 16), (21, 21), (21, 26), (21, 31), (21, 36), (21, 41), (21, 46), (21, 51), (21, 56), (21, 61), (21, 66), (26, 1), (26, 6), (26, 11), (26, 16), (26, 21), (26, 26), (26, 31), (26, 36), (26, 41), (26, 46), (26, 51), (26, 56), (26, 61), (26, 66), (31, 1), (31, 6), (31, 11), (31, 16), (31, 21), (31, 26), (31, 31), (31, 36), (31, 41), (31, 46), (31, 51), (31, 56), (31, 61), (31, 66), (36, 1), (36, 6), (36, 11), (36, 16), (36, 21), (36, 26), (36

In [160]:
class TwoDBisect:
    def __init__(self, buckets):
        self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
        self.x, self.y = zip(*buckets)
        self.x, self.y = np.array(list(self.x)), np.array(list(self.y))
        
    def twod_bisect(self, source, target):    
        offset1 = np.searchsorted(self.x, len(source), side='left')
        offset2 = np.where(self.y[offset1:] >= len(target))[0]        
        return self.buckets[offset1 + offset2[0]]  

In [161]:
bisect = TwoDBisect(all_pairs)
bisect.buckets

[(1, 1),
 (1, 6),
 (1, 11),
 (1, 16),
 (1, 21),
 (1, 26),
 (1, 31),
 (1, 36),
 (1, 41),
 (1, 46),
 (1, 51),
 (1, 56),
 (1, 61),
 (1, 66),
 (6, 1),
 (6, 6),
 (6, 11),
 (6, 16),
 (6, 21),
 (6, 26),
 (6, 31),
 (6, 36),
 (6, 41),
 (6, 46),
 (6, 51),
 (6, 56),
 (6, 61),
 (6, 66),
 (11, 1),
 (11, 6),
 (11, 11),
 (11, 16),
 (11, 21),
 (11, 26),
 (11, 31),
 (11, 36),
 (11, 41),
 (11, 46),
 (11, 51),
 (11, 56),
 (11, 61),
 (11, 66),
 (16, 1),
 (16, 6),
 (16, 11),
 (16, 16),
 (16, 21),
 (16, 26),
 (16, 31),
 (16, 36),
 (16, 41),
 (16, 46),
 (16, 51),
 (16, 56),
 (16, 61),
 (16, 66),
 (21, 1),
 (21, 6),
 (21, 11),
 (21, 16),
 (21, 21),
 (21, 26),
 (21, 31),
 (21, 36),
 (21, 41),
 (21, 46),
 (21, 51),
 (21, 56),
 (21, 61),
 (21, 66),
 (26, 1),
 (26, 6),
 (26, 11),
 (26, 16),
 (26, 21),
 (26, 26),
 (26, 31),
 (26, 36),
 (26, 41),
 (26, 46),
 (26, 51),
 (26, 56),
 (26, 61),
 (26, 66),
 (31, 1),
 (31, 6),
 (31, 11),
 (31, 16),
 (31, 21),
 (31, 26),
 (31, 31),
 (31, 36),
 (31, 41),
 (31, 46),
 (31, 51

In [196]:
tuples = []

src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

short_sentences = filter(lambda x: len(x[0]) <= max_len and len(x[1]) <= max_len, zip(src_sent, targ_sent))

for src, targ in short_sentences:
    try:
        len_tup = bisect.twod_bisect(src, targ)
        rev_src = src[::-1] 
        tuples.append((src, targ, len_tup))
    except Exception as e:
        print("src length: %d" % len(src))
        print("targ length: %d" % len(targ))

In [247]:
from itertools import groupby

In [258]:
sorted_tuples = sorted(tuples, key=operator.itemgetter(2))



In [279]:

# to prevent the iterator from running out

In [291]:
g = groupby(sorted_tuples, lambda x: x[2])

ctr = 0
num_ex = 100
invalid_symbol = -9999

for tup, grouper in g:
    print("Tuple: %s" % str(tup))
    for grp in grouper:
        source, target, _ = grp
        rev_src = source[::-1]
        print(tup[0])
        if len(rev_src) < tup[0]:
            rev_src_2 = [invalid_symbol] * tup[0]
            rev_src_2[:-(tup[0]+1):-1] = rev_src
        else:
            rev_src_2 = rev_src
        if len(target) < tup[1]:
            targ_2 = [invalid_symbol] * tup[1]
            targ_2[:tup[1]] = target
        else:
            targ_2 = target
        print(rev_src_2, targ_2)        
    print("")
    ctr += 1
    if ctr > num_ex:
        break

Tuple: (1, 1)
1
([11315], [16634])
1
([36], [30])
1
([36], [30])
1
([36], [30])
1
([36], [30])

Tuple: (1, 6)
1
([561], [606, 4, 47, 607])
1
([36], [98, 60, 1245, 1572, 30])
1
([36], [97, 12276, 11985, 100])
1
([36], [19189, 19, 315, 9, 30])
1
([36], [595, 43, 52, 237, 30])
1
([9799], [540, 3878, 41, 42, 16135, 30])

Tuple: (1, 11)
1
([4688], [6071, 360, 6072, 97, 853, 360, 854, 100, 30])
1
([36], [1812, 1085, 86, 2281, 2388, 11, 1316, 30])
1
([36], [540, 7487, 16720, 41, 1783, 2196, 30])
1
([70], [12070, 242, 24, 41, 420, 481, 24, 2555, 66, 1380, 30])
1
([36], [3485, 2, 994, 4, 70, 363, 535])
1
([36], [3483, 1778, 11, 157, 2140, 8428, 30])
1
([9799], [97, 98, 134, 3562, 70, 860, 19, 70, 7480, 100])
1
([9799], [574, 632, 4, 3407, 41, 4234, 388, 1987, 30])
1
([561], [5542, 419, 41, 7720, 467, 3429, 30])
1
([36], [595, 214, 1315, 24, 332, 689, 103, 3542, 4, 5967, 30])
1
([4688], [97, 98, 134, 6128, 1139, 70, 279, 1026, 237, 6129, 100])
1
([6293], [4840, 934, 214, 4580, 19, 835, 2845, 30]

ValueError: attempt to assign sequence of size 4 to extended slice of size 6

In [275]:
counts = list(Counter(tuples).items())
batch_size = 32
filter_smaller_counts_than = batch_size

c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
buckets = [i for i in c_sorted if i[1] >= filter_smaller_counts_than]
print(buckets)

TypeError: unhashable type: 'list'

In [None]:
class BucketSeq2SeqIter(DataIter):
    """Simple bucketing iterator for sequence-to-sequence models.
    
    @staticmethod
    def 
    
    Parameters
    ----------
    source_seq : list of list of int
    target_seq : list of list of int
    batch_size : int
        batch_size of data    
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of pairs of int (source seq length, target seq length)
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data    print(buckets)
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, source, target, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC', min_sent_len = 1, max_sent_len = 60):
        super(BucketSeq2SeqIter, self).__init__()
        
        # insert call to gen_buckets here if buckets are not provided
        if not buckets:
            all_pairs = [(i, j) for i in range(
                min_len,max_len+increment,increment
            ) for j in range(
                min_len,max_len+increment,increment
            )]
            self.bisect = TwoDBisect(buckets = all_pairs)
            buckets = 
            
#             [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
#                        if j >= batch_size]
            
        # Sorting is kinda pointless because it's first by the first element of the tuple,
        # then the next. So, it could be [(1, 2), (1, 20), (2, 5), (2, 71)]
#         buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            
            # this bisect also won't work because it's now based on a tuple of lengths
            buck = bisect.bisect_left(buckets, len(sent))
            # this test is not appropriate because of the tuple of lengths
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

In [None]:
print(batched_buckets)

In [None]:
buckets = list(range(5, 60, 5))
#[10, 20, 30, 40, 50, 60]

                 data_name='data', label_name='softmax_label', dtype='float32',
start_label = 1
invalid_label = 0

def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
    lines = open(fname).readlines()
    lines = [filter(None, i.split(' ')) for i in lines]
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

def get_data(layout):
    train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label,
                                      invalid_label=invalid_label)
    val_sent, _ = tokenize_text("./data/ptb.test.txt", vocab=vocab, start_label=start_label,
                                invalid_label=invalid_label)

    data_train  = mx.rnn.BucketSentenceIter(train_sent, args.batch_size, buckets=buckets,
                                            invalid_label=invalid_label, layout=layout)
    data_val    = mx.rnn.BucketSentenceIter(val_sent, args.batch_size, buckets=buckets,
                                            invalid_label=invalid_label, layout=layout)
    
    print("default: %s" % data_train.default_bucket_key)
    return data_train, data_val, vocab


def train(args):
    data_train, data_val, vocab = get_data('TN')
    if args.stack_rnn:
        stack = mx.rnn.SequentialRNNCell()
        for layer in range(args.num_layers):
            dropout = 0.0
            if layer < (args.num_layers - 1):
                dropout = args.dropout
            stack.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
                    mode='lstm', prefix='lstm_%d'%layer, dropout=dropout,
                    bidirectional=args.bidirectional))
        cell = stack
    else:
        cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
                mode='lstm', bidirectional=args.bidirectional)

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')

        output, _ = cell.unroll(seq_len, inputs=embed, merge_outputs=True, layout='TNC')

        pred = mx.sym.Reshape(output,
                shape=(-1, args.num_hidden*(1+args.bidirectional)))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return pred, ('data',), ('softmax_label',)

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen             = sym_gen,
        default_bucket_key  = data_train.default_bucket_key,
        context             = contexts)

    if args.load_epoch:
        _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
            cell, args.model_prefix, args.load_epoch)
    else:
        arg_params = None
        aux_params = None

    opt_params = {
      'learning_rate': args.lr,
      'wd': args.wd
    }

    if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']:
        opt_params['momentum'] = args.mom

    model.fit(
        train_data          = data_train,
        eval_data           = data_val,
        eval_metric         = mx.metric.Perplexity(invalid_label),
        kvstore             = args.kv_store,
        optimizer           = args.optimizer,
        optimizer_params    = opt_params, 
        initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
        arg_params          = arg_params,
        aux_params          = aux_params,
        begin_epoch         = args.load_epoch,
        num_epoch           = args.num_epochs,
        batch_end_callback  = mx.callback.Speedometer(args.batch_size, args.disp_batches),
        epoch_end_callback  = mx.rnn.do_rnn_checkpoint(cell, args.model_prefix, 1)
                              if args.model_prefix else None)

def test(args):
    assert args.model_prefix, "Must specifiy path to load from"
    _, data_val, vocab = get_data('NT')

    if not args.stack_rnn:
        stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
                mode='lstm', bidirectional=args.bidirectionclass BucketSentenceIter(DataIter):
    """Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.

    Parameters
    ----------
    sentences : list of list of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of int
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])al).unfuse()
    else:
        stack = mx.rnn.SequentialRNNCell()
        for i in range(args.num_layers):
            cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
            if args.bidirectional:
                cell = mx.rnn.BidirectionalCell(
                        cell,
                        mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
                        output_prefix='bi_lstm_%d'%i)
            stack.add(cell)

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
                                 output_dim=args.num_embed, name='embed')

        stack.reset()
        outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

        pred = mx.sym.Reshape(outputs,
                shape=(-1, args.num_hidden*(1+args.bidirectional)))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return pred, ('data',), ('softmax_label',)
    
class BucketSentenceIter(DataIter):
    """Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.

    Parameters
    ----------
    sentences : list of list of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of int
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen             = sym_gen,
        default_bucket_key  = data_val.default_bucket_key,
        context             = contexts)
    model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

    # note here we load using SequentialRNNCell instead of FusedRNNCell.
    _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
    model.set_params(arg_params, aux_params)

    model.score(data_val, mx.metric.Perplexity(invalid_label),
                batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

In [None]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

# args = parser.parse_args()

if args.num_layers >= 4 and len(args.gpus.split(',')) >= 4 and not args.stack_rnn:
    print('WARNING: stack-rnn is recommended to train complex model on multiple GPUs')

if args.test:
    # Demonstrates how to load a model trained with CuDNN RNN and predict
    # with non-fused MXNet symbol
    test(args)
else:
    train(args)

In [None]:
class BucketSeq2SeqIter(DataIter):
    """Simple bucketing iterator for sequence-to-sequence models.

    Parameters
    ----------
    sentences : list of list of pairs of int
        encoded sentences (source seq dict int, target seq dict int)
    batch_size : int
        batch_size of data    
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of pairs of int (source seq length, target seq length)
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data    print(buckets)
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, input_seq, output_seq, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

In [None]:
class AttentionEncoderCell(BaseRNNCell):
   """Place holder cell that prepare input for attention decoders"""
   def __init__(self, prefix='encode_', params=None):
       super(AttentionEncoderCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return []

   def __call__(self, inputs, states):
       return inputs, states + [symbol.expand_dims(inputs, axis=1)]

   def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
       outputs = _normalize_sequence(length, inputs, layout, merge_outputs)
       if merge_outputs is True:
           states = outputs
       else:
           states = inputs

       # attention cell always use NTC layout for states
       states, _ = _normalize_sequence(None, states, 'NTC', True, layout)
       return outputs, [states]


def _attention_pooling(source, scores):
   # source: (batch_size, seq_len, encoder_num_hidden)
   # scores: (batch_size, seq_len, 1)
   probs = symbol.softmax(scores, axis=1)
   output = symbol.batch_dot(source, probs, transpose_a=True)
   return symbol.reshape(output, shape=(0, 0))


class BaseAttentionCell(BaseRNNCell):
   """Base class for attention cells"""
   def __init__(self, prefix='att_', params=None):
       super(BaseAttentionCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return [(0, 0, 0)]

   def __call__(self, inputs, states):
       raise NotImplementedError


class DotAttentionCell(BaseAttentionCell):
   """Dot attention"""
   def __init__(self, prefix='dotatt_', params=None):
       super(DotAttentionCell, self).__init__(prefix, params=params)

   def __call__(self, inputs, states):
       # inputs: (batch_size, decoder_num_hidden)
       # for dot attention decoder_num_hidden must equal encoder_num_hidden
       if len(states) > 1:
           states = [symbol.concat(*states, dim=1)]

       # source: (batch_size, seq_len, encoder_num_hidden)
       source = states[0]
       # (batch_size, decoder_num_hidden, 1)
       inputs = symbol.expand_dims(inputs, axis=2)
       # (batch_size, seq_len, 1)
       scores = symbol.batch_dot(source, inputs)
       # (batch_size, encoder_num_hidden)
       return _attention_pooling(source, scores), states

In [None]:
class AttentionEncoderCell(BaseRNNCell):
   """Place holder cell that prepare input for attention decoders"""
   def __init__(self, prefix='encode_', params=None):
       super(AttentionEncoderCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return []

   def __call__(self, inputs, states):
       return inputs, states + [symbol.expand_dims(inputs, axis=1)]

   def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
       outputs = _normalize_sequence(length, inputs, layout, merge_outputs)
       if merge_outputs is True:
           states = outputs
       else:
           states = inputs

       # attention cell always use NTC layout for states
       states, _ = _normalize_sequence(None, states, 'NTC', True, layout)
       return outputs, [states]


def _attention_pooling(source, scores):
   # source: (batch_size, seq_len, encoder_num_hidden)
   # scores: (batch_size, seq_len, 1)
   probs = symbol.softmax(scores, axis=1)
   output = symbol.batch_dot(source, probs, transpose_a=True)
   return symbol.reshape(output, shape=(0, 0))


class BaseAttentionCell(BaseRNNCell):
   """Base class for attention cells"""
   def __init__(self, prefix='att_', params=None):
       super(BaseAttentionCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return [(0, 0, 0)]

   def __call__(self, inputs, states):
       raise NotImplementedError


class DotAttentionCell(BaseAttentionCell):
   """Dot attention"""
   def __init__(self, prefix='dotatt_', params=None):
       super(DotAttentionCell, self).__init__(prefix, params=params)

   def __call__(self, inputs, states):
       # inputs: (batch_size, decoder_num_hidden)
       # for dot attention decoder_num_hidden must equal encoder_num_hidden
       if len(states) > 1:
           states = [symbol.concat(*states, dim=1)]

       # source: (batch_size, seq_len, encoder_num_hidden)
       source = states[0]
       # (batch_size, decoder_num_hidden, 1)
       inputs = symbol.expand_dims(inputs, axis=2)
       # (batch_size, seq_len, 1)
       scores = symbol.batch_dot(source, inputs)
       # (batch_size, encoder_num_hidden)
       return _attention_pooling(source, scores), states