In [1]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from itertools import groupby
from mxnet.io import DataIter
from random import shuffle

import deepdish as dd

import operator
import pickle
import re
import warnings

  chunks = self.iterencode(o, _one_shot=True)


In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple(
    'Dataset', 
    ['src_sent', 'src_vocab', 'inv_src_vocab', 'targ_sent', 'targ_vocab', 'inv_targ_vocab'])

def invert_dict(d):
    return {v: k for k, v in d.iteritems()}


def get_data(src_path, targ_path, start_label=1, invalid_label=0, pad_symbol='<PAD>'):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    src_vocab[pad_symbol] = invalid_label
    inv_src_vocab = invert_dict(src_vocab)

    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)
    
    targ_vocab[pad_symbol] = invalid_label
    inv_targ_vocab = invert_dict(targ_vocab)
    
    return Dataset(
        src_sent=src_sent, src_vocab=src_vocab, inv_src_vocab=inv_src_vocab,
        targ_sent=targ_sent, targ_vocab=targ_vocab, inv_targ_vocab=inv_targ_vocab)

In [4]:
def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [5]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_small',
        targ_path='./data/europarl-v7.es-en.es_small',
        start_label=1,
        invalid_label=0
    )

In [71]:
class Seq2SeqIterator:    

    class TwoDBisect:
        def __init__(self, buckets):
            self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
            self.x, self.y = zip(*buckets)
            self.x, self.y = np.array(list(self.x)), np.array(list(self.y))

        def twod_bisect(self, source, target):    
            offset1 = np.searchsorted(self.x, len(source), side='left')
            offset2 = np.where(self.y[offset1:] >= len(target))[0]        
            return self.buckets[offset1 + offset2[0]]     
    
    def __init__(self, dataset, buckets=None, batch_size=32, max_sent_len=None):
        self.batch_size = batch_size
        self.src_sent = dataset.src_sent
        self.targ_sent = dataset.targ_sent
        if buckets:
            z = zip(*buckets)
            self.max_sent_len = max(max(z[0]), max(z[1]))
        else:
            self.max_sent_len = max_sent_len
        if self.max_sent_len:
            self.src_sent, self.targ_sent = self.filter_long_sent(
                self.src_sent, self.targ_sent, self.max_sent_len) 
        self.src_vocab = dataset.src_vocab
        self.targ_vocab = dataset.targ_vocab
        self.inv_src_vocab = dataset.inv_src_vocab
        self.inv_targ_vocab = dataset.inv_targ_vocab
        # Can't filter smaller counts per bucket if those sentences still exist!
        self.buckets = buckets if buckets else self.gen_buckets(
            self.src_sent, self.targ_sent, filter_smaller_counts_than=1, max_sent_len=max_sent_len)
        self.bisect = Seq2SeqIterator.TwoDBisect(self.buckets)
        self.max_sent_len = max_sent_len
        self.pad_id = self.src_vocab['<PAD>']
        # After bucketization, we should probably del self.src_sent and self.targ_sent
        # to free up memory.
        self.bucketed_data, self.bucket_idx_to_key = self.bucketize()
        self.bucket_key_to_idx = invert_dict(dict(enumerate(self.bucket_idx_to_key)))
        self.interbucket_idx = 0
        self.chunk_idx = 0
        self.curr_chunks = None
        self.curr_buck = None
        self.num_buckets = len(self.bucket_idx_to_key)
        self.bucket_iterator_indices = list(range(self.num_buckets))

    
    def bucketize(self):
        tuples = []
        ctr = 0
        for src, targ in zip(self.src_sent, self.targ_sent):
            len_tup = self.bisect.twod_bisect(src, targ)
            rev_src = src[::-1] 
            tuples.append((src, targ, len_tup))
            
        sorted_tuples = sorted(tuples, key=operator.itemgetter(2))
        grouped = groupby(sorted_tuples, lambda x: x[2])
        bucketed_data = [] 
        bucket_idx_to_key = []
        
        for group in grouped:
            
            # get src and targ sentences, ignore the last elem of the tuple 
            # (the grouping key of (src_len, targ_len))
            key, value = group[0], map(lambda x: x[:2], group[1])

            # create padded representation
            new_src = np.full((len(value), key[0]), self.pad_id, dtype=np.int32)
            new_targ = np.full((len(value), key[1]), self.pad_id, dtype=np.int32)
            
            for idx, example in enumerate(value):
                curr_src, curr_targ = example
                rev_src = curr_src[::-1]
                new_src[idx, :-(len(rev_src)+1):-1] = rev_src
                new_targ[idx, :len(curr_targ)] = curr_targ
                                
            bucketed_data.append((new_src, new_targ))
            bucket_idx_to_key.append(key)
        return bucketed_data, bucket_idx_to_key
    
    def current_bucket_key(self):
        return self.bucket_idx_to_key[self.interbucket_idx]
    
    def current_bucket_index(self):
        return self.bucket_iterator_indices[self.interbucket_idx]

    # shuffle the data within buckets, and reset iterator
    def reset(self):
        self.interbucket_idx = 0
        for idx in xrange(len(self.bucketed_data)):
            current = self.bucketed_data[idx]
            src, targ = current
            indices = np.array(range(len(src)))
            np.random.shuffle(indices)
            src = src[indices]
            targ = targ[indices]
            self.bucketed_data[idx] = (src, targ)
        shuffle(self.bucket_iterator_indices)
 
    @staticmethod
    def chunks(iterable, batch_size, trim_incomplete_batches=True):
        n = max(1, batch_size)
        end = len(iterable)/n*n if trim_incomplete_batches else len(iterable)
        return list(iterable[i:i+n] for i in xrange(0, end, n))

    # iterate over data
    def next(self):
        try:
            if self.chunk_idx == 0:
                self.curr_buck = self.bucketed_data[self.bucket_iterator_indices[self.interbucket_idx]]
                buck_len = len(self.curr_buck[0])     
                self.curr_chunks = self.chunks(range(buck_len), self.batch_size)
            current = self.curr_chunks.next()
            src_ex = buck[0][current]
            targ_ex = buck[1][current]
            return (src_ex, targ_ex)
        except StopIteration as si:
            if bucket_idx == self.num_buckets - 1:
                self.reset()
                raise si
            else:
                self.interbucket_idx += 1
                self.chunk_idx = 0
    # interbucket_idx / num_buckets / chunk_idx
        
