# Simple dynamic seq2seq with TensorFlow

This tutorial covers building seq2seq using dynamic RNN rollout with TensorFlow. I wasn't able to find any existing implementation of dynamic seq2seq with TF (as of 01.01.2017), so I decided to learn how to write my own, and write up what I learn in the process.

I deliberately try to be as explicit as possible. As it currently stands, TF code is the best source of documentation on itself, and I have a feeling that many conventions and design decisions are not documented anywhere except in the brains of Google Brain engineers. This makes learning it a bit of a guesswork, and also wickedly fun.

I hope this will be useful to people whose brains are wired like mine.

## RNN rollout
```
@TODO: rollout, static vs dynamic
```

## Data specification

Seq2seq maps sequence onto another sequence. Both sequences consist of integers from a fixed range. In language tasks, integers usually correspond to words: we first construct a vocabulary by assigning to every word in our corpus a serial integer. First few integers are reserved for special tokens. We'll call the upper bound on vocabulary a `vocabulary size`.

Input data consists of sequences of integers.

In [1]:
sentences = [[5, 7, 8], [6, 3], [3], [1]]
print(sentences)

[[5, 7, 8], [6, 3], [3], [1]]


In [2]:
PAD = 0
EOS = 1

While manipulating such variable-length lists are convenient to humans, RNNs prefer a different layout:
```
[[5, 6, 3, 1],
 [7, 3, 0, 0],
 [8, 0, 0, 0]]
```

Here, first dimension (major axis) becomes `time` instead of `sentence`. Hence the name `time-major layout`.

Legacy tensorflow seq2seq module used static unrolling and could only work with time-major layout. Dynamicly unrolled RNNs can use batch-major layout, but will suffer some performance penalty. 

We would use time-major layout for this tutorial.

In [3]:
import numpy as np

def preprocess_batch(inputs, max_sequence_length=None):
    """
    Args:
        inputs:
            list of sentences (integer lists)
        max_sequence_length:
            integer specifying how large should `max_time` dimension be.
            If None, maximum sequence length would be used
    
    Outputs:
        inputs_time_major:
            input sentences transformed into time-major matrix (shape [max_time, batch_size])
            padded with 0s
        sequence_lengths:
            batch-sized list of integers specifying amount of active time steps in each input sequence
    """
    
    sequence_lengths = [len(seq) for seq in inputs]
    batch_size = len(inputs)
    
    if max_sequence_length is None:
        max_sequence_length = max(sequence_lengths)
    
    inputs_batch_major = np.zeros(shape=[batch_size, max_sequence_length], dtype=np.int32) # == PAD
    
    for i, seq in enumerate(inputs):
        for j, element in enumerate(seq):
            inputs_batch_major[i, j] = element

    # [batch_size, max_time] -> [max_time, batch_size]
    inputs_time_major = inputs_batch_major.swapaxes(0, 1)

    return inputs_time_major, sequence_lengths

In [4]:
inputs_time_major, inputs_length = preprocess_batch(sentences)

In [5]:
inputs_time_major

array([[5, 6, 3, 1],
       [7, 3, 0, 0],
       [8, 0, 0, 0]], dtype=int32)

Notice that `inputs_time_major` is a matrix with `sequences` in columns, time steps in rows.

Even though TensorFlow can now do fully dynamic unrolling of RNNs, every batch should be a Tensor with some dimensions. This is why we need padding.

In [6]:
inputs_length

[3, 2, 1, 1]

`batch_length` specifies length of every example in batch Tensor. This is used by TF ops to know when to stop unrolling the network.

# Building a model

## Simple seq2seq

First we implement plain seq2seq — forward-only encoder + decoder without attention. I'll try to follow closely the original architecture described in [Sutskever, Vinyals and Le (2014)](https://arxiv.org/abs/1409.3215). If you notice any deviations, please let me know.

