In [1]:
import numpy as np
import mxnet as mx
from collections import namedtuple, Counter
from unidecode import unidecode
from itertools import groupby
from mxnet.io import DataIter
from random import shuffle

import deepdish as dd

import operator
import pickle
import re
import warnings

  chunks = self.iterencode(o, _one_shot=True)


In [2]:
# Get rid of annoying Python deprecation warnings from built-in JSON encoder
warnings.filterwarnings("ignore", category=DeprecationWarning)   

In [3]:
# Decode text as UTF-8
# Remove diacritical signs and convert to Latin alphabet
# Separate punctuation as separate "words"
def tokenize_text(fname, vocab=None, invalid_label=0, start_label=1, sep_punctuation=True):
    lines = unidecode(open(fname).read().decode('utf-8')).split('\n')
    lines = [x for x in lines if x]
    lines = map(lambda x: re.findall(r"\w+|[^\w\s]", x, re.UNICODE), lines)    
    sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
    return sentences, vocab

Dataset = namedtuple('Dataset', ['src_sent', 'src_vocab', 'targ_sent', 'targ_vocab'])

def get_data(src_path, targ_path, start_label=1, invalid_label=0):
    src_sent, src_vocab = tokenize_text(src_path, start_label=start_label,
                                invalid_label=invalid_label)
    
    targ_sent, targ_vocab = tokenize_text(targ_path, start_label=start_label, #new_start+1,
                                          invalid_label=invalid_label)

    return Dataset(src_sent=src_sent, src_vocab=src_vocab, targ_sent=targ_sent, targ_vocab=targ_vocab)

def persist_dataset(dataset, path):
    with open(path, 'wb+') as fileobj:
        pickle.dump(dataset, fileobj)
        
def load_dataset(path):
    with open(path, 'rb') as fileobj:
        return pickle.load(fileobj)

In [4]:
dataset = \
    get_data(
        src_path='./data/europarl-v7.es-en.en_small',
        targ_path='./data/europarl-v7.es-en.es_small',
        start_label=1,
        invalid_label=0
    )

In [43]:
class Seq2SeqIterator:    
    
    class TwoDBisect:
        def __init__(self, buckets):
            self.buckets = sorted(buckets, key=operator.itemgetter(0, 1))
            self.x, self.y = zip(*buckets)
            self.x, self.y = np.array(list(self.x)), np.array(list(self.y))

        def twod_bisect(self, source, target):    
            offset1 = np.searchsorted(self.x, len(source), side='left')
            offset2 = np.where(self.y[offset1:] >= len(target))[0]        
            return self.buckets[offset1 + offset2[0]] 
    
    def __init__(self, dataset, buckets=None, batch_size=32):
        self.buckets = buckets if buckets else self.gen_buckets(
            dataset, filter_smaller_counts_than=batch_size)            
    
    @staticmethod
    def gen_buckets(dataset, filter_smaller_counts_than=None, max_sent_len=60, min_sent_len=1):
        src_sent = dataset.src_sent
        targ_sent = dataset.targ_sent
        length_pairs = map(lambda x: (len(x[0]), len(x[1])), zip(src_sent, targ_sent))
        counts = list(Counter(length_pairs).items())
        c_sorted = sorted(counts, key=operator.itemgetter(0, 1))
        buckets = [i[0] for i in c_sorted if i[1] >= filter_smaller_counts_than and 
                   (max_sent_len is None or i[0][0] <= max_sent_len) and
                   (max_sent_len is None or i[0][1] <= max_sent_len) and
                   (min_sent_len is None or i[0][0] >= min_sent_len) and
                   (min_sent_len is None or i[0][1] >= min_sent_len)]
        return buckets
    

In [44]:
i1 = Seq2SeqIterator(dataset)

In [38]:
src_sent = dataset.src_sent
targ_sent = dataset.targ_sent

sent_len = lambda x: map(lambda y: len(y), x)
max_len = lambda x: max(sent_len(x))
min_len = lambda x: min(sent_len(x))

min_len = min(min(sent_len(src_sent)), min(sent_len(targ_sent)))
# max_len = max(max(sent_len(src_sent)), max(sent_len(targ_sent)))

# min_len = min
max_len = 65
increment = 5

all_pairs = [(i, j) for i in range(
        min_len,max_len+increment,increment
    ) for j in range(
        min_len,max_len+increment,increment
    )]

In [39]:
i2 = Seq2SeqIterator(dataset, all_pairs)

In [42]:
i2.buckets

[(1, 1),
 (1, 6),
 (1, 11),
 (1, 16),
 (1, 21),
 (1, 26),
 (1, 31),
 (1, 36),
 (1, 41),
 (1, 46),
 (1, 51),
 (1, 56),
 (1, 61),
 (1, 66),
 (6, 1),
 (6, 6),
 (6, 11),
 (6, 16),
 (6, 21),
 (6, 26),
 (6, 31),
 (6, 36),
 (6, 41),
 (6, 46),
 (6, 51),
 (6, 56),
 (6, 61),
 (6, 66),
 (11, 1),
 (11, 6),
 (11, 11),
 (11, 16),
 (11, 21),
 (11, 26),
 (11, 31),
 (11, 36),
 (11, 41),
 (11, 46),
 (11, 51),
 (11, 56),
 (11, 61),
 (11, 66),
 (16, 1),
 (16, 6),
 (16, 11),
 (16, 16),
 (16, 21),
 (16, 26),
 (16, 31),
 (16, 36),
 (16, 41),
 (16, 46),
 (16, 51),
 (16, 56),
 (16, 61),
 (16, 66),
 (21, 1),
 (21, 6),
 (21, 11),
 (21, 16),
 (21, 21),
 (21, 26),
 (21, 31),
 (21, 36),
 (21, 41),
 (21, 46),
 (21, 51),
 (21, 56),
 (21, 61),
 (21, 66),
 (26, 1),
 (26, 6),
 (26, 11),
 (26, 16),
 (26, 21),
 (26, 26),
 (26, 31),
 (26, 36),
 (26, 41),
 (26, 46),
 (26, 51),
 (26, 56),
 (26, 61),
 (26, 66),
 (31, 1),
 (31, 6),
 (31, 11),
 (31, 16),
 (31, 21),
 (31, 26),
 (31, 31),
 (31, 36),
 (31, 41),
 (31, 46),
 (31, 51