In [None]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Scaling Transformers - Sparse Is Enough

This colab contains all relevant code for the paper "Sparse is Enough in Scaling Transformers". We depend on th Trax library and the experiments in the paper were not run with the colab but in a distributed setup with the attached config files -- but with the code below.

# imports

In [None]:
import os
import random
import time
import numpy as np

import trax
from trax import layers as tl
from trax import fastmath
from trax.fastmath import numpy as jnp
from trax.supervised import training
from trax.layers.assert_shape import assert_shape


import copy
import functools
import gc
import os
import time
from jax import test_util  # pylint: disable=unused-import
from jax.config import config
import numpy as np
import psutil
from tensorflow.compat.v2 import test

from trax import fastmath
from trax import layers as tl
from trax import models
from trax import shapes
from trax.supervised import decoding
import gin

In [None]:
from colabtools import adhoc_import
import json
import gc
import jax
import numpy as np
import os
import time
import IPython.display as display
import gin

# Training colab using TRAX and sstables
# The colab has three inspirations
#   - https://colab.corp.le.com/drive/1J48CSaDjMcZSfJtKt_FecxPkdSjbD3up#scrollTo=przDFyGSqKwq
#   - https://colab.corp..com/drive/1yXCCUDCNkJP1es5dwjhwnVOg1dWFjDT_#scrollTo=tRa53wxjKPT9

from colabtools import adhoc_import
import functools

from trax.data import tf_inputs
import tensorflow_datasets as tfds
from t5.data import preprocessors as t5_processors
import t5.data

from trax import data
from trax import layers as tl
from trax import models
from trax import optimizers
from trax.data import inputs
from trax.supervised import lr_schedules
from trax.supervised import trainer_lib
from trax.rl import serialization_utils
from trax.rl import space_serializer
import math
from trax.fastmath import numpy as numpy_math
import trax

In [None]:
from colabtools import adhoc_import
  import trax..mira.mira_data_pipeline

# Positional Encoding overriding

In [None]:
import numpy as np

from trax import fastmath
from trax.fastmath import numpy as jnp
from trax.layers import base
from trax.layers import combinators as cb
from trax.layers import core
from trax.layers import initializers as init
from trax.layers.assert_shape import assert_shape
from trax.layers.base import Fn
from trax.layers.research import sparsity

