In [148]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from itertools import groupby
from mxnet.io import DataIter
from random import shuffle

import deepdish as dd

import operator
import pickle
import re
import warnings

In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
Args = namedtuple(
    'Args', 
    ('test load_epoch num_layers num_hidden num_embed bidirectional gpus '
     'kv_store num_epochs optimizer mom wd lr batch_size disp_batches '
     'stack_rnn dropout model_prefix'))

args = Args(
    test=          False,
    load_epoch=    0,
    num_layers=    2,
    num_hidden=    200,
    num_embed=     200,
    bidirectional= False,
    gpus=          '0,1',
    kv_store=      'device',
    num_epochs=    1,
    optimizer=    'adam',
    mom=           0.9,
    wd=           0.00001,
    lr = 0.001,
    batch_size=    32,
    disp_batches=  50,
    stack_rnn=     False,
    dropout=       0.5,
    model_prefix= 'foo'
)

In [5]:
# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple('Dataset', ['src_sent', 'src_vocab', 'targ_sent', 'targ_vocab'])

def get_data(src_path, targ_path, start_label=1, invalid_label=0):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)

    return Dataset(src_sent=src_sent, src_vocab=src_vocab, targ_sent=targ_sent, targ_vocab=targ_vocab)

def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [6]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_small',
        targ_path='./data/europarl-v7.es-en.es_small',
        start_label=1,
        invalid_label=0
    )

In [7]:
path = './foo.pickle'
persist_dataset(dataset, path)

In [8]:
del dataset
dataset = load_dataset(path)

