# tf.data Dataset for Neural Machine Translation from json-lines

In [1]:
import tensorflow as tf
import json
from collections import Counter

This notebooks shows a simple example how to use the tf.data API to write an input pipeline from a json-lines file for model training. 

> Note: For simplicity, this notebook uses a functional programming style. It often makes sense to use a object-oriented style instead.

The dataset containts English-German language tuples. Each line in the data looks like this:

In [2]:
data_file = 'data/jsonl/data_train.jsonl'

with open(data_file) as f:
    line = next(f)
    print(json.dumps(json.loads(line), indent=4))

{
    "src": [
        "Two",
        "young",
        ",",
        "White",
        "males",
        "are",
        "outside",
        "near",
        "many",
        "bushes",
        "."
    ],
    "tgt": [
        "Zwei",
        "junge",
        "wei\u00dfe",
        "M\u00e4nner",
        "sind",
        "im",
        "Freien",
        "in",
        "der",
        "N\u00e4he",
        "vieler",
        "B\u00fcsche",
        "."
    ]
}


## Creating the Stoi (string-to-integer) vocabs

#### Native python dicts

We first create the lookup-tables to map from words to integers. This should be done ahead of training and stored to a file to not slow down the preparation of the dataset, especially for larger datasets.

In [3]:
src_vocab = Counter()
tgt_vocab = Counter()

with open(data_file, 'r') as f:
    for line in f:
        obj = json.loads(line)
        src_vocab += Counter(obj['src'])
        tgt_vocab += Counter(obj['tgt'])

print("Most common tokens in source (English):")
print([w for w, _ in src_vocab.most_common(5)])

print("Most common tokens in target (German):")
print([w for w, _ in tgt_vocab.most_common(5)])

Most common tokens in source (English):
['a', '.', 'A', 'in', 'the']
Most common tokens in target (German):
['.', 'Ein', ',', 'einem', 'mit']


Following the best practives for language datasets, we use paddings, a token for unknowns, as well as start and end of sequence tokens. Note that we only start enumerating the vocabs at 4 since with use:

    <PAD> = 0  # Padding index
    <UNK> = 1  # Index for unknown tokens
    <SOS> = 2  # Start of sequence index
    <EOS> = 3  # End of sequence index

In [4]:
src_vocab_stoi = {word: c for c, (word, _) in enumerate(src_vocab.most_common(len(src_vocab)), 4)}
tgt_vocab_stoi = {word: c for c, (word, _) in enumerate(tgt_vocab.most_common(len(tgt_vocab)), 4)}

src_vocab_stoi['one'], tgt_vocab_stoi['Nähe']

(66, 87)

In [5]:
print("Length of the src vocab: {:,}\nLength of the tgt vocab: {:,}".format(len(src_vocab_stoi), len(tgt_vocab_stoi)))

Length of the src vocab: 4,921
Length of the tgt vocab: 6,818


#### Tensorflow lookups

We use Tensorflow's `StaticHashTable` to create a tensor-style lookup table. It is basically a dict for tensorflow.

In [6]:
def tf_vocab_from_dict(vocab, oov_index=1):
    return tf.lookup.StaticHashTable(
                initializer=tf.lookup.KeyValueTensorInitializer(
                    keys=tf.constant(list(vocab.keys())),
                    values=tf.constant(list(vocab.values()))),
                default_value=tf.constant(oov_index))

tf_src_vocab_stoi = tf_vocab_from_dict(src_vocab_stoi)
tf_tgt_vocab_stoi = tf_vocab_from_dict(tgt_vocab_stoi)

In [7]:
tf_src_vocab_stoi.lookup(tf.constant(['one', 'two']))  # One is index 66, two is index 77

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([66, 77], dtype=int32)>

In [8]:
tf_tgt_vocab_stoi.lookup(tf.constant(['Nähe', 'Distanz']))  # Nähe is index 87, Distanz is not in Vocab

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([87,  1], dtype=int32)>

