In [3]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from itertools import groupby
from mxnet.io import DataIter
from random import shuffle

import deepdish as dd

import operator
import pickle
import re
import warnings

  chunks = self.iterencode(o, _one_shot=True)


In [4]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [5]:
# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple(
    'Dataset', 
    ['src_sent', 'src_vocab', 'inv_src_vocab', 'targ_sent', 'targ_vocab', 'inv_targ_vocab'])

def invert_dict(d):
    return {v: k for k, v in d.iteritems()}


def get_data(src_path, targ_path, start_label=1, invalid_label=0, pad_symbol='<PAD>'):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    src_vocab[pad_symbol] = invalid_label
    inv_src_vocab = invert_dict(src_vocab)

    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)
    
    targ_vocab[pad_symbol] = invalid_label
    inv_targ_vocab = invert_dict(targ_vocab)
    
    return Dataset(
        src_sent=src_sent, src_vocab=src_vocab, inv_src_vocab=inv_src_vocab,
        targ_sent=targ_sent, targ_vocab=targ_vocab, inv_targ_vocab=inv_targ_vocab)

In [6]:
def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [7]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_small',
        targ_path='./data/europarl-v7.es-en.es_small',
        start_label=1,
        invalid_label=0
    )

In [49]:
class TwoDBisect:
    def __init__(self, buckets):
        self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
        self.x, self.y = zip(*buckets)
        self.x, self.y = np.array(list(self.x)), np.array(list(self.y))

    def twod_bisect(self, source, target):    
        offset1 = np.searchsorted(self.x, len(source), side='left')
        offset2 = np.where(self.y[offset1:] >= len(target))[0]        
        return self.buckets[offset1 + offset2[0]] 

class Seq2SeqIterator:    
    
    def __init__(self, dataset, buckets=None, batch_size=32, max_sent_len=None):
        self.src_sent = dataset.src_sent
        self.targ_sent = dataset.targ_sent
        # make this default to the maximum of the ???
        if buckets:
            z = zip(*buckets)
            self.max_sent_len = max(max(z[0]), max(z[1]))
        else:
            self.max_sent_len = max_sent_len
        if self.max_sent_len:
            self.src_sent, self.targ_sent = self.filter_long_sent(self.src_sent, self.targ_sent, self.max_sent_len) 
        self.src_vocab = dataset.src_vocab
        self.targ_vocab = dataset.targ_vocab
        self.inv_src_vocab = dataset.inv_src_vocab
        self.inv_targ_vocab = dataset.inv_targ_vocab
        # Can't filter smaller counts per bucket if those sentences still exist!
        self.buckets = buckets if buckets else self.gen_buckets(
            self.src_sent, self.targ_sent, filter_smaller_counts_than=1, max_sent_len=max_sent_len)
        self.bisect = TwoDBisect(self.buckets)
        self.max_sent_len = max_sent_len
    
    def group_lengths(self):
        tuples = []
        ctr = 0
        for src, targ in zip(self.src_sent, self.targ_sent):
            try:
                len_tup = self.bisect.twod_bisect(src, targ)
                rev_src = src[::-1] 
                tuples.append((src, targ, len_tup))
            except Exception as e:
                print(e)
                print(src)
                print(targ)
                raise e
        sorted_tuples = sorted(tuples, key=operator.itemgetter(2))
        g = groupby(sorted_tuples, lambda x: x[2])
        groups = {}
        for i in g:
            groups[i[0]] = map(lambda x: list(x[:2]), i[1])
        print(groups[(11,11)])

        
    @staticmethod 
    def filter_long_sent(src_sent, targ_sent, max_len):
        result = filter(lambda x: len(x[0]) <= max_len and len(x[1]) <= max_len, zip(src_sent, targ_sent))
        return zip(*result)

    @staticmethod
    def gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
        length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
        counts = list(Counter(length_pairs).items())
        c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
        buckets = [i[0] for i in c_sorted if i[1] >= filter_smaller_counts_than and 
                   (max_sent_len is None or i[0][0] <= max_sent_len) and
                   (max_sent_len is None or i[0][1] <= max_sent_len) and
                   (min_sent_len is None or i[0][0] >= min_sent_len) and
                   (min_sent_len is None or i[0][1] >= min_sent_len)]
        return buckets

#    def reset(self):
#         self.curr_idx = 0
#         random.shuffle(self.idx)
#         for buck in self.data:
#             np.random.shuffle(buck)

#         self.nddata = []
#         self.ndlabel = []
#         for buck in self.data:
#             label = np.empty_like(buck)
#             label[:, :-1] = buck[:, 1:]
#             label[:, -1] = self.invalid_label
#             self.nddata.append(ndarray.array(buck, dtype=self.dtype))
#             self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

#     def next(self):
#         if self.curr_idx == len(self.idx):
#             raise StopIteration
#         i, j = self.idx[self.curr_idx]
#         self.curr_idx += 1

#         if self.major_axis == 1:
#             data = self.nddata[i][j:j+self.batch_size].T
#             label = self.ndlabel[i][j:j+self.batch_size].T
#         else:
#             data = self.nddata[i][j:j+self.batch_size]
#             label = self.ndlabel[i][j:j+self.batch_size]

#         return DataBatch([data], [label], pad=0,
#                          bucket_key=self.buckets[i],
#                          provide_data=[(self.data_name, data.shape)],
#                          provide_label=[(self.label_name, label.shape)])

In [50]:
i1 = Seq2SeqIterator(dataset)

In [51]:
src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
# max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in range(
        min_len,max_len+increment,increment
    ) for j in range(
        min_len,max_len+increment,increment
    )]

In [52]:
i2 = Seq2SeqIterator(dataset, buckets=all_pairs, batch_size=32, max_sent_len=None)

In [53]:
i2.group_lengths()

[[[99, 100, 16, 11, 25, 101, 2, 102, 36], [101, 102, 11, 103, 104, 4, 105, 30]], [[99, 100, 16, 11, 25, 101, 2, 102, 36], [101, 102, 11, 103, 104, 4, 105, 30]], [[205, 213, 3, 316, 2, 320, 321, 36], [325, 8, 321, 4, 326, 327, 30]], [[436, 325, 108, 401, 448, 455, 154], [129, 439, 24, 41, 420, 485, 66, 321, 4, 480, 129]], [[120, 508, 509, 16, 302, 165, 468, 80, 62, 36], [236, 11, 552, 19, 553, 143, 539, 30]], [[259, 144, 230, 223, 228, 213, 549, 36], [595, 596, 78, 24, 152, 597, 587, 30]], [[93, 9, 761, 3, 264, 98, 100, 36], [97, 98, 9, 848, 70, 264, 100, 98, 134, 30]], [[5, 18, 19, 22, 1233, 225, 564, 101, 36], [540, 1391, 915, 11, 419, 1245, 8, 1324, 30]], [[1536, 302, 325, 802, 520, 213, 1537, 25, 1538, 36], [310, 1029, 214, 103, 1773, 19, 301, 30]], [[2848, 2, 3, 876, 2, 3, 496, 763], [98, 994, 4, 70, 363, 535, 43, 52, 3486, 30]], [[3542, 3543, 325, 968, 649, 3544, 245, 234, 62, 1501, 36], [45, 8, 594, 4446, 68, 3418, 30]], [[1600, 62, 1587, 16, 2595, 119, 1871, 36], [4511, 2589, 42