In [1]:
import tensorflow as tf
from tensorflow.data import Dataset

import numpy as np

from time import sleep, time

# Dataset
TensorFlow uses `Dataset` as dataloading interface. Parallelization and preprocessing is all handled via map-functions. For more details read the [documentation](https://www.tensorflow.org/guide/data#reading_input_data).

In PyTorch we could choose between generator (iterator) and mapped dataset. In TensorFlow everything is assumed to be a generator dataset, which makes the map-concept natural.

In [2]:
class MyDataset(Dataset):
    
    def _generator(start, step):
        for i in range(start, 44, step):
            # small data-prep. simulation delay
            sleep(0.5)
            d = tf.convert_to_tensor(np.array([i], dtype=np.float32).reshape(1,1,1), dtype=tf.float32)
            yield d, -d

    def __new__(cls, start=0, step=1):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = (tf.TensorSpec(shape = (1,1,1), dtype = tf.float32), 
                                tf.TensorSpec(shape = (1,1,1), dtype = tf.float32)),
            args = (start, step, )
        )


If we want to use the same amount of data i.e. 44 numbers and want to parallelize later on with 4 workers, we have to split the generators into 4 sub-generators, such that each generator only processes a 1/4 of the data. Otherwise we will just duplicate the data four times. 

In [3]:
ds = MyDataset(start=0, step=1)

If we want to get batches, we just chain `.batch(batch_size)` to our `ds`.

In [4]:
start_time = time()
for x, y in ds.batch(8):
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [0. 1. 2. 3. 4. 5. 6. 7.]
y: [-0. -1. -2. -3. -4. -5. -6. -7.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [ 8.  9. 10. 11. 12. 13. 14. 15.]
y: [ -8.  -9. -10. -11. -12. -13. -14. -15.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [16. 17. 18. 19. 20. 21. 22. 23.]
y: [-16. -17. -18. -19. -20. -21. -22. -23.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [24. 25. 26. 27. 28. 29. 30. 31.]
y: [-24. -25. -26. -27. -28. -29. -30. -31.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [32. 33. 34. 35. 36. 37. 38. 39.]
y: [-32. -33. -34. -35. -36. -37. -38. -39.]
--------------------------------------------

## Parallel Loading

By default everything is sequential i.e. while we wait for the batch to be processed (0.2 seconds) nothing else is happening. We can set use `prefetch` to load data while the current batch is processed.

`tf.data.AUTOTUNE` lets tf decide how large the `prefetch` buffer should be.

In [5]:
start_time = time()
for x, y in ds.batch(8).prefetch(buffer_size=tf.data.AUTOTUNE):
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [0. 1. 2. 3. 4. 5. 6. 7.]
y: [-0. -1. -2. -3. -4. -5. -6. -7.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [ 8.  9. 10. 11. 12. 13. 14. 15.]
y: [ -8.  -9. -10. -11. -12. -13. -14. -15.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [16. 17. 18. 19. 20. 21. 22. 23.]
y: [-16. -17. -18. -19. -20. -21. -22. -23.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [24. 25. 26. 27. 28. 29. 30. 31.]
y: [-24. -25. -26. -27. -28. -29. -30. -31.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [32. 33. 34. 35. 36. 37. 38. 39.]
y: [-32. -33. -34. -35. -36. -37. -38. -39.]
--------------------------------------------

This helped a little bit, but the real speed-up comes from parallel dataloading. 
For this we use `interleave` and create four datasets starting at 0, 1, 2 and 3 with a step-size of 4 and set `num_parallel_calls=4`.

This creates four datasets:
* d_0 = [0, 4, 8, ...]
* d_1 = [1, 5, 9, ...]
* d_2 = [2, 6, 10, ...]
* d_3 = [3, 7, 11, ...]

And interleaves them into [0, 1, 2, 3, 4, 5, ...].

In [6]:
ids = Dataset.range(4).interleave(lambda x: MyDataset(start=x, step=4), num_parallel_calls=4)

In [7]:
start_time = time()
for x, y in ids.batch(8).prefetch(tf.data.AUTOTUNE):
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [0. 1. 2. 3. 4. 5. 6. 7.]
y: [-0. -1. -2. -3. -4. -5. -6. -7.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [ 8.  9. 10. 11. 12. 13. 14. 15.]
y: [ -8.  -9. -10. -11. -12. -13. -14. -15.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [16. 17. 18. 19. 20. 21. 22. 23.]
y: [-16. -17. -18. -19. -20. -21. -22. -23.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [24. 25. 26. 27. 28. 29. 30. 31.]
y: [-24. -25. -26. -27. -28. -29. -30. -31.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [32. 33. 34. 35. 36. 37. 38. 39.]
y: [-32. -33. -34. -35. -36. -37. -38. -39.]
--------------------------------------------

## Shuffle

True shuffling is not supported since the dataset is an iterator. But TensorFlow can buffer a given number of samples and then shuffle them. N.B. one could set the buffer size equal to the dataset size, but this requires more memory.

In [8]:
# shuffling each sub-dataset
ids = Dataset.range(4).interleave(lambda x: MyDataset(start=x, step=4).shuffle(8), num_parallel_calls=4)

In [9]:
start_time = time()
for x, y in ids.batch(8).prefetch(tf.data.AUTOTUNE):
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [ 0. 25. 10. 31. 16.  5. 30. 19.]
y: [ -0. -25. -10. -31. -16.  -5. -30. -19.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [12. 37. 34. 11.  8. 13. 38.  7.]
y: [-12. -37. -34. -11.  -8. -13. -38.  -7.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [40. 33. 26.  3. 28.  9. 42. 39.]
y: [-40. -33. -26.  -3. -28.  -9. -42. -39.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [20. 29. 18. 43. 32. 21. 14. 15.]
y: [-20. -29. -18. -43. -32. -21. -14. -15.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [ 4.  1.  6. 27. 36. 17.  2. 35.]
y: [ -4.  -1.  -6. -27. -36. -17.  -2. -35.]
----------------------------

## Augmentation

If we want to use data-augmentation (or do anything more with the data) we can just map another function.

In [10]:
def tf_gaussian_noise(x, y):
    [x_noisy] = tf.py_function(np.random.normal, [x], [tf.float32])
    x_noisy.set_shape(x.shape)
    return x_noisy, y

In [11]:
start_time = time()
for x, y in ids.map(tf_gaussian_noise).batch(8).prefetch(tf.data.AUTOTUNE):
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [ 7.45624   12.732648   2.7712266 22.648294  20.965918   5.335439
 14.361323  30.90427  ]
y: [ -8. -13.  -2. -23. -20.  -5. -14. -31.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [22.944181   7.991416  11.431693  20.480597  27.774282  27.849373
  6.2662497  2.111086 ]
y: [-24.  -9. -10. -19. -28. -29.  -6.  -3.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [14.42896    1.5938652 30.513695  15.94892    3.620058  37.46703
 22.975163  35.45491  ]
y: [-16.  -1. -30. -15.  -4. -37. -22. -35.]
-----------------------------------------------
(8, 1, 1, 1) (8, 1, 1, 1)
-----------------------------------------------
x: [35.08035  21.495703 33.10992  38.765636 11.589279 41.005302 23.525087
 25.637701]
y: [-36. -21. -34. -39. -12. -41. -26. -27.]
----------------