# tensorflow.contrib.training.batch_sequences_with_states (Legacy)

## Introduction

As is described in the [official documentation](https://www.tensorflow.org/api_docs/python/tf/contrib/training/batch_sequences_with_states), this function creates and batches segments of input sequences. In addition, for each sample in a batch, it maintains and updates a copy of state. This is quite handy when this method is employed to construct data input pipelines for models such as RNN.

In [1]:
import os
from pprint import pprint

import numpy as np
import tensorflow as tf
from tensorflow.contrib.training import batch_sequences_with_states


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
print('TensorFlow version:', tf.VERSION)

TensorFlow version: 1.12.0


## Generating Toy Data
The toy data used in the following sections makes up an instance of the **sequence classification**. Credit goes to Jason Brownlee's illuminating [tutorial](https://machinelearningmastery.com/sequence-prediction-problems-learning-lstm-recurrent-neural-networks/) (Example 5).

_"The problem is defined as a sequence of random values between 0 and 1. This sequence is taken as input for the problem with each number provided one per timestep._

_A binary label (0 or 1) is associated with each input. The output values are all 0. Once the cumulative sum of the input values in the sequence exceeds a threshold, then the output value flips from 0 to 1._

_A threshold of 1/4 the sequence length is used."_

In [2]:
list_seqs, list_labels = list(), list()
max_seq_len = 15
min_seq_len = 12
num_seq = 100
for _ in range(num_seq):
    len_seq = np.random.randint(min_seq_len, max_seq_len)
    seq_tokens = np.random.rand(len_seq)
    list_seqs.append(seq_tokens)
    threshold = len_seq / 4.0
    seq_labels = np.array([int(x > threshold) for x in np.cumsum(seq_tokens)])
    list_labels.append(seq_labels)

# Print the first 5 sequences of tokens and labels
for seq, labels in zip(list_seqs[:5], list_labels):
    pprint(seq)
    pprint(labels)
    print('\n')

array([0.39676616, 0.86157441, 0.83979143, 0.67745519, 0.20498841,
       0.2182948 , 0.23537936, 0.12690887, 0.14148605, 0.95540276,
       0.54239793, 0.39336251])
array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])


array([0.22966606, 0.54282423, 0.48165246, 0.17545146, 0.49743892,
       0.96969642, 0.30970795, 0.57493861, 0.36068693, 0.08612079,
       0.78507474, 0.52802721, 0.91969071, 0.60057449])
array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])


array([0.13858428, 0.42406717, 0.0669202 , 0.81489591, 0.37393337,
       0.94317965, 0.99856732, 0.89698231, 0.48522192, 0.2183444 ,
       0.48269078, 0.47967869])
array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])


array([0.22425017, 0.97210877, 0.86698953, 0.2286461 , 0.12806215,
       0.24681088, 0.60760401, 0.96746138, 0.49992784, 0.21748109,
       0.92627136, 0.33691707, 0.67201917])
array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])


