In [1]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from itertools import groupby
from mxnet.io import DataIter
from random import shuffle

import deepdish as dd

import operator
import pickle
import re
import warnings

  chunks = self.iterencode(o, _one_shot=True)


In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple(
    'Dataset', 
    ['src_sent', 'src_vocab', 'inv_src_vocab', 'targ_sent', 'targ_vocab', 'inv_targ_vocab'])

def invert_dict(d):
    return {v: k for k, v in d.iteritems()}


def get_data(src_path, targ_path, start_label=1, invalid_label=0, pad_symbol='<PAD>'):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    src_vocab[pad_symbol] = invalid_label
    inv_src_vocab = invert_dict(src_vocab)

    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)
    
    targ_vocab[pad_symbol] = invalid_label
    inv_targ_vocab = invert_dict(targ_vocab)
    
    return Dataset(
        src_sent=src_sent, src_vocab=src_vocab, inv_src_vocab=inv_src_vocab,
        targ_sent=targ_sent, targ_vocab=targ_vocab, inv_targ_vocab=inv_targ_vocab)

In [4]:
def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [5]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_small',
        targ_path='./data/europarl-v7.es-en.es_small',
        start_label=1,
        invalid_label=0
    )

In [148]:
class Seq2SeqIterator:    

    class TwoDBisect:
        def __init__(self, buckets):
            self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
            self.x, self.y = zip(*buckets)
            self.x, self.y = np.array(list(self.x)), np.array(list(self.y))

        def twod_bisect(self, source, target):    
            offset1 = np.searchsorted(self.x, len(source), side='left')
            offset2 = np.where(self.y[offset1:] >= len(target))[0]        
            return self.buckets[offset1 + offset2[0]]     
    
    def __init__(self, dataset, buckets=None, batch_size=32, max_sent_len=None):
        self.batch_size = batch_size
        self.src_sent = dataset.src_sent
        self.targ_sent = dataset.targ_sent
        if buckets:
            z = zip(*buckets)
            self.max_sent_len = max(max(z[0]), max(z[1]))
        else:
            self.max_sent_len = max_sent_len
        if self.max_sent_len:
            self.src_sent, self.targ_sent = self.filter_long_sent(
                self.src_sent, self.targ_sent, self.max_sent_len) 
        self.src_vocab = dataset.src_vocab
        self.targ_vocab = dataset.targ_vocab
        self.inv_src_vocab = dataset.inv_src_vocab
        self.inv_targ_vocab = dataset.inv_targ_vocab
        # Can't filter smaller counts per bucket if those sentences still exist!
        self.buckets = buckets if buckets else self.gen_buckets(
            self.src_sent, self.targ_sent, filter_smaller_counts_than=1, max_sent_len=max_sent_len)
        self.bisect = Seq2SeqIterator.TwoDBisect(self.buckets)
        self.max_sent_len = max_sent_len
        self.pad_id = self.src_vocab['<PAD>']
        # After bucketization, we should probably del self.src_sent and self.targ_sent
        # to free up memory.
        self.bucketed_data, self.bucket_idx_to_key = self.bucketize()
        self.bucket_key_to_idx = invert_dict(dict(enumerate(self.bucket_idx_to_key)))
        self.interbucket_idx = -1
        self.curr_chunks = None
        self.curr_buck = None
        self.switch_bucket = True
        self.num_buckets = len(self.bucket_idx_to_key)
        self.bucket_iterator_indices = list(range(self.num_buckets))

    
    def bucketize(self):
        tuples = []
        ctr = 0
        for src, targ in zip(self.src_sent, self.targ_sent):
            len_tup = self.bisect.twod_bisect(src, targ)
            rev_src = src[::-1] 
            tuples.append((src, targ, len_tup))
            
        sorted_tuples = sorted(tuples, key=operator.itemgetter(2))
        grouped = groupby(sorted_tuples, lambda x: x[2])
        bucketed_data = [] 
        bucket_idx_to_key = []
        
        for group in grouped:
            
            # get src and targ sentences, ignore the last elem of the tuple 
            # (the grouping key of (src_len, targ_len))
            key, value = group[0], map(lambda x: x[:2], group[1])
            if len(value) < self.batch_size:
                continue

            # create padded representation
            new_src = np.full((len(value), key[0]), self.pad_id, dtype=np.int32)
            new_targ = np.full((len(value), key[1]), self.pad_id, dtype=np.int32)
            
            for idx, example in enumerate(value):
                curr_src, curr_targ = example
                rev_src = curr_src[::-1]
                new_src[idx, :-(len(rev_src)+1):-1] = rev_src
                new_targ[idx, :len(curr_targ)] = curr_targ
                                
            bucketed_data.append((new_src, new_targ))
            bucket_idx_to_key.append(key)
        return bucketed_data, bucket_idx_to_key
    
    def current_bucket_key(self):
        return self.bucket_idx_to_key[self.interbucket_idx]
    
    def current_bucket_index(self):
        return self.bucket_iterator_indices[self.interbucket_idx]

    # shuffle the data within buckets, and reset iterator
    def reset(self):
        self.interbucket_idx = -1
        for idx in xrange(len(self.bucketed_data)):
            current = self.bucketed_data[idx]
            src, targ = current
            indices = np.array(range(len(src)))
            np.random.shuffle(indices)
            src = src[indices]
            targ = targ[indices]
            self.bucketed_data[idx] = (src, targ)
        shuffle(self.bucket_iterator_indices)

    # iterate over data
    def next(self):
        try:
            if self.switch_bucket:
                self.interbucket_idx += 1
                bucket_idx = self.bucket_iterator_indices[self.interbucket_idx]
                print("bucket_idx: %d" % bucket_idx)
                print("bucket_tuple: %s" % str(self.bucket_idx_to_key[bucket_idx]))
                self.curr_buck = self.bucketed_data[bucket_idx]
                buck_len = len(self.curr_buck[0])
                print("buck_len: %d" % buck_len)
                self.curr_chunks = self.chunks(range(buck_len), self.batch_size)
                print("curr chunks: %s" % str(self.chunks(range(buck_len), self.batch_size)))
                print("")
                self.switch_bucket = False
            current = self.curr_chunks.next()
            print("inter-bucket idx: %d" % self.interbucket_idx)
            print("bucket size: %d" % len(self.curr_buck[0]))
            print("current indices in bucket: %s" % str(current))
            src_ex = self.curr_buck[0][current]
            targ_ex = self.curr_buck[1][current]
            return (src_ex, targ_ex)
        except StopIteration as si:
            if self.interbucket_idx == self.num_buckets - 1:
                self.reset()
                self.switch_bucket = True
                raise si
            else:
                self.switch_bucket = True
                return self.next()

    @staticmethod
    def chunks(iterable, batch_size, trim_incomplete_batches=True):
        n = max(1, batch_size)
        end = len(iterable)/n*n if trim_incomplete_batches else len(iterable)
        return (iterable[i:i+n] for i in xrange(0, end, n))
    
    @staticmethod 
    def filter_long_sent(src_sent, targ_sent, max_len):
        result = filter(lambda x: len(x[0]) <= max_len and len(x[1]) <= max_len, zip(src_sent, targ_sent))
        return zip(*result)

    @staticmethod
    def gen_buckets(src_sent, targ_sent, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
        length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
        counts = list(Counter(length_pairs).items())
        c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
        buckets = [i[0] for i in c_sorted if i[1] >= filter_smaller_counts_than and 
                   (max_sent_len is None or i[0][0] <= max_sent_len) and
                   (max_sent_len is None or i[0][1] <= max_sent_len) and
                   (min_sent_len is None or i[0][0] >= min_sent_len) and
                   (min_sent_len is None or i[0][1] >= min_sent_len)]
        return buckets

In [77]:
src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in xrange(
        min_len,max_len+increment,increment
    ) for j in xrange(
        min_len,max_len+increment,increment
    )]

