<h3 align="center" style='color:blue'>Optimize tensorflow pipeline performance with prefetch and caching</h3>

In [1]:
import tensorflow as tf
import time

2024-07-30 16:36:43.662708: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-30 16:36:43.668140: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-30 16:36:43.685616: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-30 16:36:43.714340: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-30 16:36:43.722938: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-30 16:36:43.747923: I tensorflow/core/platform/cpu_feature_gu

In [2]:
tf.__version__

'2.17.0'

<h3 style='color:purple'>Prefetch</h3>

In [3]:
class FileDataset(tf.data.Dataset):
    def read_file_in_batches(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_file_in_batches,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

In [4]:
def benchmark(dataset, num_epochs=2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)

In [5]:
%%timeit
benchmark(FileDataset())

2024-07-30 16:38:37.922476: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-07-30 16:38:38.045671: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-07-30 16:38:38.342940: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-07-30 16:38:38.918611: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


291 ms ± 4.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


2024-07-30 16:38:40.089891: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [6]:
%%timeit
benchmark(FileDataset().prefetch(1))

294 ms ± 9.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


2024-07-30 16:42:35.826013: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [7]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

284 ms ± 3.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


**As you can notice above, using prefetch improves the performance from 304 ms to 238 and 240 ms**

<h3 style='color:purple'>Cache</h3>

In [8]:
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache("mycache.txt")
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())

2024-07-30 16:46:11.877400: W tensorflow/core/kernels/data/cache_dataset_ops.cc:332] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


[0, 1, 4, 9, 16]

In [9]:
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [10]:
def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [11]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function), 5)

1.22 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [12]:
%%timeit -r1 -n1
benchmark(FileDataset().map(mapped_function).cache(), 5)

453 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


**Further reading** https://www.tensorflow.org/guide/data_performance#caching