#         for bucket_idx in xrange(self.num_buckets):
#             print("bucket_idx: %d" % bucket_idx)
#             buck = self.bucketed_data[self.bucket_iterator_indices[bucket_idx]]
#             buck_len = len(buck[0])
#             buck_chx = chunks(range(buck_len), self.batch_size)
            
#             try:
#                 buck_chk = buck_chx.next()
#                 print(buck_chk)
#                 src_ex = buck[0][buck_chk]
#                 targ_ex = buck[1][buck_chk]
#                 yield (src_ex, targ_ex)
#             except StopIteration as si:
#                 if bucket_idx == self.num_buckets - 1:
#                     raise si
#             for buck_chk in buck_chx:
#                 src_ex = buck[0][buck_chk]
#                 targ_ex = buck[1][buck_chk]
#                 yield (src_ex, targ_ex)
# #             if bucket_idx == self.num_buckets - 1:
# #                 raise StopIteration
#             return (src_ex, targ_ex)
            
        
        # pick example from current group if not done
        # if done, move to the next group 
        # if the end of the last group, raise StopError
        
        # raise StopIteration when done       

    
# From my iterator:

# def iterate_groups(groups, batch_size=32):
    
#     def chunks(l, n, trim_incomplete_batches=True):
#         n = max(1, n)
#         end = len(l)/n*n if trim_incomplete_batches else len(l)
#         return (l[i:i+n] for i in xrange(0, end, n))
    
#     for key, group in groups.items():
#         num_examples = len(group[0])
#         indices = list(xrange(num_examples))
#         shuffle(indices)
#         src_sent, targ_sent = group
#         src_sent = src_sent[indices]
#         targ_sent = targ_sent[indices]        
#         for chunk in chunks(list(xrange(num_examples)), batch_size):
#             yield [src_sent[chunk], targ_sent[chunk]]
    
# From BucketSentenceIterator:    
    
