## Example Implementation of Monotonic Chunkwise Attention (MoChA)

This notebook shows how to compute the probability distribution for Monotonic Chunkwise Attention (MoChA) [1] efficiently using TensorFlow.  We give a brief overview of MoChA below; a more thorough treatment is in [1].

We consider attention mechanisms, which given a length-$T$ memory $h_1, h_2, \ldots, h_T$ produce an attention distribution $\alpha_{i, j}$ and a "context vector" $c_i$ as
\begin{equation}
c_i = \sum_{j = 1}^T \alpha_{i, j} h_j
\end{equation}
In sequence-to-sequence settings, a context vector is computed for each output timestep $i$.
In hard monotonic attention [2], for each $i$ there is a single index $t_i$ for which $\alpha_{i, t_i} = 1$, and $\alpha_{i, j} = 0$ otherwise.
Further, the attention is constrained so that if $\alpha_{i, t_i} = 1$ then $\alpha_{i + 1, j} = 0$ for $j < t_i$.
The result is that $c_i$ is effectively chosen to be $h_{t_i}$, and that once $c_i = h_{t_i}$, then none of $h_1, \ldots, h_{k - 1}$ are chosen at subsequent output timesteps.
Since this hard-assignment has zero derivative everywhere, a "soft" version is used during training, so that $\alpha_{i, j}$ forms a probability distribution which obeys the above constraints in the limit of $\alpha_{i, j}$ being all 0 or 1.
This distribution can be computed efficiently in Tensorflow using the `tf.contrib.seq2seq.monotonic_attention` function.

Monotonic Chunkwise Attention (MoChA) extends this so that $c_i$ is set to a weighted average of the $w$ memory entries before $t_i$, as chosen by a separate hard monotonic attention mechanism.
Specifically, MoChA computes
\begin{align}
v &= t_i - w + 1\\
c_i &= \sum_{k = v}^{t_i} \frac{\exp(u_{i, k})}{\sum_{l = v}^{t_i} \exp(u_{i, l})} h_k
\end{align}
Note that we are effectively computing a softmax over the length-$w$ chunk, with logits $u_{i, j}$.
During training, we use the induced probability distribution
\begin{align}
\beta_{i, j} &= \sum_{k = j}^{j + w - 1} \left( \alpha_{i, k}\exp(u_{i, j}) \Bigg/\sum_{l = k - w + 1}^k \exp(u_{i, l}) \right)\\
c_i &= \sum_{j = 1}^T \beta_{i, j} h_j
\end{align}
where $\alpha_{i, j}$ is the soft probability distribution induced by hard monotonic attention.
$\beta_{i, :}$ can be computed efficiently in parallel by defining 
\begin{equation}
\mathrm{MovingSum}(\textbf{x}, b, f)_n := \sum_{m = n - (b - 1)}^{n + f - 1} x_m
\end{equation}
so that 
\begin{equation}
\beta_{i, :} = \exp(u_{i, :})\,\mathrm{MovingSum}\left(\frac{\alpha_{i, :}}{\mathrm{MovingSum}(\exp(u_{i, :}), w, 1)}, 1, w \right)
\end{equation}
Note that in order to compute the softmax over the chunk in a numerically stable way, we need to ensure that the range of the logits $u_{i, j}$ is not large.
One simple way to do this is to clip their range, which we demonstrate below.
After that, we'll also demonstrate a way to do this exactly and stably using $Tw$ memory.

[1] Chung-Cheng Chiu\* and Colin Raffel\*. "*Monotonic Chunkwise Attention*", in ICLR 2018.  
[2] Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck. "*Online and Linear-Time Attention by Enforcing Monotonic Alignments*", in ICML 2017.


In [1]:
import tensorflow as tf
import numpy as np

In [2]:
def moving_sum(x, back, forward):
    """Compute the moving sum of x over a window with the provided bounds.

    x is expected to be of shape (batch_size, sequence_length).
    The returned tensor x_sum is computed as
    x_sum[i, j] = x[i, j - back] + ... + x[i, j + forward]
    """
    # Moving sum is computed as a carefully-padded 1D convolution with ones
    x_padded = tf.pad(x, [[0, 0], [back, forward]])
    # Add a "channel" dimension
    x_padded = tf.expand_dims(x_padded, -1)
    # Construct filters
    filters = tf.ones((back + forward + 1, 1, 1))
    x_sum = tf.nn.conv1d(x_padded, filters, 1, padding='VALID')
    # Remove channel dimension
    return x_sum[..., 0]

