In [1]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from itertools import groupby
from mxnet.io import DataIter
from random import shuffle

import deepdish as dd

import operator
import pickle
import re
import warnings

  chunks = self.iterencode(o, _one_shot=True)


In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple(
    'Dataset', 
    ['src_sent', 'src_vocab', 'inv_src_vocab', 'targ_sent', 'targ_vocab', 'inv_targ_vocab'])

def invert_dict(d):
    return {v: k for k, v in d.iteritems()}


def get_data(src_path, targ_path, start_label=1, invalid_label=0, pad_symbol='<PAD>'):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    src_vocab[pad_symbol] = invalid_label
    inv_src_vocab = invert_dict(src_vocab)

    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)
    
    targ_vocab[pad_symbol] = invalid_label
    inv_targ_vocab = invert_dict(targ_vocab)
    
    return Dataset(
        src_sent=src_sent, src_vocab=src_vocab, inv_src_vocab=inv_src_vocab,
        targ_sent=targ_sent, targ_vocab=targ_vocab, inv_targ_vocab=inv_targ_vocab)

In [4]:
def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [5]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_small',
        targ_path='./data/europarl-v7.es-en.es_small',
        start_label=1,
        invalid_label=0
    )

In [99]:
class TwoDBisect:
    def __init__(self, buckets):
        self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
        self.x, self.y = zip(*buckets)
        self.x, self.y = np.array(list(self.x)), np.array(list(self.y))

    def twod_bisect(self, source, target):    
        offset1 = np.searchsorted(self.x, len(source), side='left')
        offset2 = np.where(self.y[offset1:] >= len(target))[0]        
        return self.buckets[offset1 + offset2[0]] 

class Seq2SeqIterator:    
    
    def __init__(self, dataset, buckets=None, batch_size=32, max_sent_len=None):
        self.src_sent = dataset.src_sent
        self.targ_sent = dataset.targ_sent
        if buckets:
            z = zip(*buckets)
            self.max_sent_len = max(max(z[0]), max(z[1]))
        else:
            self.max_sent_len = max_sent_len
        if self.max_sent_len:
            self.src_sent, self.targ_sent = self.filter_long_sent(
                self.src_sent, self.targ_sent, self.max_sent_len) 
        self.src_vocab = dataset.src_vocab
        self.targ_vocab = dataset.targ_vocab
        self.inv_src_vocab = dataset.inv_src_vocab
        self.inv_targ_vocab = dataset.inv_targ_vocab
        # Can't filter smaller counts per bucket if those sentences still exist!
        self.buckets = buckets if buckets else self.gen_buckets(
            self.src_sent, self.targ_sent, filter_smaller_counts_than=1, max_sent_len=max_sent_len)
        self.bisect = TwoDBisect(self.buckets)
        self.max_sent_len = max_sent_len
        self.pad_id = self.src_vocab['<PAD>']
        # After bucketization, we should probably del self.src_sent and self.targ_sent
        # to free up memory.
        self.bucketed_data, self.bucket_idx_to_key = self.bucketize()
        self.bucket_key_to_idx = invert_dict(dict(enumerate(self.bucket_idx_to_key)))
        self.interbucket_idx = 0
        self.intrabucket_idx = 0
        self.bucket_iterator_indices = list(range(len(self.bucket_idx_to_key)))
    
    def bucketize(self):
        tuples = []
        ctr = 0
        for src, targ in zip(self.src_sent, self.targ_sent):
#             try:
            len_tup = self.bisect.twod_bisect(src, targ)
            rev_src = src[::-1] 
            tuples.append((src, targ, len_tup))
#             except Exception as e:
#                 print(e)
#                 print(src)
#                 print(targ)
#                 raise e
        sorted_tuples = sorted(tuples, key=operator.itemgetter(2))
        grouped = groupby(sorted_tuples, lambda x: x[2])
        bucketed_data = [] 
        bucket_idx_to_key = []
        
        for group in grouped:
            
            # get src and targ sentences, ignore the last elem of the tuple 
            # (the grouping key of (src_len, targ_len))
            key, value = group[0], map(lambda x: x[:2], group[1])

            # create padded representation
            new_src = np.full((len(value), key[0]), self.pad_id, dtype=np.int32)
            new_targ = np.full((len(value), key[1]), self.pad_id, dtype=np.int32)
            
            for idx, example in enumerate(value):
                curr_src, curr_targ = example
                rev_src = curr_src[::-1]
                new_src[idx, :-(len(rev_src)+1):-1] = rev_src
                new_targ[idx, :len(curr_targ)] = curr_targ
                                
            bucketed_data.append((new_src, new_targ))
            bucket_idx_to_key.append(key)
        return bucketed_data, bucket_idx_to_key
    
    def current_bucket_key(self):
        return self.bucket_idx_to_key[self.interbucket_idx]
    
    def current_bucket_index(self):
        return self.bucket_iterator_indices[self.interbucket_idx]

    # shuffle the data within buckets, and reset iterator
    def reset(self):
        self.interbucket_idx = 0
        self.intrabucket_idx = 0
        shuffle(self.bucket_iterator_indices)
        
        # shuffle index_bucket mapping
        # shuffle examples within a bucket

