# Cache and prefetch
Demonstration of how they influence the training time.

In [1]:
import tensorflow as tf
from time import sleep


OPEN_FILE_TIME = 0.03
READ_SAMPLE_TIME = 0.015
TRAIN_STEP_TIME = 0.01
SAMPLES = 3
EPOCHS = 2


class FileDataset(tf.data.Dataset):
    
    def read_files_in_batches(samples_number):
        # Open file
        sleep(OPEN_FILE_TIME)
        for sample_idx in range(samples_number):
            sleep(READ_SAMPLE_TIME)
            yield (sample_idx,)

    def __new__(cls, samples_number=SAMPLES):
        print("FileDataset created.")
        dataset = tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature=tf.TensorSpec(shape=(1,), dtype=tf.int64),
            args=(samples_number,)
        )

        return dataset

def train(dataset, epochs=EPOCHS):
    for epoch_i in range(epochs):
        for sample in dataset:
            sleep(TRAIN_STEP_TIME)

In [2]:
dataset = FileDataset()

FileDataset created.


##### Without `.prefetch()`:

# ![title](NoPrefetchTimeline.png)

In [3]:
# Estimation if everything runing in CPU
estimated_time = (
    # Time for a single epoch
    (
        OPEN_FILE_TIME 
        + READ_SAMPLE_TIME * SAMPLES
        + TRAIN_STEP_TIME * SAMPLES
    )
    # Times the number of epochs
    * EPOCHS
)
print(f"Estimated time: {estimated_time:.5f} seconds")

Estimated time: 0.21000 seconds


In [4]:
%%timeit
train(dataset)

354 ms ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Required times is higher than estimated time because we only counted the sleep times and completely ignored the processing times for the rest of operations.

##### With `.prefetch()`:

# ![title](PrefetchTimeline.png)

In [5]:
# Estimation
estimated_time = (
    # Time for a single epoch
    (
        OPEN_FILE_TIME 
        + READ_SAMPLE_TIME * SAMPLES
        + TRAIN_STEP_TIME * SAMPLES
    )
    # Times the number of epochs
    * EPOCHS
)
print(f"Estimated time: {estimated_time:.5f} seconds")

Estimated time: 0.21000 seconds


In [6]:
%%timeit
train(dataset.prefetch(buffer_size=1)) # tf.data.AUTOTUNE

294 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%%timeit
train(dataset.prefetch(buffer_size=2)) # tf.data.AUTOTUNE

287 ms ± 16.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit
train(dataset.prefetch(buffer_size=tf.data.AUTOTUNE))

293 ms ± 34.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


##### Without `cache()`:

# ![title](NoCacheTimeline.png)

In [9]:
%%timeit
train(dataset)

317 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


##### With `cache()`:

# ![title](CacheTimeline.png)

In [10]:
# Estimation
estimated_time = (
    # Opening and reading only need to be done once
    OPEN_FILE_TIME 
    + READ_SAMPLE_TIME * SAMPLES
    # Training 
    + (TRAIN_STEP_TIME * SAMPLES) * EPOCHS

    
)
print(f"Estimated time: {estimated_time:.5f} seconds")

Estimated time: 0.13500 seconds


In [11]:
%%timeit
train(dataset.cache())

233 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