def efficient_chunkwise_attention(chunk_size, emit_probs, softmax_logits):
    """Compute chunkwise attention distribution efficiently by clipping logits."""
    # Shift logits to avoid overflow
    softmax_logits -= tf.reduce_max(softmax_logits, 1, keepdims=True)
    # Limit the range for numerical stability
    softmax_exp = tf.exp(softmax_logits)
    softmax_exp = tf.maximum(softmax_exp, 1e-5)
    # Compute chunkwise softmax denominators
    softmax_denominators = moving_sum(softmax_exp, chunk_size - 1, 0)
    # Compute \beta_{i, :}. emit_probs are \alpha_{i, :}.
    probs = softmax_exp * moving_sum(emit_probs / softmax_denominators, 0, chunk_size - 1)
    return probs

### Stable version

Ideally, we'd like to compute this distribution stably and exactly without clipping the logits.  In order to do so, the softmax in each of the summation terms needs to be normalized, or in other words, the maximum of the logits within each chunk should be subtracted.  So, what we actually want to compute is
$$
\beta_{i, j} = \sum_{k = j}^{j + w - 1} \left(\alpha_{i, k} \exp(u_{i, j} - m_{i, k}) \Bigg / \sum_{l = k - w + 1}^k \exp(u_{i, l} - m_{i, k}) \right)
$$
where $m_{i, k} = \max(u_{i, k - w + 1}, \ldots, u_{i, k})$.

We can achieve this efficiently (i.e. completely in parallel), but we must use $Tw$ memory, as follows:
1. Compute $m_{i, k}$ via max-pooling.
1. Construct $T \times w$ matrix $D$ where column $k$ is $[u_{i, k - w + 1}, \ldots, u_{i, k}]$
1. Subtract $m_{i, k}$ from $D$ (using broadcasting)
1. Sum $\exp(D)$ across columns to get $d$
1. Construct $T \times w$ matrix $E$ where column $k$ is $[d_{i, k}, \ldots, d_{i, k + w - 1}]$
1. Construct $T \times w$ matrix $N$ where column $j$ is $[u_{i, j}, u_{i, j}, \ldots, u_{i, j}]$
1. Construct $T \times w$ matrix $M$ where column $k$ is $[m_{i, k}, \ldots, m_{i, k + w - 1}]$
1. Subtract $M$ from $N$
1. Compute $N = \exp(N)$
1. Construct $T \times w$ matrix $A$ where column $j$ is $[\alpha_{i, j}, \ldots, \alpha_{i, j + w - 1}]$
1. Compute $AN/E$ (using broadcasting) and sum across columns to get the MoChA probability distribution.

In [3]:
def moving_max(x, w):
    """Compute the moving sum of x over a window with the provided bounds.

    x is expected to be of shape (batch_size, sequence_length).
    The returned tensor x_max is computed as
    x_max[i, j] = max(x[i, j - window + 1], ..., x[i, j])
    """
    # Pad x with -inf at the start
    x = tf.pad(x, [[0, 0], [w - 1, 0]], mode='CONSTANT', constant_values=-np.inf)
    # Add "height" and "channel" dimensions (max_pool operates on 2D)
    x = tf.reshape(x, [tf.shape(x)[0], 1, tf.shape(x)[1], 1])
    x = tf.nn.max_pool(x, [1, 1, w, 1], [1, 1, 1, 1], 'VALID')
    # Remove "height" and "channel" dimensions
    return x[:, 0, :, 0]

