# Initialization

In [0]:
# *** Comment out on cloud ***
# Install the newest JAX and FLAX versions.
!pip install --upgrade -q jax==0.1.61 jaxlib==0.1.42 flax==0.1.0rc2

In [12]:
# *** Comment out on cloud ***
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

grpc://10.98.69.2:8470


In [0]:
import functools
import itertools
import os
import time

import flax
from flax import jax_utils
from flax import nn
from flax import optim
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import common_utils

import jax
from jax import random
from jax import lax
import jax.nn
import jax.numpy as jnp

import numpy as np

import matplotlib.pyplot as plt

# Transformer model
Code source: https://github.com/google/flax/blob/master/examples/lm1b/models.py

In [0]:
def shift_right(x):
  """Shift the input to the right by padding on axis 1."""
  pad_widths = [(0, 0)] * len(x.shape)
  pad_widths[1] = (1, 0)  # Padding on axis=1
  padded = jnp.pad(
      x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
  return padded[:, :-1]

In [0]:
class Embed(nn.Module):
  """Embedding Module.
  A parameterized function from integers [0, n) to d-dimensional vectors.
  """

  def apply(self,
            inputs,
            num_embeddings,
            features,
            mode='input',
            emb_init=nn.initializers.normal(stddev=1.0)):
    """Applies Embed module.
    Args:
      inputs: input data
      num_embeddings: number of embedding
      features: size of the embedding dimension
      mode: either 'input' or 'output' -> to share input/output embedding
      emb_init: embedding initializer
    Returns:
      output which is embedded input data
    """
    embedding = self.param('embedding', (num_embeddings, features), emb_init)
    if mode == 'input':
      if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
        raise ValueError('Input type must be an integer or unsigned integer.')
      return jnp.take(embedding, inputs, axis=0)
    if mode == 'output':
      return jnp.einsum('bld,vd->blv', inputs, embedding)

In [0]:
def sinusoidal_init(max_len=2048):
  """1D Sinusoidal Position Embedding Initializer.
  Args:
      max_len: maximum possible length for the input
  Returns:
      output: init function returning `(1, max_len, d_feature)`
  """

  def init(key, shape, dtype=np.float32):
    """Sinusoidal init."""
    del key, dtype
    d_feature = shape[-1]
    pe = np.zeros((max_len, d_feature), dtype=np.float32)
    position = np.arange(0, max_len)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    pe = pe[np.newaxis, :, :]  # [1, max_len, d_feature]
    return jnp.array(pe)

  return init

In [0]:
class AddPositionEmbs(nn.Module):
  """Adds learned positional embeddings to the inputs."""

  def apply(self,
            inputs,
            max_len=2048,
            posemb_init=nn.initializers.normal(stddev=1.0),
            cache=None):
    """Applies AddPositionEmbs module.
    Args:
      inputs: input data
      max_len: maximum possible length for the input
      posemb_init: positional embedding initializer
      cache: flax attention cache for fast decoding.
    Returns:
      output: `(bs, timesteps, in_dim)`
    """
    assert inputs.ndim == 3, ('Number of dimensions should be 3,'
                              ' but it is: %d' % inputs.ndim)
    length = inputs.shape[1]
    pos_emb_shape = (1, max_len, inputs.shape[-1])
    pos_embedding = self.param('pos_embedding', pos_emb_shape, posemb_init)
    pe = pos_embedding[:, :length, :]
    # We abuse the same attention Cache mechanism to run positional embeddings
    # in fast predict mode. We could use state variables instead, but this
    # simplifies invocation with a single top-level cache context manager.
    # We only use the cache's position index for tracking decoding position.
    if cache:
      if self.is_initializing():
        cache.store(lambda: (4, (1, 1)))
      else:
        cache_entry = cache.retrieve(None)
        i = cache_entry.i
        one = jnp.array(1, jnp.uint32)
        cache_entry = cache_entry.replace(i=cache_entry.i + one)
        cache.store(cache_entry)
        _, _, df = pos_embedding.shape
        pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)),
                               jnp.array((1, 1, df)))
    return inputs + pe

In [0]:
class MlpBlock(nn.Module):
  """Transformer MLP block."""

  def apply(self,
            inputs,
            mlp_dim,
            out_dim=None,
            dropout_rate=0.1,
            deterministic=False,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6)):
    """Applies Transformer MlpBlock module."""
    actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
    x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init)
    x = nn.gelu(x)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    output = nn.Dense(
        x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init)
    output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
    return output