Architecture diagram from their paper:
![seq2seq architecutre](pictures/1-seq2seq.png)
Rectangles are encoder and decoder's recurrent layers. Encoder receives `[A, B, C]` sequence as inputs. We don't care about encoder outputs, only about the hidden state it accumulates while reading the sequence. After input sequence ends, encoder passes its final state to decoder, which receives `[<EOS>, W, X, Y, Z]` and is trained to output `[W, X, Y, Z, <EOS>]`. `<EOS>` token is a special word in vocabulary that signals to decoder the beginning of translation.

Encoder starts with empty state and runs through the input sequence. We are not interested in encoder's outputs, only in its `final_state`.

Decoder uses encoder's `final_state` as its `initial_state`. Its inputs are a batch-sized matrix with `<EOS>` token at the 1st time step and `<PAD>` at the following. This is a rather crude setup, useful only for tutorial purposes. In practice, we would like to feed previously generated tokens after `<EOS>`.

Decoder's outputs are mapped onto the output space using `[hidden_units x output_vocab_size]` projection layer. This is necessary because we cannot make `hidden_units` of decoder arbitrarily large, while our target space would grow with the size of the dictionary.

This kind of encoder-decoder is forced to learn fixed-length representation (specifically, `hidden_units` size) of the variable-length input sequence and restore output sequence only from this representation.

In [7]:
import tensorflow as tf

tf.reset_default_graph()
sess = tf.InteractiveSession()

### Model inputs and outputs 

First critical thing to decide: vocabulary size.

Dynamic RNN models can be adapted to different batch sizes and sequence lengths without retraining (e.g. by serializing model parameters and Graph definitions via `tf.train.Saver`), but changing vocabulary size requires retraining the model.

In [8]:
vocab_size = 10

Nice way to understand complicated function is to study its signature - inputs and outputs. With pure functions, only inputs-output relation matters.

- `encoder_inputs` int32 tensor is shaped `[encoder_max_time, batch_size]`
- `decoder_targets` int32 tensor is shaped `[decoder_max_time, batch_size]`

In [9]:
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')

We'll add one additional placeholder tensor: 
- `decoder_inputs` int32 tensor is shaped `[decoder_max_time, batch_size]`

In [10]:
decoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_inputs')

We actually don't want to feed `decoder_inputs` by hand, but in case of a tutorial, I think it will be illustrative.

During training, `decoder_inputs` will consist of `<EOS>` token concatenated with `decoder_targets` along time axis. In this way, we always pass target sequence as the history to the decoder, regrardless of what it actually outputs predicts. This can introduce distribution shift from training to prediction. 
In prediction mode, model will receive tokens it previously generated (via argmax over logits), not the ground truth, which would be unknowable.

Notice that all shapes are specified with `None`s (dynamic). We can use batches of any size with any number of timesteps. This is convenient and efficient, however but there are obvious constraints: 
- Feed values for all tensors should have same `batch_size`
- Decoder inputs and ouputs (`decoder_inputs` and `decoder_targets`) should have same `decoder_max_time`

### Encoder

The centerpiece of all things RNN in TensorFlow is `RNNCell` class and its descendants (like `LSTMCell` and `InputProjectionWrapper`). But they are outside of the scope of this post — nice [official tutorial](https://www.tensorflow.org/tutorials/recurrent/) is available. While at it, make sure you're also familiar with the [official tutorial on embeddings](https://www.tensorflow.org/tutorials/word2vec/).

`@TODO: RNNCell as a factory`

In [11]:
from tensorflow.contrib.rnn import (LSTMCell,
                                    InputProjectionWrapper,
                                    OutputProjectionWrapper)

In [12]:
hidden_units = 20

encoder_cell = LSTMCell(hidden_units)

encoder_cell = InputProjectionWrapper(cell=encoder_cell, num_proj=hidden_units)

encoder_inputs_onehot = tf.one_hot(encoder_inputs, depth=vocab_size, dtype=tf.float32)

encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
    encoder_cell, encoder_inputs_onehot,
    dtype=tf.float32, time_major=True,
)

del encoder_outputs