@assert_shape('...d->...d')
class PositionalEncoding(base.Layer):
  """Implements bare positional encoding.

  Positional encoding includes a kind of dropout, if the layer is created in
  ``'train'`` mode with a nonzero ``dropout`` value. For such a layer, on each
  forward pass a subset of sequence positions selected at random will *not*
  receive positional marking.
  """

  def __init__(self, max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2,),
               use_bfloat16=False, start_from_zero_prob=1.0,
               max_offset_to_add=0, d_feature=64, mode='train'):
    """Creates a :py:class:`PositionalEncoding` instance in a given mode.

    Args:
      max_len: Maximum input sequence length.
      dropout: Probability of *not* adding positional encoding to a sequence
          position. Applies only if layer is created in ``'train'`` mode.
      dropout_broadcast_dims: Axes along which dropout mask values are
          broadcast rather than individually set at random.
      use_bfloat16: If ``True``, use bfloat16 weights instead of the default
        float32; this can save memory but may (rarely) lead to numerical issues.
      start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
      max_offset_to_add: maximum offset to add to the positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
      d_feature: int or None; have this dimension for embeddings + shared FF if
        not None.
      mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    """
    super().__init__()
    self._max_len = max_len
    if dropout >= 1.0:
      raise ValueError('Dropout rates must be lower than 1.')
    if mode == 'train':
      self._dropout = dropout
    else:
      self._dropout = 0.0
    self._dropout_broadcast_dims = dropout_broadcast_dims
    self._use_bfloat16 = use_bfloat16
    self._start_from_zero_prob = start_from_zero_prob
    self._max_offset_to_add = max_offset_to_add
    self._mode = mode
    self._d_feature = d_feature

  def forward(self, inputs):
    """Returns the input activations, with added positional information."""
    weights = self.weights
    # if self._d_feature is not None and self._mode != 'predict':
    if self._d_feature is not None:
      weights, ff = weights
      weights = jnp.dot(weights[:inputs.shape[1], :], ff)
    if len(weights.shape) < 3:  # old checkpoints have 1 in first dim already
      weights = weights[None, :, :]  # [1, self._max_len, d_feature]
    if self._mode != 'predict':
      x = inputs
      symbol_size = jnp.shape(x)[1]
      if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
        px = weights[:, :symbol_size, :]
      else:
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add)
        start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1)
        start = jnp.where(start_from_zero < self._start_from_zero_prob,
                          jnp.zeros((), dtype=jnp.int32), start)
        px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size,
                                           axis=1)
      if self._dropout == 0:
        return x + px
      else:
        noise_shape = list(px.shape)
        for dim in self._dropout_broadcast_dims:
          noise_shape[dim] = 1
        keep_prob = 1.0 - self._dropout
        keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                         tuple(noise_shape))
        multiplier = keep.astype(x.dtype) / keep_prob
        return x + px * multiplier
    else:
      if self._dropout != 0:
        raise ValueError(f'In predict mode, but dropout rate '
                         f'({self._dropout}) is not zero.')

      # State in this class is only used for fast inference. In that case,
      # the model is called with consecutive elements position-by-position.
      # This positional encoding layer stores the index of the current
      # position and increments it on each call.
      emb = fastmath.dynamic_slice_in_dim(
          weights, self.state, inputs.shape[1], axis=1)
      self.state += inputs.shape[1]
      return inputs + emb

  def init_weights_and_state(self, input_signature):
    """Randomly initializes the positional encoding vectors.

    Args:
      input_signature: :py:class:`ShapeDtype` instance characterizing the input
          this layer should compute on.
    """
    d_feature = input_signature.shape[-1]
    if self._d_feature is not None:
      d_feature = self._d_feature
    pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
    position = np.arange(0, self._max_len)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)  # [self._max_len, d_feature]
    if self._use_bfloat16:
      pe = pe.astype(jnp.bfloat16)
    w = jnp.array(pe)  # Trainable parameters, initialized above.
    if self._d_feature is not None:
      ff = init.GlorotUniformInitializer()(
          (d_feature, input_signature.shape[-1]), self.rng)
      self.weights = w, ff
    else:
      self.weights = w
    if self._mode == 'predict':
      self.state = jnp.zeros((), dtype=jnp.int32)

In [None]:
og_PositionalEncoding = PositionalEncoding

trax.layers.attention.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64)
trax.layers.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64)
tl.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64)

# Configurable Terraformer - copied implementation

In [None]:
import functools
from trax import layers as tl
from trax.fastmath import numpy as jnp
from trax.models.reformer import reformer
from trax.models.research import configurable_transformer as ct
from trax.models.research import transformer2 as t2

