In [3]:
import tensorflow as tf
import numpy as np

# [How to use Dataset and Iterators in Tensorflow with code samples](https://medium.com/ymedialabs-innovation/how-to-use-dataset-and-iterators-in-tensorflow-with-code-samples-3bb98b6b74ab)
From the time I have started using Tensorflow, I have always been feeding the data to my graph during training, testing or inferencing using the `feed_dict` mechanism of `Session`. This particular practice has been advised by Tensorflow developers to be strongly discontinued either during the training or repeatedly testing same series of dataset. The only particular scenario in which `feed_dict` mechanism is to be used is during inferencing of data during deployment. The replacement of feed_dict has taken place with `Dataset` and `Iterator`. The dataset can be created either with Numpy array or TFRecords or with text.

In this post, we will be exploring on Datasets and Iterators. We will start with how to create datasets using some source data and then apply various type of transformations to it. We will demonstrate on how to do training using various types of iterators with MNIST handwritten digits data on LeNet-5 model.

__Note__: The Tensorflow Dataset class can get very confusing with word meant for datasets like X_train, y_train etc. Hence, going forward in this article, I am referring ‘Dataset’ (capital D) as Tensorflow Dataset class and ‘dataset’ as dataset of X_train, y_train etc.

## Datasets Creation
Datasets can be generated using multiple type of data sources like Numpy, TFRecords, text files, CSV files etc. The most commonly used practice for generating Datasets is from Numpy (or Tensors). Lets go through each of the functions provided by Tensorflow to generate them.

### from_tensor_slices
This method accepts individual (or multiple) Numpy (or Tensors) objects. In case you are feeding multiple objects, pass them as tuple and make sure that all the objects have same size in zeroth dimension.

In [5]:
# Assume batch size is 1
dataset1 = tf.data.Dataset.from_tensor_slices(tf.range(10, 15))
# Emits data of 10, 11, 12, 13, 14, (One element at a time)

dataset2 = tf.data.Dataset.from_tensor_slices((tf.range(30, 45, 3), np.arange(60, 70, 2)))
# Emits data of (30, 60), (33, 62), (36, 64), (39, 66), (42, 68)
# Emits one tuple at a time

In [8]:
#dataset3 = tf.data.Dataset.from_tensor_slices((tf.range(10), np.arange(5)))
# Dataset not possible as zeroth dimenion is different at 10 and 5

### from_tensors
Just like from_tensor_slices, this method also accepts individual (or multiple) Numpy (or Tensors) objects. But this method doesn’t support batching of data, i.e all the data will be given out instantly. As a result, you can pass differently sized inputs at zeroth dimension if you are passing multiple objects. This method is useful in cases where dataset is very small or your learning model needs all the data at once.

In [9]:
dataset1 = tf.data.Dataset.from_tensors(tf.range(10, 15))
# Emits data of [10, 11, 12, 13, 14]
# Holds entire list as one element

dataset2 = tf.data.Dataset.from_tensors((tf.range(30, 45, 3), np.arange(60, 70, 2)))
# Emits data of ([30, 33, 36, 39, 42], [60, 62, 64, 66, 68])
# Holds entire tuple as one element

dataset3 = tf.data.Dataset.from_tensors((tf.range(10), np.arange(5)))
# Possible with from_tensors, regardless of zeroth dimension mismatch of constituent elements.
# Emits data of ([1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4])
# Holds entire tuple as one element

### from_generators
In this method, a generator function is passed as input. This method is useful in cases where you wish to generate the data at runtime and as such no raw data exists with you or in scenarios where your training data is extremely huge and it is not possible to store them in your disk. I would strongly encourage people to __not use__ this method for the purpose of generating data augmentations.

In [10]:
# Assume batch size is 1
def generator(sequence_type):
    if sequence_type == 1:
        for i in range(5):
            yield 10 + i
    elif sequence_type == 2:
        for i in range(5):
            yield (30 + 3 * i, 60 + 2 * i)
    elif sequence_type == 3:
        for i in range(1, 4):
            yield (i, ['Hi'] * i)

dataset1 = tf.data.Dataset.from_generator(generator, (tf.int32), args = ([1]))
# Emits data of 10, 11, 12, 13, 14, (One element at a time)

dataset2 = tf.data.Dataset.from_generator(generator, (tf.int32, tf.int32), args = ([2]))
# Emits data of (30, 60), (33, 62), (36, 64), (39, 66), (42, 68)
# Emits one tuple at a time