array([0.71588358, 0.95822808, 0.75794044, 0.13306914, 0.13664142,
       0.29724938, 0.05303291, 0.4270082

<a id='writing_to_tfrecords'></a>
## Writing to TFRecords

### Thread-and-queue-based implementation

Instead of directly accepting the sequence data, the `input_sequences` argument of `batch_sequences_with_states` requires thread-and-queue-based implementation of the input sequences. Unfortunately, this crucial requirement is either ambiguously expressed or totally missing in the documentation. A main consequence of falling into this pitfall is getting duplicates of the same sequence of data. For an example, see [this issue](https://github.com/tensorflow/tensorflow/issues/21959).

###  `tf.train.SequenceExample`

In order to prepare the input for `batch_sequences_with_states`, the data can be first wrapped by `tf.train.SequenceExample`, which is essentially a [protocal buffer](https://developers.google.com/protocol-buffers/?hl=en). To have a glance at concrete examples of `SequenceExample`, it may suffice to look at the examples coming with the [definition](https://github.com/tensorflow/tensorflow/blob/r1.11/tensorflow/core/example/example.proto). The code in the following cell instantiates a naive usage of this data wrapper.

In [0]:
F_TOKEN_ID = 'token'
F_LABEL_ID = 'label'
fname_data = 'toydata.tfrecords'
if os.path.isfile(fname_data):
    os.remove(fname_data)
    
writer = tf.python_io.TFRecordWriter(fname_data)
for seq_tokens, seq_labels in zip(list_seqs, list_labels):
    seq = tf.train.SequenceExample()
    flist = seq.feature_lists.feature_list
    tokens = flist[F_TOKEN_ID].feature
    labels = flist[F_LABEL_ID].feature
    for tk, lb in zip(seq_tokens, seq_labels):
        token = tokens.add()
        token.float_list.value.append(tk)
        label = labels.add()
        label.int64_list.value.append(lb)
    writer.write(seq.SerializeToString())

# Dont' forget to close the writer!!!
writer.close()

## Load Data

Three classes / functions are involved in the loading and parsing of the toy data: `tf.train.string_input_producer`,  `tf.TFRecordReader` and `tf.parse_single_sequence_example`.

### `tf.train.string_input_producer`
According to the [documentation](https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer), this function creates a queue which outputs the string values passed to this function for data input pipelines. A typical usage of this function is outputting filenames. In addition, when this function is called, a `QueueRunner` is added to the containing graph's `QUEUE_RUNNER` collection. Therefore, the user does not need to manually add a `QueueRunner`.

### `tf.TFRecordReader`
Reads the data stored in a `tfrecords` file.

### `tf.parse_single_sequence_example`
Parses a serialized `SequenceExample`. Apart from the serialized record, the main argument is `sequence_features` which specifies the characteristic information about each contained feature, as is identified by the id / key in the `feature_list` of `SequenceExample`. In this toy scenario, the only feature at issue is a 1-D scalar, namely, an integer, which is identified by the key `F_TOKEN_ID`. Therefore, it suffices to characterize the feature to parse with `tf.FixedLenSequenceFeature([], dtype=tf.int64)`.


In [4]:
file_queue = tf.train.string_input_producer([fname_data])
reader = tf.TFRecordReader()
seq_key, serialized_record = reader.read(file_queue)
ctx, sequence = tf.parse_single_sequence_example(
    serialized_record,
    sequence_features={
        F_TOKEN_ID: tf.FixedLenSequenceFeature([], dtype=tf.float32),
        F_LABEL_ID: tf.FixedLenSequenceFeature([], dtype=tf.int64)
    })

Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.TFRecordDataset`.


Let's first output the sequence keys, sequence values and label values loaded from the `tfrecords` file.

In [5]:
with tf.train.MonitoredSession() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for _ in range(5):
        val_seq_key, val_seq = sess.run([seq_key, sequence])
        print(val_seq_key.decode())
        print('tokens:\t', val_seq[F_TOKEN_ID])
        print('labels:\t', val_seq[F_LABEL_ID], '\n')
    coord.request_stop()
    coord.join(threads)

INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
toydata.tfrecords:0
tokens:	 [0.39676616 0.8615744  0.8397914  0.6774552  0.2049884  0.2182948
 0.23537935 0.12690887 0.14148605 0.9554028  0.5423979  0.39336252]
labels:	 [0 0 0 0 0 1 1 1 1 1 1 1] 

toydata.tfrecords:246
tokens:	 [0.22966605 0.5428242  0.48165247 0.17545146 0.4974389  0.9696964
 0.30970794 0.5749386  0.36068693 0.08612079 0.7850748  0.52802724
 0.9196907  0.6005745 ]
labels:	 [0 0 0 0 0 0 0 1 1 1 1 1 1 1] 

toydata.tfrecords:527
tokens:	 [0.13858429 0.42406717 0.06692021 0.8148959  0.37393337 0.94317967
 0.99856734 0.8969823  0.48522192 0.2183444  0.48269078 0.4796787 ]
labels:	 [0 0 0 0 0 0 1 1 1 1 1 1] 

toydata.tfrecords:773
tokens:	 [0.22425017 0.9721088  0.86698955 0.2286461  0.12806216 0.24681088
 0.607604   0.9674614  0.49992785 0.21748109 0.9262714  0.33691707


##Batch Sequence with States

### `initial_states`
As a characteristic feature of `tf.contrib.training.batch_sequences_with_states`, a copy of states for each input sequence is maintained. Therefore, it contains the variable `initial_states` requiring a dict mapping state names (of string type) to initial values of the states. Specifically, the values of `initial_states` should be in line with the RNN model to combine with. Since the LSTM layer used later is built from `tf.nn.static_state_saving_rnn` using `tf.nn.rnn_cell.LSTMCell`, two states are involved: `'lstm_state_c'` and `'lstm_state_h'`.

For the purpose of readability, a small `lstm_size` is temporarily used, which will be revised later when training a toy LSTM model.

### `make_keys_unique`
In addition, the argument `make_keys_unique` needs to be set to True, in order to repeatedly use the data in training. Formally, as the documentation for this function states, it appends a random integer to the end of `input_key`. This effect can be viewed below. As a result, the value of `input_key` for an input sequence is different every time when the sequence is fed to the model.

In [0]:
lstm_size = 8
num_unroll = 3
states =  {
    'lstm_state_c': np.random.rand(lstm_size),
    'lstm_state_h': np.random.rand(lstm_size)
}
batch_size = 2
batch = tf.contrib.training.batch_sequences_with_states(
    input_key=seq_key,
    input_sequences=sequence,
    input_context=ctx,
    input_length=tf.shape(sequence[F_TOKEN_ID])[0],
    initial_states=states,
    num_unroll=num_unroll,
    batch_size=batch_size,
    allow_small_batch=False,
    make_keys_unique=True,
    make_keys_unique_seed=29392
)

A trivial update of the states is used. Specifically, each scalar component is added 1. Looking at the values of the tensors involved in the function `batch_sequences_with_states`, the following properties can be seen.

### `batch.key`

The value of each segment's `key` is a string built by joining three components with colon.

* The number of segments obtained from an input sequence and the index of the current segment (e.g., `'00000_of_00005'`).
* The TFRecords file name.
* The random integer identifying the input sequence in the current iteration.

### `batch.state`

Echoing the unique random integer contained in the value of `key`, a fresh copy of states is created and updated.
That is, if an input sequence is fetched N times in the session, N copies of states will be created.
This can be seen if a small value of `num_seq` is picked and more iterations are run.

### Zero padding

When the length of an input sequence is not a multiple of `num_unroll`, zeros are padded to the last segment.


In [7]:
with tf.Session() as sess:
    state_c = batch.state('lstm_state_c')
    state_h = batch.state('lstm_state_h')
    update_state_c = batch.save_state('lstm_state_c', state_c + 1)
    update_state_h = batch.save_state('lstm_state_h', state_h + 1)
    update_states = tf.group(update_state_c, update_state_h)
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(8):
        print('Iteration {}:'.format(i + 1))
        val_key, val_next_key, val_seqs, val_labels, val_state_c, val_state_h, \
        val_batch_len, _ = sess.run([
            batch.key, batch.next_key, batch.sequences[F_TOKEN_ID],
            batch.sequences[F_LABEL_ID], state_c, state_h,
            batch.length, update_states
        ])
        for name, val in zip(
            ['keys','sequences', 'labels', 'state_c', 'state_h'],
            [val_key, val_seqs, val_labels, val_state_c, val_state_h]
        ):
            print(name, '\n', val)
        print('\n')
    coord.request_stop()

Iteration 1:
keys 
 [b'00000_of_00004:toydata.tfrecords:079098054'
 b'00000_of_00005:toydata.tfrecords:24651255884']
sequences 
 [[0.39676616 0.8615744  0.8397914 ]
 [0.22966605 0.5428242  0.48165247]]
labels 
 [[0 0 0]
 [0 0 0]]
state_c 
 [[0.96406707 0.63527062 0.02453578 0.55637392 0.61685748 0.62238103
  0.01480179 0.77377287]
 [0.96406707 0.63527062 0.02453578 0.55637392 0.61685748 0.62238103
  0.01480179 0.77377287]]
state_h 
 [[0.1720632  0.7212568  0.9412366  0.012278   0.63105657 0.17493996
  0.12981576 0.53359212]
 [0.1720632  0.7212568  0.9412366  0.012278   0.63105657 0.17493996
  0.12981576 0.53359212]]


Iteration 2:
keys 
 [b'00001_of_00004:toydata.tfrecords:079098054'
 b'00001_of_00005:toydata.tfrecords:24651255884']
sequences 
 [[0.6774552  0.2049884  0.2182948 ]
 [0.17545146 0.4974389  0.9696964 ]]
labels 
 [[0 0 1]
 [0 0 0]]
state_c 
 [[1.96406707 1.63527062 1.02453578 1.55637392 1.61685748 1.62238103
  1.01480179 1.77377287]
 [1.96406707 1.63527062 1.02453578 1.5563

## Application to Model Training

In order to illustrate the usage of `batch_sequence_with_states` in the training of a model, a few other under-documented black boxes in TensorFlow are used.
Mainly, `batch_sequence_with_states` and `tf.nn.static_state_saving_rnn` make a perfect match.
The former not only generates input data for the latter, but also provide spaces tracking the states of the LSTM layer.
Formally, the latter takes the former as the value of the argument `state_saver`

In [0]:
lstm_size = 64
states =  {
    'lstm_state_c': np.random.rand(lstm_size),
    'lstm_state_h': np.random.rand(lstm_size)
}
batch = tf.contrib.training.batch_sequences_with_states(
    input_key=seq_key,
    input_sequences=sequence,
    input_context=ctx,
    input_length=tf.shape(sequence[F_TOKEN_ID])[0],
    initial_states=states,
    num_unroll=num_unroll,
    batch_size=batch_size,
    allow_small_batch=False,
    make_keys_unique=True,
    make_keys_unique_seed=29392)

In [0]:
inputs = tf.cast(batch.sequences[F_TOKEN_ID], tf.float64)
inputs_by_time = tf.split(value=inputs, num_or_size_splits=num_unroll, axis=1)
assert len(inputs_by_time) == num_unroll
assert inputs_by_time[0].get_shape() == (batch_size, 1)

In [0]:
# LSTM layer
cell = tf.nn.rnn_cell.LSTMCell(num_units=lstm_size, reuse=tf.AUTO_REUSE)
lstm_output, _ = tf.nn.static_state_saving_rnn(
    cell, inputs_by_time, batch, state_name=('lstm_state_c', 'lstm_state_h'))
# print(len(lstm_output), lstm_output[0].get_shape())

# Adding Dropout
lstm_output = tf.nn.dropout(lstm_output, 0.5)

In [0]:
# logits layer
layer_size = 32
dense = tf.keras.models.Sequential(name='logits')
dense.add(
    tf.keras.layers.Dense(layer_size, activation='relu', input_dim=lstm_size))
dense.add(
    tf.keras.layers.Dense(layer_size, activation='relu'))
dense.add(tf.keras.layers.Dense(1))

In [0]:
labels = tf.transpose(batch.sequences[F_LABEL_ID])
logits = dense(tf.cast(lstm_output, tf.float32))
logits = tf.squeeze(logits, -1)
assert logits.get_shape() == labels.get_shape()

loss = tf.nn.sigmoid_cross_entropy_with_logits(
    logits=logits, labels=tf.cast(labels, tf.float32))

In [0]:
optimizer = tf.train.AdadeltaOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
with tf.control_dependencies([train_op]):
    loss_val = tf.reduce_sum(loss)

In [14]:
len_phase = 1000
sum_loss = 0

initializer = tf.initializers.global_variables()
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(initializer)
    for i in range(1, 30001):
        sum_loss += sess.run(loss_val)
        if i % len_phase == 0:
            print('Mean loss at step {0:5}:  {1}'.format(i, sum_loss / len_phase))
            sum_loss = 0
    coord.request_stop()

Mean loss at step  1000:  4.066202268600464
Mean loss at step  2000:  4.019917325496674
Mean loss at step  3000:  3.9779060764312746
Mean loss at step  4000:  3.9280554127693175
Mean loss at step  5000:  3.8788565654754636
Mean loss at step  6000:  3.8282575843334197
Mean loss at step  7000:  3.781137935400009
Mean loss at step  8000:  3.714286876440048
Mean loss at step  9000:  3.6556279363632203
Mean loss at step 10000:  3.595656502485275
Mean loss at step 11000:  3.5299532765150072
Mean loss at step 12000:  3.4480678269863128
Mean loss at step 13000:  3.3744378654956817
Mean loss at step 14000:  3.284869670748711
Mean loss at step 15000:  3.183376985669136
Mean loss at step 16000:  3.1031557419300078
Mean loss at step 17000:  2.9778473482728005
Mean loss at step 18000:  2.8747157000303267
Mean loss at step 19000:  2.7662972851991654
Mean loss at step 20000:  2.66409159809351
Mean loss at step 21000:  2.573668681561947
Mean loss at step 22000:  2.4652993575334547
Mean loss at step 23