def ConfigurableTerraformer(input_vocab_size,
                            output_vocab_size=None,
                            d_model=512,
                            d_ff=2048,
                            d_attention_key=None,
                            d_attention_value=None,
                            n_encoder_layers=6,
                            n_decoder_layers=6,
                            n_heads=8,
                            dropout=0.1,
                            max_len=2048,
                            encoder_attention_type=tl.SelfAttention,
                            encoder_decoder_attention_type=tl.SelfAttention,
                            pos_type='fixed-base',
                            pos_axial_shape=(),
                            pos_d_axial_embs=None,
                            pos_start_from_zero_prob=1.0,
                            pos_max_offset_to_add=0,
                            ff_activation=tl.Relu,
                            ff_use_sru=0,
                            ff_chunk_size=0,
                            ff_dropout=None,
                            ff_sparsity=0,
                            loss_sparsity_type='mult',
                            loss_sparsity=0,
                            loss_d_lowrank=0,
                            loss_sparsity_prob=None,
                            attention_chunk_size=0,
                            n_layers_forget=0,
                            forget_dense=True,
                            n_decoder_attention_layers=2,
                            use_bfloat16=False,
                            reversible_encoder=False,
                            use_two_swaps_per_encoder_block=True,
                            center_layernorm=True,
                            half_before_layer=None,
                            double_after_layer=None,
                            mode='train'):
  """Returns a highly configurable Terraformer encoder-decoder model.

  This model maps paired text sequences (source and target) to float-valued
  losses. If ``input_vocab_size`` is not ``None``, the layer takes
  two input sequences:

    - inputs (2):

        - source: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(input_vocab_size)``, and 0 values mark padding positions.

        - target: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 1-D float array of losses; shape is `(batch_size)`.

  If ``input_vocab_size`` is ``None``, the layer takes three input sequences:

    - inputs (3):

        - source: 3-D float array representing a batch of already-embedded text
          strings; shape is `(batch_size, sequence_length, d_model)`, where
          sequence_length <= ``max_len``.

        - mask: 2-D int array representing active versus masked positions; 0
          values mark masked (padding) positions.

        - target: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 1-D float array of losses; shape is `(batch_size)`.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if ``None``, then input and target integers (token IDs) are assumed to
        come from the same vocabulary.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    d_attention_key: Depth of key vectors in each attention head.
    d_attention_value: Depth of value vectors in each attention head.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder/decoder blocks. The same rate is
        also used for attention dropout in encoder/decoder blocks.
    max_len: Maximum symbol length for positional encoding.
    encoder_attention_type: Type of attention to use in the encoder; must be
        an attention-type subclass of :py:class:`trax.layers.Layer`.
    encoder_decoder_attention_type: Type of attention to use in the decoder;
        must be an attention-type subclass of :py:class:`trax.layers.Layer`.
    pos_type: String indicating the type of positional embeddings to use.
    pos_axial_shape: Shape (tuple of ints) to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: Tuple of ints specifying the depth of position embedding
        for each axis. Tuple length must match ``pos_axial_shape``, and values
        must sum to ``d_model``.
    pos_start_from_zero_prob: Stochastic rate (probability) for starting
        positional encoding at position 0 during training. If 1.0, always start
        from position 0; if < 1.0, the non-zero starts will be uniformly
        distributed up to ``pos_max_offset_to_add``.
    pos_max_offset_to_add: Maximum offset to add to positions during training
        when randomizing. This offset plus input length must be less than
        ``max_len`` for all training examples.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of :py:class:`trax.layers.Layer`.
    ff_use_sru: If > 0, use this number of SRU layers in place of feedforward
        layers.
    ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this
        size.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
        at feedforward nonlinearities.
    ff_sparsity: If > 0, use sparse feedforward blocks with this level of
        sparsity.
    loss_sparsity_type: String indicating the type of sparsity to used in loss
        layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``,
        use no sparsity.
    loss_sparsity: If > 0, use this level of sparsity in the loss layer.
    loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this
        dimension, in the loss.
    loss_sparsity_prob: Stochastic rate (probability) for using the sparse
        version of the loss. If ``None``, use the sparse version exclusively.
    attention_chunk_size: If > 0, compute attention using chunks of this size.
    n_layers_forget: How often to have a forgetting block between layers.
    forget_dense: If True, use :py:class:`Dense` instances as forget layers;
        else use no-ops.
    n_decoder_attention_layers: Number of attention layers in a decoder block.
    use_bfloat16: If True, use bfloat16 for weights; else use float32.
    reversible_encoder: If True, make the encoder be reversible.
    use_two_swaps_per_encoder_block: If True, ensure that there is a an even
        number of swaps across the encoder.
    center_layernorm: If True, use centering in :py:class:`LayerNorm` (the
        default); else omit centering (which is known as RMS normalization).
    half_before_layer: If not None, specifies an n'th layer such that all
        layers before the n'th use half the normal values for ``d_model`` and
        ``d_ff``.
    double_after_layer: If not None, specifies an n'th layer such that all
        layers after the n'th use double the normal values for ``d_model`` and
        ``d_ff``.
    mode: If ``'train'``, include dropout in each encoder/decoder block; else
        dropout layers have no effect.

  Returns:
    A Terraformer encoder-decoder as a layer that maps from target and source
    text sequences to a scalar loss.
  """
  if mode == 'predict':
    portal_mask = _PortalInput()
  else:
    portal_mask = None

  # Set default dimensions for attention head key and value sizes.
  if (d_model / 2) % n_heads != 0:
    raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
  if d_attention_key is None:
    d_attention_key = d_model // n_heads
  if d_attention_value is None:
    d_attention_value = d_model // n_heads

  # Set values of d_model, d_ff and d_qkv for the first stage.
  d_model1, d_ff1 = d_model, d_ff
  d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
  if half_before_layer:
    d_model1, d_ff1 = d_model / 2, d_ff / 2
    d_attention_key1 = d_attention_key / 2
    d_attention_value1 = d_attention_value / 2

  # Set values of d_model, d_ff and d_qkv for the final stage.
  d_model2, d_ff2 = d_model, d_ff
  d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
  if double_after_layer:
    d_model2, d_ff2 = d_model * 2, d_ff * 2
    d_attention_key2 = d_attention_key * 2
    d_attention_value2 = d_attention_value * 2

  # Vector embeddings.
  in_encoder, out_encoder, output_vocab_size = (
      ct.EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model1,
          mode,
          dropout,
          [-2],  # dropout_shared_axes
          max_len,
          output_vocab_size=output_vocab_size,
          pos_type=pos_type,
          pos_axial_shape=pos_axial_shape,
          pos_d_axial_embs=pos_d_axial_embs,
          pos_start_from_zero_prob=pos_start_from_zero_prob,
          pos_max_offset_to_add=pos_max_offset_to_add,
          use_bfloat16=use_bfloat16)
  )

  def _EncoderBlock():
    return reformer.EncoderBlock(
        d_model1,
        d_ff1,
        n_heads,
        encoder_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        center_layernorm=center_layernorm,
        use_bfloat16=use_bfloat16,
        use_two_swaps_per_block=use_two_swaps_per_encoder_block,
        mode=mode)

  def _Encoder():  # vec_e mask_e tok_e tok_d tok_d
    layers = [
        tl.ReversibleSelect([0, 0]),
        _ReversibleSerialForget(
            [_EncoderBlock() for _ in range(n_encoder_layers)],
            d_model1,
            n_layers_forget,
            forget_dense)
    ]
    if not reversible_encoder:
      layers += [
          _XYAvg(),
          tl.Dense(d_model1, use_bfloat16=use_bfloat16),
          tl.LayerNorm(),
      ]
    if mode == 'predict':
      return tl.Cache(tl.Serial(layers))
    else:
      return tl.Serial(layers)

  if mode == 'predict':
    # TODO(jaszczur): Remove temporary fix of Terraformer padding in predict.
    # In predict mode Terraformer needs masking for merged encoder-decoder
    # sequence. This monkey patches the layer with a mask to neccessary places.
    # This shouldn't be a permanent solution - mask should be passed through
    # the stack and all the layers.
    tl.attention.DotProductCausalAttention.monkey_patched_mask = (
        lambda x: portal_mask)
    tl.research.sparsity._RememberPad.monkey_patched_mask = (  # pylint: disable=protected-access
        lambda x: portal_mask)
    originalScanSRUCell = tl.rnn.ScanSRUCell
    tl.rnn.ScanSRUCell = functools.partial(tl.rnn.ScanSRUCell,
                                           monkey_patched_mask=portal_mask)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    # Grow d_model, d_ff, and d_qkv if requested.
    d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
    if half_before_layer and layer_idx >= half_before_layer:
      d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
    if double_after_layer and layer_idx > double_after_layer:
      d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
    decoder_block = reformer.DecoderBlock(
        d_m, d_f, d_k, d_v, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        n_attention_layers=n_decoder_attention_layers,
        center_layernorm=center_layernorm,
        use_bfloat16=use_bfloat16,
        mode=mode)
    decoder_blocks.append(decoder_block)
    if half_before_layer and layer_idx == half_before_layer - 1:
      decoder_blocks.append(tl.ReversibleConcatenatePair())
    if double_after_layer and layer_idx == double_after_layer:
      decoder_blocks.append(tl.ReversibleConcatenatePair())

  if mode == 'predict':
    # After initializing the decoder we can revert to original state of
    # previously monkey-patched classes/functions.
    tl.attention.DotProductCausalAttention.monkey_patched_mask = (
        lambda x: None)
    tl.research.sparsity._RememberPad.monkey_patched_mask = (lambda x: None)  # pylint: disable=protected-access
    tl.rnn.ScanSRUCell = originalScanSRUCell

  def _Loss():
    return tl.SparseDenseWithOptions(
        output_vocab_size,
        d_input=d_model2,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        use_bfloat16=use_bfloat16,
        mode=mode)

  def _enc_dec_concat():
    """Layers to merge encoder and decoder."""
    if reversible_encoder:
      return [
          tl.ReversibleSelect([0, 1, 4, 2, 3]),  # v_e v_d mask_e tok_e tok_d
          t2.ConcatWithPadding2(mode=mode),      # v_ed v_ed tok_e tok_d
      ]
    else:
      return [
          tl.ReversibleSelect([0, 3, 1, 2]),     # v_e v_d mask_e tok_e tok_d
          t2.ConcatWithPadding(mode=mode),       # v_ed tok_e tok_d
          tl.ReversibleSelect([0, 0]),           # v_ed v_ed tok_e tok_d
      ]

  def _inp_layers():
    if input_vocab_size is not None:

      return tl.AssertFunction(
          'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
        # return (
          tl.Serial(  # tok_e tok_d
              tl.Select([0, 0, 0, 1]),
              tl.Parallel(in_encoder, [tl.PaddingMask(),
                                       _RemoveAxes12()])
          ))  # vec_e mask_e tok_e tok_d
    else:
      # Input in this case is vec_e, mask_e, tok_d. Where all downstream
      # operations expect tok_e, we give it instead mask_e, expecting that
      # downstream ops only are looking for padding/not padding.

      return tl.AssertFunction(
          'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
      # return (
          tl.Serial(  # vec_e mask_e tok_d
              tl.Select([0, 1, 1, 2]),
              tl.Parallel(in_encoder, [], _AsTokenIDs())
          ))  # vec_e mask_e tok_e tok_d

  # Assemble and return the model.
  return tl.Serial(
      _inp_layers(),               # vec_e mask_e tok_e tok_d
      tl.Parallel([], portal_mask),

      tl.Select([0, 1, 2, 3, 3]),  # Copy decoder tokens for use in loss.

      # Embed in and out tokens; done together as weights may be shared.
      tl.Parallel([], [], [], [tl.ShiftRight(mode=mode),
                               out_encoder]),  # vec_e mask_e tok_e vec_d tok_d

      # Encode; then concat encoder and decoder, given encoder mask.
      _Encoder(),                             # vec_e mask_e tok_e vec_d tok_d
      _enc_dec_concat(),

      # Run decoder blocks.
      _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                              forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
      _XYAvg(),                               # vec_ed tok_e tok_d
      tl.LayerNorm(),                         # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector,
      # then compute loss.
      tl.Select([0, 1, 2, 2]),                        # vec_ed tok_e tok_d tok_d
      t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d
      _Loss(),  # vec_d tok_d
  )


