# Scaling Transformers - Sparse Is Enough

Licensed under the Apache License, Version 2.0
This colab contains all relevant code for the paper "Sparse is Enough in Scaling Transformers". We depend on the Trax library and the experiments in the paper were not run with the colab but in a distributed setup with the attached config files -- but with the code below.

In [None]:
# Imports.
!pip install --upgrade -q trax==1.3.9

import functools
import os
import random
import time
import numpy as np

import jax
import trax
from trax import layers as tl
from trax import fastmath
from trax.fastmath import numpy as jnp
from trax.supervised import training

## Main sparse layers

This cell contains the implementation of our main sparse layers:
* sparse QKV layers
* sparse feed-forward blocks

In [None]:
def SplitLastAxis(num_splits):
  return tl.Fn(f'SplitLastAxis_{num_splits}',
               lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1)))


def MergeLastTwoAxes():
  return tl.Fn('MergeLastTwoAxes',
               lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,)))


def LocallyConnectedDense(n_modules, n_units, kernel_size=1,
                          kernel_initializer=tl.GlorotUniformInitializer(),
                          bias_initializer=tl.RandomNormalInitializer(1e-6),
                          use_bias=True):
  """Layer using LocallyConnected1d for approximation of Dense layer.

  The layer splits the last axis of a tensor into `n_modules`, then runs
  LocallyConnected1d (grouped convolution) on all those modules, and
  concatenates their results. It is essentially a locally-sensitive
  approximation of Dense layer, with number of parameters smaller by the factor
  of `n_modules / kernel_size`.

  Args:
    n_modules: Indicates how many modules (pixels) should be input and output
        split into for processing.
    n_units: how many outputs (filters) should each module generate.
    kernel_size: The size of the kernel to be used.
    kernel_initializer: Function that creates a matrix of (random) initial
        connection weights `W` for the layer.
    bias_initializer: Function that creates a vector of (random) initial
        bias weights `b` for the layer.
    use_bias: If `True`, compute an affine map `y = Wx + b`; else compute
        a linear map `y = Wx`.

  Returns:
      LocallyConnectedDense tl.Layer.
  """
  if n_modules == 1:
    return tl.Dense(n_units, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, use_bias=use_bias)
  return tl.Serial(
      SplitLastAxis(n_modules),
      tl.LocallyConnected1d(
          n_units, kernel_size, kernel_initializer=kernel_initializer,
          bias_initializer=bias_initializer, use_bias=use_bias, padding='WRAP'),
      MergeLastTwoAxes())


class _RememberPad(tl.Layer):
  """Layer which remembers last N elements in predict mode."""

  def __init__(self, n_items_to_remember, mode):
    """Returns a layer which remembers last N elements in predict mode.

    For predict mode, the layer remembers last N elements and pads with them.
    For other modes, it pads with zeros. The layer pads/remembers elements from
    the second axis.

    Args:
      n_items_to_remember: Number of items to remember/pad with.
      mode: One of `'train'`, `'eval'`, or `'predict'`.
    """
    super().__init__(name='_RememberPad')
    self._n_items_to_remember = n_items_to_remember
    self._mode = mode
    self._portal_mask = self.monkey_patched_mask()  # pylint: disable=assignment-from-none

  def monkey_patched_mask(self):
    # This is necessary for Terraformer model. See comments there.
    # The mask will only be used in Terraformer in predict mode.
    return None

  def forward(self, x):
    if self._n_items_to_remember == 0:
      return x
    if self._mode == 'predict':
      x = jnp.concatenate([self.state[0], x], axis=1)
      if self._portal_mask is not None and 'init' in self.state[1]:
        assert x.shape[0] == 1
        mask = self._portal_mask.get_value()
        count_padding = jnp.sum(mask == 0, dtype=jnp.int32)
        self.state = (fastmath.dynamic_slice_in_dim(
            x, x.shape[1] - (self._n_items_to_remember + count_padding),
            self._n_items_to_remember, axis=1), {'forward': ()})
      else:
        self.state = (x[:, -self._n_items_to_remember:, ...], {'forward': ()})
    else:
      pad_widths = [[0, 0] for _ in range(len(x.shape))]
      pad_widths[1][0] = self._n_items_to_remember
      x = jnp.pad(x, pad_width=pad_widths, mode='constant')
    return x

  def init_weights_and_state(self, input_signature):
    """Initializes this layer's weights."""
    if isinstance(input_signature, (list, tuple)):
      input_signature = input_signature[0]
    self.weights = ()
    if self._mode == 'predict':
      shape = list(input_signature.shape)
      shape[1] = self._n_items_to_remember
      self.state = (jnp.zeros(shape, dtype=jnp.float32), {'init': ()})
    else:
      self.state = ()


def LocallyConvDense(n_modules, n_units, mode, kernel_size=1,
                     length_kernel_size=1):
  """Layer using local convolutions for approximation of Dense layer.

  The layer splits the last axis of a tensor into `n_modules`, then runs
  a convolution on all those modules, and concatenates their results.
  It is similar to LocallyConnectedDense above, but shares weights.

  Args:
    n_modules: Indicates how many modules (pixels) should be input and output
        split into for processing.
    n_units: how many outputs (filters) should each module generate.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    kernel_size: The size of the kernel to be used.
    length_kernel_size: If > 1, also do causal convolution on the previous axis,
      which is often the sentence length in sequence models.

  Returns:
      LocallyConvDense tl.Layer.
  """
  if n_modules == 1:
    return tl.Dense(n_units)
  if kernel_size % 2 != 1:
    raise ValueError('Currently we only handle odd kernel sizes.')
  half = (kernel_size - 1) // 2
  pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]]
  return tl.Serial(
      SplitLastAxis(n_modules),
      tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths, mode='constant')),
      _RememberPad(length_kernel_size-1, mode=mode),
      tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)),
      MergeLastTwoAxes()
  )


def RandomLayer(layer_a, layer_b, prob_a):
  """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`."""
  condition = tl.Serial(
      tl.RandomUniform(),
      tl.Fn('SmallerThan', lambda x: x < prob_a)
      )
  return tl.Cond(condition, layer_a, layer_b)


def SparseDenseWithOptions(n_units, d_input=None, sparsity_type=None,
                           sparsity=0, d_lowrank=None, prob_sparse=None,
                           mode=None, use_bias=True, use_bfloat16=False):
  """Configurable sparse version of Dense layer."""
  if prob_sparse is not None:
    if mode is not None and mode != 'train':
      # For non-training modes, we want to use a sparse variant.
      # This is different than simply prob_sparse being None, as the weights of
      # the model are different.
      prob_sparse = 1.0
    return RandomLayer(
        SparseDenseWithOptions(n_units, d_input, sparsity_type, sparsity,
                               d_lowrank, use_bias=use_bias,
                               use_bfloat16=use_bfloat16),
        tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16),
        prob_sparse)

  if sparsity_type is None or sparsity_type == 'None' or sparsity == 0:
    return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16)
  if sparsity_type == 'mult':
    return FactoredDense(sparsity, d_input, n_units, use_bias=use_bias,
                         use_bfloat16=use_bfloat16)

  assert not use_bfloat16  # use_bfloat16 is unsupported for other variants
  if sparsity_type == 'local':
    assert use_bias  # use_bias = False is unsupported
    assert n_units % sparsity == 0
    return LocallyConnectedDense(sparsity, n_units/sparsity)
  if sparsity_type == 'local3':
    assert use_bias  # use_bias = False is unsupported
    assert n_units % sparsity == 0
    return LocallyConnectedDense(sparsity, n_units/sparsity, kernel_size=3)

  raise ValueError('Unknown sparsity type: {}'.format(sparsity_type))


def FactoredDense(n_modules, d_in, d_out, use_bias=True, use_bfloat16=False):
  r"""Returns a Dense-like layer, internally factored to use fewer parameters.

  This layer treats an activation vector as if divided into :math:`M`
  subvectors (``n_modules`` 'modules'). It uses this factored view to compute
  a :py:class:`Dense`-like mapping with high mixing/connectivity, but using
  approximately :math:`1/M` the number of weights of a similarly dimensioned
  :py:class:`Dense` layer.

  More specifically, each activation vector of dimensionality ``n_in`` is
  multiplied element-wise (a generalized form of gating) with ``n_modules``
  vectors also of dimensionality ``n_in``. The resulting vectors are projected
  to the subvector/module dimensionality ``d_out / n_modules`` via a matrix
  multiply, and finally reshaped back to a single vector of dimensionality
  ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at
  the end. All the above-mentioned non-input objects -- gating vectors,
  projection matrix, and optional bias -- are trainable weights.

  Args:
    n_modules: Number by which an activation vector is divided into subvectors
        (modules) for the factored computation.
    d_in: Last/innermost dimension of input array.
    d_out: Last/innermost dimension of output array.
    use_bias: If True, add bias vectors at the end of the layer; else end the
        layer with the matrix multiply.
    use_bfloat16: If True, use bfloat16 weights; else use float32 weights.
  """
  if d_out % n_modules != 0:
    raise ValueError(f'Value d_out ({d_out}) must be a multiple of arg '
                     f'n_modules ({n_modules}).')
  d_module = d_out // n_modules

  def GatingVectors():
    return tl.Weights(tl.RandomNormalInitializer(stddev=0.5),
                      shape=[n_modules, d_in],
                      use_bfloat16=use_bfloat16)

  def ProjectionMatrix():
    return tl.Weights(tl.GlorotUniformInitializer(),
                      shape=[d_in, d_module],
                      use_bfloat16=use_bfloat16),

  def Bias():
    return tl.Weights(tl.RandomNormalInitializer(1e-6),
                      shape=[d_out],
                      use_bfloat16=use_bfloat16),

  layers = [
      GatingVectors(),
      ProjectionMatrix(),
      _GateAndProject(),
      MergeLastTwoAxes(),
  ]
  if use_bias:
    layers += [Bias(), tl.Add()]

  return tl.Serial(layers)


