# Getting started with TensorFlow's `Dataset` API (continuation)

In this notebook we continue on how to divide the dataset over the ranks in distributed training. We will combine sharding and interleaving.

The following steps were done on one of the previous notebooks. If necessary they can be run again on a new cell.
```bash
wget https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv
echo "sepal_length,sepal_width,petal_length,petal_width,species" > iris_setosa.csv
grep setosa iris.csv >> iris_setosa.csv
echo "sepal_length,sepal_width,petal_length,petal_width,species" > iris_versic.csv
grep versicolor iris.csv >> iris_versic.csv
echo "sepal_length,sepal_width,petal_length,petal_width,species" > iris_virgin.csv
grep virginica iris.csv >> iris_virgin.csv
```

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (6 seconds)


In [3]:
%%px
import tensorflow as tf
import horovod.tensorflow.keras as hvd

In [4]:
%%px --target 1
tf.version.VERSION

[0;31mOut[1:2]: [0m'2.3.0'

In [5]:
%%px
def parse_columns(*row, classes):
    """Convert the string classes to one-hot encoded:
    setosa     -> [1, 0, 0]
    virginica  -> [0, 1, 0]
    versicolor -> [0, 0, 1]
    
    and stack all features on a single tensor.
    """
    features = tf.stack(row[:4])
    label_int = tf.where(tf.equal(classes, row[4]))
    label = tf.one_hot(label_int, 3)
    return features, label


def get_csv_dataset(filename):
    return tf.data.experimental.CsvDataset(filename, header=True,
                                           record_defaults=[tf.float32,
                                                            tf.float32,
                                                            tf.float32,
                                                            tf.float32,
                                                            tf.string])

In [6]:
%%px
hvd.init()

By combinig sharding and interleaving it's possible to do things like making every file to be read by only one node, instead of multiple nodes accessing it. We do that by sharding after `list_files`. Here we `interleave` only to pass the file name to the `get_csv_dataset` function. In other setups, interleave can be used to mix datasets within the same worker.

In [7]:
%%px
dataset = tf.data.Dataset.list_files(['iris_setosa.csv',
                                      'iris_versic.csv'],
                                      shuffle=False)
dataset = dataset.shard(hvd.size(), hvd.rank())
dataset = dataset.interleave(get_csv_dataset,
                             cycle_length=1,
                             block_length=1,
                             num_parallel_calls=1)
# dataset = dataset.batch(1)
dataset = dataset.map(lambda *row: parse_columns(*row, classes=['setosa', 'virginica', 'versicolor']))

for i, (x, y) in enumerate(dataset):
    print(f'{i:5} features: {x}    label: {y}')

[stdout:0] 
    0 features: [5.1 3.5 1.4 0.2]    label: [[[1. 0. 0.]]]
    1 features: [4.9 3.  1.4 0.2]    label: [[[1. 0. 0.]]]
    2 features: [4.7 3.2 1.3 0.2]    label: [[[1. 0. 0.]]]
    3 features: [4.6 3.1 1.5 0.2]    label: [[[1. 0. 0.]]]
    4 features: [5.  3.6 1.4 0.2]    label: [[[1. 0. 0.]]]
    5 features: [5.4 3.9 1.7 0.4]    label: [[[1. 0. 0.]]]
    6 features: [4.6 3.4 1.4 0.3]    label: [[[1. 0. 0.]]]
    7 features: [5.  3.4 1.5 0.2]    label: [[[1. 0. 0.]]]
    8 features: [4.4 2.9 1.4 0.2]    label: [[[1. 0. 0.]]]
    9 features: [4.9 3.1 1.5 0.1]    label: [[[1. 0. 0.]]]
   10 features: [5.4 3.7 1.5 0.2]    label: [[[1. 0. 0.]]]
   11 features: [4.8 3.4 1.6 0.2]    label: [[[1. 0. 0.]]]
   12 features: [4.8 3.  1.4 0.1]    label: [[[1. 0. 0.]]]
   13 features: [4.3 3.  1.1 0.1]    label: [[[1. 0. 0.]]]
   14 features: [5.8 4.  1.2 0.2]    label: [[[1. 0. 0.]]]
   15 features: [5.7 4.4 1.5 0.4]    label: [[[1. 0. 0.]]]
   16 features: [5.4 3.9 1.3 0.4]    label: 

Make sure the amount of data is distributed evenly over all workers. Unbalanced workers can hurt performance, convergence or make the program not function correctly.

<mark>Exercise</mark>: Add the third csv file `iris_virgin.csv` and see what happens on the previous cell.

In [8]:
%ipcluster stop