def _InsertAxes12():
  """Returns a layer that inserts two internal size-1 axes into an array."""
  return tl.Fn('InsertAxes12',
               lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])))


def _RemoveAxes12():
  """Returns a layer that removes two internal size-1 axes from an array."""
  return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2)))


def _AsTokenIDs():
  """Returns a layer that makes mask values look like token ID ints."""
  return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32))


def _XYAvg():
  """Returns a layer that computes the element-wise average of two arrays."""
  return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0)


def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True):
  """ReversibleSerial but with a forgetting block every n_layers."""
  if not n_layers or len(layers) <= n_layers + 1:
    return tl.ReversibleSerial(layers)
  layers1, layers2 = layers[:n_layers], layers[n_layers:]

  if forget_dense:
    forgetting_layer = tl.Serial(
        _XYAvg(),
        tl.Dense(d_model),
        tl.Dup(),
    )
  else:
    forgetting_layer = tl.Select([0, 1])

  return tl.Serial(
      tl.ReversibleSerial(layers1),
      forgetting_layer,
      _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense)
  )


def _ConvertToNaNsOnAnyZero():
  def _convert_to_nans(x, y):
    # if all values in y are non-zeros, return x; otherwise return 0s
    return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y
  return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2)


class _PortalInput(tl.Layer):
  """Portal input for monkey-patching of mask in predict mode."""

  def __init__(self):
    super().__init__(name='_PortalInput', n_out=1, n_in=1)
    self._portal_output = _PortalOutput(self)

  def forward(self, x):
    if isinstance(x, (list, tuple)):
      x = x[0]
    self.state = (x,)
    return x

  def init_weights_and_state(self, input_signature):
    """Initializes this layer's weights."""
    if isinstance(input_signature, (list, tuple)):
      input_signature = input_signature[0]
    self.state = (jnp.zeros(input_signature.shape),)

  def get_value(self):
    return self.state[0]

  def get_layer(self):
    return self._portal_output