def _GateAndProject():
  """Returns a combined gating+projection layer that saves on memory."""

  def f(projection, gating, x):
    # Args arrive in reverse order because of how they were put on the stack.
    # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules)
    return jnp.einsum('...d,nd,dm->...nm', x, gating, projection)

  return tl.Fn('_GateAndProject', f)


def MultiplicativeConvCausalAttention(
    d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0,
    force_no_dropout=False, max_inference_length=2048, share_qk=False,
    output_layer_type='none', v_concat_type='none', mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  FactoredDense layer with LocallyConvLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    length_kernel_size: Size of convolution kernel on the length dimension.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    force_no_dropout: If True, force dropout to be 0.0 independent of the above
        value; used to override some configurations.
    max_inference_length: maximum length for inference.
    share_qk: if True, average Q and K embeddings and share for both Q and K.
    output_layer_type: Which sparse layers to use for processing output from the
        attention mechanism. One of `'none'`, `'mult'`, `'conv'`,
        or `'multconv'`.
    v_concat_type: What kind of concatenation to use when computing V tensor.
        One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just
        output from mutliplicative layer shared by Q, K, V. `'fixed'` means
        using output from multiplicative layer concatenated, for each module,
        with the layer input. `'original'` means using concatenation without
        properly taking modules into account; this method was used in
        experiments previously, so it is included for backwards-compatibility.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  assert output_layer_type in ['none', 'mult', 'conv', 'multconv']
  assert v_concat_type in ['original', 'fixed', 'none']

  dropout = 0.0 if force_no_dropout else dropout
  sparsity = n_heads if sparsity is None else sparsity
  d_module = d_feature // sparsity

  output_layers = []
  if 'mult' in output_layer_type:
    output_layers.append(FactoredDense(
        sparsity, d_feature, d_feature))
  if 'conv' in output_layer_type:
    output_layers.append(LocallyConvDense(
        sparsity, d_module, mode=mode, kernel_size=3,
        length_kernel_size=length_kernel_size))

  if v_concat_type == 'original':
    # 'original'` uses concatenation without properly taking modules into
    # account; this method was used in experiments previously, so it is included
    # for backwards-compatibility.
    concat_layers = [tl.Concatenate()]  # use permuted and original for v
  elif v_concat_type == 'fixed':
    # `'fixed'` uses the output from multiplicative layer concatenated, for each
    # module, with the layer input. This means that every module in Conv layer
    # has access both to parts of embeddings which were used to compute Q/K of
    # this particular module, and it ha access to parts of the embedding which
    # will be modified by this module.
    concat_layers = [
        tl.Parallel(
            tl.Fn('Reshape1', lambda x: jnp.reshape(  # pylint: disable=g-long-lambda
                x, (x.shape[0], x.shape[1], sparsity, d_module))),
            tl.Fn('Reshape2', lambda x: jnp.reshape(  # pylint: disable=g-long-lambda
                x, (x.shape[0], x.shape[1], sparsity, d_module)))),
        tl.Concatenate(),
        tl.Fn('Reshape3',
              lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2*d_feature))),
    ]
  elif v_concat_type == 'none':
    # `'none'` doesn't use concatenation: we throw away the original layer
    # input and pass to Conv only output of shared Multiplicative layer.
    concat_layers = [tl.Select([0], n_in=2)]

  if share_qk:
    return tl.Serial(
        tl.Select([0, 0]),  # pre-qkv, pre-v-for-concat
        FactoredDense(sparsity, d_feature, d_feature),  # shared q k
        tl.Select([0, 0]),  # pre-qk, pre-v, pre-v-for-concat
        LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,
                         length_kernel_size=length_kernel_size),
        tl.SplitIntoHeads(n_heads),
        tl.Select([0, 0]),  # use for q and k
        tl.Parallel(
            [],
            [],
            [concat_layers,
             LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,
                              length_kernel_size=length_kernel_size),
             tl.SplitIntoHeads(n_heads)],
        ),
        tl.DotProductCausalAttention(
            dropout=dropout, max_inference_length=max_inference_length,
            mode=mode),
        tl.MergeHeads(n_heads),
        output_layers,
    )
  return tl.Serial(
      tl.Select([0, 0]),  # duplicate activations
      FactoredDense(sparsity, d_feature, d_feature),  # shared q, k
      tl.Select([0, 0, 0]),  # use for q, k, v
      tl.Parallel(
          [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,
                            length_kernel_size=length_kernel_size),
           tl.SplitIntoHeads(n_heads)],
          [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,
                            length_kernel_size=length_kernel_size),
           tl.SplitIntoHeads(n_heads)],
          [concat_layers,
           LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,
                            length_kernel_size=length_kernel_size),
           tl.SplitIntoHeads(n_heads)],
      ),
      tl.DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode),
      tl.MergeHeads(n_heads),
      output_layers,
  )


class DotProductCausalAttention(tl.Layer):
  """Layer that computes attention strengths by masking out the "future".

  Causal attention uses masking to prevent a given sequence position from
  attending to positions greater than / following it. This is used, for
  example, when training autoregressive sequence models, or when decoding a
  sequence symbol by symbol.

  This layer performs the core per-head attention calculation. The layer
  assumes that any splitting into attention heads precedes it, and that any
  merging of attention heads will follow it.
  """

  def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'):
    """Creates a :py:class:`DotProductCausalAttention` instance.

    Args:
      dropout: Probababilistic rate for attention dropout, which overrides
          (sets to zero) some attention strengths derived from query-key
          matching. As a result, on a given forward pass, some value vectors
          don't contribute to the output, analogous to how regular dropout can
          cause some node activations to be ignored. Applies only if layer is
          created in ``'train'`` mode.
      max_inference_length: Maximum sequence length allowed in non-training
          modes.
      mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    """
    super().__init__(n_in=3, n_out=1)
    self._dropout = dropout
    self._mode = mode
    self._max_len = max_inference_length
    self._portal_mask = self.monkey_patched_mask()  # pylint: disable=assignment-from-none

  def monkey_patched_mask(self):
    # This is necessary for Terraformer model. See comments there.
    # The mask will only be used in Terraformer in predict mode.
    return None

  def forward(self, inputs):
    """Returns attention-computed activations.

    Args:
      inputs: A (queries, keys, values) tuple.
    """
    q, k, v = inputs

    if self._portal_mask is not None:
      mask_for_predict = self._portal_mask.get_value()
    else:
      mask_for_predict = None

    if self._mode == 'predict':
      self.state, mask = _fast_inference_update_state(
          inputs, self.state,
          mask_for_predict=mask_for_predict)
      if self._portal_mask is not None:
        (_, k, v, _) = self.state
      else:
        (k, v, _) = self.state
    else:
      sequence_length = q.shape[-2]
      mask = _causal_mask(sequence_length)

    activations, attn_strengths = _per_head_attention(
        q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)
    if self._mode == 'viz':
      self.state = attn_strengths
    return activations

  def init_weights_and_state(self, input_signature):
    """Initializes this layer for fast inference, if in ``'predict'`` mode."""
    if self._mode == 'predict':
      self.state = _fast_inference_init_state(
          input_signature, self._max_len,
          predict_mask=self._portal_mask)
      
def _fast_inference_init_state(input_signature, buffer_length,
                               predict_mask=None):
  """Returns an initial state for causal attention layer fast inference."""
  def zeros_for(batch_size, shape_dtype):
    shape, dtype = shape_dtype.as_tuple()
    d_feature = shape[-1]
    return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)

  batch_size = input_signature[0].shape[0]
  k = zeros_for(batch_size, input_signature[1])
  v = zeros_for(batch_size, input_signature[2])
  if predict_mask is not None:
    mask_for_predict = jnp.zeros((buffer_length,)) != 0
    return (mask_for_predict, k, v, jnp.array(0))
  else:
    return (k, v, jnp.array(0))


def _fast_inference_update_state(inputs, state, mask_for_predict=None):
  """Updates state of a causal attention layer for fast inference.

  The layer state stores arrays with cached values of keys and values,
  as well as an index. To make shapes static, keys and values in the state are
  long, and the index indicates where the new keys and values from inputs need
  to be appended.

  During update, we append new_keys and new_values to keys and values at
  position given by index. And we increment index by length of new keys.
  We also create a mask to be 1 at appropriate positions (causal mask).

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, index)
    mask_for_predict: mask used for predict mode. This is used only in
      Terraformer.

  Returns:
    Updated state and mask to be used.
  """
  # Fast inference: run step-by-step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  if mask_for_predict is not None:
    (state_mask_for_predict, ks, vs, idx) = state
  else:
    (ks, vs, idx) = state
  length = new_k.shape[1]
  ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
  vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
  k_length = ks.shape[1]

  # Mask is of shape [1, q_length, k_length].
  # Mask should be true for every pair of (query_token, key_token) such that
  # index of query_token is equal or larger to index of key_token.
  mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length))
          <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1)))
  if mask_for_predict is None:
    return (ks, vs, idx + length), mask
  else:
    state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
        state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0,
        axis=0)

    state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
        state_mask_for_predict != 0, jnp.ones((1,)) != 0,
        jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0)

    state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
        state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0)
    placeholder = jnp.reshape(state_mask_for_predict != 0,
                              (1, 1, mask.shape[2],))
    mask = mask * placeholder

    return (state_mask_for_predict, ks, vs, idx + length), mask


def _causal_mask(length):
  # Not all backends define jnp.tril. However, using np.tril is inefficient
  # in that it creates a large global constant.
  if fastmath.is_backend(fastmath.Backend.JAX):
    return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)
  else:
    return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)


