Here explaining what is prefetch and cache not the use of these methods

In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.17.0'

# Prefetch

In [10]:
class FileDataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        # read the file
        # open file
        time.sleep(0.03)
        for sample_ind in range(num_samples):
            time.sleep(0.015)
            yield (sample_ind,)
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature=tf.TensorSpec(shape=(1,), dtype= tf.int64),
            args=(num_samples,)
        )

In [4]:
obj = FileDataset()

new_called


In [11]:
def benchmark(dataset, num_epochs = 2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01)

In [12]:
%%timeit
benchmark(FileDataset())

326 ms ± 33.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%%timeit
benchmark(FileDataset().prefetch(1))

324 ms ± 67.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

315 ms ± 40.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Cache

In [15]:
ds = tf.data.Dataset.range(5)
for d in ds:
    print(d.numpy())

0
1
2
3
4


In [16]:
dataset = ds.map(lambda x: x**2)
for d in dataset:
    print(d.numpy())

0
1
4
9
16


In [18]:
# untill this line dataset was using map function over and over. but after this line dataset will be in cache memory and it'll keep retrieve data from that cache memory

dataset = dataset.cache()
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [19]:
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [21]:
def mapped_function(s):
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [23]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function), 5)

1.27 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [24]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function).cache(), 5)

547 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