class _PortalOutput(tl.Layer):
  """Portal input for monkey-patching of mask in predict mode."""

  def __init__(self, portal_input):
    super().__init__(name='_PortalOutput', n_out=1, n_in=0)
    self._portal_input = portal_input

  def forward(self, x):
    return self._portal_input.get_value()

  def get_value(self):
    return self._portal_input.get_value()


import gin
gin.enter_interactive_mode()
def model_configure(*args, **kwargs):
  kwargs['module'] = 'trax.models'
  return gin.external_configurable(*args, **kwargs)

trax.models.reformer.ConfigurableTerraformer = ConfigurableTerraformer

trax.models.ConfigurableTerraformer = model_configure(trax.models.reformer.ConfigurableTerraformer)

# copying

In [None]:
gs_link = "gs://trax-ml/terraformer"
mira_xm2a = "https://xm2a.corp..com/experiments/25250921"
_mira_xm2a_main = "//gc-d/home/afrozm/rs=6.3/mira_med-05-11-07-41/model_200000.pkl.gz"
_mira_xm2a_weights = "//gc-d/home/afrozm/rs=6.3/mira_med-05-11-07-41/model_200000.weights.npy.gz"
_mira_xm2a_opt_slots = "//gc-d/home/afrozm/rs=6.3/mira_med-05-11-07-41/model_200000.opt_slots0.npy.gz"
_mira_xm2a_config = "//gc-d/home/afrozm/rs=6.3/mira_med-05-11-07-41/config.gin"