In [0]:
class Transformer1DBlock(nn.Module):
  """Transformer layer (https://openreview.net/forum?id=H1e5GJBtDr)."""

  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            causal_mask=False,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=False,
            cache=None):
    """Applies Transformer1DBlock module.
    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
    Returns:
      output after transformer block.
    """

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs)
    x = nn.SelfAttention(
        x,
        num_heads=num_heads,
        qkv_features=qkv_dim,
        attention_axis=(1,),
        causal_mask=causal_mask,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        dropout_rate=attention_dropout_rate,
        deterministic=deterministic,
        cache=cache)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(x)
    y = MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    return x + y

# Hyperparameters

In [0]:
num_train_steps = 500000      # Max number of training steps.
eval_frequency = 1000         # How often to run model evaluation.
num_eval_steps = 20           # Number of steps to take during evaluation.
random_seed = 0               # JAX PRNG random seed.
learning_rate = 0.05          # Base learning rate.
weight_decay = 1e-1           # AdamW-style relative weight decay factor.
batch_size = 256              # "Target" Batch size.
max_target_length = 256       # Maximum input length.
max_eval_target_length = 256  # Maximum eval-set input length.

lm_emb_dim = 512              # LM initial token embedding dimension.
lm_num_heads = 8              # Number of heads in decoder layers.
lm_num_layers = 6             # Number of decoder layers.
lm_qkv_dim = 512              # Decoder query/key/value depth.
lm_mlp_dim = 2048             # Feedforward (MLP) layer depth.

rep_size = 256                 # Size of learned linear representation

In [0]:
# Init PRNG Stream.
rng = random.PRNGKey(random_seed)
rng, init_rng = random.split(rng)
# We init the first set of dropout PRNG keys, but update it afterwards inside
# the main pmap'd training update for performance.
dropout_rngs = random.split(rng, jax.local_device_count())

# Transformer language model
Code source: https://github.com/google/flax/blob/master/examples/lm1b/models.py

In [0]:
class TransformerLM(nn.Module):
  """Transformer Model for language modeling."""

  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            shift=True,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            cache=None):
    """Applies Transformer model on the inputs.
    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.
    Returns:
      output of a transformer decoder.
    """
    padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None]
    assert inputs.ndim == 2  # (batch, len)
    x = inputs
    if shift:
      x = shift_right(x)
    x = x.astype('int32')

    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')

    x = AddPositionEmbs(
        x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len),
        cache=cache)

    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=True,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          deterministic=not train,
          cache=cache,
      )

    x = nn.LayerNorm(x)

    logits = nn.Dense(
        x,
        vocab_size,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))

    return logits

# Transformer lenses
Based on: https://arxiv.org/pdf/2002.08866.pdf

### Lens 1: Pooling

Mean pooling

In [0]:
class TransformerMeanPool(nn.Module):
  """Transformer Model + mean pooling for representations."""

  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            shift=True,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            cache=None):
    """Applies Transformer model on the inputs.
    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.
    Returns:
      output of a transformer decoder.
    """
    padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None]
    assert inputs.ndim == 2  # (batch, len)
    x = inputs
    if shift:
      x = shift_right(x)
    x = x.astype('int32')

    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')

    x = AddPositionEmbs(
        x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len),
        cache=cache)

    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=True,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          deterministic=not train,
          cache=cache,
      )

    x = nn.LayerNorm(x)

    rep = jnp.mean(x, axis=1)

    return rep

Max pooling

In [0]:
class TransformerMaxPool(nn.Module):
  """Transformer Model + max pooling for representations."""

  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            shift=True,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            cache=None):
    """Applies Transformer model on the inputs.
    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.
    Returns:
      output of a transformer decoder.
    """
    padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None]
    assert inputs.ndim == 2  # (batch, len)
    x = inputs
    if shift:
      x = shift_right(x)
    x = x.astype('int32')

    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')

    x = AddPositionEmbs(
        x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len),
        cache=cache)

    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=True,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          deterministic=not train,
          cache=cache,
      )

    x = nn.LayerNorm(x)

    rep = jnp.max(x, axis=1)

    return rep

### Lens 2: Linear + ReLU + Max Pooling