dataset3 = tf.data.Dataset.from_generator(generator, (tf.int32, tf.string), args = ([3]))
# Emits data of (1, ['Hi']), (2, ['Hi', 'Hi']), (3, ['Hi', 'Hi', 'Hi'])
# Emits one tuple at a time

## Datasets Transformations
Once you have created the Dataset covering all the data 
(or scenarios, in some cases like, runtime data generation), 
it is time to apply various types of transformation. 
Let us go through some of commonly used transformations.

### Batch
Batch corresponds to sequentially dividing your dataset by the specified batch size.

<img src="../../images/prasadpai/batch.jpeg" alt="batch" width="300"/>

### Repeat
Whatever Dataset you have generated, use this transformation to create duplicates of the existing data in your Dataset.

<img src="../../images/prasadpai/repeat.jpeg" alt="repeat" width="300"/>

### Shuffle
Shuffle transformation randomly shuffles the data in your Dataset.

<img src="../../images/prasadpai/shuffle.jpeg" alt="shuffle" width="300"/>

### Map
In Map transformation, you can apply some operations to all the individual data elements in your dataset. Use this particular transformation to apply various types of data augmentation.

<img src="../../images/prasadpai/map.jpeg" alt="map" width="300"/>

### Filter
During the course of training, if you wish to filter out some elements from Dataset, use filter function.

<img src="../../images/prasadpai/filter.jpeg" alt="filter" width="300"/>

The code example of various transformations being applied on a Dataset is shown next.

In [15]:
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
# Create a dataset with data of [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

dataset = dataset.repeat(2)
# Duplicate the dataset
# Data will be [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

dataset = dataset.shuffle(5)
# Shuffle the dataset
# Assumed shuffling: [3, 0, 7, 9, 4, 2, 5, 0, 1, 7, 5, 9, 4, 6, 2, 8, 6, 8, 1, 3]

def map_fn(x):
    return x * 3

dataset = dataset.map(map_fn)
# Same as dataset = dataset.map(lambda x: x * 3)
# Multiply each element with 3 using map transformation
# Dataset: [9, 0, 21, 27, 12, 6, 15, 0, 3, 21, 15, 27, 12, 18, 6, 24, 18, 24, 3, 9]

def filter_fn(x):
    return tf.reshape(tf.not_equal(x % 5, 1), [])

dataset = dataset.filter(filter_fn)
# Same as dataset = dataset.filter(lambda x: tf.reshape(tf.not_equal(x % 5, 1), []))
# Filter out all those elements whose modulus 5 returns 1
# Dataset: [9, 0, 27, 12, 15, 0, 3, 15, 27, 12, 18, 24, 18, 24, 3, 9]

dataset = dataset.batch(4)
# Batch at every 4 elements
# Dataset: [9, 0, 27, 12], [15, 0, 3, 15], [27, 12, 18, 24], [18, 24, 3, 9]

## Ordering of transformation
The ordering of the application of the transformation is very important. Your model may learn differently for the same Dataset but differently ordered transformations. Take a look at the code sample in which it has been shown that different set of data is produced.

In [4]:
# Ordering #1
dataset1 = tf.data.Dataset.from_tensor_slices(tf.range(10))
# Dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

dataset1 = dataset1.batch(4)
# Dataset: [0, 1, 2, 3], [4, 5, 6, 7], [8, 9]

dataset1 = dataset1.repeat(2)
# Dataset: [0, 1, 2, 3], [4, 5, 6, 7], [8, 9], [0, 1, 2, 3], [4, 5, 6, 7], [8, 9]
# Notice a 2 element batch in between

dataset1 = dataset1.shuffle(4)
# Shuffles at batch level.
# Dataset: [0, 1, 2, 3], [4, 5, 6, 7], [8, 9], [8, 9], [0, 1, 2, 3], [4, 5, 6, 7]



# Ordering #2
dataset2 = tf.data.Dataset.from_tensor_slices(tf.range(10))
# Dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

dataset2 = dataset2.shuffle(4)
# Dataset: [3, 1, 0, 4, 5, 8, 6, 9, 7, 2]

dataset2 = dataset2.repeat(2)
# Dataset: [3, 1, 0, 4, 5, 8, 6, 9, 7, 2, 3, 1, 0, 4, 5, 8, 6, 9, 7, 2]