mira_xm2a_big = "https://xm2a.corp..com/experiments/24886122"
_mira_xm2a_big_main = "//gc-d/home/afrozm/rs=6.3/mira_big-04-02-06-20/model_210000.pkl.gz"
_mira_xm2a_big_weights = "//gc-d/home/afrozm/rs=6.3/mira_big-04-02-06-20/model_210000.weights.npy.gz"
_mira_xm2a_big_opt_slots = "//gc-d/home/afrozm/rs=6.3/mira_big-04-02-06-20/model_210000.opt_slots0.npy.gz"
_mira_xm2a_big_config = "//gc-d/home/afrozm/rs=6.3/mira_big-04-02-06-20/config.gin"

In [None]:
gin_file = _mira_xm2a_config
files = !ls /tmp

if 'model_med' not in files:
  ! cp {gin_file} /tmp
  !mkdir /tmp/model_med
  ! cp -f //gc-d/home/afrozm/rs=6.3/mira_med-05-11-07-41/model_200000* /tmp/model_med
  pass

# parsing gin config

In [None]:
f = open('/tmp/config.gin')
gin_config = list(f)
f.close()
#
#  Uncomment this part to get the original results from the docs
#
# keep_gin = [l for l in gin_config if 'predict_mem' not in l]
# change_gin = [l for l in gin_config if 'predict_mem' in l]
# changed_gin = [l[:-6] + '2048\n' for l in change_gin]
# gin_config = keep_gin + changed_gin
# keep_gin = [l for l in gin_config if 'predict_drop' not in l]
# change_gin = [l for l in gin_config if 'predict_drop' in l]
# changed_gin = [l[:-6] + '2048\n' for l in change_gin]
# gin_config = keep_gin + changed_gin


