# Training a Subsampling Mechanism in Expectation

Consider a mechanism which, given a sequence of vectors $\mathbf{s} = \{s_0, s_1, ... , s_{T - 1}\}$, produces a sequence of "sampling probabilities" $\mathbf{e} = \{e_0, e_1, ..., e_{T - 1}\}, e_t \in [0, 1]$ which denote the probability of including $s_t$ in the output sequence $\mathbf{y} = \{y_0, y_1, ..., y_{U - 1}\}$. Producing $\mathbf{y}$ from $\mathbf{s}$ and $\mathbf{e}$ is encapsulated by the following pseudo-code:
```
# Initialize y as an empty sequence
y = []
for t in {0, 1, ..., T - 1}:
    # Draw a random number in [0, 1] and compare to e[t]
    if rand() < e[t]:
        # Add s[t] to y with probability e[t]
        y.append(s[t])
```
We call this a "subsampling mechanism", because by construction, $U \le T$, and each element of $y$ is drawn directly from $s$.  When including this mechanism inside of a larger neural network model, it's possible to backpropagate through it by differentiating with respect to the expected output.  To compute the expected output, we first compute $p(y_m = s_n)$, i.e. the probability that the $m$th element of $\mathbf{y}$ is the $n$th element of $\mathbf{s}$, as follows:
$$
p(y_m = s_n) = \begin{cases}
0, n < m\\
e_n\prod_{i = 0}^{n - 1}(1 - e_i), m = 0\\
e_n\left(\sum_{j = 0}^{n - 1}p(y_{m - 1} = s_j)\prod_{i = j + 1}^{n - 1}(1 - e_i)\right), \mathrm{otherwise}
\end{cases}
$$
We can then compute the expected value of $y_m$ simply by computing $\sum_n s_n p(y_m = s_n)$.  Further details are available in [1].  This notebook contains an example implementation of this approach in TensorFlow and Numpy for illustration purposes.

[1] "_Training a Subsampling Mechanism in Expectation_", Colin Raffel & Dieterich Lawson.  Submitted to ICLR 2017, workshop track.

In [0]:
import numpy as np
import tensorflow as tf

### TensorFlow example

In [0]:
def safe_cumprod(x, **kwargs):
    """Computes cumprod in logspace using cumsum to avoid numerical issues."""
    return tf.exp(tf.cumsum(tf.log(tf.clip_by_value(x, 1e-10, 1)), **kwargs))