def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new per-head activations via scaled dot-product attention.

  This function is the core of the attention mechanism. Given per-head
  ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:

    - computes the scaled dot product of each Q-K pair;
    - applies ``mask`` to screen out positions that come from padding tokens
      (indicated by 0 value);
    - [in ``'train'`` mode] applies dropout to Q-K dot products;
    - computes Q-K attention strengths using a per-query softmax of the Q-K dot
      products; and
    - for each query position, combines V vectors according to the Q-K
      attention strengths.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention strengths.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only in ``'train'``
        mode.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Tuple of (activations, attn_strengths), where activations are new per-head
    activation vectors and attn_strengths is a matrix of per-head attention
    strengths.
  """
  if dropout >= 1.0:
    raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')

  d_feature = queries.shape[-1]

  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    dots = jnp.where(mask,
                     dots,
                     jnp.full_like(dots, -1e9))
  attn_strengths = (
      jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)
    attn_strengths = jnp.where(keep,
                               attn_strengths / (1.0 - dropout),
                               jnp.zeros_like(attn_strengths))
  activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)
  attn_strengths = attn_strengths.astype(jnp.float32)
  return activations, attn_strengths


class _RememberInReverse(tl.Layer):
  """Layer remembering the input in forward pass. For reversible models."""

  def __init__(self, output=True):
    """Layer remembering the input in forward pass. For reversible models.

    During the first pass through the model this layer saves the input as
    state, and returns the input unmodified. During the second pass through the
    model the layer outputs the input from the first pass. This is used to
    combat numerical stability problems in Terraformer. It doesn't do anything
    in non-reversible models.

    Args:
      output: Whether to pass the input or not.
    """
    n_out = 1 if output else 0
    self._output = output
    super().__init__(name='_RememberInReverse', n_out=n_out)

  def forward(self, x):
    if 'running_second_time_yes' in self.state[1]:
      result = self.state[0]
    else:
      result = x
    self.state = (x, {'running_second_time': ()})

    if self._output:
      return result
    else:
      return tuple()

  def init_weights_and_state(self, input_signature):
    """Initializes this layer's weights."""
    if isinstance(input_signature, (list, tuple)):
      input_signature = input_signature[0]
    self.weights = ()
    self.state = (jnp.zeros(input_signature.shape, dtype=jnp.int32),
                  {'running_second_time': ()})


class _RecallQuantMaskInReverse(tl.Layer):
  """Layer recalling quant mask from specific _RememberInReverse.

  This layer is needed for memory-efficient training of reversible model with
  ff chunking. During forward pass it simply returns minus ones, which are
  ignored in the controller. During reverse_and_grad it returns a quant_mask
  which was memorized (saved to state) by a RememberInReverse layer.

  This enable us to save quant_mask right after chunking, and load it again
  (when reversing) right before chunking.
  """

  def __init__(self, remember_layer, elements):
    self._remember_layer = remember_layer
    self._elements = elements
    super().__init__(name='_RecallQuantMaskInReverse', n_in=1, n_out=2)

  def forward(self, x):
    if (self._remember_layer.state and
        'running_second_time_yes' in self._remember_layer.state[1]):
      # It's reverse_and_grad, so we pull the quant_mask from remembering layer.
      result = self._remember_layer.state[0]
    else:
      result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32)
    return (x, result)


class _SparseFFController(tl.Layer):
  """The controller part of Sparse Feed-Forward layer."""

  def __init__(self, d_ff, n_elements_in_block, d_lowrank, temperature,
               use_bfloat16, mode, kernel_initializer, bias_initializer,
               also_return_nondiscrete_output):
    """Returns a sparse feed-forward block."""
    n_out = 2 if also_return_nondiscrete_output else 1
    super().__init__(name=f'_SparseFFController_{d_ff}', n_in=2, n_out=n_out)
    self._use_bfloat16 = use_bfloat16
    self._d_ff = d_ff
    self._d_lowrank = d_lowrank
    # Q: what temperature is actually most useful in training?
    self._temperature = temperature if mode == 'train' else 0.0
    self._mode = mode
    self._n_elements_in_block = n_elements_in_block
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    # Helper numbers as d_ff will be divided by n_elements_in_block.
    assert self._d_ff % self._n_elements_in_block == 0
    self._d1 = self._d_ff // self._n_elements_in_block
    self._d2 = self._n_elements_in_block
    self._also_return_nondiscrete_output = also_return_nondiscrete_output

  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    x, recalled_quant_mask = x
    m1, m2, mb = self.weights

    x_shape = x.shape
    x = jnp.reshape(x, [-1, x_shape[-1]])  # Easier to operate on flattened x.

    # Q: should we add bias and/or put relu after the low-rank m1 dot?
    # Replacing multiplication and reshape by this einsum brings training speed
    # improvement (see also reshape in initialization).
    mask_logits = jnp.einsum('bd,dl,lxy->bxy', x, m1, m2) + mb

    if self._also_return_nondiscrete_output:
      # Softmax.
      mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)
      log_mask = mask_logits - mask_logsumexp
      mask = jnp.exp(log_mask)
      # Gumbel-softmax with straight-through discretization.
      if self._temperature == 0.0:
        quant_mask = jnp.argmax(log_mask, axis=-1)
      else:
        u = fastmath.random.uniform(self.rng, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
    else:
      quant_mask = jnp.argmax(mask_logits, axis=-1)

    if self._mode == 'train':
      # We use recalled_quant_mask if it's different than -1; otherwise
      # we use a quant_mask which we have just computed.
      quant_mask = jnp.where(recalled_quant_mask == -1,
                             quant_mask, recalled_quant_mask)

    if self._also_return_nondiscrete_output:
      return quant_mask, mask
    else:
      return quant_mask

  def init_weights_and_state(self, input_signature):
    """Randomly initializes this layer's weights."""
    x_input_signature = input_signature[0]
    d_model = x_input_signature.shape[-1]
    shape_m1 = (d_model, self._d_lowrank)
    shape_m2 = (self._d_lowrank, self._d_ff)
    shape_mb = (self._d_ff,)

    rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3)
    m1 = self._kernel_initializer(shape_m1, rng_m1)
    m2 = self._kernel_initializer(shape_m2, rng_m2)
    mb = self._bias_initializer(shape_mb, rng_mb)
    if self._use_bfloat16:
      m1 = m1.astype(jnp.bfloat16)
      m2 = m2.astype(jnp.bfloat16)
      mb = mb.astype(jnp.bfloat16)

    # Reshapes below, with einsum in feedforward, improve the training speed.
    m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2])
    mb = jnp.reshape(mb, [self._d1, self._d2])

    self.weights = (m1, m2, mb)


class _SparseFFMain(tl.Layer):
  """The main (non-controller) part of Sparse Feed-Forward layer."""

  def __init__(self, d_ff, n_elements_in_block, d_lowrank, quant_prob,
               use_bfloat16, big_weights_in_bfloat16, mode, kernel_initializer,
               bias_initializer, multiply_by_controller_output, kernel_scaling):
    """Returns a sparse feed-forward block."""
    n_in = 3 if mode == 'train' or multiply_by_controller_output else 2
    super().__init__(name=f'_SparseFFMain_{d_ff}', n_in=n_in, n_out=2)
    self._mode = mode
    self._use_bfloat16 = use_bfloat16
    self._big_weights_in_bfloat16 = big_weights_in_bfloat16
    self._d_ff = d_ff
    self._d_lowrank = d_lowrank
    self._quant_prob = quant_prob
    self._n_elements_in_block = n_elements_in_block
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    # Helper numbers as d_ff will be divided by n_elements_in_block.
    assert self._d_ff % self._n_elements_in_block == 0
    self._d1 = self._d_ff // self._n_elements_in_block
    self._d2 = self._n_elements_in_block
    self._multiply_by_controller_output = multiply_by_controller_output
    self._kernel_scaling = kernel_scaling

  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    if self._mode == 'train' or self._multiply_by_controller_output:
      quant_mask, mask, x = x
    else:
      quant_mask, x = x
    original_quant_mask = quant_mask

    w1, w2, b2 = self.weights

    if self._mode == 'predict':
      w1 = jnp.transpose(w1, (1, 2, 0))  # dm, d1, d2 -> d1, d2, dm
      w2 = jnp.transpose(w2, (1, 0, 2))  # d2, d1, dm -> d1, d2, dm
    x_shape = x.shape
    x = jnp.reshape(x, [-1, x_shape[-1]])  # Easier to operate on flattened x.

    if self._mode == 'train':
      # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
      quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
      quant_mask = fastmath.stop_gradient(quant_mask)
      quant_mask += mask - fastmath.stop_gradient(mask)  # straight-through
      # We will sometimes (quant_prob of the batches) use the soft-mask instead
      # of the quantized mask to improve training stability (see paper above).
      select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0)
      quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)

      # In training, run full matmul to get benefits from the above tricks.
      mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      if self._multiply_by_controller_output:
        # We multiply only for quantized decisions, since for non-quantized
        # decisions we've already multiplied the output.
        mask_mult = jnp.where(select < self._quant_prob,
                              mask, jnp.ones_like(mask))
        # Stop-gradient is here, because we already have a pass-through gradient
        # (for quantized decisions).
        mask_mult = fastmath.stop_gradient(mask_mult)
        relu = relu * mask_mult
      res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2
    elif self._mode == 'predict':
      # This implementation mimicks inference. It's not efficient for large
      # size of joint_batch, but at inference that will be 1 most of the time.
      # Shapes:
      # quant_mask is [joint_batch, self._d1]
      # w1 is [d_model, self._d1, self._d2]
      # we'll index w1 with advanced numpy indexing, first range over
      # self._d1 times the batch size, second range being quant_mask
      batch_size = quant_mask.shape[0]
      idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
      # flatten indices and select from w1
      idx1 = jnp.reshape(idx1, [-1])
      idx2 = jnp.reshape(quant_mask, [-1])
      w = w1[idx1, idx2, :]  # now we have per-element weights with batch dim
      w = jnp.reshape(w, [batch_size, self._d1, -1])
      mid = jnp.einsum('ai,aji->aj', x, w)
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      if self._multiply_by_controller_output:
        mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0]
        relu = relu * mask_mult
      # w2 is [self._d1, self._d2, d_model]
      v = w2[idx1, idx2, :]
      v = jnp.reshape(v, [batch_size, self._d1, -1])
      res = jnp.einsum('ai,aij->aj', relu, v) + b2
    else:
      quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
      mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      if self._multiply_by_controller_output:
        relu = relu * mask
      res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2

    return original_quant_mask, jnp.reshape(res, x_shape)

  def init_weights_and_state(self, input_signature):
    """Randomly initializes this layer's weights."""
    d_model = input_signature[-1].shape[-1]
    shape_w1 = (d_model, self._d_ff)
    shape_w2 = (self._d_ff, d_model)
    shape_b2 = (d_model,)

    rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3)
    if tl.N_WEIGHTS_SHARDS > 1:
      # In sharded-weights mode, put the weights on CPU on init
      # as they will be sharded later.
      w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1))
      w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2))
    else:
      w1 = self._kernel_initializer(shape_w1, rng_w1)
      w2 = self._kernel_initializer(shape_w2, rng_w2)

    b2 = self._bias_initializer(shape_b2, rng_b2)
    if self._use_bfloat16:
      b2 = b2.astype(jnp.bfloat16)
    if self._use_bfloat16 or self._big_weights_in_bfloat16:
      w1 = w1.astype(jnp.bfloat16)
      w2 = w2.astype(jnp.bfloat16)

    w1 = jnp.reshape(w1, (-1, self._d1, self._d2))
    w2 = jnp.reshape(w2, (self._d2, self._d1, -1))

    if self._kernel_scaling:
      # This keeps expected variance of the output regardless of N.
      w2 = w2 * (self._n_elements_in_block ** 0.5)

    self.weights = (w1, w2, b2)