# keep_gin = [l for l in gin_config if 'std_length' not in l]
# change_gin = [l for l in gin_config if 'std_length' in l]
# changed_gin = [l[:-5] + '2048\n' for l in change_gin]
# gin_config = keep_gin + changed_gin


# keep_gin = [l for l in gin_config if 'max_length = 512' not in l]
# change_gin = [l for l in gin_config if 'max_length = 512' in l]
# changed_gin = ['max_length = 2048\n' for l in change_gin]
# gin_config = keep_gin + changed_gin
#
#  End of the part that needs to be uncommented.
#

gin_config = [l.replace('Reformer2', 'ConfigurableTerraformer') for l in gin_config]

gin_config.append(
    'DotProductCausalAttention.max_inference_length = 2048'
)

og_DotProductCausalAttention = trax.layers.attention.DotProductCausalAttention
trax.layers.attention.DotProductCausalAttention = functools.partial(
    og_DotProductCausalAttention, max_inference_length=2048,
)

# gin_config.append(
#     'MixedLSHSelfAttention.std_length='
# )

gin_config = ''.join(gin_config)
gin.parse_config(gin_config)
gin.operative_config_str().split('\n')

print(gin_config)

In [None]:
# encoder/MixedLSHSelfAttention.predict_drop_len = 32768
# encoder/MixedLSHSelfAttention.predict_mem_len = 32768
# encoder/MixedLSHSelfAttention.std_length = 2048

# random stuff

In [None]:
def model(mode):
  return models.ConfigurableTerraformer(mode=mode)

In [None]:
padding_fun = trax.data.PadToLength(len_map={0: 512, 1: 512, 2: 512}, pad_value = {0: 0, 1: 0, 2:0})
question = """code:
def square_list(xs):
  return [<SENTINEL> for x in xs]
print(square_list([1, 2, 3, 4]))

output:
[1, 4, 9, 16]"""

tokenized = next(padding_fun(trax.data.tokenize([question,], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)))
print(trax.data.detokenize(tokenized, vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100))
print(tokenized.shape)

# autoregressive_sample_stream etc.

In [None]:
import sys
import time

def autoregressive_sample_stream(model, inputs=None,
                                 batch_size=1, temperature=1.0,
                                 start_id=2, accelerate=True, prefix=None):
  if inputs is not None and inputs.shape[0] != batch_size:
    raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match '
                     f'batch_size arg ({batch_size}.')

  fast_model = tl.Accelerate(model) if accelerate else model
  if np.isscalar(start_id):
    start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
  else:
    start_symbol = start_id
  if model.n_in == 1 and inputs is not None:
    current_symbols = np.concatenate([start_symbol, inputs], axis=1)
  else:
    if prefix is None:
      current_symbols = start_symbol
    else:
      current_symbols = np.concatenate([start_symbol, prefix], axis=1)

  while True:
    t0 = time.time()
    if model.n_in > 1 and inputs is not None:
      # print("inp, curr:", inputs.shape, current_symbols.shape)
      logits = fast_model((inputs, current_symbols))[0]
    else:
      logits = fast_model(current_symbols)
    # print('logits:', str(logits)[:100])
    logits = tl.log_softmax(logits[:, -1, :])
    sample = tl.logsoftmax_sample(logits, temperature=temperature)

    print(trax.data.detokenize(sample, vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100))
    print("Time per token: {}".format(time.time() - t0))
    sys.stdout.flush();
    
    yield sample
    # NOTE: Because the model is autoregressive and in 'predict' mode, its
    # history is cached in the model state and the next input is the single
    # symbol just sampled.
    current_symbols = sample[:, None]


START_INDEX = 10