dataset2 = dataset2.batch(4)
# Dataset: [3, 1, 0, 4], [5, 8, 6, 9], [7, 2, 3, 1], [0, 4, 5, 8], [6, 9, 7, 2]

## Building LeNet-5 Model
Before we start the iterators part, let us quickly build our LeNet-5 Model and extract the MNIST data. I have used [Tensorflow’s Slim library](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) to build the model in few lines. This is going to be the common code for all types of iterators we are going to work on next.

In [1]:
import tensorflow as tf
import tensorflow.contrib.slim as slim

# LeNet-5 model
class Model:
    def __init__(self, data_X, data_y):
        self.n_class = 10
        self._create_architecture(data_X, data_y)

    def _create_architecture(self, data_X, data_y):
        y_hot = tf.one_hot(data_y, depth = self.n_class)
        logits = self._create_model(data_X)
        predictions = tf.argmax(logits, 1, output_type = tf.int32)
        self.loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y_hot, 
                                                                              logits = logits))
        self.optimizer = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(self.loss)
        self.accuracy = tf.reduce_sum(tf.cast(tf.equal(predictions, data_y), tf.float32))

    def _create_model(self, X):
        X1 = X - 0.5
        X1 = tf.pad(X1, tf.constant([[0, 0], [2, 2], [2, 2], [0, 0]]))
        with slim.arg_scope([slim.conv2d, slim.fully_connected], 
                            weights_initializer = tf.truncated_normal_initializer(0.0, 0.1)):
            net = slim.conv2d(X1, 6, [5, 5], padding = 'VALID')
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net, 16, [5, 5], padding = 'VALID')
            net = slim.max_pool2d(net, [2, 2])
            
            net = tf.reshape(net, [-1, 400])
            net = slim.fully_connected(net, 120)
            net = slim.fully_connected(net, 84)
            net = slim.fully_connected(net, self.n_class, activation_fn = None)
        return net
        
        
# Extracting MNIST data        
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", reshape=False)
X_train, y_train = mnist.train.images, mnist.train.labels
X_val, y_val     = mnist.validation.images, mnist.validation.labels
X_test, y_test   = mnist.test.images, mnist.test.labels

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## Iterators
Now, let’s start building up the iterators. Tensorflow has provided four types of iterators and each of them has a specific purpose and use-case behind it.

