# Getting started with TensorFlow's `Dataset` API (continuation)

In this notebook we continue on how to divide the dataset over the ranks in distributed training. We will combine sharding and interleaving.

The following steps were done on one of the previous notebooks. If necessary they can be run again on a new cell.
```bash
wget https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv
echo "sepal_length,sepal_width,petal_length,petal_width,species" > iris_setosa.csv
grep setosa iris.csv >> iris_setosa.csv
echo "sepal_length,sepal_width,petal_length,petal_width,species" > iris_versic.csv
grep versicolor iris.csv >> iris_versic.csv
echo "sepal_length,sepal_width,petal_length,petal_width,species" > iris_virgin.csv
grep virginica iris.csv >> iris_virgin.csv
```

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (6 seconds)


In [3]:
%%px
import tensorflow as tf
import horovod.tensorflow.keras as hvd

In [5]:
%%px --target 1
tf.version.VERSION

[0;31mOut[1:3]: [0m'2.3.0'

In [6]:
%%px
def parse_columns(*row, classes):
    """Convert the string classes to one-hot encoded:
    setosa     -> [1, 0, 0]
    virginica  -> [0, 1, 0]
    versicolor -> [0, 0, 1]
    """
    features = tf.stack(row[:4])
    label_int = tf.where(tf.equal(classes, row[4]))
    label = tf.one_hot(label_int, 3)
    return features, label


def get_csv_dataset(filename):
    return tf.data.experimental.CsvDataset(filename, header=True,
                                           record_defaults=[tf.float32,
                                                            tf.float32,
                                                            tf.float32,
                                                            tf.float32,
                                                            tf.string])

In [15]:
%%px
hvd.init()

In [20]:
%%px
dataset = tf.data.Dataset.list_files(['iris_setosa.csv',
                                      'iris_versic.csv'],
                                      shuffle=False)
dataset = dataset.interleave(get_csv_dataset,
                             cycle_length=2,
                             block_length=1,
                             num_parallel_calls=2)
dataset = dataset.shard(hvd.size(), hvd.rank())
dataset = dataset.map(lambda *row: parse_columns(*row, classes=['setosa', 'virginica', 'versicolor']))

for x, y in dataset.take(5):
    print(f'features: {x}    label: {y}')

[stdout:0] 
features: [5.1 3.5 1.4 0.2]    label: [[[1. 0. 0.]]]
features: [4.9 3.  1.4 0.2]    label: [[[1. 0. 0.]]]
features: [4.7 3.2 1.3 0.2]    label: [[[1. 0. 0.]]]
features: [4.6 3.1 1.5 0.2]    label: [[[1. 0. 0.]]]
features: [5.  3.6 1.4 0.2]    label: [[[1. 0. 0.]]]
[stdout:1] 
features: [7.  3.2 4.7 1.4]    label: [[[0. 0. 1.]]]
features: [6.4 3.2 4.5 1.5]    label: [[[0. 0. 1.]]]
features: [6.9 3.1 4.9 1.5]    label: [[[0. 0. 1.]]]
features: [5.5 2.3 4.  1.3]    label: [[[0. 0. 1.]]]
features: [6.5 2.8 4.6 1.5]    label: [[[0. 0. 1.]]]


In [21]:
%ipcluster stop