def SparseFF(
    d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.1, quant_prob=0.3,
    use_bfloat16=False, big_weights_in_bfloat16=False, mode='train',
    kernel_initializer=tl.GlorotUniformInitializer(),
    bias_initializer=tl.RandomNormalInitializer(1e-6),
    dropout_rate=0.0, dropout_shared_axes=None, ff_chunk_size=0,
    multiply_by_controller_output=False, kernel_scaling=False):
  """Returns Feed-forward block with sparsity.

  The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense
  that takes an input, makes it of size d_ff (usually larger than it was) and
  then brings it back to the original size after Relu. It is commonly used in
  Transformer models where it often accounts for most of the trainable weights.

  The original block can be slow in decoding due to the need to fetch a lot of
  weights from memory. This sparse block only allows one non-zero element
  in a block of a specified size. This is trained with straight-through Gumbel
  softmax trick.

  Args:
    d_ff: Depth/dimensionality of FeedForward layer.
    n_elements_in_block: The sparsity level. The layer is divided into blocks of
      this size, and each block has only a single element active.
    d_lowrank: The dimensionality of low-rank controller.
    temperature: The temperature of the controller during training.
    quant_prob: During training this proportion of blocks will have quantized
      mask (i.e. a single element active). The rest will use a soft mask.
    use_bfloat16: Whether to use bfloat16 for weights.
    big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the
      FeedForward layer.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    kernel_initializer: Function that creates a matrix of (random) initial
        connection weights `W` for the layer.
    bias_initializer: Function that creates a vector of (random) initial
        bias weights `b` for the layer.
    dropout_rate: Probability for dropping an activation value.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks.
    multiply_by_controller_output: whether to multiply the middle activation
      layer of FF by controller output (i.e. softmax).
    kernel_scaling: Whether to scale the kernel matrix (during init) to keep the
      variance of the layer output regardless of n_elements_in_block.
  """

  if mode == 'train' or multiply_by_controller_output:
    also_return_nondiscrete_output = True
  else:
    also_return_nondiscrete_output = False
  controller = _SparseFFController(
      d_ff=d_ff, n_elements_in_block=n_elements_in_block,
      d_lowrank=d_lowrank, temperature=temperature,
      use_bfloat16=use_bfloat16, mode=mode,
      kernel_initializer=kernel_initializer,
      bias_initializer=bias_initializer,
      also_return_nondiscrete_output=also_return_nondiscrete_output)

  main = [
      _SparseFFMain(
          d_ff=d_ff, n_elements_in_block=n_elements_in_block,
          d_lowrank=d_lowrank, quant_prob=quant_prob, use_bfloat16=use_bfloat16,
          big_weights_in_bfloat16=big_weights_in_bfloat16, mode=mode,
          kernel_initializer=kernel_initializer,
          bias_initializer=bias_initializer,
          multiply_by_controller_output=multiply_by_controller_output,
          kernel_scaling=kernel_scaling),
      # quant_mask, emb
      tl.Select([1, 0]),
      # emb, quant_mask
      tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode),
      tl.Select([1, 0]),
      # quant_mask, emb
  ]

  # We will "remember" quant_mask _after_ chunking, and "recall" this same
  # quant_mask during reverse_and_grad _before_ chunking.
  remembering = _RememberInReverse(output=False)
  recalling = _RecallQuantMaskInReverse(
      remember_layer=remembering, elements=d_ff//n_elements_in_block)

  return tl.BatchLeadingAxes(tl.Serial(
      recalling,  # emb, quant_mask
      tl.Chunk(chunk_size=ff_chunk_size, layer=tl.Serial(
          # emb, quant_mask
          tl.Select((0, 1, 0)),  # emb, quant_mask, emb
          controller,  # quant_mask, mask, emb
          main,  # quant_mask, emb/output
          )),
      remembering,  # emb/output
      ))


class BlockSparseFF(tl.Layer):
  """Feed-forward block with block sparsity.

  The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense
  that takes an input, makes it of size d_ff (usually larger than it was) and
  then brings it back to the original size after Relu. It is commonly used in
  Transformer models where it often accounts for most of the trainable weights.

  This block sparse layer mimics mixture of experts architecture.
  It divides the dimension of d_ff in each weight matrix to # of blocks equal to
  n_experts and activates only one non-zero block from the weights matrix.
  This is trained with straight-through Gumbel softmax trick.
  """

  def __init__(self,
               d_ff,
               n_experts=64,
               temperature=0.7,
               mode='train',
               kernel_initializer=tl.GlorotUniformInitializer(),
               bias_initializer=tl.RandomNormalInitializer(1e-6)):
    """Returns a block sparse feed-forward block."""
    super().__init__(name=f'BlockSparseFF_{d_ff}')
    self._mode = mode
    self._d_ff = d_ff
    self._n_experts = n_experts
    self._temperature = temperature if mode == 'train' else 0.0
    self._n_elements_in_block = d_ff // n_experts
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    assert self._d_ff % self._n_experts == 0

  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    m1, w1, w2, b2 = self.weights
    x_shape = x.shape
    x = jnp.reshape(x, [-1, x_shape[-1]])  # Easier to operate on flattened x.

    # Q: check if we need bias and/or put relu after the m1 dot?
    mask_logits = jnp.dot(x, m1)
    # Softmax.
    mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)
    log_mask = mask_logits - mask_logsumexp
    mask = jnp.exp(log_mask)
    # Gumbel-softmax with straight-through discretization.
    rng1, rng2 = fastmath.random.split(self.rng, 2)
    u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)
    g = -jnp.log(-jnp.log(u))
    selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)
    if self._mode == 'train':
      # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
      quant_mask = tl.one_hot(selected_experts, self._n_experts)
      quant_mask = fastmath.stop_gradient(quant_mask)
      quant_mask += mask - fastmath.stop_gradient(mask)  # straight-through
      # We will sometimes (50% of the batches) use the soft-mask instead of
      # the quantized mask to improve training stability (see the paper above).
      # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
      select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
      quant_mask = jnp.where(select > 0.0, quant_mask, mask)
    else:
      quant_mask = tl.one_hot(selected_experts, self._n_experts)
    quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])
    batch_size = quant_mask.shape[0]

    if self._mode == 'predict' and batch_size == 1:
      # This implementation mimicks inference for batch_size 1.
      start_idx = selected_experts[0] * self._n_elements_in_block
      # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
      w = fastmath.dynamic_slice(w1, [0, start_idx],
                                 [w1.shape[0], self._n_elements_in_block])
      mid = jnp.dot(x, w)
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
      v = fastmath.dynamic_slice(w2, [start_idx, 0],
                                 [self._n_elements_in_block, w2.shape[-1]])
      v = jnp.reshape(v, [self._n_elements_in_block, -1])
      res = jnp.dot(relu, v) + b2
    else:
      expanded_mask = jnp.broadcast_to(
          quant_mask,
          (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))
      expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
      mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      res = jnp.dot(relu, w2) + b2

    return jnp.reshape(res, x_shape)  # un-flatten if needed

  def init_weights_and_state(self, input_signature):
    """Randomly initializes this layer's weights."""
    d_model = input_signature.shape[-1]
    shape_m1 = (d_model, self._n_experts)
    shape_w1 = (d_model, self._d_ff)
    shape_w2 = (self._d_ff, d_model)
    shape_b2 = (d_model,)

    rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)
    m1 = self._kernel_initializer(shape_m1, rng_m1)
    w1 = self._kernel_initializer(shape_w1, rng_w1)
    w2 = self._kernel_initializer(shape_w2, rng_w2)
    b2 = self._bias_initializer(shape_b2, rng_b2)

    self.weights = (m1, w1, w2, b2)


