In [0]:
%tensorflow_version 2.x
import tensorflow as tf
import time

Making reproducible performance benchmarks can be difficults, different factors impacting it:
- the current CPU load,
- the network traffic,
- complex mechanisms like cache, etc.
Hence, provide a reproducible benchmark, build an artificial example.


Define a class inheriting from `tf.data.Dataset` called `ArtificialDataset`,
This dataset:
- Generates `num_samples` samples (default is 3)
- Sleeps for some time before the first item to simulate opening a file
- Sleeps for some time before producing each item to simulate reading data from a file

In [0]:
class ArtificialDataset(tf.data.Dataset):
  def _generator(num_samples):
    # Opening the file
    time.sleep(0.03)

    for sample_idx in range(num_samples):
      # Reading data (line, record) from the file
      time.sleep(0.015)

      yield(sample_idx, )
  
  def __new__(cls, num_samples=3):
    return tf.data.Dataset.from_generator(
        cls._generator,
        output_types=tf.dtypes.int64,
        output_shapes=(1, ),
        args=(num_samples, ))

In [0]:
def benchmark(dataset, num_epochs=2):
  """Write a dummy training loop that measures how long it takes to iterate over a dataset.
  Training time is simulated."""
  start_time = time.perf_counter()
  for epoch_num in range(num_epochs):
    for sample in dataset:
      # Performing a training step
      time.sleep(0.01)
  tf.print('Execution time:', time.perf_counter() - start_time)

## Optimize performance
To exhibit how performance can be optimized, you will improve the performance of the `ArtificialDataset`.

### The naive approach
Start with a naive pipeline using no tricks, iterating over the dataset as-is.

In [0]:
benchmark(ArtificialDataset())

### Prefetching
Prefetching overlaps the processing and model execution of a training step. While the model is executing training step `s`, the input pipeline is reading the data for step `s+1`. Doing so reduces the step time to the maximum (as opposed to the sum) of the training and the time it takes to extract the data.

The `tf.data` API provides the `tf.data.Dataset.prefetch` transformation. It can be used to decouple the time when data is produced from the time when data is consumed. In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested. The number of elements to prefetch should be equal to (or possibly greater than) the number of batches consumed by a single training step. You could either manually tune this value, or set it to `tf.data.experimental.AUTOTUNE` which will prompt the `tf.data` runtime to tune the value dynamically at runtime.

Note that the prefetch transformation provides benefits any time there is opportunity to overlap the work of a "producer" with the work of a "consumer"

In [0]:
benchmark(
    ArtificialDataset()
    .prefetch(tf.data.experimental.AUTOTUNE))

### Parallelizing data extraction
In a real-world setting, the input data may be stored remotely (for example, GCS or HDFS). A dataset pipeline that works well when reading data locally might become bottlenecked on I/O when reading data remotely because of the following differences between local and remote storate:
- **Time-to-first-byte**: Reading the first byte of a file from remote storate can take orders of magnitude longer than from local storate.
- **Read throughput**: While remote storate typically offers large aggregate bandwidth, reading a single file might only be able to utilize a small fraction of this bandwidth.


### Sequential interleave
The default arguments of the `tf.data.Dataset.interleave` transformation make it interleave single samples from two datasets sequentially.

In [0]:
# Feching samples alternatively from the two datasets available.
# No performance improvement is involved here.
benchmark(
    tf.data.Dataset.range(2)
    .interleave(ArtificialDataset)
)

### Parallel interleave
=

In [0]:
# Now use the `num_parallel_calls` argument of the `interleave` transformation. This loads
# multiple datasets in parallel, reducing the time waiting for the files to be opened.
benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        ArtificialDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)