Observe what is going on here:
- `encoder_inputs`, integer tensor with shape `[max_time, batch_size]`, gets one-hot encoded. Every integer at the every timestep in every batch gets transformed into a vector with size `vocab_size` that forms the new inner dimension. Resulting `encoder_inputs_onehot` is shaped `[max_time, batch_size, vocab_size]`. It consists of a single 1 and (vocab_size-1) 0s, but we want it to be float32 since we're going to multiply it with projection layer's weights.
- `InputProjectionWrapper` adds `[vocab_size, hidden_units]` linear projection layer (without nonlinearities) before `LSTMCell`'s inputs. Resulting projection is shaped `[max_time, batch_size, hidden_units]` and is compatible with `LSTMCell`.

There is more efficient way to get `[max_time, batch_size, hidden_units]` shape from `encoder_inputs`: `tf.nn.embedding_lookup`. Instead of one-hot encoding followed by matrix multiplication, it treats input integers as indices of projection layer's weights (this works because multiplication of one-hot vector with a dense matrix is exactly equivalent to indexing this matrix with an integer). Additionally, I think, it removes unnecessary compution of gradients for inactive embeddings (`@TODO: check, reformulate`).

We discard `encoder_outputs` because we are not interested in them within seq2seq framework. What we actually want is `encoder_final_state` — state of LSTM's hidden cells at the last moment of the Encoder rollout.

`encoder_final_state` is also called "thought vector". We will use it as initial state for the Decoder. In seq2seq without attention this is the only point where Encoder passes information to Decoder. We hope that backpropagation through time algorithm will tune the model to pass enough information in thought vector for correct output decoding.

In [13]:
encoder_final_state

LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_2:0' shape=(?, 20) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 20) dtype=float32>)

TensorFlow LSTM implementation stores state as a tuple of tensors. 
- `encoder_final_state.h` is activations of hidden layer of LSTM cell
- `encoder_final_state.c` is final output, which can potentially be transfromed with some wrapper `@TODO: check correctness`

### Decoder

In [14]:
decoder_cell = LSTMCell(hidden_units)

decoder_cell = InputProjectionWrapper(cell=decoder_cell, num_proj=hidden_units)

decoder_cell = OutputProjectionWrapper(cell=decoder_cell, output_size=vocab_size)

decoder_inputs_onehot = tf.one_hot(decoder_inputs, depth=vocab_size, dtype=tf.float32)

decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
    decoder_cell, decoder_inputs_onehot,
    
    initial_state=encoder_final_state,

    dtype=tf.float32, time_major=True, scope="plain_decoder",
)

Since we pass `encoder_final_state` as `initial_state` to the decoder, they should be compatible. This means the same cell type (`LSTMCell` in our case), the same amount of `hidden_units` and the same amount of layers (single layer). I suppose this can be relaxed if we additonally pass `encoder_final_state` through a one-layer MLP.

Decoder inputs get the same one-hot -> projection treatment as encoder inputs. But in case of the decoder, there is one more thing to do.

With encoder, we were not interested in cells output. But decoder's outputs are what we actually after: we use them to get distribution over words of output sequence.

At this point `decoder_cell` output is a `hidden_units` sized vector at every timestep. However, for training and prediction we need logits of size `vocab_size`. Reasonable thing would be to put linear layer on top of LSTM output to get non-normalized logits. This layer is called projection layer by convention.

In [15]:
prediction_time_major = tf.argmax(decoder_outputs, 2)
prediction = tf.transpose(prediction_time_major, perm=[1, 0])

### Optimizer

In [16]:
decoder_outputs

<tf.Tensor 'plain_decoder/TensorArrayStack/TensorArrayGatherV3:0' shape=(?, ?, 10) dtype=float32>

RNN outputs tensor of shape `[max_time, batch_size, hidden_units]` which projection layer maps onto `[max_time, batch_size, vocab_size]`. `vocab_size` part of the shape is static, while `max_time` and `batch_size` is dynamic.

In [17]:
decoder_targets_onehot = tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32)

stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=decoder_targets_onehot,
    logits=decoder_outputs,
)

loss = tf.reduce_mean(stepwise_cross_entropy)
train_op = tf.train.AdamOptimizer().minimize(loss)

In [18]:
sess.run(tf.global_variables_initializer())