class SwitchSparseFF(tl.Layer):
  """Feed-forward block with switch-style block sparsity.

  The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense
  that takes an input, makes it of size d_ff (usually larger than it was) and
  then brings it back to the original size after Relu. It is commonly used in
  Transformer models where it often accounts for most of the trainable weights.

  This block sparse layer mimics mixture of experts architecture.
  It divides the dimension of d_ff in each weight matrix to # of blocks equal to
  n_experts and activates only one non-zero block from the weights matrix.
  This is trained with methods following the Switch Transformer.
  """

  def __init__(self,
               d_ff,
               n_experts=64,
               temperature=0.1,
               mode='train',
               kernel_initializer=tl.GlorotUniformInitializer(),
               bias_initializer=tl.RandomNormalInitializer(1e-6)):
    """Returns a switch-style training block sparse feed-forward block."""
    super().__init__(name=f'SwitchSparseFF_{d_ff}')
    self._mode = mode
    self._d_ff = d_ff
    self._n_experts = n_experts
    self._temperature = temperature if mode == 'train' else 0.0
    self._n_elements_in_block = d_ff // n_experts
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    assert self._d_ff % self._n_experts == 0

  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    m1, w1, w2, b2 = self.weights
    x_shape = x.shape
    x = jnp.reshape(x, [-1, x_shape[-1]])  # Easier to operate on flattened x.

    # Q: check if we need bias and/or put relu after the m1 dot?
    mask_logits = jnp.dot(x, m1)
    # Softmax.
    mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)
    log_mask = mask_logits - mask_logsumexp
    mask = jnp.exp(log_mask)
    # Gumbel noise to allow sampling from the softmax.
    rng1, _ = fastmath.random.split(self.rng, 2)
    u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)
    g = -jnp.log(-jnp.log(u))
    selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)
    quant_mask = tl.one_hot(selected_experts, self._n_experts)
    quant_mask = fastmath.stop_gradient(quant_mask)
    quant_mask *= mask  # go to just the selected expert
    quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])
    batch_size = quant_mask.shape[0]

    if self._mode == 'predict' and batch_size == 1:
      mask_flat = jnp.reshape(mask, [-1, self._n_experts])
      selected_flat = jnp.reshape(selected_experts, [-1])
      selected_mask_flat = mask_flat[np.arange(selected_flat.size),
                                     selected_flat]
      # This implementation mimicks inference for batch_size 1.
      start_idx = selected_experts[0] * self._n_elements_in_block
      # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
      w = fastmath.dynamic_slice(w1, [0, start_idx],
                                 [w1.shape[0], self._n_elements_in_block])
      mid = jnp.dot(x, w)
      mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None]
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
      v = fastmath.dynamic_slice(w2, [start_idx, 0],
                                 [self._n_elements_in_block, w2.shape[-1]])
      v = jnp.reshape(v, [self._n_elements_in_block, -1])
      res = jnp.dot(relu, v) + b2
    else:
      expanded_mask = jnp.broadcast_to(
          quant_mask,
          (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))
      expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
      mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      res = jnp.dot(relu, w2) + b2

    return jnp.reshape(res, x_shape)  # un-flatten if needed

  def init_weights_and_state(self, input_signature):
    """Randomly initializes this layer's weights."""
    d_model = input_signature.shape[-1]
    shape_m1 = (d_model, self._n_experts)
    shape_w1 = (d_model, self._d_ff)
    shape_w2 = (self._d_ff, d_model)
    shape_b2 = (d_model,)

    rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)
    m1 = self._kernel_initializer(shape_m1, rng_m1)
    w1 = self._kernel_initializer(shape_w1, rng_w1)
    w2 = self._kernel_initializer(shape_w2, rng_w2)
    b2 = self._bias_initializer(shape_b2, rng_b2)

    self.weights = (m1, w1, w2, b2)

In [None]:
# SRU needs to be changed in order for concatenated encoder-decoder attention
# to work in predict mode.

def MakeZeroState(depth_multiplier=1):
  """Makes zeros of shape like x but removing the length (axis 1)."""
  def f(x):  # pylint: disable=invalid-name
    if len(x.shape) != 3:
      raise ValueError(f'Layer input should be a rank 3 tensor representing'
                       f' (batch_size, sequence_length, feature_depth); '
                       f'instead got shape {x.shape}.')
    return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),
                     dtype=jnp.float32)
  return tl.Fn('MakeZeroState', f)

def InnerSRUCell():
  """The inner (non-parallel) computation of an SRU."""
  def f(cur_x_times_one_minus_f, cur_f, cur_state):  # pylint: disable=invalid-name
    res = cur_f * cur_state + cur_x_times_one_minus_f
    return res, res
  return tl.Fn('InnerSRUCell', f, n_out=2)


def ScanSRUCell(mode, monkey_patched_mask=None):
  """The inner (non-parallel) computation of an SRU."""
  if monkey_patched_mask is None:
    return tl.Scan(InnerSRUCell(), axis=1, mode=mode)

  # This is necessary for Terraformer model. See comments there.
  # The mask will only be used in Terraformer in predict mode.
  assert mode == 'predict'

  def update_mask(mask, x_times_one_minus_f):  # pylint: disable=invalid-name
    initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)
    if initial.shape[1] > 1:
      updated_mask = fastmath.dynamic_update_slice_in_dim(
          initial != 0, mask != 0, 1, axis=1)
    else:
      updated_mask = initial
    return updated_mask, x_times_one_minus_f

  def masked_inner_sru_cell(cur_mask, cur_x_times_one_minus_f, cur_f,  # pylint: disable=invalid-name
                            cur_state):
    res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask
           + (1 - cur_mask) * cur_state)
    return res, res

  return tl.Serial(
      monkey_patched_mask.get_layer(),
      tl.Fn('update_mask', update_mask, n_out=2),
      tl.Scan(tl.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2),
              axis=1, mode=mode),
      )