def stable_chunkwise_attention(chunk_size, emit_probs, softmax_logits):
    """Compute chunkwise attention distriobution stably by subtracting logit max."""
    # Compute length-chunk_size sliding max of sequences in softmax_logits (m)
    logits_max = moving_max(softmax_logits, chunk_size)

    # Produce matrix with length-chunk_size frames of softmax_logits (D)
    # Padding makes it so that the first frame is [-inf, -inf, ..., logits[0]]
    padded_logits = tf.pad(softmax_logits, [[0, 0], [chunk_size - 1, 0]],
                           constant_values=-np.inf)
    framed_logits = tf.contrib.signal.frame(padded_logits, chunk_size, 1)
    # Normalize each logit subsequence by the max in that subsequence
    framed_logits = framed_logits - tf.expand_dims(logits_max, -1)
    # Compute softmax denominators (d)
    softmax_denominators = tf.reduce_sum(tf.exp(framed_logits), 2)
    # Construct matrix of framed denominators, padding at the end so the final
    # frame is [softmax_denominators[-1], inf, inf, ..., inf] (E)
    framed_denominators = tf.contrib.signal.frame(
        softmax_denominators, chunk_size, 1, pad_end=True, pad_value=np.inf)

    # Create matrix of copied logits so that column j is softmax_logits[j] copied
    # chunk_size times (N)
    batch_size, seq_len = tf.unstack(tf.shape(softmax_logits))
    copied_shape = (batch_size, seq_len, chunk_size)
    copied_logits = (tf.expand_dims(softmax_logits, -1) *
                     tf.ones(copied_shape, softmax_logits.dtype))
    # Subtract the max over subsequences(M) from each logit
    framed_max = tf.contrib.signal.frame(logits_max, chunk_size, 1,
                                         pad_end=True, pad_value=np.inf)
    copied_logits = copied_logits - framed_max
    # Take exp() to get softmax numerators
    softmax_numerators = tf.exp(copied_logits)

    # Create matrix with length-chunk_size frames of emit_probs, padded so that
    # the last frame is [emit_probs[-1], 0, 0, ..., 0] (A)
    framed_probs = tf.contrib.signal.frame(emit_probs, chunk_size, 1, pad_end=True)
  
    # Compute chunkwise probability distributions
    return tf.reduce_sum(framed_probs*softmax_numerators/framed_denominators, 2)

Now, we expect `efficient_chunkwise_attention` and `stable_chunkwise_attention` to be equivalent when the range of $u_{i, :}$ is relatively small.  When the difference between the smallest and largest $u_{i, :}$ is large, however, the "efficient" version will clip the logits and the "stable" version will produce the correct distribution.  We expect them to be about equally efficient, since they both are fully parallelizable, though the "stable" version takes about $w$ times more memory to compute.

In [4]:
BATCH_SIZE = 50
SEQUENCE_LENGTH = 100
CHUNK_SIZE = 8

g = tf.Graph()
sess = tf.Session(graph=g)

with g.as_default():
    # Synthetic monotonic attention probabilities alpha_{i, j}
    emit_probs_data = np.random.uniform(size=(BATCH_SIZE, SEQUENCE_LENGTH))
    emit_probs_data /= np.sum(emit_probs_data, axis=1, keepdims=True)
    # We'll use tf.Variables throughout for benchmarking reasons
    emit_probs = tf.Variable(emit_probs_data.astype(np.float32))
    # Synthetic softmax logits
    softmax_logits_data = np.random.normal(size=(BATCH_SIZE, SEQUENCE_LENGTH))
    softmax_logits = tf.Variable(softmax_logits_data.astype(np.float32))

    # Test whether the efficient and stable versions compute the same thing
    option_1 = efficient_chunkwise_attention(CHUNK_SIZE, emit_probs, softmax_logits)
    option_2 = stable_chunkwise_attention(CHUNK_SIZE, emit_probs, softmax_logits)
    sess.run(tf.global_variables_initializer())
    out_1, out_2 = sess.run([option_1, option_2])
    print 'Are efficient_chunkwise_attention and stable_chunkwise_attention the same with'
    print '  small-magnitude softmax_logits?', np.allclose(out_1, out_2)

    # Test they no longer compute the same thing when the range of logits is large
    softmax_logits_data[0, 5:7] -= 1e10
    sess.run(softmax_logits.assign(softmax_logits_data))
    out_1, out_2 = sess.run([option_1, option_2])
    print '  large-magnitude softmax_logits?', np.allclose(out_1, out_2)

    # Time them.  Use tf.group and pre-run to test graph execution time only.
    option_1 = tf.group(option_1)
    option_2 = tf.group(option_2)
    sess.run(option_1)
    sess.run(option_2)
    %timeit -r10 sess.run(option_1)
    %timeit -r10 sess.run(option_2)

Are efficient_chunkwise_attention and stable_chunkwise_attention the same with
  small-magnitude softmax_logits? True
  large-magnitude softmax_logits? False
10000 loops, best of 10: 167 µs per loop
10000 loops, best of 10: 164 µs per loop