#    def reset(self):
#         self.curr_idx = 0
#         random.shuffle(self.idx)
#         for buck in self.data:
#             np.random.shuffle(buck)

#         self.nddata = []
#         self.ndlabel = []
#         for buck in self.data:
#             label = np.empty_like(buck)
#             label[:, :-1] = buck[:, 1:]
#             label[:, -1] = self.invalid_label
#             self.nddata.append(ndarray.array(buck, dtype=self.dtype))
#             self.ndlabel.append(ndarray.array(label, dtype=self.dtype))    
    
    
    # iterate over data
    def next(self):
        # raise StopIteration when done
        pass        

#     def next(self):
#         if self.curr_idx == len(self.idx):
#             raise StopIteration
#         i, j = self.idx[self.curr_idx]
#         self.curr_idx += 1

#         if self.major_axis == 1:
#             data = self.nddata[i][j:j+self.batch_size].T
#             label = self.ndlabel[i][j:j+self.batch_size].T
#         else:
#             data = self.nddata[i][j:j+self.batch_size]
#             label = self.ndlabel[i][j:j+self.batch_size]

#         return DataBatch([data], [label], pad=0,
#                          bucket_key=self.buckets[i],
#                          provide_data=[(self.data_name, data.shape)],
#                          provide_label=[(self.label_name, label.shape)])    
    
    @staticmethod 
    def filter_long_sent(src_sent, targ_sent, max_len):
        result = filter(lambda x: len(x[0]) <= max_len and len(x[1]) <= max_len, zip(src_sent, targ_sent))
        return zip(*result)

    @staticmethod
    def gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
        length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
        counts = list(Counter(length_pairs).items())
        c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
        buckets = [i[0] for i in c_sorted if i[1] >= filter_smaller_counts_than and 
                   (max_sent_len is None or i[0][0] <= max_sent_len) and
                   (max_sent_len is None or i[0][1] <= max_sent_len) and
                   (min_sent_len is None or i[0][0] >= min_sent_len) and
                   (min_sent_len is None or i[0][1] >= min_sent_len)]
        return buckets

In [100]:
i1 = Seq2SeqIterator(dataset)

In [98]:
# 
# i1.bucket_idx_to_key[10]
# i1.bucketed_data[10]
# i1.bucket_idx_to_key[10]
# len(i2.bucket_idx_to_key)
i2.bucket_idx_to_key

[(1, 1),
 (1, 6),
 (1, 11),
 (1, 16),
 (1, 21),
 (1, 26),
 (1, 31),
 (1, 36),
 (1, 41),
 (1, 46),
 (1, 51),
 (1, 56),
 (1, 61),
 (1, 66),
 (6, 1),
 (6, 6),
 (6, 11),
 (6, 16),
 (6, 21),
 (6, 26),
 (6, 31),
 (6, 36),
 (6, 41),
 (6, 46),
 (6, 51),
 (6, 56),
 (6, 61),
 (6, 66),
 (11, 1),
 (11, 6),
 (11, 11),
 (11, 16),
 (11, 21),
 (11, 26),
 (11, 31),
 (11, 36),
 (11, 41),
 (11, 46),
 (11, 51),
 (11, 56),
 (11, 61),
 (11, 66),
 (16, 1),
 (16, 6),
 (16, 11),
 (16, 16),
 (16, 21),
 (16, 26),
 (16, 31),
 (16, 36),
 (16, 41),
 (16, 46),
 (16, 51),
 (16, 56),
 (16, 61),
 (16, 66),
 (21, 1),
 (21, 6),
 (21, 11),
 (21, 16),
 (21, 21),
 (21, 26),
 (21, 31),
 (21, 36),
 (21, 41),
 (21, 46),
 (21, 51),
 (21, 56),
 (21, 61),
 (21, 66),
 (26, 1),
 (26, 6),
 (26, 11),
 (26, 16),
 (26, 21),
 (26, 26),
 (26, 31),
 (26, 36),
 (26, 41),
 (26, 46),
 (26, 51),
 (26, 56),
 (26, 61),
 (26, 66),
 (31, 1),
 (31, 6),
 (31, 11),
 (31, 16),
 (31, 21),
 (31, 26),
 (31, 31),
 (31, 36),
 (31, 41),
 (31, 46),
 (31, 51

In [101]:
src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
# max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in range(
        min_len,max_len+increment,increment
    ) for j in range(
        min_len,max_len+increment,increment
    )]

In [102]:
i2 = Seq2SeqIterator(dataset, buckets=all_pairs)