def autoregressive_sample(model, inputs=None,
                          batch_size=1, temperature=1.0,
                          start_id=0, eos_id=1, max_length=100,
                          accelerate=True, prefix=None):
  result = []
  eos_seen = []
  counter = 0
  for index, sample in enumerate(autoregressive_sample_stream(
      model, inputs, batch_size=batch_size, temperature=temperature,
      start_id=start_id, accelerate=accelerate, prefix=prefix)):
    if index == START_INDEX:
      start_time = time.time()
    sample = sample[:, None]
    result.append(sample)
    counter += 1
    if counter >= max_length:
      print('decoded one token per {} s'.format(
          (time.time()-start_time)/(index-START_INDEX)))
      return np.concatenate(result, axis=1)
    # Check at which batch positions have we already encountered EOS.
    for j in range(batch_size):
      if int(sample[j, 0]) == eos_id:
        eos_seen.append(j)
    # If EOS has been seen on all positions, stop.
    if all([j in eos_seen for j in range(batch_size)]):
      print('decoded one token per {} s'.format(
          (time.time()-start_time)/(index-START_INDEX)))
      return np.concatenate(result, axis=1)
  print('decoded one token per {} s'.format(
      (time.time()-start_time)/(index-START_INDEX)))
  return np.concatenate(result, axis=1)

# Predictions

In [None]:
model_file = "/tmp/model_med/model_200000.pkl.gz"
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=numpy_math.int32)
# The model does not like other numbers than 1024 in the line below.
# In particular 15 * 1024 does not work.
shape1l = trax.shapes.ShapeDtype((1, 1024), dtype=numpy_math.int32)

with trax.fastmath.use_backend(trax.fastmath.Backend.JAX):
  model_predict = model(mode='predict')

  
  model_predict.init_from_file(model_file, weights_only=True,
                              input_signature=(shape1l, shape11))
  old_state = model_predict.state

In [None]:
import tensorflow_datasets as tfds
dataset = tfds.summarization.scientific_papers.ScientificPapers()
valid = tfds.load(name='scientific_papers/arxiv:1.1.1')['test']
index = 0
xarts = []
for x in valid:
  xarts.append(x)
  index += 1
  if index == 3:
    break

In [None]:
# Decode the first article
xart = xarts[0]['article']
question = xart.numpy().decode()
print(question[:512])

tokenized = next(padding_fun(trax.data.tokenize([question,], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)))
trax.data.detokenize(tokenized, vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)

with trax.fastmath.use_backend(trax.fastmath.Backend.JAX):
  model_predict.state = old_state
  
  # Putting below 15*1024 does not work. 
  tokens = autoregressive_sample(model_predict, tokenized[None,:1024], temperature=0., max_length=50)
  print(tokens) 
  print(trax.data.detokenize(tokens[0], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100))

In [None]:
# Decode the first article
xart = xarts[1]['article']
question = xart.numpy().decode()
print(question[:512])

tokenized = next(padding_fun(trax.data.tokenize([question,], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)))
trax.data.detokenize(tokenized, vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)

with trax.fastmath.use_backend(trax.fastmath.Backend.JAX):
  model_predict.state = old_state
  
  tokens = autoregressive_sample(model_predict, tokenized[None,:1024], temperature=0., max_length=50)
  print(tokens) 
  print(trax.data.detokenize(tokens[0], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100))

In [None]:
# Decode the first article
xart = xarts[2]['article']
question = xart.numpy().decode()
print(question[:512])

tokenized = next(padding_fun(trax.data.tokenize([question,], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)))
trax.data.detokenize(tokenized, vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100)

with trax.fastmath.use_backend(trax.fastmath.Backend.JAX):
  model_predict.state = old_state
  
  tokens = autoregressive_sample(model_predict, tokenized[None,:1024], temperature=0., max_length=50)
  print(tokens) 
  print(trax.data.detokenize(tokens[0], vocab_file='all.16k.vocab', vocab_dir='//je-d/home/afrozm/rs=6.3/mira/data/v1/', n_reserved_ids=100))

In [None]:
import tensorflow_datasets as tfds
dataset = tfds.summarization.scientific_papers.ScientificPapers()
valid = tfds.load(name='scientific_papers/arxiv:1.1.1')['test']
index = 0
for x in valid:
  xart = x['article']
  question = xart.numpy().decode()
  print('========= Article {} ========'.format(index))
  print(question[:15*1024])
  index += 1
  if index == 3:
    break