<a href="https://colab.research.google.com/github/SauravMaheshkar/trax/blob/SauravMaheshkar-example-1/examples/trax_data_Explained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
## Install the Latest Version of Trax
!pip install --upgrade trax

Notebook Author: [@SauravMaheshkar](https://github.com/SauravMaheshkar)

# Introduction

In [None]:
import trax

# Serial Fn

In Trax, we use combinators to build input pipelines, much like building deep learning models. The `Serial` combinator applies layers serially using function composition and uses stack semantics to manage data. 

Trax has the following definition for a `Serial` combinator.

> ```
def Serial(*fns):
  def composed_fns(generator=None):
    for f in fastmath.tree_flatten(fns):
      generator = f(generator)
    return generator
  return composed_fns
  ```

The `Serial` function has the following structure:

* It takes as **input** arbitrary number of functions
* Convert the structure into lists
* Iterate through the list and apply the functions Serially

---

The [`fastmath.tree_flatten()`](https://github.com/google/trax/blob/c38a5b1e4c5cfe13d156b3fc0bfdb83554c8f799/trax/fastmath/numpy.py#L195) function, takes a tree as a input and returns a flattened list. This way we can use various generator functions like Tokenize and Shuffle, and apply them serially by '*iterating*' through the list. 

Initially, we've defined `generator` to `None`. Thus, in the first iteration we have no input and thus the first step executes the first function in our tree structure. In the next iteration, the `generator` variable is updated to be the output of the next function in the list.


# Log Function

> ```
def Log(n_steps_per_example=1, only_shapes=True):
  def log(stream):
    counter = 0
    for example in stream:
      item_to_log = example
      if only_shapes:
        item_to_log = fastmath.nested_map(shapes.signature, example)
      if counter % n_steps_per_example == 0:
        logging.info(str(item_to_log))
        print(item_to_log)
      counter += 1
      yield example
  return log

Every Deep Learning Framework needs to have a logging component for efficient debugging. 

`trax.data.Log` generator uses the `absl` package for logging. It uses a [`fastmath.nested_map`](https://github.com/google/trax/blob/c38a5b1e4c5cfe13d156b3fc0bfdb83554c8f799/trax/fastmath/numpy.py#L80) function that maps a certain function recursively inside a object. In the case depicted below, the function maps the `shapes.signature` recursively inside the input stream, thus giving us the shapes of the various objects in our stream.

--

The following two cells show the difference between when we set the `only_shapes` variable to `False`

In [None]:
data_pipeline = trax.data.Serial(
    trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
    trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),
    trax.data.Log(only_shapes=False)
  )
example = data_pipeline()
print(next(example))

(array([ 182,   31,   43, 5981,   67, 6322,  243, 3898,   22,    8, 2138,
          2,   36,   47,   66,  597,  300,   10,   34, 3986, 2613,   64,
       5281, 2367,    2,   46, 1902, 4713, 2942, 3461,    8, 4797,   55,
       1466, 1351,  409,    3,  121,  114, 1622, 5622,   66,  124, 4106,
         47, 1972,   10,  536,    8, 4533,    2,  124, 1466, 3207,   93,
        449,   90,  407, 4860,   76,  114, 3898,   22,   36,    6, 2339,
       5160,  275, 2395, 6293,  181,    8,  182, 3898,   22,   25,   43,
        402, 4423,  794,  995, 3040, 2420, 2128,    2, 5116,    2,    8,
         28,  180, 3166, 3171, 3839,   44,   80,  668,  232,    4, 1743,
       3661,  239, 3082, 4076,   80, 2067,  124, 2700,   35, 3854, 1052,
        221,    8, 6149, 5481, 4607,   12,  547, 2942,   75, 4445, 3054,
         29,    3,    7,  245, 5372, 1135,   75,   14, 3304,    2, 4935,
       1197,   39, 5281, 2367,    2,   31, 5032, 2528,  121,   12, 3166,
       3171, 5888, 5403,    2, 2305,   93,   10,  

In [None]:
data_pipeline = trax.data.Serial(
    trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
    trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),
    trax.data.Log(only_shapes=True)
  )
example = data_pipeline()
print(next(example))

(ShapeDtype{shape:(203,), dtype:int64}, ShapeDtype{shape:(), dtype:int64})
(array([ 182,   31,   43, 5981,   67, 6322,  243, 3898,   22,    8, 2138,
          2,   36,   47,   66,  597,  300,   10,   34, 3986, 2613,   64,
       5281, 2367,    2,   46, 1902, 4713, 2942, 3461,    8, 4797,   55,
       1466, 1351,  409,    3,  121,  114, 1622, 5622,   66,  124, 4106,
         47, 1972,   10,  536,    8, 4533,    2,  124, 1466, 3207,   93,
        449,   90,  407, 4860,   76,  114, 3898,   22,   36,    6, 2339,
       5160,  275, 2395, 6293,  181,    8,  182, 3898,   22,   25,   43,
        402, 4423,  794,  995, 3040, 2420, 2128,    2, 5116,    2,    8,
         28,  180, 3166, 3171, 3839,   44,   80,  668,  232,    4, 1743,
       3661,  239, 3082, 4076,   80, 2067,  124, 2700,   35, 3854, 1052,
        221,    8, 6149, 5481, 4607,   12,  547, 2942,   75, 4445, 3054,
         29,    3,    7,  245, 5372, 1135,   75,   14, 3304,    2, 4935,
       1197,   39, 5281, 2367,    2,   31, 5032,

# Shuffling our datasets

Trax offers two generator functions to add shuffle functionality in our input pipelines. 

1. The `shuffle` function shuffles a given stream
2. The `Shuffle` function returns a shuffle function instead

## `shuffle`

> ```
def shuffle(samples, queue_size):
  if queue_size < 1:
    raise ValueError(f'Arg queue_size ({queue_size}) is less than 1.')
  if queue_size == 1:
    logging.warning('Queue size of 1 results in no shuffling.')
  queue = []
  try:
      queue.append(next(samples))
      i = np.random.randint(queue_size)
      yield queue[i]
      queue[i] = sample
  except StopIteration:
    logging.warning(
        'Not enough samples (%d) to fill initial queue (size %d).',
        len(queue), queue_size)
  np.random.shuffle(queue)
  for sample in queue:
    yield sample


The `shuffle` function takes two inputs, the data stream and the queue size (minimum number of samples within which the shuffling takes place). Apart from the usual warnings, for negative and unity queue sizes, this generator function shuffles the given stream using [`np.random.randint()`](https://docs.python.org/3/library/random.html#random.randint) by randomly picks out integers using the `queue_size` as a range and then shuffle this new stream again using the [`np.random.shuffle()`](https://docs.python.org/3/library/random.html#random.shuffle)

In [None]:
sentence = ['Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem. Ut enim ad minima veniam, quis nostrum exercitationem ullam corporis suscipit laboriosam, nisi ut aliquid ex ea commodi consequatur? Quis autem vel eum iure reprehenderit qui in ea voluptate velit esse quam nihil molestiae consequatur, vel illum qui dolorem eum fugiat quo voluptas nulla pariatur?',
            'But I must explain to you how all this mistaken idea of denouncing pleasure and praising pain was born and I will give you a complete account of the system, and expound the actual teachings of the great explorer of the truth, the master-builder of human happiness. No one rejects, dislikes, or avoids pleasure itself, because it is pleasure, but because those who do not know how to pursue pleasure rationally encounter consequences that are extremely painful. Nor again is there anyone who loves or pursues or desires to obtain pain of itself, because it is pain, but because occasionally circumstances occur in which toil and pain can procure him some great pleasure. To take a trivial example, which of us ever undertakes laborious physical exercise, except to obtain some advantage from it? But who has any right to find fault with a man who chooses to enjoy a pleasure that has no annoying consequences, or one who avoids a pain that produces no resultant pleasure?',
            'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum',
            'At vero eos et accusamus et iusto odio dignissimos ducimus qui blanditiis praesentium voluptatum deleniti atque corrupti quos dolores et quas molestias excepturi sint occaecati cupiditate non provident, similique sunt in culpa qui officia deserunt mollitia animi, id est laborum et dolorum fuga. Et harum quidem rerum facilis est et expedita distinctio. Nam libero tempore, cum soluta nobis est eligendi optio cumque nihil impedit quo minus id quod maxime placeat facere possimus, omnis voluptas assumenda est, omnis dolor repellendus. Temporibus autem quibusdam et aut officiis debitis aut rerum necessitatibus saepe eveniet ut et voluptates repudiandae sint et molestiae non recusandae. Itaque earum rerum hic tenetur a sapiente delectus, ut aut reiciendis voluptatibus maiores alias consequatur aut perferendis doloribus asperiores repellat.']

def sample_generator(x):
  for i in x:
    yield i

example_shuffle = list(trax.data.inputs.shuffle(sample_generator(sentence), queue_size = 2))
example_shuffle

['Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem. Ut enim ad minima veniam, quis nostrum exercitationem ullam corporis suscipit laboriosam, nisi ut aliquid ex ea commodi consequatur? Quis autem vel eum iure reprehenderit qui in ea voluptate velit esse quam nihil molestiae consequatur, vel illum qui dolorem eum fugiat quo voluptas nulla pariatur?',
 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut e

## `Shuffle`

> ```
def Shuffle(queue_size=1024): 
  return lambda g: shuffle(g, queue_size)

This function returns the aforementioned `shuffle` function and is mostly used in input pipelines.


# Batch Generators

## `batch`

This function, creates batches for the input generator function.

> ```
def batch(generator, batch_size):
  if batch_size <= 0:
    raise ValueError(f'Batch size must be positive, but is {batch_size}.')
  buf = []
  for example in generator:
    buf.append(example)  
    if len(buf) == batch_size:
      batched_example = tuple(np.stack(x) for x in zip(*buf))
      yield batched_example
      buf = []

It keeps adding objects from the generator into a list until the size becomes equal to the `batch_size` and then creates batches using the `np.stack()` function.

It also raises an error for non-positive batch_sizes.


## `Batch`

> ```
    def Batch(batch_size):  
      return lambda g: batch(g, batch_size)

This Function returns the aforementioned `batch` function with given batch size.

# Pad to Maximum Dimensions

This function is used to pad a tuple of tensors to a joint dimension and return their batch.

For example, in this case a pair of tensors (1,2) and ( (3,4) , (5,6) ) is changed to (1,2,0) and ( (3,4) , (5,6) , 0)

In [None]:
import numpy as np

tensors = np.array([(1.,2.),
           ((3.,4.),(5.,6.))])
padded_tensors = trax.data.inputs.pad_to_max_dims(tensors=tensors, boundary=3)
padded_tensors

array([[1.0, 2.0, 0],
       [(3.0, 4.0), (5.0, 6.0), 0]], dtype=object)

# Creating Buckets

For training Recurrent Neural Networks, with large vocabulary a method called Bucketing is usually applied. 

The usual technique of using padding ensures that all occurences within a mini-batch are of the same length. But this reduces the inter-batch variability and intuitively puts similar sentences into the same batch therefore, reducing the overall robustness of the system. 

Thus, we use Bucketing where multiple buckets are created depending on the length of the sentences and these occurences are assigned to buckets on the basis of which bucket corresponds to it's length. We need to ensure that the bucket sizes are large for adding some variablity to the system.

## `bucket_by_length`



> ```
def bucket_by_length(generator, length_fn, boundaries, batch_sizes,strict_pad_on_len=False):
  buckets = [[] for _ in range(len(batch_sizes))]
  boundaries = boundaries + [math.inf] 
  for example in generator:
    length = length_fn(example)
    bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b])
    buckets[bucket_idx].append(example)
    if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]:
      batched = zip(*buckets[bucket_idx])
      boundary = boundaries[bucket_idx]
      boundary = None if boundary == math.inf else boundary
      padded_batch = tuple(
          pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched)
      yield padded_batch
      buckets[bucket_idx] = []

---

This function can be summarised as:

* Create buckets as per the lengths given in the `batch_sizes` array

* Assign sentences into buckets if their length matches the bucket size

* If padding is required, we use the `pad_to_max_dims` function

---

### Parameters

1. **generator:** The input generator function
2. **length_fn:** A custom length function for determing the length of functions, not necessarily `len()`
3. **boundaries:** A python list containing corresponding bucket boundaries
4. **batch_sizes:** A python list containing batch sizes
5. **strict_pad_on_len:** – A python boolean variable (`True` or `False`). If set to true then the function pads on the length dimension, where dim[0] is strictly a multiple of boundary.
 

## `BucketByLength`

> ```
def BucketByLength(boundaries, batch_sizes,length_keys=None, length_axis=0, strict_pad_on_len=False):
  length_keys = length_keys or [0, 1]
  length_fn = lambda x: _length_fn(x, length_axis, length_keys)
  return lambda g: bucket_by_length(g, length_fn, boundaries, batch_sizes, strict_pad_on_len)

---

This function, is usually used inside input pipelines(*combinators*) and uses the afforementioned `bucket_by_length`. It applies a predefined `length_fn` which chooses the maximum shape on length_axis over length_keys.

It's use is illustrated below

In [None]:
data_pipeline = trax.data.Serial(
    trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
    trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),
    trax.data.BucketByLength(boundaries=[32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    trax.data.Log(only_shapes=True)
  )
example = data_pipeline()
print(next(example))

(ShapeDtype{shape:(8, 2048), dtype:int64}, ShapeDtype{shape:(8,), dtype:int64})
(array([[ 155,  452,   29, ...,    0,    0,    0],
       [ 182, 1989, 1826, ...,    0,    0,    0],
       [1389, 2597, 5378, ...,    0,    0,    0],
       ...,
       [4846, 1008,    2, ...,    0,    0,    0],
       [  68,   12,  173, ...,    0,    0,    0],
       [ 186, 3817, 2064, ...,    0,    0,    0]]), array([0, 1, 1, 1, 1, 0, 1, 0]))


# Filter by Length

> ```
def FilterByLength(max_length,length_keys=None, length_axis=0):
  length_keys = length_keys or [0, 1]
  length_fn = lambda x: _length_fn(x, length_axis, length_keys)
  def filtered(gen):
    for example in gen:
      if length_fn(example) <= max_length:
        yield example
  return filtered

---

This function used the same predefined `length_fn` to only include those instances which are less than the given `max_length` parameter.


In [None]:
Filtered = trax.data.Serial(
    trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
    trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),
    trax.data.BucketByLength(boundaries=[32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.Log(only_shapes=True)
  )
filtered_example = Filtered()
print(next(filtered_example))

(ShapeDtype{shape:(8, 2048), dtype:int64}, ShapeDtype{shape:(8,), dtype:int64})
(array([[ 155,  452,   29, ...,    0,    0,    0],
       [ 182, 1989, 1826, ...,    0,    0,    0],
       [1389, 2597, 5378, ...,    0,    0,    0],
       ...,
       [4846, 1008,    2, ...,    0,    0,    0],
       [  68,   12,  173, ...,    0,    0,    0],
       [ 186, 3817, 2064, ...,    0,    0,    0]]), array([0, 1, 1, 1, 1, 0, 1, 0]))


# Adding Loss Weights

## `add_loss_weights`

> ```
def add_loss_weights(generator, id_to_mask=None):
  for example in generator:
    if len(example) > 3 or len(example) < 2:
      assert id_to_mask is None, 'Cannot automatically mask this stream.'
      yield example
    else:
      if len(example) == 2:
        weights = np.ones_like(example[1]).astype(np.float32)
      else:
        weights = example[2].astype(np.float32)
      mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32)
      weights *= mask
      yield (example[0], example[1], weights)

---

This function essentially adds a loss mask (tensor of ones of the same shape) to the input stream. 

**Masking** is essentially a way to tell sequence-processing layers that certain timesteps in an input are missing, and thus should be skipped when processing the data.

Thus, it adds 'weights' to the system. 

---

### Parameters

1. **generator:** The input data generator
2. **id_to_mask:** The value with which to mask. Can be used as `<PAD>` in NLP.

```

train_generator = trax.data.inputs.add_loss_weights(
    data_generator(batch_size, x_train, y_train,vocab['<PAD>'], True),
    id_to_mask=vocab['<PAD>'])


```

For example, in this case I used the `add_loss_weights` function to add padding while implementing Named Entity Recogntion using the Reformer Architecture. You can read more about the project [here](https://www.kaggle.com/sauravmaheshkar/trax-ner-using-reformer).

## `AddLossWeights`

This function performs the afforementioned `add_loss_weights` to the data stream. 

> ```
def AddLossWeights(id_to_mask=None):
  return lambda g: add_loss_weights(g,id_to_mask=id_to_mask)


In [None]:
data_pipeline = trax.data.Serial(
    trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
    trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights(),
    trax.data.Log(only_shapes=True)
  )

example = data_pipeline()
print(next(example))

(ShapeDtype{shape:(8, 2048), dtype:int64}, ShapeDtype{shape:(8,), dtype:int64}, ShapeDtype{shape:(8,), dtype:float32})
(array([[4176,  570,  636, ...,    0,    0,    0],
       [3030,    2,    7, ...,    0,    0,    0],
       [  28, 3898,   22, ...,    0,    0,    0],
       ...,
       [ 139,   36,   76, ...,    0,    0,    0],
       [2275,    2, 4198, ...,    0,    0,    0],
       [ 182,  103,  151, ...,    0,    0,    0]]), array([0, 1, 1, 0, 0, 0, 1, 0]), array([1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))