### Test forward pass

Did I say that deep learning is a game of shapes? When building a Graph, TF will throw errors when static shapes are not matching. However, mismatches between dynamic shapes are often only discovered when we try to run something through the graph.


So let's try running something. For that we need to prepare values we will feed into placeholders.

```
this is key part where everything comes together

@TODO: describe
- how encoder shape is fixed to max
- how decoder shape is arbitraty and determined by inputs, but should probably be longer then encoder's
- how decoder input values are also arbitraty, and how we use GO token, and what are those 0s, and what can be used instead (shifted gold sequence, beam search)
@TODO: add references
```

In [19]:
batch_size = 5
min_sequence_length = 1
max_sequence_length = 4
batch = [
    np.random.randint(low=2, # reserved tokens: 0 - <PAD>, 1 - <EOS>
                      high=vocab_size,
                      size=np.random.randint(low=min_sequence_length, high=max_sequence_length)).tolist()
    for _ in range(batch_size)
]
print('batch:')
print(repr(batch))
print()

batch_encoded, _ = preprocess_batch(batch)
print('batch_encoded:')
print(batch_encoded)

batch:
[[5], [4, 5, 8], [9, 6, 3], [4], [4]]

batch_encoded:
[[5 4 9 4 4]
 [0 5 6 0 0]
 [0 8 3 0 0]]


In [20]:
decoder_inputs_, _ = preprocess_batch(np.ones(shape=(batch_size, 1), dtype=np.int32),
                                       max_sequence_length=max_sequence_length)
print('decoder_inputs_:')
print(decoder_inputs_)
print()

decoder_inputs_:
[[1 1 1 1 1]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]



In [21]:
sess.run(prediction_time_major,
         feed_dict={
            encoder_inputs: batch_encoded,
            decoder_inputs: decoder_inputs_
         })

array([[4, 4, 5, 4, 4],
       [6, 9, 9, 6, 6],
       [6, 9, 9, 6, 6],
       [6, 9, 9, 6, 6]])

Successful forward computation, everything is wired correctly.

## Training on the toy task

Consider the copy task — given a random sequence of integers from a `vocabulary`, learn to memorize and reproduce input sequence. Because sequences are random, they do not contain any structure, unlike natural language.

In [22]:
batch_size = 100

def copy_task(min_sequence_length = 2,
              max_sequence_length = 7,
              vocab_lower = 2,
              vocab_upper = vocab_size,
              batch_size = batch_size):
    """ Generates batches of random integer sequences,
        sequence length in [min_sequence_length, max_sequence_length],
        vocabulary in [vocab_lower, vocab_upper]
    """

    def random_length():
        if min_sequence_length > max_sequence_length:
            raise ValueError('min_sequence_length > max_sequence_length')
        elif min_sequence_length == max_sequence_length:
            return min_sequence_length
        return np.random.randint(min_sequence_length, max_sequence_length+1)

    def random_sequence():
        return np.random.randint(low=vocab_lower,
                                 high=vocab_upper,
                                 size=random_length()).tolist()
    
    while True:
        yield [random_sequence() for _ in range(batch_size)]


task = copy_task()
batch = next(task)
print('head of the batch:')
for seq in batch[:10]:
    print(seq)

head of the batch:
[6, 6, 5, 7]
[4, 5]
[9, 3, 5, 5]
[7, 9, 5, 2, 5, 5]
[8, 7, 3, 3, 3]
[6, 2, 3, 3, 7, 9, 7]
[5, 5, 9, 7, 7, 5, 9]
[5, 7, 6, 7, 4]
[9, 8, 8, 2, 7, 3]
[5, 9, 3, 8, 2]


## Training loop

In [23]:
batches_in_epoch = 1000
max_epochs = 5