In [9]:
def gen_buckets(dataset, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
    src_sent = dataset.src_sent
    targ_sent = dataset.targ_sent
    length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
    counts = list(Counter(length_pairs).items())
    c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
    buckets = [i for i in c_sorted if i[1] >= filter_smaller_counts_than and 
               (max_sent_len is None or i[0][0] <= max_sent_len) and
               (max_sent_len is None or i[0][1] <= max_sent_len) and
               (min_sent_len is None or i[0][0] >= min_sent_len) and
               (min_sent_len is None or i[0][1] >= min_sent_len)]
    return buckets

In [10]:
max_len = 60

all_buckets = gen_buckets(dataset, filter_smaller_counts_than=None, max_sent_len=None, min_sent_len=None)
batch_size = 32
batched_buckets = gen_buckets(dataset, filter_smaller_counts_than=batch_size, max_sent_len=max_len, min_sent_len=1)

In [11]:
all_buckets

[((1, 1), 13),
 ((1, 2), 1),
 ((1, 3), 8),
 ((1, 4), 11),
 ((1, 5), 18),
 ((1, 6), 26),
 ((1, 7), 26),
 ((1, 8), 43),
 ((1, 9), 37),
 ((1, 10), 38),
 ((1, 11), 43),
 ((1, 12), 54),
 ((1, 13), 37),
 ((1, 14), 49),
 ((1, 15), 43),
 ((1, 16), 46),
 ((1, 17), 46),
 ((1, 18), 65),
 ((1, 19), 49),
 ((1, 20), 69),
 ((1, 21), 51),
 ((1, 22), 53),
 ((1, 23), 56),
 ((1, 24), 50),
 ((1, 25), 53),
 ((1, 26), 47),
 ((1, 27), 59),
 ((1, 28), 39),
 ((1, 29), 51),
 ((1, 30), 44),
 ((1, 31), 38),
 ((1, 32), 35),
 ((1, 33), 43),
 ((1, 34), 34),
 ((1, 35), 42),
 ((1, 36), 34),
 ((1, 37), 28),
 ((1, 38), 28),
 ((1, 39), 34),
 ((1, 40), 28),
 ((1, 41), 28),
 ((1, 42), 27),
 ((1, 43), 25),
 ((1, 44), 17),
 ((1, 45), 19),
 ((1, 46), 14),
 ((1, 47), 20),
 ((1, 48), 21),
 ((1, 49), 18),
 ((1, 50), 9),
 ((1, 51), 15),
 ((1, 52), 10),
 ((1, 53), 11),
 ((1, 54), 11),
 ((1, 55), 16),
 ((1, 56), 7),
 ((1, 57), 9),
 ((1, 58), 11),
 ((1, 59), 8),
 ((1, 60), 11),
 ((1, 61), 6),
 ((1, 62), 6),
 ((1, 63), 9),
 ((1, 64),

In [12]:
batched_buckets

[((1, 8), 43),
 ((1, 9), 37),
 ((1, 10), 38),
 ((1, 11), 43),
 ((1, 12), 54),
 ((1, 13), 37),
 ((1, 14), 49),
 ((1, 15), 43),
 ((1, 16), 46),
 ((1, 17), 46),
 ((1, 18), 65),
 ((1, 19), 49),
 ((1, 20), 69),
 ((1, 21), 51),
 ((1, 22), 53),
 ((1, 23), 56),
 ((1, 24), 50),
 ((1, 25), 53),
 ((1, 26), 47),
 ((1, 27), 59),
 ((1, 28), 39),
 ((1, 29), 51),
 ((1, 30), 44),
 ((1, 31), 38),
 ((1, 32), 35),
 ((1, 33), 43),
 ((1, 34), 34),
 ((1, 35), 42),
 ((1, 36), 34),
 ((1, 39), 34),
 ((3, 11), 36),
 ((3, 14), 37),
 ((3, 18), 39),
 ((3, 19), 33),
 ((3, 20), 34),
 ((3, 25), 32),
 ((4, 8), 32),
 ((4, 10), 33),
 ((4, 11), 56),
 ((4, 12), 38),
 ((4, 13), 33),
 ((4, 14), 36),
 ((4, 15), 40),
 ((4, 16), 39),
 ((4, 17), 34),
 ((4, 18), 38),
 ((4, 19), 54),
 ((4, 20), 63),
 ((4, 21), 45),
 ((4, 22), 43),
 ((4, 23), 43),
 ((4, 24), 40),
 ((4, 25), 42),
 ((4, 26), 50),
 ((4, 27), 45),
 ((4, 28), 49),
 ((4, 29), 41),
 ((4, 30), 35),
 ((4, 31), 38),
 ((4, 33), 36),
 ((4, 38), 37),
 ((5, 5), 44),
 ((5, 6), 59

In [13]:
num_sent = lambda sent: reduce(lambda a, b: a + b, map(lambda x: x[1], sent))
num_all = num_sent(all_buckets)
num_batched = num_sent(batched_buckets)
print("# of all buckets: %d (# sent: %d)" % (len(all_buckets), num_all))
print("# of buckets with counts < %d filtered out: %d (num sent: %d)" % (batch_size, len(batched_buckets), num_batched))
print("percent of examples remaining after filtering: %.2f" % (100.0*num_batched/num_all))

# of all buckets: 11751 (# sent: 498378)
# of buckets with counts < 32 filtered out: 2964 (num sent: 432626)
percent of examples remaining after filtering: 86.81


In [14]:
def print_sorted_buckets(buckets):
    b = sorted(buckets, key=operator.itemgetter(0, 1))
    for i in b:
        print(i)

print_sorted_buckets(batched_buckets)

((1, 8), 43)
((1, 9), 37)
((1, 10), 38)
((1, 11), 43)
((1, 12), 54)
((1, 13), 37)
((1, 14), 49)
((1, 15), 43)
((1, 16), 46)
((1, 17), 46)
((1, 18), 65)
((1, 19), 49)
((1, 20), 69)
((1, 21), 51)
((1, 22), 53)
((1, 23), 56)
((1, 24), 50)
((1, 25), 53)
((1, 26), 47)
((1, 27), 59)
((1, 28), 39)
((1, 29), 51)
((1, 30), 44)
((1, 31), 38)
((1, 32), 35)
((1, 33), 43)
((1, 34), 34)
((1, 35), 42)
((1, 36), 34)
((1, 39), 34)
((3, 11), 36)
((3, 14), 37)
((3, 18), 39)
((3, 19), 33)
((3, 20), 34)
((3, 25), 32)
((4, 8), 32)
((4, 10), 33)
((4, 11), 56)
((4, 12), 38)
((4, 13), 33)
((4, 14), 36)
((4, 15), 40)
((4, 16), 39)
((4, 17), 34)
((4, 18), 38)
((4, 19), 54)
((4, 20), 63)
((4, 21), 45)
((4, 22), 43)
((4, 23), 43)
((4, 24), 40)
((4, 25), 42)
((4, 26), 50)
((4, 27), 45)
((4, 28), 49)
((4, 29), 41)
((4, 30), 35)
((4, 31), 38)
((4, 33), 36)
((4, 38), 37)
((5, 5), 44)
((5, 6), 59)
((5, 7), 61)
((5, 8), 79)
((5, 9), 81)
((5, 10), 89)
((5, 11), 102)
((5, 12), 111)
((5, 13), 103)
((5, 14), 125)
((5, 15), 

In [15]:
src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
# max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in range(
        min_len,max_len+increment,increment
    ) for j in range(
        min_len,max_len+increment,increment
    )]

print(all_pairs)

[(1, 1), (1, 6), (1, 11), (1, 16), (1, 21), (1, 26), (1, 31), (1, 36), (1, 41), (1, 46), (1, 51), (1, 56), (1, 61), (1, 66), (6, 1), (6, 6), (6, 11), (6, 16), (6, 21), (6, 26), (6, 31), (6, 36), (6, 41), (6, 46), (6, 51), (6, 56), (6, 61), (6, 66), (11, 1), (11, 6), (11, 11), (11, 16), (11, 21), (11, 26), (11, 31), (11, 36), (11, 41), (11, 46), (11, 51), (11, 56), (11, 61), (11, 66), (16, 1), (16, 6), (16, 11), (16, 16), (16, 21), (16, 26), (16, 31), (16, 36), (16, 41), (16, 46), (16, 51), (16, 56), (16, 61), (16, 66), (21, 1), (21, 6), (21, 11), (21, 16), (21, 21), (21, 26), (21, 31), (21, 36), (21, 41), (21, 46), (21, 51), (21, 56), (21, 61), (21, 66), (26, 1), (26, 6), (26, 11), (26, 16), (26, 21), (26, 26), (26, 31), (26, 36), (26, 41), (26, 46), (26, 51), (26, 56), (26, 61), (26, 66), (31, 1), (31, 6), (31, 11), (31, 16), (31, 21), (31, 26), (31, 31), (31, 36), (31, 41), (31, 46), (31, 51), (31, 56), (31, 61), (31, 66), (36, 1), (36, 6), (36, 11), (36, 16), (36, 21), (36, 26), (36

In [16]:
class TwoDBisect:
    def __init__(self, buckets):
        self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
        self.x, self.y = zip(*buckets)
        self.x, self.y = np.array(list(self.x)), np.array(list(self.y))
        
    def twod_bisect(self, source, target):    
        offset1 = np.searchsorted(self.x, len(source), side='left')
        offset2 = np.where(self.y[offset1:] >= len(target))[0]        
        return self.buckets[offset1 + offset2[0]]  

In [17]:
bisect = TwoDBisect(all_pairs)
bisect.buckets

[(1, 1),
 (1, 6),
 (1, 11),
 (1, 16),
 (1, 21),
 (1, 26),
 (1, 31),
 (1, 36),
 (1, 41),
 (1, 46),
 (1, 51),
 (1, 56),
 (1, 61),
 (1, 66),
 (6, 1),
 (6, 6),
 (6, 11),
 (6, 16),
 (6, 21),
 (6, 26),
 (6, 31),
 (6, 36),
 (6, 41),
 (6, 46),
 (6, 51),
 (6, 56),
 (6, 61),
 (6, 66),
 (11, 1),
 (11, 6),
 (11, 11),
 (11, 16),
 (11, 21),
 (11, 26),
 (11, 31),
 (11, 36),
 (11, 41),
 (11, 46),
 (11, 51),
 (11, 56),
 (11, 61),
 (11, 66),
 (16, 1),
 (16, 6),
 (16, 11),
 (16, 16),
 (16, 21),
 (16, 26),
 (16, 31),
 (16, 36),
 (16, 41),
 (16, 46),
 (16, 51),
 (16, 56),
 (16, 61),
 (16, 66),
 (21, 1),
 (21, 6),
 (21, 11),
 (21, 16),
 (21, 21),
 (21, 26),
 (21, 31),
 (21, 36),
 (21, 41),
 (21, 46),
 (21, 51),
 (21, 56),
 (21, 61),
 (21, 66),
 (26, 1),
 (26, 6),
 (26, 11),
 (26, 16),
 (26, 21),
 (26, 26),
 (26, 31),
 (26, 36),
 (26, 41),
 (26, 46),
 (26, 51),
 (26, 56),
 (26, 61),
 (26, 66),
 (31, 1),
 (31, 6),
 (31, 11),
 (31, 16),
 (31, 21),
 (31, 26),
 (31, 31),
 (31, 36),
 (31, 41),
 (31, 46),
 (31, 51

In [35]:
tuples = []

src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

short_sentences = filter(lambda x: len(x[0]) <= max_len and len(x[1]) <= max_len, zip(src_sent, targ_sent))
# def print_sorted_buckets(buckets):
#     b = sorted(buckets, key=operator.itemgetter(0, 1))
#     for i in b:
#         print(i)

# print_sorted_buckets(batched_buckets)
for src, targ in short_sentences:
    try:
        len_tup = bisect.twod_bisect(src, targ)
        rev_src = src[::-1] 
        tuples.append((src, targ, len_tup))
    except Exception as e:
        print("src length: %d" % len(src))
        print("targ length: %d" % len(targ))

In [36]:
sorted_tuples = sorted(tuples, key=operator.itemgetter(2))
sorted_tuples

[([11315], [16634], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([36], [30], (1, 1)),
 ([561], [606, 4, 47, 607], (1, 6)),
 ([36], [98, 60, 1245, 1572, 30], (1, 6)),
 ([36], [97, 12276, 11985, 100], (1, 6)),
 ([36], [19189, 19, 315, 9, 30], (1, 6)),
 ([36], [595, 43, 52, 237, 30], (1, 6)),
 ([9799], [540, 3878, 41, 42, 16135, 30], (1, 6)),
 ([36], [3629, 285, 7001, 120, 2609, 30], (1, 6)),
 ([36], [97, 799, 100], (1, 6)),
 ([36], [48827, 74, 8, 1837, 4940, 30], (1, 6)),
 ([36], [4641, 1013, 19, 648, 30], (1, 6)),
 ([6193], [98, 60, 1245, 1572, 30], (1, 6)),
 ([561], [345, 277, 11, 883, 3079, 30], (1, 6)),
 ([20197], [1112, 75, 37188, 4, 4963, 30], (1, 6)),
 ([36], [606, 4, 47, 607], (1, 6)),
 ([36], [98, 60, 1245, 1572, 30], (1, 6)),
 ([36], [213, 24, 214, 2213, 30], (1,

In [163]:
g = groupby(sorted_tuples, lambda x: x[2])
groups = {}

for i in g:
    lst = list(i[1])
    count = len(lst)
    if count < batch_size:
        continue
    groups[i[0]] = map(lambda x: list(x[:2]), lst)
    print("tuple: %s, count: %d" % (str(i[0]), count))

tuple: (1, 6), count: 64
tuple: (1, 11), count: 187
tuple: (1, 16), count: 229
tuple: (1, 21), count: 280
tuple: (1, 26), count: 259
tuple: (1, 31), count: 231
tuple: (1, 36), count: 188
tuple: (1, 41), count: 146
tuple: (1, 46), count: 102
tuple: (1, 51), count: 83
tuple: (1, 56), count: 55
tuple: (1, 61), count: 45
tuple: (6, 1), count: 63
tuple: (6, 6), count: 445
tuple: (6, 11), count: 1294
tuple: (6, 16), count: 1738
tuple: (6, 21), count: 1858
tuple: (6, 26), count: 1766
tuple: (6, 31), count: 1583
tuple: (6, 36), count: 1266
tuple: (6, 41), count: 1024
tuple: (6, 46), count: 828
tuple: (6, 51), count: 595
tuple: (6, 56), count: 442
tuple: (6, 61), count: 307
tuple: (6, 66), count: 180
tuple: (11, 1), count: 207
tuple: (11, 6), count: 1479
tuple: (11, 11), count: 4224
tuple: (11, 16), count: 5910
tuple: (11, 21), count: 6242
tuple: (11, 26), count: 5925
tuple: (11, 31), count: 5142
tuple: (11, 36), count: 4266
tuple: (11, 41), count: 3364
tuple: (11, 46), count: 2427
tuple: (11, 

In [92]:
groups[(1,11)]

[[[4688], [6071, 360, 6072, 97, 853, 360, 854, 100, 30]],
 [[36], [1812, 1085, 86, 2281, 2388, 11, 1316, 30]],
 [[36], [540, 7487, 16720, 41, 1783, 2196, 30]],
 [[70], [12070, 242, 24, 41, 420, 481, 24, 2555, 66, 1380, 30]],
 [[36], [3485, 2, 994, 4, 70, 363, 535]],
 [[36], [3483, 1778, 11, 157, 2140, 8428, 30]],
 [[9799], [97, 98, 134, 3562, 70, 860, 19, 70, 7480, 100]],
 [[9799], [574, 632, 4, 3407, 41, 4234, 388, 1987, 30]],
 [[561], [5542, 419, 41, 7720, 467, 3429, 30]],
 [[36], [595, 214, 1315, 24, 332, 689, 103, 3542, 4, 5967, 30]],
 [[4688], [97, 98, 134, 6128, 1139, 70, 279, 1026, 237, 6129, 100]],
 [[6293], [4840, 934, 214, 4580, 19, 835, 2845, 30]],
 [[13345], [9207, 17703, 74, 240, 288, 17, 34445, 30]],
 [[36], [2644, 11, 86, 2017, 5533, 650, 8, 10700, 1783, 2273, 30]],
 [[11315], [811, 4962, 41, 15934, 86, 1137, 30]],
 [[12536], [439, 78, 440, 11, 3676, 1224, 22392, 3676]],
 [[9799], [574, 978, 152, 6205, 19, 12354, 152, 6094, 30]],
 [[36], [3599, 4576, 35899, 97, 1023, 360

In [105]:
invalid_symbol = -1

groups2 = {}

for tup, grouper in groups.items():
    new_grouper = []
    new_src = np.full((len(grouper), tup[0]), invalid_symbol, dtype=np.int32)
    new_targ = np.full((len(grouper), tup[1]), invalid_symbol, dtype=np.int32)
    
    for idx, grp in enumerate(grouper):
        source = grp[0]
        target = grp[1]
        rev_src = source[::-1]
        new_src[idx, :-(len(rev_src)+1):-1] = rev_src
        new_targ[idx, :len(target)] = target
    new_grouper = (new_src, new_targ)
    groups2[tup] = new_grouper

In [134]:
groups2[(1,11)][1][[2, 1, 0]]

array([[  540,  7487, 16720,    41,  1783,  2196,    30,    -1,    -1,
           -1,    -1],
       [ 1812,  1085,    86,  2281,  2388,    11,  1316,    30,    -1,
           -1,    -1],
       [ 6071,   360,  6072,    97,   853,   360,   854,   100,    30,
           -1,    -1]], dtype=int32)

In [166]:


# foo = list(range(99))
# bar = list(chunks(foo, 5))
# print(bar)

def iterate_groups(groups, batch_size=32):
    
    def chunks(l, n):
        n = max(1, n)
        return (l[i:i+n] for i in xrange(0, len(l), n))
    
    for key, group in groups.items():
        num_examples = len(group[0])
        indices = list(xrange(num_examples))
        shuffle(indices)
        src_sent, targ_sent = group
        src_sent = src_sent[indices]
        targ_sent = targ_sent[indices]        
        for chunk in chunks(list(xrange(num_examples)), batch_size):
            yield [src_sent[chunk], targ_sent[chunk]]

In [167]:
gg = iterate_groups(groups2, batch_size=32)

In [170]:
cnt = 0

for g in gg:
    try:
        cnt += len(g[0])
    except StopIteration:
        print("End of iteration!")
        
print(cnt)

0


In [None]:
preprocessed_dataset = {
    'groups': groups,
    'src_vocab': 
    'targ_vocab'
}

In [85]:
dd.io.save('groups.h5', groups)

In [86]:
del groups

In [88]:
groups = dd.io.load('groups.h5')

In [None]:
# Add dictionary and inverse dictionary to the serialized dataset

In [89]:
groups[(11,11)]

[array([[   -1,    -1,    99, ...,     2,   102,    36],
        [   -1,    -1,    99, ...,     2,   102,    36],
        [   -1,    -1,    -1, ...,   320,   321,    36],
        ..., 
        [   -1,    -1,    94, ...,  1391, 10824,    36],
        [   -1,    70,  5713, ...,  2067,   545,    36],
        [   -1,    70,  4766, ...,     2,   335,    16]], dtype=int32),
 array([[ 101,  102,   11, ...,   -1,   -1,   -1],
        [ 101,  102,   11, ...,   -1,   -1,   -1],
        [ 325,    8,  321, ...,   -1,   -1,   -1],
        ..., 
        [1448,  134,   11, ...,   30,   -1,   -1],
        [5874,   52,   59, ...,   -1,   -1,   -1],
        [ 345,  277,   11, ..., 4679,   30,   -1]], dtype=int32)]

In [82]:
# print some inversed source sentences and target sentences.

'VOTE'

In [None]:
# # ctr = 0
# # num_ex = 10
# invalid_symbol = -9999

# for tup, grouper in groups.items():
# #     print("Tuple: %s" % str(tup))
#     new_grouper = []
#     new_src = np.full((len(grouper), tup[0], 2), invalid_symbol)
#     new_targ = np.full((len(grouper), tup[1], 2), invalid_symbol)
    
#     for grp in grouper:
#         source, target = grp
#         rev_src = source[::-1]
#         if len(rev_src) < tup[0]:
#             rev_src_2 = [invalid_symbol] * tup[0]
#             rev_src_2[:-(len(rev_src)+1):-1] = rev_src
#         else:
#             rev_src_2 = rev_src
#         if len(target) < tup[1]:
#             targ_2 = [invalid_symbol] * tup[1]
#             targ_2[:len(target)] = target
#         else:
#             targ_2 = target
#         new_grouper.append([rev_src_2, targ_2])
#         groups[tup] = new_grouper #np.array(new_grouper)

In [None]:
grp = groups[(11,11)]
print("shape: %s" % str(np.shape(grp)))
np.shape(grp[0,:,:])

[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39], [40, 41, 42, 43, 44], [45, 46, 47, 48, 49], [50, 51, 52, 53, 54], [55, 56, 57, 58, 59], [60, 61, 62, 63, 64], [65, 66, 67, 68, 69], [70, 71, 72, 73, 74], [75, 76, 77, 78, 79], [80, 81, 82, 83, 84], [85, 86, 87, 88, 89], [90, 91, 92, 93, 94], [95, 96, 97, 98]]


In [None]:
ctr = 0
num_cases = 2

for tup, grouper in groups.items():
    print("tuple: %s" % str(tup))
    print(grouper)
    print("")
    ctr += 1
    if ctr >= num_cases:
        break

In [None]:
counts = list(Counter(tuples).items())
batch_size = 32
filter_smaller_counts_than = batch_size

c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
buckets = [i for i in c_sorted if i[1] >= filter_smaller_counts_than]
print(buckets)

In [None]:
class BucketSeq2SeqIter(DataIter):
    """Simple bucketing iterator for sequence-to-sequence models.
    
    @staticmethod
    def 
    
    Parameters
    ----------
    source_seq : list of list of int
    target_seq : list of list of int
    batch_size : int
        batch_size of data    
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of pairs of int (source seq length, target seq length)
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data    print(buckets)
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, source, target, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC', min_sent_len = 1, max_sent_len = 60):
        super(BucketSeq2SeqIter, self).__init__()
        
        # insert call to gen_buckets here if buckets are not provided
        if not buckets:
            all_pairs = [(i, j) for i in range(
                min_len,max_len+increment,increment
            ) for j in range(
                min_len,max_len+increment,increment
            )]
            self.bisect = TwoDBisect(buckets = all_pairs)
            buckets = 
            
#             [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
#                        if j >= batch_size]
            
        # Sorting is kinda pointless because it's first by the first element of the tuple,
        # then the next. So, it could be [(1, 2), (1, 20), (2, 5), (2, 71)]
#         buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            
            # this bisect also won't work because it's now based on a tuple of lengths
            buck = bisect.bisect_left(buckets, len(sent))
            # this test is not appropriate because of the tuple of lengths
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

In [None]:
print(batched_buckets)

In [None]:
buckets = list(range(5, 60, 5))
#[10, 20, 30, 40, 50, 60]

                 data_name='data', label_name='softmax_label', dtype='float32',
start_label = 1
invalid_label = 0

def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
    lines = open(fname).readlines()
    lines = [filter(None, i.split(' ')) for i in lines]
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

def get_data(layout):
    train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label,
                                      invalid_label=invalid_label)
    val_sent, _ = tokenize_text("./data/ptb.test.txt", vocab=vocab, start_label=start_label,
                                invalid_label=invalid_label)

    data_train  = mx.rnn.BucketSentenceIter(train_sent, args.batch_size, buckets=buckets,
                                            invalid_label=invalid_label, layout=layout)
    data_val    = mx.rnn.BucketSentenceIter(val_sent, args.batch_size, buckets=buckets,
                                            invalid_label=invalid_label, layout=layout)
    
    print("default: %s" % data_train.default_bucket_key)
    return data_train, data_val, vocab


def train(args):
    data_train, data_val, vocab = get_data('TN')
    if args.stack_rnn:
        stack = mx.rnn.SequentialRNNCell()
        for layer in range(args.num_layers):
            dropout = 0.0
            if layer < (args.num_layers - 1):
                dropout = args.dropout
            stack.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
                    mode='lstm', prefix='lstm_%d'%layer, dropout=dropout,
                    bidirectional=args.bidirectional))
        cell = stack
    else:
        cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
                mode='lstm', bidirectional=args.bidirectional)

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')

        output, _ = cell.unroll(seq_len, inputs=embed, merge_outputs=True, layout='TNC')

        pred = mx.sym.Reshape(output,
                shape=(-1, args.num_hidden*(1+args.bidirectional)))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return pred, ('data',), ('softmax_label',)

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen             = sym_gen,
        default_bucket_key  = data_train.default_bucket_key,
        context             = contexts)

    if args.load_epoch:
        _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
            cell, args.model_prefix, args.load_epoch)
    else:
        arg_params = None
        aux_params = None

    opt_params = {
      'learning_rate': args.lr,
      'wd': args.wd
    }

    if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']:
        opt_params['momentum'] = args.mom

    model.fit(
        train_data          = data_train,
        eval_data           = data_val,
        eval_metric         = mx.metric.Perplexity(invalid_label),
        kvstore             = args.kv_store,
        optimizer           = args.optimizer,
        optimizer_params    = opt_params, 
        initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
        arg_params          = arg_params,
        aux_params          = aux_params,
        begin_epoch         = args.load_epoch,
        num_epoch           = args.num_epochs,
        batch_end_callback  = mx.callback.Speedometer(args.batch_size, args.disp_batches),
        epoch_end_callback  = mx.rnn.do_rnn_checkpoint(cell, args.model_prefix, 1)
                              if args.model_prefix else None)

def test(args):
    assert args.model_prefix, "Must specifiy path to load from"
    _, data_val, vocab = get_data('NT')

    if not args.stack_rnn:
        stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
                mode='lstm', bidirectional=args.bidirectionclass BucketSentenceIter(DataIter):
    """Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.

    Parameters
    ----------
    sentences : list of list of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of int
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])al).unfuse()
    else:
        stack = mx.rnn.SequentialRNNCell()
        for i in range(args.num_layers):
            cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
            if args.bidirectional:
                cell = mx.rnn.BidirectionalCell(
                        cell,
                        mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
                        output_prefix='bi_lstm_%d'%i)
            stack.add(cell)

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
                                 output_dim=args.num_embed, name='embed')

        stack.reset()
        outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

        pred = mx.sym.Reshape(outputs,
                shape=(-1, args.num_hidden*(1+args.bidirectional)))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return pred, ('data',), ('softmax_label',)
    
class BucketSentenceIter(DataIter):
    """Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.

    Parameters
    ----------
    sentences : list of list of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of int
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen             = sym_gen,
        default_bucket_key  = data_val.default_bucket_key,
        context             = contexts)
    model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

    # note here we load using SequentialRNNCell instead of FusedRNNCell.
    _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
    model.set_params(arg_params, aux_params)

    model.score(data_val, mx.metric.Perplexity(invalid_label),
                batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

In [None]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

# args = parser.parse_args()

if args.num_layers >= 4 and len(args.gpus.split(',')) >= 4 and not args.stack_rnn:
    print('WARNING: stack-rnn is recommended to train complex model on multiple GPUs')

if args.test:
    # Demonstrates how to load a model trained with CuDNN RNN and predict
    # with non-fused MXNet symbol
    test(args)
else:
    train(args)

In [None]:
class BucketSeq2SeqIter(DataIter):
    """Simple bucketing iterator for sequence-to-sequence models.

    Parameters
    ----------
    sentences : list of list of pairs of int
        encoded sentences (source seq dict int, target seq dict int)
    batch_size : int
        batch_size of data    
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : list of pairs of int (source seq length, target seq length)
        size of data buckets. Automatically generated if None.
    data_name : str, default 'data'
        name of data    print(buckets)
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
    """
    def __init__(self, input_seq, output_seq, batch_size, buckets=None, invalid_label=-1,
                 data_name='data', label_name='softmax_label', dtype='float32',
                 layout='NTC'):
        super(BucketSentenceIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i, sent in enumerate(sentences):
            buck = bisect.bisect_left(buckets, len(sent))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sent)] = sent
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]

        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        self.batch_size = batch_size
        self.buckets = buckets
        self.data_name = data_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nddata = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        random.shuffle(self.idx)
        for buck in self.data:
            np.random.shuffle(buck)

        self.nddata = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck)
            label[:, :-1] = buck[:, 1:]
            label[:, -1] = self.invalid_label
            self.nddata.append(ndarray.array(buck, dtype=self.dtype))
            self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            data = self.nddata[i][j:j+self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size].T
        else:
            data = self.nddata[i][j:j+self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        return DataBatch([data], [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.data_name, data.shape)],
                         provide_label=[(self.label_name, label.shape)])

In [None]:
class AttentionEncoderCell(BaseRNNCell):
   """Place holder cell that prepare input for attention decoders"""
   def __init__(self, prefix='encode_', params=None):
       super(AttentionEncoderCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return []

   def __call__(self, inputs, states):
       return inputs, states + [symbol.expand_dims(inputs, axis=1)]

   def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
       outputs = _normalize_sequence(length, inputs, layout, merge_outputs)
       if merge_outputs is True:
           states = outputs
       else:
           states = inputs

       # attention cell always use NTC layout for states
       states, _ = _normalize_sequence(None, states, 'NTC', True, layout)
       return outputs, [states]


def _attention_pooling(source, scores):
   # source: (batch_size, seq_len, encoder_num_hidden)
   # scores: (batch_size, seq_len, 1)
   probs = symbol.softmax(scores, axis=1)
   output = symbol.batch_dot(source, probs, transpose_a=True)
   return symbol.reshape(output, shape=(0, 0))


class BaseAttentionCell(BaseRNNCell):
   """Base class for attention cells"""
   def __init__(self, prefix='att_', params=None):
       super(BaseAttentionCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return [(0, 0, 0)]

   def __call__(self, inputs, states):
       raise NotImplementedError


class DotAttentionCell(BaseAttentionCell):
   """Dot attention"""
   def __init__(self, prefix='dotatt_', params=None):
       super(DotAttentionCell, self).__init__(prefix, params=params)

   def __call__(self, inputs, states):
       # inputs: (batch_size, decoder_num_hidden)
       # for dot attention decoder_num_hidden must equal encoder_num_hidden
       if len(states) > 1:
           states = [symbol.concat(*states, dim=1)]

       # source: (batch_size, seq_len, encoder_num_hidden)
       source = states[0]
       # (batch_size, decoder_num_hidden, 1)
       inputs = symbol.expand_dims(inputs, axis=2)
       # (batch_size, seq_len, 1)
       scores = symbol.batch_dot(source, inputs)
       # (batch_size, encoder_num_hidden)
       return _attention_pooling(source, scores), states

In [None]:
class AttentionEncoderCell(BaseRNNCell):
   """Place holder cell that prepare input for attention decoders"""
   def __init__(self, prefix='encode_', params=None):
       super(AttentionEncoderCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return []

   def __call__(self, inputs, states):
       return inputs, states + [symbol.expand_dims(inputs, axis=1)]

   def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
       outputs = _normalize_sequence(length, inputs, layout, merge_outputs)
       if merge_outputs is True:
           states = outputs
       else:
           states = inputs

       # attention cell always use NTC layout for states
       states, _ = _normalize_sequence(None, states, 'NTC', True, layout)
       return outputs, [states]


def _attention_pooling(source, scores):
   # source: (batch_size, seq_len, encoder_num_hidden)
   # scores: (batch_size, seq_len, 1)
   probs = symbol.softmax(scores, axis=1)
   output = symbol.batch_dot(source, probs, transpose_a=True)
   return symbol.reshape(output, shape=(0, 0))


class BaseAttentionCell(BaseRNNCell):
   """Base class for attention cells"""
   def __init__(self, prefix='att_', params=None):
       super(BaseAttentionCell, self).__init__(prefix, params=params)

   @property
   def state_shape(self):
       return [(0, 0, 0)]

   def __call__(self, inputs, states):
       raise NotImplementedError


class DotAttentionCell(BaseAttentionCell):
   """Dot attention"""
   def __init__(self, prefix='dotatt_', params=None):
       super(DotAttentionCell, self).__init__(prefix, params=params)

   def __call__(self, inputs, states):
       # inputs: (batch_size, decoder_num_hidden)
       # for dot attention decoder_num_hidden must equal encoder_num_hidden
       if len(states) > 1:
           states = [symbol.concat(*states, dim=1)]

       # source: (batch_size, seq_len, encoder_num_hidden)
       source = states[0]
       # (batch_size, decoder_num_hidden, 1)
       inputs = symbol.expand_dims(inputs, axis=2)
       # (batch_size, seq_len, 1)
       scores = symbol.batch_dot(source, inputs)
       # (batch_size, encoder_num_hidden)
       return _attention_pooling(source, scores), states