In [156]:
i2 = Seq2SeqIterator(dataset, buckets=all_pairs)

In [177]:
foo = i2.next()

inter-bucket idx: 7
bucket size: 192
current indices in bucket: [160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191]


In [186]:
while True:
    foo = i2.next()

inter-bucket idx: 0
bucket size: 270
current indices in bucket: [160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191]
inter-bucket idx: 0
bucket size: 270
current indices in bucket: [192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223]
inter-bucket idx: 0
bucket size: 270
current indices in bucket: [224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
bucket_idx: 162
bucket_tuple: (61, 26)
buck_len: 446
curr chunks: <generator object <genexpr> at 0x7f8d57bccb40>

inter-bucket idx: 1
bucket size: 446
current indices in bucket: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
inter-bucket idx: 1
bucket s

StopIteration: 

In [193]:
foo = i2.next()

inter-bucket idx: 1
bucket size: 121
current indices in bucket: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]


In [194]:
foo

(array([[ 263,   62,  101, ...,  279,   11,   36],
        [3708,   17, 1242, ..., 5909, 1253,   36],
        [   0,    0,    5, ...,  120, 9519,   36],
        ..., 
        [   0,  535,    5, ..., 2846, 2260,   36],
        [   0,    0,   71, ..., 1840,   36,  839],
        [   0,    0,    0, ...,   92, 1923,   36]], dtype=int32),
 array([[  116,  2567,  1693, ...,     0,     0,     0],
        [  213,    24,   214, ...,    30,     0,     0],
        [  300,   539,   200, ...,    30,     0,     0],
        ..., 
        [ 5774,    66,    47, ...,     0,     0,     0],
        [ 1394,     4,    47, ...,  1957, 13868,    30],
        [   45,    47,   607, ...,     0,     0,     0]], dtype=int32))

