In [24]:
import tensorflow as tf
import time

In [26]:
tf.__version__

'2.17.0'

In [39]:
class FileDataset(tf.data.Dataset):
  def read_file_in_batches(num_samples):
    ## Open files
    time.sleep(0.03)
    for sample_idx in range(num_samples):
      # Reading lines one by one
      time.sleep(0.015)
      yield(sample_idx,)

  def __new__(cls,num_samples = 3):
    return tf.data.Dataset.from_generator(
        cls.read_file_in_batches,
        output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
        args = (num_samples ,)
    )



In [37]:
def benchmark(dataset,num_epochs = 2):
  for epoch_num in range(num_epochs):
    for sample in dataset:
      time.sleep(0.01)

In [42]:
%%timeit
benchmark(FileDataset())

272 ms ± 21.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

257 ms ± 2.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
%%timeit
benchmark(FileDataset().prefetch(1))

270 ms ± 19.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
######  Cache

data = tf.data.Dataset.range(5)
for d in data:
  print(d.numpy())

0
1
2
3
4


In [49]:
data = data.map(lambda x : x**2)
for d in data:
  print(d.numpy())

0
1
4
9
16


In [51]:
data = data.cache()
for d in data:
  print(d.numpy())

0
1
4
9
16


In [53]:
def mapped_function(s):
  tf.py_function(lambda: time.sleep(0.03) , [] , ())
  return s


In [56]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function), 5)

1.09 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [57]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function).cache(), 5)

497 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