for epoch in range(max_epochs):
    batches_in_this_epoch = batches_in_epoch
    if epoch == 0:
        batches_in_this_epoch = 1

    for _ in range(batches_in_this_epoch):

        batch = next(task)
        
        encoder_inputs_, _ = preprocess_batch(
            [sequence for sequence in batch]
        )
        
        # for decoder inputs we put <EOS> token in 1st time step, 
        # and then target sequence;
        # we add 2 additional paddings to allow decoder be a bit creative
        # and go beyond target sequence;
        decoder_inputs_, _ = preprocess_batch(
            [[EOS] + sequence + [PAD] * 2 for sequence in batch]
        )
        
        # fore decoder targets we put <EOS> after the sequence
        # sequence lengths are same for decoder inputs and decoder targets
        decoder_targets_, _ = preprocess_batch(
            [sequence + [EOS] + [PAD] * 2 for sequence in batch]
        )

        feed_dict = {
            encoder_inputs: encoder_inputs_,
            decoder_inputs: decoder_inputs_,
            decoder_targets: decoder_targets_,
        }

        sess.run(train_op, feed_dict)
    
    print('epoch {} done after {} batches'.format(epoch, batches_in_this_epoch))
    
    # reporting minibatch loss is a common lazy way to estimate 
    # how good our learner is doing here we fetch loss from the last batch in epoch
    # better way for real-life learning would be to use `tf.summary` facilities 
    # with Tensorboard
    print('  minibatch loss: {}'.format(sess.run(loss, feed_dict)))
    predict_ = sess.run(prediction, feed_dict)
    
    # after each epoch, print a few autoencoding examples
    for i, (truth, pred) in enumerate(zip(decoder_targets_.T, predict_)):
        print('  sample {}:'.format(i + 1))
        print('    target    > {}'.format(truth))
        print('    predicted > {}'.format(pred))
        if i >= 2:
            break
    print()

epoch 0 done after 1 batches
  minibatch loss: 2.2727274894714355
  sample 1:
    target    > [4 5 5 9 4 9 1 0 0 0]
    predicted > [4 9 9 9 9 9 9 9 9 9]
  sample 2:
    target    > [3 3 4 1 0 0 0 0 0 0]
    predicted > [4 6 6 6 6 6 6 6 6 6]
  sample 3:
    target    > [9 5 3 2 9 6 1 0 0 0]
    predicted > [4 9 9 9 4 9 9 9 9 6]

epoch 1 done after 1000 batches
  minibatch loss: 0.2912651300430298
  sample 1:
    target    > [2 3 6 8 7 3 5 1 0 0]
    predicted > [2 3 6 5 7 3 5 1 0 0]
  sample 2:
    target    > [8 7 9 7 8 5 7 1 0 0]
    predicted > [8 7 9 7 8 5 7 1 0 0]
  sample 3:
    target    > [4 3 2 2 9 1 0 0 0 0]
    predicted > [4 2 2 2 9 1 0 0 0 0]

epoch 2 done after 1000 batches
  minibatch loss: 0.12642544507980347
  sample 1:
    target    > [4 5 3 7 2 1 0 0 0 0]
    predicted > [4 5 3 7 2 1 0 0 0 0]
  sample 2:
    target    > [9 5 6 4 5 2 7 1 0 0]
    predicted > [9 5 6 4 5 5 7 1 0 0]
  sample 3:
    target    > [3 8 8 1 0 0 0 0 0 0]
    predicted > [3 8 8 1 0 0 0 0 0 0]



Something is definitely getting learned.

# Limitations of the model

We have no control over transitions of `tf.nn.dynamic_rnn`, it is unrolled in a single sweep

- can't use beam search decoder optimization
- can't feed previously generated tokens without falling back to python loops
- can't use attention, because attention conditions decoder inputs on its previous state

Solution would be to use `tf.nn.raw_rnn` to reimplement relevant parts of `tf.nn.dynamic_rnn`, add attention and beam search loops.
**Will be done in following tutorials.**

# Fun things to try (aka Exercises)

- In `copy_task` increasing `max_sequence_size` and `vocab_upper`. Observe slower learning and general performance degradation.

- For `decoder_inputs`, instead of shifted target sequence `[<EOS> W X Y Z]`, try feeding `[<EOS> <PAD> <PAD> <PAD>]`, like we've done when we tested forward pass. Does it break things? Or slows learning?