In [202]:
np.shape(foo[0])

(32, 61)

In [201]:
i2.bucket_idx_to_key[i2.bucket_iterator_indices[1]]

(61, 56)

In [200]:
i2.bucket_iterator_indices

[118,
 168,
 154,
 104,
 101,
 14,
 1,
 160,
 39,
 117,
 128,
 28,
 68,
 16,
 97,
 65,
 40,
 32,
 84,
 112,
 120,
 85,
 150,
 62,
 108,
 146,
 134,
 91,
 58,
 110,
 9,
 80,
 74,
 37,
 13,
 27,
 159,
 38,
 94,
 172,
 139,
 100,
 178,
 105,
 51,
 152,
 162,
 142,
 55,
 63,
 151,
 135,
 109,
 66,
 20,
 166,
 10,
 86,
 111,
 77,
 149,
 183,
 87,
 177,
 157,
 45,
 95,
 153,
 76,
 137,
 106,
 165,
 83,
 82,
 71,
 173,
 4,
 26,
 99,
 121,
 12,
 145,
 18,
 33,
 156,
 81,
 130,
 92,
 103,
 122,
 5,
 56,
 123,
 3,
 133,
 19,
 79,
 50,
 23,
 69,
 124,
 36,
 49,
 70,
 143,
 48,
 41,
 107,
 175,
 98,
 42,
 140,
 75,
 167,
 161,
 31,
 54,
 8,
 52,
 116,
 88,
 89,
 136,
 174,
 138,
 125,
 67,
 158,
 60,
 114,
 119,
 180,
 72,
 102,
 163,
 181,
 131,
 90,
 53,
 30,
 169,
 96,
 46,
 2,
 7,
 127,
 43,
 6,
 44,
 148,
 78,
 171,
 182,
 21,
 35,
 22,
 15,
 59,
 34,
 61,
 25,
 11,
 155,
 126,
 129,
 57,
 147,
 141,
 0,
 115,
 170,
 113,
 64,
 17,
 164,
 132,
 29,
 93,
 24,
 144,
 179,
 176,
 47,
 73]