def SRU(n_units, activation=None, mode='train'):
  r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    mode: if 'predict' then we save the previous state for one-by-one inference

  Returns:
    The SRU layer.
  """
  sigmoid_activation = tl.Sigmoid()
  return tl.Serial(                                         # x
      tl.Branch(tl.Dense(3 * n_units), []),               # r_f_y, x
      tl.Split(n_items=3),                                  # r, f, y, x
      tl.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
      tl.Fn('',
              lambda r, f, y: (y * (1.0 - f), f, r),    # y * (1 - f), f, r, x
              n_out=3),
      tl.Parallel([], [], tl.Branch(MakeZeroState(), [])),
      ScanSRUCell(mode=mode),
      tl.Select([0], n_in=2),                               # act(c), r, x
      activation if activation is not None else [],
      tl.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)),
      # Set the name to SRU and don't print sublayers.
      name=f'SRU_{n_units}', sublayers_to_print=[]
  )

## Terraformer

The cells below contain the implementation of the Terraformer architecture:
* feed-forward and positional encoding blocks
* encoder and decoder blocks
* concatenation and stripping to combine the encoder and decoder
* the final Terraformer model

In [None]:
def _FeedForward(d_model, d_ff, dropout, activation, act_dropout,
                 use_bfloat16, mode):
  """Feed-forward block with layer normalization at start."""
  if act_dropout is None:
    act_dropout = dropout
  return [
      tl.Dense(d_ff, use_bfloat16=use_bfloat16),
      tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode),
      activation(),
      tl.Dense(d_model, use_bfloat16=use_bfloat16),
  ]


def FeedForwardWithOptions(d_model,
                           d_ff,
                           dropout,
                           dropout_shared_axes,
                           ff_activation,
                           ff_dropout,
                           ff_chunk_size,
                           ff_use_sru,
                           ff_sparsity,
                           center_layernorm,
                           mode,
                           use_bfloat16=False,
                           ff_sparsity_type='1inN'):
  """Feed-Forward block with all the options.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    ff_activation: Type of activation function at the end of each block; must be
      an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block
      with this sparsity
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    mode: If `'train'`, each block will include dropout; else, it will pass all
      values through unaltered.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
      use SwitchSparseFF if ff_sparsity_type=`'Switch'`

  Returns:
    A list of layers which maps vectors to vectors.
  """
  if ff_sparsity and ff_sparsity_type == '1inN':
    temperature, quant_prob = 0.1, 0.3
    if isinstance(ff_sparsity, str):
      # This is hacky but used to pass ff_sparsity in yaml sweep files.
      ff_sparsity = [(float(x) if '.' in x else int(x))
                     for x in ff_sparsity.split()]
    if isinstance(ff_sparsity, (list, tuple)):
      if len(ff_sparsity) == 2:
        n_elements_in_block, d_lowrank = ff_sparsity
      else:
        n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity
    else:
      assert isinstance(ff_sparsity, int)
      n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity
    ff = SparseFF(
        d_ff,
        n_elements_in_block=n_elements_in_block,
        d_lowrank=d_lowrank,
        temperature=temperature,
        quant_prob=quant_prob,
        use_bfloat16=use_bfloat16,
        mode=mode,
        dropout_rate=dropout,
        dropout_shared_axes=dropout_shared_axes,
        ff_chunk_size=ff_chunk_size)
  elif ff_sparsity and ff_sparsity_type == 'Block':
    ff = BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)
  elif ff_sparsity and ff_sparsity_type == 'Switch':
    ff = SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)
  else:
    ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout,
                      use_bfloat16, mode)
  res = [tl.LayerNorm(center=center_layernorm), ff]
  if ff_sparsity_type != '1inN' or ff_sparsity == 0:
    # SparseFF has Dropout and BatchLeadingAxes built-in.
    res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes,
                          mode=mode))
    if ff_chunk_size > 0:
      res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size))
  if ff_use_sru:
    if isinstance(ff_use_sru, (list, tuple)):
      sru_n_layers, sru_n_units = ff_use_sru
    else:
      sru_n_layers, sru_n_units = ff_use_sru, 32
    sru = [SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)]
    block = [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units)
             ] + sru + [tl.Dense(d_model)]
    res = tl.Residual(block, shortcut=res)
  return [res]


def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,
                        masked, attention_dropout, output_dropout,
                        attention_chunk_size, mode):
  """Runs the supplied attention layer."""
  try:
    attention = attention_type(
        n_heads=n_heads,
        d_qk=d_qk,
        d_v=d_v,
        causal=causal,
        masked=masked,
        output_dropout=output_dropout,
        attention_dropout=attention_dropout,
        mode=mode)
  except TypeError:  # No d_qk arguments in less advanced layers.
    attention = attention_type(
        d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode)
  return tl.Chunk(attention, attention_chunk_size)


def PositionalEncoder(mode,
                      dropout=None,
                      max_len=None,
                      pos_type=None,
                      pos_axial_shape=None,
                      pos_d_axial_embs=None,
                      pos_start_from_zero_prob=1.0,
                      pos_max_offset_to_add=0,
                      use_bfloat16=False):
  """Returns the positional encoding layer depending on the arguments.

  Args:
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    dropout: Stochastic rate (probability) for dropping an activation
      value when applying dropout after the embedding block.
    max_len: Maximum symbol length for positional encoding.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    use_bfloat16: If `True`, use bfloat16 weights instead of the default
      float32; this can save memory but may (rarely) lead to numerical issues.

  Returns:
    A layer that will do the positional encoding.
  """
  if not pos_type:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16,
        start_from_zero_prob=pos_start_from_zero_prob,
        max_offset_to_add=pos_max_offset_to_add, mode=mode)
  elif pos_type == 'sin-cos':
    positional_encoding = tl.SinCosPositionalEncoding(mode=mode)
  elif pos_type == 'fixed-base':
    positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)
  elif pos_type == 'infinite':
    positional_encoding = tl.InfinitePositionalEncoding(affine=False)
  elif pos_type == 'infinite-affine':
    positional_encoding = tl.InfinitePositionalEncoding()
  elif pos_type == 'time-bin':
    positional_encoding = tl.TimeBinPositionalEncoding()
  else:
    assert pos_d_axial_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=pos_axial_shape, d_embs=pos_d_axial_embs,
        dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)),
        dropout=dropout, mode=mode)

  return positional_encoding


def EmbeddingAndPositionalEncodings(input_vocab_size,
                                    d_model,
                                    mode,
                                    embedding_dropout,
                                    dropout_shared_axes,
                                    max_len,
                                    output_vocab_size=None,
                                    pos_type=None,
                                    pos_axial_shape=None,
                                    pos_d_axial_embs=None,
                                    pos_start_from_zero_prob=1.0,
                                    pos_max_offset_to_add=0,
                                    use_bfloat16=False):
  """Returns the embedder and positional encoder.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
      should be an integer in `range(vocab_size)`. These integers typically
      represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    embedding_dropout: Stochastic rate (probability) for dropping an activation
      value when applying dropout after the embedding block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    max_len: Maximum symbol length for positional encoding.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
      if None, then input and target integers (token IDs) are assumed to come
      from the same vocabulary.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    use_bfloat16: If `True`, use bfloat16 weights instead of the default
      float32; this can save memory but may (rarely) lead to numerical issues.

  Returns:
    A tuple of (input encoder, output encoder, output vocab size used).
  """
  # tokens --> vectors
  def Embedder(vocab_size, embedding_mode):
    if vocab_size is not None:
      embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16)
    else:
      embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16)
    return [
        embedding,
        tl.Dropout(rate=embedding_dropout,
                   shared_axes=dropout_shared_axes,
                   mode=embedding_mode),
    ]

  # NOTE: Positional encodings are not shared between encoder and decoder.

  # Since encoder doesn't run stepwise, we do not use predict mode there.
  encoder_mode = 'eval' if mode == 'predict' else mode
  in_embedder = Embedder(input_vocab_size, encoder_mode)
  in_encoder = in_embedder + [
      PositionalEncoder(encoder_mode,
                        dropout=embedding_dropout,
                        max_len=max_len,
                        pos_type=pos_type,
                        pos_axial_shape=pos_axial_shape,
                        pos_d_axial_embs=pos_d_axial_embs,
                        pos_start_from_zero_prob=pos_start_from_zero_prob,
                        pos_max_offset_to_add=pos_max_offset_to_add,
                        use_bfloat16=use_bfloat16)
  ]

  # If output_vocab_size is None, we reuse the same embedding matrix, otherwise
  # we initialize one.
  assert input_vocab_size or output_vocab_size
  if output_vocab_size is None:
    out_embedder = in_embedder
  else:
    out_embedder = Embedder(output_vocab_size, mode)

  out_encoder = out_embedder + [
      PositionalEncoder(mode,
                        dropout=embedding_dropout,
                        max_len=max_len,
                        pos_type=pos_type,
                        pos_axial_shape=pos_axial_shape,
                        pos_d_axial_embs=pos_d_axial_embs,
                        pos_start_from_zero_prob=pos_start_from_zero_prob,
                        pos_max_offset_to_add=pos_max_offset_to_add,
                        use_bfloat16=use_bfloat16)
  ]

  # Set this to the value actually used.
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  if input_vocab_size is None:
    in_encoder = tl.AssertFunction('...a->...b', in_encoder)
  else:
    in_encoder = tl.AssertFunction('...->...d', in_encoder)
  out_encoder = tl.AssertFunction('...->...d', out_encoder)

  return in_encoder, out_encoder, output_vocab_size

In [None]:
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, attention_type, dropout, ff_activation,
                 ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,
                 attention_chunk_size, n_attention_layers=1,
                 n_feedforward_layers=1, center_layernorm=True,
                 use_bfloat16=False, mode='train'):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_attention_layers: how many residual causal attention layers should we
      have before the feed-forward block (default: 1, the standard block)
    n_feedforward_layers: how many FFNN layers should we have (default 1).
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    mode: str: 'train' or 'eval'


  Returns:
    the layer.
  """
  # pylint: disable=g-complex-comprehension
  def _Attn():
    return ApplyAttentionLayer(
        attention_type, d_model, n_heads, d_attention_key,
        d_attention_value, True, False, dropout, dropout,
        attention_chunk_size, mode)

  def _FF():
    return FeedForwardWithOptions(
        d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
        ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,
        mode, use_bfloat16)

  def _attention_half_residual():
    return [
        tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),
                                  attention_layer=_Attn(),
                                  name='ReversibleHalfResidualDecoderAttn'),
        tl.ReversibleSwap()
    ]

  def _feed_forward():
    return [
        tl.ReversibleHalfResidual(_FF(),
                                  name='ReversibleHalfResidualDecoderFF'),
        tl.ReversibleSwap()
    ]

  return ([_attention_half_residual() for _ in range(n_attention_layers)]
          + [_feed_forward() for _ in range(n_feedforward_layers)])


