# Getting started with TensorFlow's `Dataset` API (continuation)

In this notebook we will learn how to divide the dataset over the ranks in distributed training.

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (6 seconds)


In [3]:
%%px
import numpy as np
import tensorflow as tf
import horovod.tensorflow.keras as hvd

In [4]:
%%px
hvd.init()

In [5]:
%%px
def dataset_generator():
    """A data-producing logic"""
    for i in range(8):
        yield (i, i)

In [6]:
%%px
for x, y in dataset_generator():
    print((f'    x: {x}    y: {y}'))

[stdout:0] 
    x: 0    y: 0
    x: 1    y: 1
    x: 2    y: 2
    x: 3    y: 3
    x: 4    y: 4
    x: 5    y: 5
    x: 6    y: 6
    x: 7    y: 7
[stdout:1] 
    x: 0    y: 0
    x: 1    y: 1
    x: 2    y: 2
    x: 3    y: 3
    x: 4    y: 4
    x: 5    y: 5
    x: 6    y: 6
    x: 7    y: 7


In [7]:
%%px
# batch after shard or shard after bash?

dataset = tf.data.Dataset.from_generator(dataset_generator, output_types=(tf.int32, tf.int32))
dataset = dataset.batch(2)
dataset = dataset.shard(hvd.size(), hvd.rank())
dataset = dataset.repeat(2)

for x, y in dataset:
    print(f'    x: {x}    y: {y}')

[stdout:0] 
    x: [0 1]    y: [0 1]
    x: [4 5]    y: [4 5]
    x: [0 1]    y: [0 1]
    x: [4 5]    y: [4 5]
[stdout:1] 
    x: [2 3]    y: [2 3]
    x: [6 7]    y: [6 7]
    x: [2 3]    y: [2 3]
    x: [6 7]    y: [6 7]


In [8]:
%%px
# The shuffle is different for every epoch

dataset = tf.data.Dataset.from_generator(dataset_generator, output_types=(tf.int32, tf.int32))
dataset = dataset.shuffle(8)
dataset = dataset.batch(1)
dataset = dataset.shard(hvd.size(), hvd.rank())
dataset = dataset.repeat(2)

for x, y in dataset:
    print(f'    x: {x}    y: {y}')

[stdout:0] 
    x: [3]    y: [3]
    x: [5]    y: [5]
    x: [4]    y: [4]
    x: [7]    y: [7]
    x: [6]    y: [6]
    x: [1]    y: [1]
    x: [5]    y: [5]
    x: [7]    y: [7]
[stdout:1] 
    x: [3]    y: [3]
    x: [4]    y: [4]
    x: [5]    y: [5]
    x: [6]    y: [6]
    x: [7]    y: [7]
    x: [3]    y: [3]
    x: [1]    y: [1]
    x: [0]    y: [0]


In [9]:
%ipcluster stop