#     def next(self):
#         if self.curr_idx == len(self.idx):
#             raise StopIteration
#         i, j = self.idx[self.curr_idx]
#         self.curr_idx += 1

#         if self.major_axis == 1:
#             data = self.nddata[i][j:j+self.batch_size].T
#             label = self.ndlabel[i][j:j+self.batch_size].T
#         else:
#             data = self.nddata[i][j:j+self.batch_size]
#             label = self.ndlabel[i][j:j+self.batch_size]

#         return DataBatch([data], [label], pad=0,
#                          bucket_key=self.buckets[i],
#                          provide_data=[(self.data_name, data.shape)],
#                          provide_label=[(self.label_name, label.shape)])    
    
    @staticmethod 
    def filter_long_sent(src_sent, targ_sent, max_len):
        result = filter(lambda x: len(x[0]) <= max_len and len(x[1]) <= max_len, zip(src_sent, targ_sent))
        return zip(*result)

    @staticmethod
    def gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
        length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
        counts = list(Counter(length_pairs).items())
        c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
        buckets = [i[0] for i in c_sorted if i[1] >= filter_smaller_counts_than and 
                   (max_sent_len is None or i[0][0] <= max_sent_len) and
                   (max_sent_len is None or i[0][1] <= max_sent_len) and
                   (min_sent_len is None or i[0][0] >= min_sent_len) and
                   (min_sent_len is None or i[0][1] >= min_sent_len)]
        return buckets

In [72]:
i1 = Seq2SeqIterator(dataset)

In [75]:
i1.next()

AttributeError: 'list' object has no attribute 'next'

In [58]:
i1.reset()
i1.next()

bucket_idx: 0
bucket_idx: 1
bucket_idx: 2
bucket_idx: 3
bucket_idx: 4
bucket_idx: 5


(array([[ 2255,     3,    65, ...,    27, 13654,    36],
        [   99,   100,    16, ...,  3384,   184,    36],
        [   99,   100,    16, ...,    17, 20446,    36],
        ..., 
        [   99,   100,    16, ...,  2261,   860,    36],
        [ 1269,  1014,   460, ...,    43,   235,    36],
        [  884,    64,  5385, ...,   144,    62,    36]], dtype=int32),
 array([[ 6954,   214,    24, ...,    39,  3134,    30],
        [  899,  2073,    11, ...,     4,   625,    30],
        [  723,    24,    49, ...,   103, 18480,    30],
        ..., 
        [  129,  5874,  3955, ...,     8,   594,   129],
        [ 2578,    70,  9320, ...,    41, 15982,    30],
        [ 1362,   120,   706, ...,    70,   632,    30]], dtype=int32))

In [14]:
i1.reset()
i1.bucketed_data[10]

(array([[11315],
        [   36],
        [   36],
        [ 4688],
        [14993],
        [   36],
        [   36],
        [   36],
        [17233],
        [   36],
        [ 9511]], dtype=int32),
 array([[  503, 29654,   835,   319,  1712,  3436,   161,   152,  3467,
            66,  2577,    30],
        [ 1112,    24,  4157,   148,   143,    66,     8,  1062,     4,
            70,   469,    30],
        [18936,  4189,    24,    78,    24,   802,   214,    24, 19004,
            74,  8430,    30],
        [  595,   784, 10958,     8,  3112,  1952,  4940,   273,  1954,
            17,  1953,    30],
        [   98, 25750,   653,   290,   103,  5544,  1221,    24,     8,
         25750,  2827,    30],
        [   37,   595,    11,    37,   285,  8571,    11,    37, 31222,
           774,    30,    37],
        [   97,   938,  3484,    70,   363,    19,    86,   617,    30,
          1357,  1576,   100],
        [   98,  2887,   214,   619,   120,  1763,     4,  2314,     4,
     

In [None]:
src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
# max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in xrange(
        min_len,max_len+increment,increment
    ) for j in xrange(
        min_len,max_len+increment,increment
    )]

In [None]:
i2 = Seq2SeqIterator(dataset, buckets=all_pairs)

In [None]:
src, targ = i1.bucketed_data[10]

In [None]:
indices = np.array(range(len(src)))
np.random.shuffle(indices)
indices

In [None]:
src, targ = i1.bucketed_data[10]

In [None]:
targ[indices]