## Step-by-step: Creating the tf.data.Dataset

Reading from the json-lines format, we have to first read and decode the data before feeding it into tf.data. The generator-function below reads the data-file and yields a `(source string, target string)` tuple.

In [9]:
def json_generator(file_name):
    with open(file_name, 'r') as f:
        for line in f:
            obj = json.loads(line)
            yield " ".join(obj['src']), " ".join(obj['tgt'])

In [10]:
gen = json_generator(data_file)

In [11]:
next(iter(gen))

('Two young , White males are outside near many bushes .',
 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche .')

#### Step 1: Reading from Generator

In [12]:
dataset_1 = tf.data.Dataset.from_generator(
    lambda: map(tuple, json_generator(data_file)), 
    output_types=(tf.string, tf.string))

In [13]:
next(iter(dataset_1))

(<tf.Tensor: shape=(), dtype=string, numpy=b'Two young , White males are outside near many bushes .'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'Zwei junge wei\xc3\x9fe M\xc3\xa4nner sind im Freien in der N\xc3\xa4he vieler B\xc3\xbcsche .'>)

#### Step 2: Preprocessing the strings

In [14]:
src_stoi = tf_src_vocab_stoi
tgt_stoi = tf_tgt_vocab_stoi
sos_index = tf.constant([2])
eos_index = tf.constant([3])

def preprocess_function(src, tgt):
    # Breaking string into tokens
    src = tf.strings.split(tf.expand_dims(src, axis=0), sep=' ')[0]
    tgt = tf.strings.split(tf.expand_dims(tgt, axis=0), sep=' ')[0]
    
    # Converting strings to ints
    src = src_stoi.lookup(src)
    tgt = tgt_stoi.lookup(tgt)
    
    # Adding SOS and EOS
    src = tf.concat([sos_index, src, eos_index], axis=0)
    tgt = tf.concat([sos_index, tgt, eos_index], axis=0)
    
    # Length of the src sequence
    src_len = tf.expand_dims(tf.shape(src)[0], axis=0)
    tgt_len = tf.expand_dims(tf.shape(tgt)[0], axis=0)

    return src, tgt, src_len, tgt_len

In [15]:
dataset_2 = dataset_1.map(preprocess_function)

In [16]:
next(iter(dataset_2))

(<tf.Tensor: shape=(13,), dtype=int32, numpy=
 array([   2,   18,   27,   15,  833,  699,   16,   58,   78,  389, 1042,
           5,    3], dtype=int32)>,
 <tf.Tensor: shape=(15,), dtype=int32, numpy=
 array([   2,   22,  118,  201,   35,  100,   21,   86,    9,   14,   87,
        1841, 2729,    4,    3], dtype=int32)>,
 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([13], dtype=int32)>,
 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([15], dtype=int32)>)

#### Step 3: Batching and Padding

In [17]:
batch_size = 4

In [18]:
dataset_3 = dataset_2.padded_batch(batch_size, padded_shapes=([None], [None], [1], [1]))

In [19]:
next(iter(dataset_3))

(<tf.Tensor: shape=(4, 17), dtype=int32, numpy=
 array([[   2,   18,   27,   15,  833,  699,   16,   58,   78,  389, 1042,
            5,    3,    0,    0,    0,    0],
        [   2,  184,   38,    7,  307,  298,   16,  760,    4,  761, 1465,
         1863,    5,    3,    0,    0,    0],
        [   2,    6,   52,   30,  172,   61,    4,  198, 2575,    5,    3,
            0,    0,    0,    0,    0,    0],
        [   2,    6,   10,    7,    4,   36,   25,   11,   35,    9,    4,
          526,  605,    4,  199,    5,    3]], dtype=int32)>,
 <tf.Tensor: shape=(4, 17), dtype=int32, numpy=
 array([[   2,   22,  118,  201,   35,  100,   21,   86,    9,   14,   87,
         1841, 2729,    4,    3,    0,    0],
        [   2,  107,   35,    8,  651, 1141,   16, 2730,    4,    3,    0,
            0,    0,    0,    0,    0,    0],
        [   2,    5,   72,   26,  174,    9,   16, 2731,   55,  395,    4,
            3,    0,    0,    0,    0,    0],
        [   2,    5,   12,    9,    7,   

## Putting it all together

### Simple dataloader

In [20]:
batch_size = 16

In [21]:
dataset = tf.data.Dataset.from_generator(
                lambda: map(tuple, json_generator(data_file)), 
                output_types=(tf.string,)*2)\
            .map(preprocess_function)\
            .padded_batch(batch_size, padded_shapes=([None], [None], [1], [1]))

In [22]:
next(iter(dataset))

(<tf.Tensor: shape=(16, 19), dtype=int32, numpy=
 array([[   2,   18,   27,   15,  833,  699,   16,   58,   78,  389, 1042,
            5,    3,    0,    0,    0,    0,    0,    0],
        [   2,  184,   38,    7,  307,  298,   16,  760,    4,  761, 1465,
         1863,    5,    3,    0,    0,    0,    0,    0],
        [   2,    6,   52,   30,  172,   61,    4,  198, 2575,    5,    3,
            0,    0,    0,    0,    0,    0,    0,    0],
        [   2,    6,   10,    7,    4,   36,   25,   11,   35,    9,    4,
          526,  605,    4,  199,    5,    3,    0,    0],
        [   2,   18,   38,   16,   20,    8,  762,  324,  132,    5,    3,
            0,    0,    0,    0,    0,    0,    0,    0],
        [   2,    6,   10,    7,   53,  135,    4,  105,   29,    8,   90,
           10, 1466,   33,   25,    5,    3,    0,    0],
        [   2,    6,   10,   11,  126,   20,    4,  763, 1864,    3,    0,
            0,    0,    0,    0,    0,    0,    0,    0],
        [   2,    6,

### Adding shuffling, prefetching, caching, multiprocessing

In [23]:
dataset = tf.data.Dataset.from_generator(
                lambda: map(tuple, json_generator(data_file)), 
                output_types=(tf.string,)*2)\
            .map(preprocess_function, num_parallel_calls=4)\
            .cache()\
            .shuffle(buffer_size=10000)\
            .padded_batch(4, padded_shapes=([None], [None], [1], [1]))\
            .prefetch(4)

In [24]:
it = iter(dataset)

In [25]:
next(it)

(<tf.Tensor: shape=(4, 16), dtype=int32, numpy=
 array([[   2,   18,  236,   60,  342,   54,  680,    5,    3,    0,    0,
            0,    0,    0,    0,    0],
        [   2,   18,  857,  705,  121,   95,    4, 1616,  176, 2944,    7,
            4,  978,  176,    5,    3],
        [   2,    6,   50,  143,  521,  178,    4, 4536,  726,  490,    5,
            3,    0,    0,    0,    0],
        [   2,   46,  882,  218,   16,    7,    8, 1075,   13,    4, 1683,
           53,  165,    5,    3,    0]], dtype=int32)>,
 <tf.Tensor: shape=(4, 15), dtype=int32, numpy=
 array([[   2,   22,  198,   64, 2164,   37,  106,    4,    3,    0,    0,
            0,    0,    0,    0],
        [   2,   22,  807,  300,   96,   13, 1561,  819,    9,    7, 2086,
          190,  179,    4,    3],
        [   2,    5,   42,   71,  713,   51,   19, 6052,  676,    4,    3,
            0,    0,    0,    0],
        [   2,  124,  924, 1837,   21,  486,   45, 6622,   94, 2121,    4,
            3,    0,    0,

### Adding steps to create batches of equal sizes

In [26]:
def sort_batch(src, tgt, src_len, tgt_len):
    """ Takes a large batch and sorts it descendingly according to the src length """
    sorting_order = tf.argsort(tf.squeeze(src_len, axis=-1), direction='DESCENDING')
    src = tf.gather(src, sorting_order, axis=0)
    tgt = tf.gather(tgt, sorting_order, axis=0)
    src_len = tf.gather(src_len, sorting_order, axis=0)
    return src, tgt, src_len, tgt_len

def cut_length(src, tgt, src_len, tgt_len):
    """ Cuts off padding fomr a padded sequence """
    src = src[:tf.squeeze(src_len)]
    tgt = tgt[:tf.squeeze(tgt_len)]
    return src, tgt, src_len, tgt_len

In [27]:
dataset = tf.data.Dataset.from_generator(
                lambda: map(tuple, json_generator(data_file)), 
                output_types=(tf.string,)*2)\
            .map(preprocess_function, num_parallel_calls=4)\
            .cache()\
            .shuffle(buffer_size=10000)\

# Create buffer to be sorted and pad to max length
dataset = dataset.padded_batch(10000, padded_shapes=([None], [None], [1], [1]))

# Sort buffer by length
dataset = dataset.map(
    sort_batch,
    num_parallel_calls=4)

# Resolve buffer and cut each element to its original length
dataset = dataset.unbatch()
dataset = dataset.map(
    cut_length,
    num_parallel_calls=4)

dataset = dataset\
            .padded_batch(batch_size, padded_shapes=([None], [None], [1], [1]))\
            .prefetch(4)

In [28]:
it = iter(dataset)

In [29]:
next(it)

(<tf.Tensor: shape=(16, 40), dtype=int32, numpy=
 array([[   2,    6,   10,    7,    4,  742,  308,   82,   76,   19,  596,
          675,   19,  269,    4,  742, 3424,   43, 1371,   22,   15,   14,
            4, 3425,    7,    8,  996,   15,   12,   90,  742, 3426,   12,
          742,  954,    7,    8,  113,    5,    3],
        [   2,   59,   27,  400,   16,   32,   92,   15,   14,    8,   30,
         1929,   19, 1080,   66,   13,    8,  391,    7,    8,  168,   29,
          335,   15,   12,    8,  194,   95,  140,  111,  326,  277,  100,
            7, 2665,    5,    3,    0,    0,    0],
        [   2,    6,   17,    7,    4,   53,   25,   12,  215, 1482,    9,
            4, 1924,   14,  433,    7,  142,   15,   49,  606,  337,   49,
           12,    4,   55,  543,   95,   49,   15,   73,  364, 2660,   75,
         1925,    5,    3,    0,    0,    0,    0],
        [   2,    6,  124,  154,    7,    4,   53,  308,   14,    4,   64,
            7,   11,  159,   11,  208,  620, 

## Extended dataset from sharded CSVs with batch-length bucketing

This dataloader loads directly from sharded csv files (these can also be passed as GCS bucket with `gs://` if the machine has the rights to read from the bucket). In addition, the dataset groups sequences of equal input-length together to minimize the number of padding-elements and computation on paddings.

In [30]:
# Pattern to find files. Tensorflow will automatically pick up all files that fit the pattern.
file_pattern = 'data/csv/train_*.csv'

batch_size = 8                    # Final batch size
parallel_processes = 4            # Number of parallel workers
shuffle_buffer_size = 10000       # Buffer size for shuffling
length_sort_buffer_size = 10000   # Buffer size for sorting into equal length
prefetch_length = 4               # Number of batches to pre-fetch
skip_header_rows = 1              # Number of rows in files to skip

In [31]:
def decode_csv(inp):
    """ Splits the  """
    fields = tf.io.decode_csv(records=inp,
                              record_defaults=[tf.constant([], dtype=tf.string)] * 2,
                              field_delim=',')
    return fields

In [40]:
# Get all the files that match the pattern
dataset = tf.data.Dataset.list_files(file_pattern, seed=42)

# Reading from csv and interleaving the files
# You can even shuffle before this to read the files in random order for better shuffling
dataset = dataset.interleave(
    lambda fp: tf.data.TextLineDataset(fp).skip(skip_header_rows),
    cycle_length=4,
    num_parallel_calls=4)

# Decode csv input
dataset = dataset.map(decode_csv, num_parallel_calls=parallel_processes)

# Same preprocessing as before
dataset = dataset\
            .map(preprocess_function, num_parallel_calls=parallel_processes)\
            .shuffle(buffer_size=shuffle_buffer_size)\

# Create buffer to be sorted and pad to max length
dataset = dataset.padded_batch(length_sort_buffer_size, 
                               padded_shapes=([None], [None], [1], [1]))

# Sort buffer by length
dataset = dataset.map(
    sort_batch,
    num_parallel_calls=parallel_processes)

# Resolve buffer and cut each element to its original length
dataset = dataset.unbatch()
dataset = dataset.map(
    cut_length,
    num_parallel_calls=parallel_processes)

# Batch to desired train-batch size and prefetch
dataset = dataset\
            .padded_batch(batch_size, padded_shapes=([None], [None], [1], [1]))\
            .shuffle(shuffle_buffer_size // batch_size)\
            .prefetch(prefetch_length)

In [41]:
it = iter(dataset)

Length of the sequences in batches:

In [62]:
for i in range(20):
    src, tgt, src_len, tgt_len = next(it)
    print("Batch {: >2d}:   ".format(i) + "  ".join(["{: >2d}".format(x[0]) for x in src_len.numpy()]))

Batch  0:   15  15  15  15  15  15  15  15
Batch  1:   13  13  13  13  13  13  13  13
Batch  2:   13  13  13  13  13  13  13  13
Batch  3:   24  24  24  24  24  24  24  24
Batch  4:   21  21  21  21  21  21  21  21
Batch  5:   21  21  21  21  21  21  21  21
Batch  6:   10  10  10  10  10  10  10  10
Batch  7:   21  21  21  21  21  21  21  21
Batch  8:   16  16  16  16  16  16  16  16
Batch  9:   17  17  17  17  17  17  17  17
Batch 10:   16  16  16  16  16  16  16  16
Batch 11:   10  10  10  10  10  10  10  10
Batch 12:    9   9   9   9   9   9   9   9
Batch 13:    8   8   8   8   8   8   8   8
Batch 14:   10  10  10  10  10  10  10  10
Batch 15:   10  10  10  10  10  10  10  10
Batch 16:   14  14  14  14  14  14  14  14
Batch 17:   13  13  13  13  13  13  13  13
Batch 18:   18  18  18  18  18  18  18  18
Batch 19:   16  16  16  16  16  16  16  16


In [63]:
next(it)

(<tf.Tensor: shape=(8, 15), dtype=int32, numpy=
 array([[   2,   18,   23,   16,  121,   12,   66,   13,  180,   11,  234,
            4,  217,    5,    3],
        [   2,    6,  544,   12, 1195,   35,    7,   41,   13,    4,  561,
          246,  123,    5,    3],
        [   2,   91,   32,    7,    4,  176,  161,   12,   66,   11,  177,
            8,  286,    5,    3],
        [   2,  281,   23,   15,   21,  181,   36,  259,   15, 1271,    4,
         1429,  328,    5,    3],
        [   2,   18,   23,  153,    7,   41,   13, 1004,   12,    4,   55,
          181,  199,    5,    3],
        [   2,    6, 1598,  797,  128,   33, 2902,    7,   41,   13,    4,
           40,   13,   23,    3],
        [   2,    6,   10,   12,    4,   17,  186,  502,   16,   44,   39,
            8,  107,    5,    3],
        [   2,    6,   55,   26,   12,   24,   22,   11,   72,   57,    4,
          214,   83,    5,    3]], dtype=int32)>,
 <tf.Tensor: shape=(8, 19), dtype=int32, numpy=
 array([[   2,  