def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,
                 ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0,
                 attention_chunk_size=0, center_layernorm=True,
                 use_bfloat16=False, use_two_swaps_per_block=True,
                 mode='train'):
  """Returns a list of layers that implements a Terraformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder
      block, otherwise use only one swap.
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
  if mode == 'predict':
    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    mode = 'eval'

  def _Attn():
    return ApplyAttentionLayer(
        attention_type=attention_type, d_model=d_model, n_heads=n_heads,
        d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False,
        attention_dropout=dropout, output_dropout=dropout,
        attention_chunk_size=attention_chunk_size, mode=mode)

  def _FF():
    return FeedForwardWithOptions(
        d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
        ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,
        mode, use_bfloat16)

  attention = _Attn()
  if attention.n_out == 2:
    attention = tl.Serial(
        tl.Parallel([], _InsertAxes12()),
        attention,
        tl.Select([0], n_in=2)
    )

  def _attention_half_residual():
    return [
        tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),
                                  attention_layer=attention,
                                  name='ReversibleHalfResidualEncoderAttn'),
        tl.ReversibleSwap()
    ]

  def _feed_forward():
    layers = [
        tl.ReversibleHalfResidual(_FF(),
                                  name='ReversibleHalfResidualEncoderFF')
    ]
    if use_two_swaps_per_block:
      layers.append(tl.ReversibleSwap())
    return layers

  return _attention_half_residual() + _feed_forward()

In [None]:
# Arg shapes: (B, L1, H), (B, L2, H), (B, L1).
def _ConcatWithPadding(vec_e, vec_d, mask_e):
  """Concatenate with padding: see the ConcatWithPadding layer for details."""
  # pylint: disable=invalid-name
  B, L1, H = vec_e.shape
  L2 = vec_d.shape[1]
  # pylint: enable=invalid-name

  if vec_d.shape != (B, L2, H):
    raise ValueError(f'Shape of decoder vector, {vec_d.shape}, does not'
                     f' equal {(B, L2, H)}.')
  if mask_e.shape != (B, L1):
    raise ValueError(f'Shape of encoder mask, {mask_e.shape}, does not'
                     f' equal {(B, L1)}.')

  def _UpdateRow(x):
    # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
    row_e, row_d, row_mask_e = x
    # final_row - (L1+L2, H)
    final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
    # Find the last real token/vector of the encoder.
    e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
    # Starting after that index, update with the decoder row.
    zero = jnp.array(0, dtype=e_idx.dtype)  # avoid int32/int64 mismatch
    return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))

  return fastmath.map(_UpdateRow, [vec_e, vec_d, mask_e])


def _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d):
  """Strip concatenate with padding: see the layer below for details."""
  # pylint: disable=invalid-name
  B, L, H = vec_ed.shape
  L1 = tok_e.shape[1]
  L2 = tok_d.shape[1]
  # pylint: enable=invalid-name
  if L != L1 + L2:
    raise ValueError(f'Length from encoder-decoder vectors ({L}) does not'
                     f' equal sum of lengths from encoder ({L1}) and decoder'
                     f' ({L2}).')
  if tok_e.shape != (B, L1):
    raise ValueError(f'Shape of encoder tokens, {tok_e.shape}, does not'
                     f' equal {(B, L1)}.')
  if tok_d.shape != (B, L2):
    raise ValueError(f'Shape of decoder tokens, {tok_d.shape}, does not'
                     f' equal {(B, L2)}.')

  def _UpdateRow(x):
    # (L, H), (L1, H) & (L2, H)
    row_ed, row_e, _ = x
    mask_e = row_e != 0
    len_e = jnp.sum(mask_e, dtype=jnp.int32)
    # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
    # and pick up (L2, H) tensor slice from there.
    zero = jnp.array(0, dtype=len_e.dtype)  # avoid int32/int64 mismatch
    return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))

  return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])


class StripFromConcatenateWithPadding(tl.Layer):
  """Strips out the leading encoder tokens from the concatenated array."""

  def __init__(self, mode='train'):
    super().__init__(n_in=3, n_out=1)
    self._mode = mode

  def init_weights_and_state(self, input_signature):
    """Sets layer-specific internal state."""
    del input_signature
    self.state = jnp.array(0, dtype=jnp.int32)

  def forward(self, inputs):
    vec_ed, tok_e, tok_d = inputs

    # In training/eval mode or at the first step predict mode i.e. when
    # state.shape is (), i.e. at first step, we do the actual compuration
    if self._mode != 'predict' or not self.state.shape:
      # Now state.shape will not evaluate to false.
      self.state = self.state.reshape((1,))
      return _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d)

    # In predict mode and on subsequent steps (i.e. after the first step) vec_ed
    # is actually vec_d, since no concatenation happened at all.
    return vec_ed


class ConcatWithPadding(tl.ReversibleLayer):
  """Concatenates two length padded (B, L, H) arrays (of different lenghts)."""

  def __init__(self, mode='train'):
    super().__init__(n_in=5, n_out=3)
    self._mode = mode

  def init_weights_and_state(self, input_signature):
    """Sets layer-specific internal state."""
    del input_signature
    self.state = jnp.array(0, dtype=jnp.int32)

  def forward(self, inputs):
    vec_e, vec_d, mask_e, tok_e, tok_d = inputs

    # In training/eval mode or at the first step predict mode i.e. when
    # state.shape is (), i.e. at first step, we return the concatenated output.
    if self._mode != 'predict' or not self.state.shape:
      # Now state.shape will not evaluate to false.
      self.state = self.state.reshape((1,))
      return _ConcatWithPadding(vec_e, vec_d, mask_e), tok_e, tok_d

    # In predict mode and on subsequent steps (i.e. after the first step) we
    # don't concatenate anymore, but just return the decoder vector.
    return vec_d, tok_e, tok_d

  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    del state, new_state, rng, weights
    assert self._mode != 'predict', 'cannot reverse in predict mode'
    vecs_ed, toks_e, toks_d = output
    vecs_d = _StripFromConcatenateWithPadding(vecs_ed, toks_e, toks_d)
    mask_e = (toks_e != 0)
    mask_e_float = mask_e.astype(jnp.float32)
    vecs_e = vecs_ed[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]
    return vecs_e, vecs_d, mask_e, toks_e, toks_d


class ConcatWithPadding2(tl.ReversibleLayer):
  """Concatenate with padding operating on pairs to combine with rev-nets."""

  def __init__(self, mode='train'):
    super().__init__(n_in=6, n_out=4)
    self._mode = mode

  def init_weights_and_state(self, input_signature):
    """Sets layer-specific internal state."""
    del input_signature
    self.state = jnp.array(0, dtype=jnp.int32)

  def forward(self, inputs):
    vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d = inputs

    # In training/eval mode or at the first step predict mode i.e. when
    # state.shape is (), i.e. at first step, we return the concatenated output.
    if self._mode != 'predict' or not self.state.shape:
      # Now state.shape will not evaluate to false.
      self.state = self.state.reshape((1,))
      # Calculate mask and concat_with_padding on the pairs.
      vecs_ed1 = _ConcatWithPadding(vecs_e1, vecs_d, mask_e)
      vecs_ed2 = _ConcatWithPadding(vecs_e2, vecs_d, mask_e)
      return vecs_ed1, vecs_ed2, toks_e, toks_d

    # In predict mode and on subsequent steps (i.e. after the first step) we
    # don't concatenate anymore, but just return the decoder vector.
    return vecs_d, vecs_d, toks_e, toks_d

  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    del state, new_state, rng, weights
    assert self._mode != 'predict', 'cannot reverse in predict mode'
    vecs_ed1, vecs_ed2, toks_e, toks_d = output
    vecs_d = _StripFromConcatenateWithPadding(vecs_ed1, toks_e, toks_d)
    mask_e = (toks_e != 0)
    mask_e_float = mask_e.astype(jnp.float32)
    vecs_e1 = vecs_ed1[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]
    vecs_e2 = vecs_ed2[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]
    return vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d

In [None]:
def Terraformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                d_attention_key=None,
                d_attention_value=None,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                encoder_attention_type=tl.SelfAttention,
                encoder_decoder_attention_type=tl.SelfAttention,
                pos_type='fixed-base',
                pos_axial_shape=(),
                pos_d_axial_embs=None,
                pos_start_from_zero_prob=1.0,
                pos_max_offset_to_add=0,
                ff_activation=tl.Relu,
                ff_use_sru=(1, 32),
                ff_chunk_size=0,
                ff_dropout=None,
                ff_sparsity=32,
                loss_sparsity_type='mult',
                loss_sparsity=0,
                loss_d_lowrank=0,
                loss_sparsity_prob=None,
                attention_chunk_size=0,
                n_layers_forget=0,
                forget_dense=True,
                n_decoder_attention_layers=2,
                use_bfloat16=False,
                reversible_encoder=False,
                use_two_swaps_per_encoder_block=True,
                center_layernorm=True,
                half_before_layer=None,
                double_after_layer=None,
                mode='train'):
  """Returns a highly configurable Terraformer encoder-decoder model.

  This model maps paired text sequences (source and target) to float-valued
  losses. If ``input_vocab_size`` is not ``None``, the layer takes
  two input sequences:

    - inputs (2):

        - source: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(input_vocab_size)``, and 0 values mark padding positions.

        - target: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 1-D float array of losses; shape is `(batch_size)`.

  If ``input_vocab_size`` is ``None``, the layer takes three input sequences:

    - inputs (3):

        - source: 3-D float array representing a batch of already-embedded text
          strings; shape is `(batch_size, sequence_length, d_model)`, where
          sequence_length <= ``max_len``.

        - mask: 2-D int array representing active versus masked positions; 0
          values mark masked (padding) positions.

        - target: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 1-D float array of losses; shape is `(batch_size)`.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if ``None``, then input and target integers (token IDs) are assumed to
        come from the same vocabulary.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    d_attention_key: Depth of key vectors in each attention head.
    d_attention_value: Depth of value vectors in each attention head.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder/decoder blocks. The same rate is
        also used for attention dropout in encoder/decoder blocks.
    max_len: Maximum symbol length for positional encoding.
    encoder_attention_type: Type of attention to use in the encoder; must be
        an attention-type subclass of :py:class:`trax.layers.Layer`.
    encoder_decoder_attention_type: Type of attention to use in the decoder;
        must be an attention-type subclass of :py:class:`trax.layers.Layer`.
    pos_type: String indicating the type of positional embeddings to use.
    pos_axial_shape: Shape (tuple of ints) to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: Tuple of ints specifying the depth of position embedding
        for each axis. Tuple length must match ``pos_axial_shape``, and values
        must sum to ``d_model``.
    pos_start_from_zero_prob: Stochastic rate (probability) for starting
        positional encoding at position 0 during training. If 1.0, always start
        from position 0; if < 1.0, the non-zero starts will be uniformly
        distributed up to ``pos_max_offset_to_add``.
    pos_max_offset_to_add: Maximum offset to add to positions during training
        when randomizing. This offset plus input length must be less than
        ``max_len`` for all training examples.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of :py:class:`trax.layers.Layer`.
    ff_use_sru: If > 0, use this number of SRU layers in place of feedforward
        layers.
    ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this
        size.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
        at feedforward nonlinearities.
    ff_sparsity: If > 0, use sparse feedforward blocks with this level of
        sparsity.
    loss_sparsity_type: String indicating the type of sparsity to used in loss
        layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``,
        use no sparsity.
    loss_sparsity: If > 0, use this level of sparsity in the loss layer.
    loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this
        dimension, in the loss.
    loss_sparsity_prob: Stochastic rate (probability) for using the sparse
        version of the loss. If ``None``, use the sparse version exclusively.
    attention_chunk_size: If > 0, compute attention using chunks of this size.
    n_layers_forget: How often to have a forgetting block between layers.
    forget_dense: If True, use :py:class:`Dense` instances as forget layers;
        else use no-ops.
    n_decoder_attention_layers: Number of attention layers in a decoder block.
    use_bfloat16: If True, use bfloat16 for weights; else use float32.
    reversible_encoder: If True, make the encoder be reversible.
    use_two_swaps_per_encoder_block: If True, ensure that there is a an even
        number of swaps across the encoder.
    center_layernorm: If True, use centering in :py:class:`LayerNorm` (the
        default); else omit centering (which is known as RMS normalization).
    half_before_layer: If not None, specifies an n'th layer such that all
        layers before the n'th use half the normal values for ``d_model`` and
        ``d_ff``.
    double_after_layer: If not None, specifies an n'th layer such that all
        layers after the n'th use double the normal values for ``d_model`` and
        ``d_ff``.
    mode: If ``'train'``, include dropout in each encoder/decoder block; else
        dropout layers have no effect.

  Returns:
    A Terraformer encoder-decoder as a layer that maps from target and source
    text sequences to a scalar loss.
  """
  if mode == 'predict':
    portal_mask = _PortalInput()
  else:
    portal_mask = None

  # Set default dimensions for attention head key and value sizes.
  if (d_model / 2) % n_heads != 0:
    raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
  if d_attention_key is None:
    d_attention_key = d_model // n_heads
  if d_attention_value is None:
    d_attention_value = d_model // n_heads

  # Set values of d_model, d_ff and d_qkv for the first stage.
  d_model1, d_ff1 = d_model, d_ff
  d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
  if half_before_layer:
    d_model1, d_ff1 = d_model / 2, d_ff / 2
    d_attention_key1 = d_attention_key / 2
    d_attention_value1 = d_attention_value / 2

  # Set values of d_model, d_ff and d_qkv for the final stage.
  d_model2, d_ff2 = d_model, d_ff
  d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
  if double_after_layer:
    d_model2, d_ff2 = d_model * 2, d_ff * 2
    d_attention_key2 = d_attention_key * 2
    d_attention_value2 = d_attention_value * 2

  # Vector embeddings.
  in_encoder, out_encoder, output_vocab_size = (
      EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model1,
          mode,
          dropout,
          [-2],  # dropout_shared_axes
          max_len,
          output_vocab_size=output_vocab_size,
          pos_type=pos_type,
          pos_axial_shape=pos_axial_shape,
          pos_d_axial_embs=pos_d_axial_embs,
          pos_start_from_zero_prob=pos_start_from_zero_prob,
          pos_max_offset_to_add=pos_max_offset_to_add,
          use_bfloat16=use_bfloat16)
  )

  def _EncoderBlock():
    return EncoderBlock(
        d_model1,
        d_ff1,
        n_heads,
        encoder_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        center_layernorm=center_layernorm,
        use_bfloat16=use_bfloat16,
        use_two_swaps_per_block=use_two_swaps_per_encoder_block,
        mode=mode)

  def _Encoder():  # vec_e mask_e tok_e tok_d tok_d
    layers = [
        tl.ReversibleSelect([0, 0]),
        _ReversibleSerialForget(
            [_EncoderBlock() for _ in range(n_encoder_layers)],
            d_model1,
            n_layers_forget,
            forget_dense)
    ]
    if not reversible_encoder:
      layers += [
          _XYAvg(),
          tl.Dense(d_model1, use_bfloat16=use_bfloat16),
          tl.LayerNorm(),
      ]
    if mode == 'predict':
      return tl.Cache(tl.Serial(layers))
    else:
      return tl.Serial(layers)

  if mode == 'predict':
    global DotProductCausalAttention
    DotProductCausalAttention.monkey_patched_mask = (
        lambda x: portal_mask)
    global _RememberPad
    _RememberPad.monkey_patched_mask = (  # pylint: disable=protected-access
        lambda x: portal_mask)
    global ScanSRUCell
    originalScanSRUCell = ScanSRUCell
    ScanSRUCell = functools.partial(ScanSRUCell,
                                    monkey_patched_mask=portal_mask)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    # Grow d_model, d_ff, and d_qkv if requested.
    d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
    if half_before_layer and layer_idx >= half_before_layer:
      d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
    if double_after_layer and layer_idx > double_after_layer:
      d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
    decoder_block = DecoderBlock(
        d_m, d_f, d_k, d_v, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        n_attention_layers=n_decoder_attention_layers,
        center_layernorm=center_layernorm,
        use_bfloat16=use_bfloat16,
        mode=mode)
    decoder_blocks.append(decoder_block)
    if half_before_layer and layer_idx == half_before_layer - 1:
      decoder_blocks.append(tl.ReversibleConcatenatePair())
    if double_after_layer and layer_idx == double_after_layer:
      decoder_blocks.append(tl.ReversibleConcatenatePair())

  if mode == 'predict':
    # After initializing the decoder we can revert to original state of
    # previously monkey-patched classes/functions.
    DotProductCausalAttention.monkey_patched_mask = (
        lambda x: None)
    _RememberPad.monkey_patched_mask = (lambda x: None)  # pylint: disable=protected-access
    ScanSRUCell = originalScanSRUCell

  def _Loss():
    return SparseDenseWithOptions(
        output_vocab_size,
        d_input=d_model2,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        use_bfloat16=use_bfloat16,
        mode=mode)

  def _enc_dec_concat():
    """Layers to merge encoder and decoder."""
    if reversible_encoder:
      return [
          tl.ReversibleSelect([0, 1, 4, 2, 3]),  # v_e v_d mask_e tok_e tok_d
          ConcatWithPadding2(mode=mode),      # v_ed v_ed tok_e tok_d
      ]
    else:
      return [
          tl.ReversibleSelect([0, 3, 1, 2]),     # v_e v_d mask_e tok_e tok_d
          ConcatWithPadding(mode=mode),       # v_ed tok_e tok_d
          tl.ReversibleSelect([0, 0]),           # v_ed v_ed tok_e tok_d
      ]

  def _inp_layers():
    if input_vocab_size is not None:
      return tl.AssertFunction(
          'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
          tl.Serial(  # tok_e tok_d
              tl.Select([0, 0, 0, 1]),
              tl.Parallel(in_encoder, [tl.PaddingMask(),
                                       _RemoveAxes12()])
          ))  # vec_e mask_e tok_e tok_d
    else:
      # Input in this case is vec_e, mask_e, tok_d. Where all downstream
      # operations expect tok_e, we give it instead mask_e, expecting that
      # downstream ops only are looking for padding/not padding.
      return tl.AssertFunction(
          'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
          tl.Serial(  # vec_e mask_e tok_d
              tl.Select([0, 1, 1, 2]),
              tl.Parallel(in_encoder, [], _AsTokenIDs())
          ))  # vec_e mask_e tok_e tok_d

  # Assemble and return the model.
  return tl.Serial(
      _inp_layers(),               # vec_e mask_e tok_e tok_d
      tl.Parallel([], portal_mask),

      tl.Select([0, 1, 2, 3, 3]),  # Copy decoder tokens for use in loss.

      # Embed in and out tokens; done together as weights may be shared.
      tl.Parallel([], [], [], [tl.ShiftRight(mode=mode),
                               out_encoder]),  # vec_e mask_e tok_e vec_d tok_d

      # Encode; then concat encoder and decoder, given encoder mask.
      _Encoder(),                             # vec_e mask_e tok_e vec_d tok_d
      _enc_dec_concat(),

      # Run decoder blocks.
      _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                              forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
      _XYAvg(),                               # vec_ed tok_e tok_d
      tl.LayerNorm(),                         # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector,
      # then compute loss.
      tl.Select([0, 1, 2, 2]),                        # vec_ed tok_e tok_d tok_d
      StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d
      _Loss(),  # vec_d tok_d
  )


def _InsertAxes12():
  """Returns a layer that inserts two internal size-1 axes into an array."""
  return tl.Fn('InsertAxes12',
               lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])))


def _RemoveAxes12():
  """Returns a layer that removes two internal size-1 axes from an array."""
  return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2)))


def _AsTokenIDs():
  """Returns a layer that makes mask values look like token ID ints."""
  return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32))


def _XYAvg():
  """Returns a layer that computes the element-wise average of two arrays."""
  return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0)


def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True):
  """ReversibleSerial but with a forgetting block every n_layers."""
  if not n_layers or len(layers) <= n_layers + 1:
    return tl.ReversibleSerial(layers)
  layers1, layers2 = layers[:n_layers], layers[n_layers:]

  if forget_dense:
    forgetting_layer = tl.Serial(
        _XYAvg(),
        tl.Dense(d_model),
        tl.Dup(),
    )
  else:
    forgetting_layer = tl.Select([0, 1])

  return tl.Serial(
      tl.ReversibleSerial(layers1),
      forgetting_layer,
      _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense)
  )


def _ConvertToNaNsOnAnyZero():
  def _convert_to_nans(x, y):
    # if all values in y are non-zeros, return x; otherwise return 0s
    return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y
  return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2)


class _PortalInput(tl.Layer):
  """Portal input for monkey-patching of mask in predict mode."""

  def __init__(self):
    super().__init__(name='_PortalInput', n_out=1, n_in=1)
    self._portal_output = _PortalOutput(self)

  def forward(self, x):
    if isinstance(x, (list, tuple)):
      x = x[0]
    self.state = (x,)
    return x

  def init_weights_and_state(self, input_signature):
    """Initializes this layer's weights."""
    if isinstance(input_signature, (list, tuple)):
      input_signature = input_signature[0]
    self.state = (jnp.zeros(input_signature.shape),)

  def get_value(self):
    return self.state[0]

  def get_layer(self):
    return self._portal_output


