<h1 span style="text-align:center;color:mediumvioletred;">Optimize Tensorflow Pipeline Performance</h1>

In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.19.0'

In [26]:
class FileDataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        #open file
        time.sleep(0.03)
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            yield (sample_idx,)
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature = tf.TensorSpec(shape = (1,), dtype=tf.int64),
            args=(num_samples,)
        )

In [27]:
def benchmark(dataset, num_epochs=2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01)

## Using sequenctial processing

In [56]:
%%timeit
benchmark(FileDataset())

295 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Using parallel processing

In [57]:
%%timeit
benchmark(FileDataset().prefetch(1))

295 ms ± 20.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [58]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

284 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### *Since I don't have a gpu, It's taking the same time to process.*

In [63]:
dataset = tf.data.Dataset.range(5)
for d in dataset:
    print(d.numpy())

0
1
2
3
4


In [64]:
dataset = dataset.map(lambda x: x**2)
for d in dataset:
    print(d.numpy())

0
1
4
9
16


In [66]:
dataset = dataset.cache()

for d in dataset.as_numpy_iterator():
    print(d)

0
1
4
9
16


In [67]:
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [78]:
def mapped_function(s):
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [79]:
FileDataset().map(mapped_function)

<_MapDataset element_spec=TensorSpec(shape=(1,), dtype=tf.int64, name=None)>

## Without Cache Memory

In [81]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function), 5)

1.15 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


## With Cache Memory

In [82]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function).cache(), 5)

446 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