Regardless of the type of iterator, [get_next](https://www.tensorflow.org/api_docs/python/tf/data/Iterator#get_next) function of iterator is used to create an operation in your Tensorflow graph which when run over a session, returns the values from the fed Dataset of iterator. Also, iterator doesn’t keep track of how many elements are present in the Dataset. Hence, it is normal to keep running the iterator’s get_next operation till Tensorflow’s [tf.errors.OutOfRangeError](https://www.tensorflow.org/api_docs/python/tf/errors/OutOfRangeError) exception is occurred. This is usually the skeleton code of how a Dataset and iterator looks like.

In [15]:
# Create iterator
iterator = dataset1.make_one_shot_iterator()
next_batch = iterator.get_next()

# Create session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    try: 
        # Keep running next_batch till the Dataset is exhausted
        while True:
            sess.run(next_batch)
            
    except tf.errors.OutOfRangeError:
        pass
        

### One-shot iterator
This is the most basic type of iterator. All the data with all types of transformations that is needed in the dataset has to be decided before the Dataset is fed into this iterator. One-shot iterator will iterate through all the elements present in Dataset and once exhausted, cannot be used anymore. As a result, the Dataset generated for this iterator can tend to occupy a lot of memory.

In [17]:
from tqdm import tqdm_notebook as tqdm


epochs = 10
batch_size = 64
iterations = len(y_train) * epochs

dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
# Generate the complete Dataset required in the pipeline
dataset = dataset.repeat(epochs).batch(batch_size)
iterator = dataset.make_one_shot_iterator()

data_X, data_y = iterator.get_next()
data_y = tf.cast(data_y, tf.int32)
model = Model(data_X, data_y)

with tf.Session() as sess, tqdm(total = iterations) as pbar:
    sess.run(tf.global_variables_initializer())

    tot_accuracy = 0
    try:
        while True:
            accuracy, _ = sess.run([model.accuracy, model.optimizer])
            tot_accuracy += accuracy
            pbar.update(batch_size)
    except tf.errors.OutOfRangeError:
        pass

print('\nAverage training accuracy: {:.4f}'.format(tot_accuracy / iterations))

HBox(children=(IntProgress(value=0, max=550000), HTML(value='')))



Average training accuracy: 0.9816


In the example above, we have generated the Dataset for a total of 10 epochs. Use this particular iterator only if your dataset is small in size or in cases where you would like to perform testing on your model only once.

### Initializable
In One-shot iterator, we had the shortfall of repetition of same training dataset in memory and there was absence of periodically validating our model using validation dataset in our code. In initializable iterator we overcome these problems. Initializable iterator has to be initialized with dataset before it starts running. Take a look at the code.

In [18]:
epochs = 10
batch_size = 64

placeholder_X = tf.placeholder(tf.float32, [None, 28, 28, 1])
placeholder_y = tf.placeholder(tf.int32, [None])

dataset = tf.data.Dataset.from_tensor_slices((placeholder_X, placeholder_y))
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()

data_X, data_y = iterator.get_next()
data_y = tf.cast(data_y, tf.int32)
model = Model(data_X, data_y)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch_no in range(epochs):
        train_loss, train_accuracy = 0, 0
        val_loss, val_accuracy = 0, 0

        # Initialize iterator with training data
        sess.run(iterator.initializer, feed_dict = {placeholder_X: X_train, placeholder_y: y_train})
        try:
            with tqdm(total = len(y_train)) as pbar:
                while True:
                    _, loss, acc = sess.run([model.optimizer, model.loss, model.accuracy])
                    train_loss += loss 
                    train_accuracy += acc
                    pbar.update(batch_size)
        except tf.errors.OutOfRangeError:
            pass
    
        # Initialize iterator with validation data
        sess.run(iterator.initializer, feed_dict = {placeholder_X: X_val, placeholder_y: y_val})
        try:
            while True:
                loss, acc = sess.run([model.loss, model.accuracy])
                val_loss += loss 
                val_accuracy += acc
        except tf.errors.OutOfRangeError:
            pass
    
        print('\nEpoch No: {}'.format(epoch_no + 1))
        print('Train accuracy = {:.4f}, loss = {:.4f}'.format(train_accuracy / len(y_train), 
                                                        train_loss / len(y_train)))
        print('Val accuracy = {:.4f}, loss = {:.4f}'.format(val_accuracy / len(y_val), 
                                                        val_loss / len(y_val)))

HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 1
Train accuracy = 0.9150, loss = 0.2886
Val accuracy = 0.9664, loss = 0.1087


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 2
Train accuracy = 0.9741, loss = 0.0830
Val accuracy = 0.9788, loss = 0.0714


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 3
Train accuracy = 0.9824, loss = 0.0577
Val accuracy = 0.9812, loss = 0.0635


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 4
Train accuracy = 0.9868, loss = 0.0428
Val accuracy = 0.9844, loss = 0.0549


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 5
Train accuracy = 0.9898, loss = 0.0331
Val accuracy = 0.9858, loss = 0.0474


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 6
Train accuracy = 0.9915, loss = 0.0270
Val accuracy = 0.9888, loss = 0.0430


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 7
Train accuracy = 0.9936, loss = 0.0205
Val accuracy = 0.9874, loss = 0.0462


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 8
Train accuracy = 0.9939, loss = 0.0183
Val accuracy = 0.9890, loss = 0.0448


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 9
Train accuracy = 0.9946, loss = 0.0164
Val accuracy = 0.9868, loss = 0.0542


HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch No: 10
Train accuracy = 0.9952, loss = 0.0140
Val accuracy = 0.9874, loss = 0.0514


As can be seen, using [initializer operation](https://www.tensorflow.org/api_docs/python/tf/data/Iterator#initializer), we have changed the dataset between training and validation using the same Dataset object.

This iterator is very ideal when you have to train your model with datasets which are split across multiple places and you are not able to accumulate them into one place.

### Reinitializable
In initializable iterator, there was a shortfall of different datasets undergoing the same pipeline before the Dataset is fed into the iterator. This problem is overcome by reinitializable iterator as we have the ability to feed different types of Datasets thereby undergoing different pipelines. Only one care has to be taken is that different Datasets are of the same data type. Take a look at the code.

In [19]:
def map_fn(x, y):
    # Do transformations here
    return x, y

epochs = 10
batch_size = 64

placeholder_X = tf.placeholder(tf.float32, shape = [None, 28, 28, 1])
placeholder_y = tf.placeholder(tf.int32, shape = [None])

# Create separate Datasets for training and validation
train_dataset = tf.data.Dataset.from_tensor_slices((placeholder_X, placeholder_y))
train_dataset = train_dataset.batch(batch_size).map(lambda x, y: map_fn(x, y))
val_dataset = tf.data.Dataset.from_tensor_slices((placeholder_X, placeholder_y))
val_dataset = val_dataset.batch(batch_size)

# Iterator has to have same output types across all Datasets to be used
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
data_X, data_y = iterator.get_next()
data_y = tf.cast(data_y, tf.int32)
model = Model(data_X, data_y)

# Initialize with required Datasets
train_iterator = iterator.make_initializer(train_dataset)
val_iterator = iterator.make_initializer(val_dataset)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch_no in range(epochs):
        train_loss, train_accuracy = 0, 0
        val_loss, val_accuracy = 0, 0

        # Start train iterator
        sess.run(train_iterator, feed_dict = {placeholder_X: X_train, placeholder_y: y_train})
        try:
            with tqdm(total = len(y_train)) as pbar:
                while True:
                    _, acc, loss = sess.run([model.optimizer, model.accuracy, model.loss])
                    train_loss += loss
                    train_accuracy += acc
                    pbar.update(batch_size)
        except tf.errors.OutOfRangeError:
            pass

        # Start validation iterator
        sess.run(val_iterator, feed_dict = {placeholder_X: X_val, placeholder_y: y_val})
        try:
            while True:
                acc, loss = sess.run([model.accuracy, model.loss])
                val_loss += loss
                val_accuracy += acc
        except tf.errors.OutOfRangeError:
            pass

        print('\nEpoch: {}'.format(epoch_no + 1))
        print('Train accuracy: {:.4f}, loss: {:.4f}'.format(train_accuracy / len(y_train),
                                                             train_loss / len(y_train)))
        print('Val accuracy: {:.4f}, loss: {:.4f}\n'.format(val_accuracy / len(y_val), 
                                                            val_loss / len(y_val)))

HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 1
Train accuracy: 0.9163, loss: 0.2781
Val accuracy: 0.9646, loss: 0.1165



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 2
Train accuracy: 0.9756, loss: 0.0786
Val accuracy: 0.9788, loss: 0.0761



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 3
Train accuracy: 0.9826, loss: 0.0550
Val accuracy: 0.9808, loss: 0.0650



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 4
Train accuracy: 0.9876, loss: 0.0411
Val accuracy: 0.9858, loss: 0.0527



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 5
Train accuracy: 0.9906, loss: 0.0320
Val accuracy: 0.9858, loss: 0.0498



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 6
Train accuracy: 0.9920, loss: 0.0270
Val accuracy: 0.9810, loss: 0.0700



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 7
Train accuracy: 0.9929, loss: 0.0227
Val accuracy: 0.9860, loss: 0.0538



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 8
Train accuracy: 0.9939, loss: 0.0193
Val accuracy: 0.9854, loss: 0.0538



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 9
Train accuracy: 0.9946, loss: 0.0172
Val accuracy: 0.9862, loss: 0.0582



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 10
Train accuracy: 0.9952, loss: 0.0147
Val accuracy: 0.9862, loss: 0.0547



Notice that training Dataset object is undergoing additional augmentation which validation Dataset is not. You could have directly fed the training and validation datasets into Dataset objects but I have made use of placeholders just to show the flexibility.

### Feedable
The reinitializable iterator gave the flexibility of assigning differently pipelined Datasets to iterator, but the iterator was inadequate to maintain the state (i.e till where the data has been emitted by individual iterator). In the code sample, I am showing how to use Feedable iterator.

In [21]:
import numpy as np

def map_fn(x, y):
    # Do transformations here
    return x, y

epochs = 10
batch_size = 64

placeholder_X = tf.placeholder(tf.float32, shape = [None, 28, 28, 1])
placeholder_y = tf.placeholder(tf.int32, shape = [None])

# Create separate Datasets for training, validation and testing
train_dataset = tf.data.Dataset.from_tensor_slices((placeholder_X, placeholder_y))
train_dataset = train_dataset.batch(batch_size).map(lambda x, y: map_fn(x, y))

val_dataset = tf.data.Dataset.from_tensor_slices((placeholder_X, placeholder_y))
val_dataset = val_dataset.batch(batch_size)

y_test = np.array(y_test, dtype = np.int32)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_dataset = test_dataset.batch(batch_size)

# Feedable iterator assigns each iterator a unique string handle it is going to work on 
handle = tf.placeholder(tf.string, shape = [])
iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
data_X, data_y = iterator.get_next()
data_y = tf.cast(data_y, tf.int32)
model = Model(data_X, data_y)

# Create Reinitializable iterator for Train and Validation, one shot iterator for Test
train_val_iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
train_iterator = train_val_iterator.make_initializer(train_dataset)
val_iterator = train_val_iterator.make_initializer(val_dataset)
test_iterator = test_dataset.make_one_shot_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # Create string handles for above reinitializable and one shot iterators.
    train_val_string = sess.run(train_val_iterator.string_handle())
    test_string = sess.run(test_iterator.string_handle())

    for epoch_no in range(epochs):
        train_loss, train_accuracy = 0, 0
        val_loss, val_accuracy = 0, 0

        # Start reinitializable's train iterator
        sess.run(train_iterator, feed_dict = {placeholder_X: X_train, placeholder_y: y_train})
        try:
            with tqdm(total = len(y_train)) as pbar:
                while True:
                    # Feed to feedable iterator the string handle of reinitializable iterator
                    _, loss, acc = sess.run([model.optimizer, model.loss, model.accuracy], \
                                                feed_dict = {handle: train_val_string})
                    train_loss += loss
                    train_accuracy += acc
                    pbar.update(batch_size)
        except tf.errors.OutOfRangeError:
            pass
      
        # Start reinitializable's validation iterator
        sess.run(val_iterator, feed_dict = {placeholder_X: X_val, placeholder_y: y_val})
        try:
            while True:
                loss, acc = sess.run([model.loss, model.accuracy], \
                                        feed_dict = {handle: train_val_string})
                val_loss += loss
                val_accuracy += acc
        except tf.errors.OutOfRangeError:
            pass
    
        print('\nEpoch: {}'.format(epoch_no + 1))
        print('Training accuracy: {:.4f}, loss: {:.4f}'.format(train_accuracy / len(y_train), 
                                                                train_loss / len(y_train)))
        print('Val accuaracy: {:.4f}, loss: {:.4f}\n'.format(val_accuracy / len(y_val), 
                                                                val_loss / len(y_val)))
    
    test_loss, test_accuracy = 0, 0
    try:
        while True:
            # Feed to feedable iterator the string handle of one shot iterator
            loss, acc = sess.run([model.loss, model.accuracy], feed_dict = {handle: test_string})
            test_loss += loss
            test_accuracy += acc
    except tf.errors.OutOfRangeError:
        pass

print('\nTest accuracy: {:.4f}, loss: {:.4f}'.format(test_accuracy / len(y_test), test_loss / len(y_test)))

HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 1
Training accuracy: 0.9218, loss: 0.2656
Val accuaracy: 0.9626, loss: 0.1290



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 2
Training accuracy: 0.9752, loss: 0.0804
Val accuaracy: 0.9810, loss: 0.0610



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 3
Training accuracy: 0.9836, loss: 0.0551
Val accuaracy: 0.9812, loss: 0.0570



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 4
Training accuracy: 0.9871, loss: 0.0415
Val accuaracy: 0.9826, loss: 0.0513



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 5
Training accuracy: 0.9902, loss: 0.0328
Val accuaracy: 0.9828, loss: 0.0522



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 6
Training accuracy: 0.9916, loss: 0.0275
Val accuaracy: 0.9832, loss: 0.0535



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 7
Training accuracy: 0.9933, loss: 0.0224
Val accuaracy: 0.9828, loss: 0.0557



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 8
Training accuracy: 0.9940, loss: 0.0186
Val accuaracy: 0.9876, loss: 0.0521



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 9
Training accuracy: 0.9945, loss: 0.0174
Val accuaracy: 0.9844, loss: 0.0612



HBox(children=(IntProgress(value=0, max=55000), HTML(value='')))



Epoch: 10
Training accuracy: 0.9949, loss: 0.0156
Val accuaracy: 0.9846, loss: 0.0661


Test accuracy: 0.9855, loss: 0.0575


Though not illustrated in above code sample, using the string handle, we can restart the particular point from where the data extraction was done while altering between different Datasets.

This iterator is ideal in scenarios where you are training simultaneously a model with different datasets and you need better control to decide which particular batch of dataset has to be fed next to model.