class _PortalOutput(tl.Layer):
  """Portal input for monkey-patching of mask in predict mode."""

  def __init__(self, portal_input):
    super().__init__(name='_PortalOutput', n_out=1, n_in=0)
    self._portal_input = portal_input

  def forward(self, x):
    return self._portal_input.get_value()

  def get_value(self):
    return self._portal_input.get_value()

## Example training

Here we show how the Terraformer can be trained on example inputs. The results for the paper were obtained with identical training but for different configurations of inputs and models, which are specified in the attached config files.

In [None]:
model = Terraformer(
    input_vocab_size=12,
    # small model for testing
    d_model=128,
    d_ff=512,
    n_encoder_layers=2,
    n_decoder_layers=2,
    # setting sparsity
    ff_use_sru=(1, 32),
    ff_sparsity=32,
    loss_sparsity=4,
    encoder_decoder_attention_type=functools.partial(
        MultiplicativeConvCausalAttention, sparsity=16, length_kernel_size=3),
    )

copy_inputs = trax.data.inputs.simple_sequence_copy_inputs(
    vocab_size=10, batch_size=32, train_length=32,
    eval_min_length=16, eval_max_length=32)

# Training task.
train_task = training.TrainTask(
    labeled_data=copy_inputs.train_stream(1),
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.0001),
    n_steps_per_checkpoint=5,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=copy_inputs.eval_stream(1),
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=2  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(20)