def output_probabilities(p_emit):
    """Given emission probabilities p_emit, compute p(y_m = s_n).

    Given the probability of emitting, compute the probability of the output
    being each state.  E.g., given p_emit whose entries correspond to the
    probability of emitting a token at time t when producing an output sequence y,
    compute the matrix p(y_m = s_n).

    Args:
        p_emit: The probability of emitting at each time, shape (n_batch, n_seq)

    Returns:
        A Tensor of shape (n_batch, n_seq, n_seq), where the entry (b, m, n)
        corresponds to p(y_m, s_n) for batch b.
    """
    # Retrieve the number of sequence steps
    n_batch = tf.shape(p_emit)[0]
    n_seq = tf.shape(p_emit)[1]

    def build_prod_1_m_h_jp1_nm1(_, j):
        r"""Computes cumprod(1 - p_emit[j + 1:]) for each batch."""
        # Compute the cumulative product \prod_{i = j + 1}^{n - 1}(1 - h_i) for this
        # value of j.
        # exclusive=True means include a [1] at the beginning
        prod_1_m_h = safe_cumprod(1. - p_emit[:, j + 1:], axis=1, exclusive=True)
        # Pad with zeros to make it length-N
        return tf.concat([tf.zeros((n_batch, j + 1)), prod_1_m_h], 1)
    # Build the prod_1_m_h_jp1_nm1 using scan
    prod_1_m_h_jp1_nm1 = tf.scan(
        build_prod_1_m_h_jp1_nm1, tf.range(n_seq), tf.ones((n_batch, n_seq,)))
    # Reshape to (n_batch, n_seq, n_seq)
    prod_1_m_h_jp1_nm1 = tf.transpose(prod_1_m_h_jp1_nm1, [1, 0, 2])

    def build_p_y_m_eq_s_n(previous_outputs, _):
        """Function to build rows of the matrix of values of p(y_m = s_n)."""
        # This replicates the following:
        # for n in range(p_emit.shape[0]):
        #   p_y_m_eq_s_n[m, n] = p_emit[n]*np.sum(
        #       p_y_m_eq_s_n[m - 1]*prod_1_m_h_jp1_nm1[:, n])
        p_y_mm1_eq_s_n = previous_outputs[1]
        p_y_m_eq_s_n = (p_emit * tf.matmul(
            tf.reshape(p_y_mm1_eq_s_n, (n_batch, 1, n_seq)),
            prod_1_m_h_jp1_nm1)[:, 0])
        # We need to return the new row of the matrix (p_y_m_eq_s_n) as well as the
        # old row (p_y_mm1_eq_s_n) because we want the output of scan to include
        # p_y_0_eq_s_n
        return p_y_mm1_eq_s_n, p_y_m_eq_s_n

    # Compute first row of matrix, to pass as initial value to scan
    p_y_0_eq_s_n = p_emit*safe_cumprod(1 - p_emit, axis=1, exclusive=True)
    # Use scan to construct the p_y_m_eq_s_n matrix
    p_y_m_eq_s_n = tf.scan(build_p_y_m_eq_s_n, tf.range(n_seq),
                           (p_y_0_eq_s_n, p_y_0_eq_s_n))[0]
    # Reshape to (n_batch, n_seq, n_seq)
    p_y_m_eq_s_n = tf.transpose(p_y_m_eq_s_n, [1, 0, 2])
    return p_y_m_eq_s_n

### Numpy Example

In [0]:
def output_probabilities_explicit(p_emit):
    """Explicitly compute p(y_m = s_n) using nested for loops."""
    p_y_m_eq_s_n = np.zeros((p_emit.shape[0], p_emit.shape[0]))
    for m in range(p_emit.shape[0]):
        for n in range(p_emit.shape[0]):
            # p(y_m = s_n) = 0 when n < m because in order for the output sequence
            # to be of length m, at least m - 1 symbols must already have been
            # emitted.
            if n < m:
                p_y_m_eq_s_n[m, n] = 0.
            # The probability that the first output element y_0 is a given element
            # in the sequence s_n is the probability that none of s_0, ...,
            # s_{n-1} were emitted multiplied by the probability of emitting s_n.
            elif m == 0:
                p_y_m_eq_s_n[m, n] = p_emit[n]*np.prod(1 - p_emit[:n])
            # In order for y_m = s_n in general, we must have:
            # y_{m - 1} = s_j must be one of the states before s_n
            # None of s_j, ..., s_{n - 1} may be emitted at time m
            # s_n is emitted at time n.
            else:
                for j in range(0, n):
                    p_y_m_eq_s_n[m, n] += (p_y_m_eq_s_n[m - 1, j] *
                                           np.prod(1 - p_emit[j + 1:n]))
                p_y_m_eq_s_n[m, n] *= p_emit[n]
    return p_y_m_eq_s_n

### Test that they're the same

In [0]:
# Create some test input emission probabilitiy sequences
test_input = np.random.uniform(size=(10, 20))

# Build tensorflow graph to compute tensorflow version
with tf.Session() as sess:
    p_emit = tf.placeholder(tf.float32, [None, None])
    # Compute p(y_m = s_n) using tensorflow utility function
    p_y_m_eq_s_n = output_probabilities(p_emit)
    tensorflow_output = sess.run(p_y_m_eq_s_n, {p_emit: test_input})

assert np.allclose(
    tensorflow_output,
    [output_probabilities_explicit(batch) for batch in test_input])