In [0]:
class TransformerLinearMaxPool(nn.Module):
  """Transformer Model + linear layer + max pooling for representations."""

  def apply(self,
            inputs,
            vocab_size,
            rep_size=256,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            shift=True,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            cache=None):
    """Applies Transformer model on the inputs.
    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.
    Returns:
      output of a transformer decoder.
    """
    padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None]
    assert inputs.ndim == 2  # (batch, len)
    x = inputs
    if shift:
      x = shift_right(x)
    x = x.astype('int32')

    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')

    x = AddPositionEmbs(
        x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len),
        cache=cache)

    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=True,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          deterministic=not train,
          cache=cache,
      )

    x = nn.LayerNorm(x)

    x = nn.Dense(
        x,
        rep_size,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    
    x = nn.relu(x)
    
    rep = jnp.max(x, axis=1)

    return rep

## Test models

In [0]:
@functools.partial(jax.jit, static_argnums=(1, 2))
def create_language_model(key, input_shape, model_kwargs):
  """
  We create a model definition from the top-level Language Model and 
  passed in hyperparameters.
  """
  module = TransformerLM.partial(**model_kwargs)
  # We initialize an autoregressive Cache collection for fast, autoregressive
  # decoding through the language model's decoder layers.
  with nn.attention.Cache().mutate() as cache_def:
    # create_by_shape initializes the model parameters.
    _, model = module.create_by_shape(key,
                                         [(input_shape, jnp.float32)],
                                         cache=cache_def)
  return model, cache_def


@functools.partial(jax.jit, static_argnums=(1, 2))
def create_meanpool_model(key, input_shape, model_kwargs):
  """
  We create a model definition from the top-level Representation Model and 
  passed in hyperparameters.
  """
  module = TransformerMeanPool.partial(**model_kwargs)
  # We initialize an autoregressive Cache collection for fast, autoregressive
  # decoding through the language model's decoder layers.
  with nn.attention.Cache().mutate() as cache_def:
    # create_by_shape initializes the model parameters.
    _, model = module.create_by_shape(key,
                                         [(input_shape, jnp.float32)],
                                         cache=cache_def)
  return model, cache_def


@functools.partial(jax.jit, static_argnums=(1, 2))
def create_maxpool_model(key, input_shape, model_kwargs):
  """
  We create a model definition from the top-level Representation Model and 
  passed in hyperparameters.
  """
  module = TransformerMaxPool.partial(**model_kwargs)
  # We initialize an autoregressive Cache collection for fast, autoregressive
  # decoding through the language model's decoder layers.
  with nn.attention.Cache().mutate() as cache_def:
    # create_by_shape initializes the model parameters.
    _, model = module.create_by_shape(key,
                                         [(input_shape, jnp.float32)],
                                         cache=cache_def)
  return model, cache_def


@functools.partial(jax.jit, static_argnums=(1, 2))
def create_linearmaxpool_model(key, input_shape, model_kwargs):
  """
  We create a model definition from the top-level Representation Model and 
  passed in hyperparameters.
  """
  module = TransformerLinearMaxPool.partial(**model_kwargs)
  # We initialize an autoregressive Cache collection for fast, autoregressive
  # decoding through the language model's decoder layers.
  with nn.attention.Cache().mutate() as cache_def:
    # create_by_shape initializes the model parameters.
    _, model = module.create_by_shape(key,
                                         [(input_shape, jnp.float32)],
                                         cache=cache_def)
  return model, cache_def

In [0]:
vocab_size = 20

input_shape = (batch_size, max_target_length)

transformer_kwargs = {
    'vocab_size': vocab_size,
    'emb_dim': lm_emb_dim,
    'num_heads': lm_num_heads,
    'num_layers': lm_num_layers,
    'qkv_dim': lm_qkv_dim,
    'mlp_dim': lm_mlp_dim,
    'max_len': max(max_target_length, max_eval_target_length)
}

transformer_linear_kwargs = {
    'vocab_size': vocab_size,
    'rep_size' : rep_size,
    'emb_dim': lm_emb_dim,
    'num_heads': lm_num_heads,
    'num_layers': lm_num_layers,
    'qkv_dim': lm_qkv_dim,
    'mlp_dim': lm_mlp_dim,
    'max_len': max(max_target_length, max_eval_target_length)
}

In [0]:
# generate a random sequence
random_seq_len = 32
random_seq = jnp.array([[np.random.randint(vocab_size) for _ in range(random_seq_len)] for __ in range(batch_size)])

In [29]:
# language model
language_model, cache_def = create_language_model(init_rng, input_shape, transformer_kwargs)
logits = language_model(random_seq)
logits, logits.shape

(DeviceArray([[[-4.5319790e-01, -1.3108684e+00,  2.1032448e+00, ...,
                -1.0460795e+00, -1.6000499e+00, -2.6734672e+00],
               [-3.0559456e-01, -2.4639273e+00,  2.3458025e+00, ...,
                -6.5893048e-01, -1.9393097e+00, -8.4083211e-01],
               [-4.4434652e-01, -1.5671092e+00,  3.6528590e+00, ...,
                 9.1124940e-01, -3.7526426e+00, -3.5367021e-01],
               ...,
               [-1.3982043e+00, -1.5023800e+00,  8.0621487e-01, ...,
                 8.2945955e-01, -9.7772032e-01, -1.6740191e+00],
               [-1.8099425e+00, -1.3558569e+00,  1.3833988e+00, ...,
                 1.2333223e+00, -1.7917389e-01, -5.7866496e-01],
               [-5.6542236e-01, -2.0965300e+00,  1.4960983e+00, ...,
                 9.0686068e-02, -9.1157734e-01,  5.9599954e-01]],
 
              [[-4.5319790e-01, -1.3108684e+00,  2.1032448e+00, ...,
                -1.0460795e+00, -1.6000499e+00, -2.6734672e+00],
               [-1.9407952e+00, -2.0996

In [30]:
# mean pool representation
meanpool_model, cache_def = create_meanpool_model(init_rng, input_shape, transformer_kwargs)
meanpool_rep = meanpool_model(random_seq)
meanpool_rep, meanpool_rep.shape

(DeviceArray([[ 1.349282  ,  0.16946445, -0.09409462, ..., -0.39269203,
               -0.03846568, -0.96833426],
              [ 1.42307   , -0.3462407 , -0.17818117, ...,  0.39621493,
               -0.36760882, -0.42438346],
              [ 0.9597972 , -0.34990683, -0.7753747 , ..., -0.04606063,
               -0.79815584, -0.7785816 ],
              ...,
              [ 1.516348  , -0.7187285 , -0.1575165 , ..., -0.19013198,
               -0.34239438, -0.8993825 ],
              [ 0.872891  , -0.6433126 , -0.87928617, ..., -0.26089457,
                0.07963648, -0.7745665 ],
              [ 1.5188763 , -0.1361511 , -0.00225492, ...,  0.19809611,
               -0.32082427, -1.0763489 ]], dtype=float32), (256, 512))

In [31]:
# max pool representation
maxpool_model, cache_def = create_maxpool_model(init_rng, input_shape, transformer_kwargs)
maxpool_rep = maxpool_model(random_seq)
maxpool_rep, maxpool_rep.shape

(DeviceArray([[ 2.5953405 ,  1.194404  ,  0.8714977 , ...,  1.4238741 ,
                0.75553006,  0.47919446],
              [ 2.6580777 ,  0.63564575,  1.0401169 , ...,  2.014348  ,
                0.5694527 ,  0.92812794],
              [ 2.1212747 ,  0.58689904,  0.36625844, ...,  0.9152341 ,
               -0.04691934,  0.37015328],
              ...,
              [ 3.0935042 ,  0.37259296,  1.121785  , ...,  1.2526655 ,
                0.831916  ,  0.27077362],
              [ 2.3825467 ,  0.5376274 ,  0.20739071, ...,  1.0502782 ,
                0.8913085 ,  0.6887138 ],
              [ 2.5058422 ,  1.1807996 ,  1.4759212 , ...,  1.0194496 ,
                0.54213744,  0.19198617]], dtype=float32), (256, 512))

In [32]:
# linear + ReLU + max pool representation
linearmaxpool_model, cache_def = create_linearmaxpool_model(init_rng, input_shape, transformer_linear_kwargs)
linearmaxpool_rep = linearmaxpool_model(random_seq)
linearmaxpool_rep, linearmaxpool_rep.shape

(DeviceArray([[0.844927  , 2.8515341 , 1.1309768 , ..., 0.92408   ,
               1.3553803 , 2.94846   ],
              [1.2444927 , 2.42132   , 0.48922083, ..., 1.1925056 ,
               1.6753874 , 2.1524806 ],
              [0.8767196 , 2.9770274 , 1.1056776 , ..., 0.5257038 ,
               1.0089328 , 2.5957775 ],
              ...,
              [0.85885113, 2.5804234 , 0.8209575 , ..., 1.3480121 ,
               1.244442  , 2.6098092 ],
              [1.1995897 , 2.4142926 , 0.88203317, ..., 0.45323965,
               0.56194437, 2.6673195 ],
              [1.1544257 , 2.6171696 , 0.49701124, ..., 1.037597  ,
               1.9722457 , 2.0675519 ]], dtype=float32), (256, 256))