In [1]:
import tensorflow as tf
import time

In [2]:
tf.__version__

'2.12.0'

In [17]:
class FileDataset(tf.data.Dataset):
  def read_files_in_batches(num_samples):
    # open file
    time.sleep(0.03)
    for sample_idx in range(num_samples):
      time.sleep(0.015)
      yield (sample_idx, )

  def __new__(cls, num_samples=3):
    return tf.data.Dataset.from_generator(
        cls.read_files_in_batches,
        output_signature=tf.TensorSpec(shape = (1, ), dtype = tf.int64),
        args=(num_samples,)
    )




In [18]:
def benchmark(dataset, num_epochs=2):
  for epoch_num in range(num_epochs):
    for sample in dataset:
      time.sleep(0.01)

In [19]:
%%timeit
benchmark(FileDataset())

316 ms ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
%%timeit
benchmark(FileDataset().prefetch(1))

332 ms ± 54.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

292 ms ± 30.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
dataset = tf.data.Dataset.range(5)
for d in dataset:
  print(d.numpy())

0
1
2
3
4


In [25]:
df = dataset.map(lambda x: x**2)

In [26]:
for d in df:
  print(d.numpy())

0
1
4
9
16


In [27]:
df = dataset.cache()
for d in df.as_numpy_iterator():
  print(d)

0
1
2
3
4


In [28]:
list(df.as_numpy_iterator())

[0, 1, 2, 3, 4]

In [30]:
def mapped_function(s):
  tf.py_function(lambda: time.sleep(0.03), [], ())
  return s

In [31]:
FileDataset().map(mapped_function)

<_MapDataset element_spec=TensorSpec(shape=(1,), dtype=tf.int64, name=None)>

In [32]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function), 5)

2.42 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [33]:
%%timeit -n1 -r1
benchmark(FileDataset().map(mapped_function).cache(), 5)

1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
