**Open this notebook in Colab: https://colab.research.google.com/github/google/sequence-layers/blob/main/notebooks/intro.ipynb**

-----

SequenceLayers is a new design for neural network layer APIs that enables better sequence processing and streaming inference.

**Library:** https://github.com/google/sequence-layers

**Whitepaper:** https://github.com/google/sequence-layers/blob/main/tech-report.pdf

# Setup

In [None]:
!pip install sequence-layers==0.1

In [None]:
import flax
import jax
import jax.numpy as jnp
import numpy as np

import sequence_layers.jax as sl
import sequence_layers.jax.test_utils as sl_test_utils
import sequence_layers.jax.utils as sl_utils

key = jax.random.key(0)
random_sequence = sl_test_utils.random_sequence
unbox = flax.core.meta.unbox

# The Sequence Type

In [None]:
#@title A `Sequence` is a PyTree containing `values` and a `mask`. {vertical-output: true}
# Values can have any dtype, while mask must be bool.
# Values and mask must have a batch and time dimension.

x = sl.Sequence(
  values=jnp.ones((2, 3, 5)),
  mask=jnp.ones((2, 3), jnp.bool_)
)
x

In [None]:
#@title You can mask values to zero with `mask_invalid()`. {vertical-output: true}

x = sl.Sequence(
  jnp.ones((2, 3, 5)),
  jnp.asarray([[True, True, False],
               [True, False, False]]))
x = x.mask_invalid()

# Masking twice is a no-op.
assert x.mask_invalid() is x

x

In [None]:
# @title Constructing a sequence from values. {vertical-output: true}

# When every value is valid, you can use from_values to create an all-True mask.
x = sl.Sequence.from_values(jnp.ones((2, 3, 5)))

x

In [None]:
# @title Constructing a left-aligned sequence from values and lengths. {vertical-output: true}

# When every value is valid, you can use from_values to create an all-True mask.
x = sl.Sequence.from_lengths(jnp.ones((2, 3, 5)), jnp.array([2, 1]))

# from_lengths does not mask the values for you.
print(x)
print()

# The lengths() method tells you how many valid timepoints are in the sequence.
print(f'{x.lengths()=}')
print()

x = x.mask_invalid()
x

In [None]:
#@title A `Sequence` has a `channel_spec`, describing its inner dimensions (minus batch and time).

# SequenceLayers process Sequences along the time dimension.
# Sequences can represent any data with a time dimension.

print('Audio data: [batch, time, features]:')
x = sl.Sequence(jnp.ones((2, 3, 5)),
                jnp.ones((2, 3), jnp.bool_))
print(f'{x.shape=} {x.ndim=} {x.dtype=}')
print(f'{x.channel_shape=}')
print(f'{x.channel_spec=}')
print()

print('Text data: [batch, time] tokens:')
x = sl.Sequence(jnp.ones((2, 3), jnp.int32),
                jnp.ones((2, 3), jnp.bool_))
print(f'{x.shape=} {x.ndim=} {x.dtype=}')
print(f'{x.channel_shape=}')
print(f'{x.channel_spec=}')
print()

print('Video data (RGB images over time): [batch, time, height, width, channels]:')
x = sl.Sequence(jnp.ones((2, 3, 32, 32, 3), jnp.int8),
                jnp.ones((2, 3), jnp.bool_))
print(f'{x.shape=} {x.ndim=} {x.dtype=}')
print(f'{x.channel_shape=}')
print(f'{x.channel_spec=}')

# The SequenceLayer Type

In [None]:
#@title `sl.SequenceLayerConfig` objects are declarative network specifications.

config: sl.SequenceLayerConfig = sl.Serial.Config([
  sl.Dense.Config(8, name='dense1', activation=jax.nn.relu),
  sl.Dense.Config(8, name='dense2')
])

In [None]:
#@title Construct a layer from a `SequenceLayerConfig` with `make()`.
model: sl.SequenceLayer = config.make()

In [None]:
# @title In JAX, `SequenceLayer` is just a Flax layer.

# Get variables for a layer with `nn.Module.init`, just like Flax.
batch_size, time, channels = 2, 3, 5

x = sl.Sequence(
    jax.random.uniform(key, (batch_size, time, channels)),
    jnp.ones((batch_size, time), dtype=jnp.bool_),
)

model_vars = model.init(key, x, training=False)
sl_utils.pprint_tree_shapes_types(unbox(model_vars))

# Bind the model to the variables for imperative demonstration.
model = model.bind(model_vars)

In [None]:
x

In [None]:
#@title A `SequenceLayer` has a `layer` and a `step` method.

# Process x layer-wise.
y = model.layer(x, training=False)
print(y)

In [None]:
# Process x step-wise (arbitrary block sizes).
state = model.get_initial_state(batch_size, x.channel_spec, training=False)
y0, state = model.step(x[:, 0:2], state, training=False)
y1, state = model.step(x[:, 2:4], state, training=False)
y2, state = model.step(x[:, 4:6], state, training=False)
print(y)

In [None]:
# layer() and step() are required to produce identical results.
np.testing.assert_array_almost_equal(
    y.values,
    jnp.concatenate([y0.values, y1.values, y2.values], axis=1))

print('Layer and step produced identical results:')
print(y)

# The SequenceLayer Contract

`SequenceLayer`s must obey the following contract:

*   **Layer-wise and step-wise equivalence**: If `SequenceLayer.supports_step`,
    `SequenceLayer.layer` and `SequenceLayer.step` must produce identical
    results when fed identical data and starting state (slicing the data up into
    blocks of multiples of `SequenceLayer.block_size` timesteps.

    *   Stateful stochastic layers (e.g. `Dropout`) should obey this property if
        RNG state were made deterministic.

*   **Padding and batching invariance**: `SequenceLayer.layer` and
    `SequenceLayer.step` must produce identical results when fed identical data
    with differing amounts of end padding, or when the position of examples in a
    batch is shuffled. For the common use-case of batching contiguous sequences
    of mixed lengths together, the lengths of other sequences in the batch or
    the position in the batch should have no bearing on the calculation
    performed by the layer.

    Said another way, the physical dimensions of the input `Sequence` `values`
    (`[b, t, ...]`) batch size `b` or length `t` must have no impact on the
    resulting computation for an individual sequence in the batch. For example,
    adding arbitrary amounts of padding to the end of the input `values` must
    not change the valid portions of the returned sequence.

    *   **Important note:** Padding invariance is currently only required for
        end padding. Start padding or interior padding (for non-contiguous
        sequences) does affect the behavior of calculations.

    *   **Corollary:** Padding values must not affect the calculation of
        non-padding values.

*   **Masked inputs and outputs**: For an input `Sequence` provided to a
    `SequenceLayer` with `values` (`[b, t, ...]`) and `mask` (`[b, t]`),
    `SequenceLayer`s **must not assume `values` is masked** (a Sequence is
    masked if `values[mask == False] == 0.0`). If the computation performed by
    the layer mixes information across timesteps, then the layer must mask the
    sequence before use. The layer may return either a `Sequence` or a
    `MaskedSequence`.

Each `SequenceLayer` offered in the `sl.` namespace has unit tests that it obeys
this contract. You can test your own layers obey this contract with
[`test_utils.verify_contract`](https://github.com/google/sequence-layers/blob/7a67779f5b8af2b904a1b4aab9b846dd4d6801ae/sequence_layers/jax/test_utils.py#L846C7-L846C22).

# SequenceLayers: State Management for Free

In [None]:
#@title What is `state`? An empty PyTree for our `Dense -> Dense` layer. {vertical-output: true}

state = model.get_initial_state(batch_size, x.channel_spec, training=False)
sl_utils.pprint_tree_shapes_types(state)

In [None]:
#@title Let's define a "stateful" `SequenceLayer`. What is its state? {vertical-output: true}
#@markdown State for two temporal convolutions is two input buffers of length `kernel_size - 1`.

config = sl.Serial.Config([
  sl.Conv1D.Config(8, kernel_size=4, strides=2, padding='causal', name='conv1'),
  sl.Conv1D.Config(8, kernel_size=6, strides=4, padding='causal', name='conv2')
])

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

state = model.get_initial_state(batch_size, x.channel_spec, training=False)
sl_utils.pprint_tree_shapes_types(state)

In [None]:
#@title (inspecting properties of a layer) {vertical-output: true}
print(f"{model.output_latency=}")
print(f"{model.output_ratio=}")
print(f"{model.receptive_field=}")

In [None]:
# @title Let's define a "stateful" `SequenceLayer`: Self Attention. {vertical-output: true}
# @markdown State for a self attention layer is a KV cache.

config: sl.SequenceLayerConfig = sl.Serial.Config([
    sl.DotProductSelfAttention.Config(
        units_per_head=2, num_heads=8, max_past_horizon=32
    ),
    sl.EinsumDense.Config('...nh,nhd->...d', output_shape=[16]),
])

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

state = model.get_initial_state(batch_size, x.channel_spec, training=False)
sl_utils.pprint_tree_shapes_types(state)

In [None]:
#@title Let's define a "stateful" `SequenceLayer`: LSTM {vertical-output: true}
#@markdown State for an LSTM layer is the fixed size state array.

config = sl.LSTM.Config(units=32)
model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

state = model.get_initial_state(batch_size, x.channel_spec, training=False)
sl_utils.pprint_tree_shapes_types(state)

In [None]:
#@title A SequenceLayer has a `block_size` and an `output_ratio`. {vertical-output: true}

config = sl.Serial.Config([
  sl.Conv1D.Config(8, kernel_size=4, strides=2, padding='causal', name='conv1'),
  sl.Conv1D.Config(8, kernel_size=6, strides=4, padding='causal', name='conv2')
])

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

# Two convolutions of stride 2 and 4 means we must feed 8 inputs to get 1 output,
# and we must feed inputs in multiples of 8 to the `step` function.
print('output_ratio: The number of output timesteps for one input as a fraction.')
print(f'{model.output_ratio=}\n')
print('block_size: Multiple of timesteps required as input to step.')
print(f'{model.block_size=}\n')

state = model.get_initial_state(batch_size, x.channel_spec, training=False)

# Step requires a sequence whose physical length is a multiple of the block size.
x = random_sequence(2, 24, *x.channel_shape)
y, state = model.step(x, state, training=False)

# The output ratio determines the physical length of the resulting sequence.
print(f'Input sequence: {x.shape} -> Output sequence: {y.shape}')
assert y.shape[1] == x.shape[1] * model.output_ratio

# x.shape is not a multiple of block_size, so the step fails:
x = random_sequence(2, 23, *x.channel_shape)
try:
  model.step(x, state, training=False)
except ValueError as e:
  print(f'Input sequence: {x.shape} -> Error {e}')

In [None]:
#@title Stepping with multiples of `block_size` at a time.

# We have to feed inputs to step 8 timesteps at a time, since model.block_size is 8.
# Any multiple of block_size is supported as an input to step.
# If block_size > 1, causality is still preserved.

batch_size, time, channels = 2, 32, 5

x = sl.Sequence(
  jax.random.uniform(key, (batch_size, time, channels)),
  jnp.ones((batch_size, time), dtype=jnp.bool_))

# Process x layer-wise.
y = model.layer(x, training=False)

# Process x step-wise in blocks of block_size.
state = model.get_initial_state(batch_size, x.channel_spec, training=False)
y0, state = model.step(x[:, 0:8], state, training=False)
y1, state = model.step(x[:, 8:16], state, training=False)
y2, state = model.step(x[:, 16:24], state, training=False)
y3, state = model.step(x[:, 24:32], state, training=False)

np.testing.assert_array_almost_equal(
    y.values,
    jnp.concatenate([y0.values, y1.values, y2.values, y3.values], axis=1))

# Process x step-wise in blocks of `2 * block_size`.
state = model.get_initial_state(batch_size, x.channel_spec, training=False)
y0, state = model.step(x[:, 0:16], state, training=False)
y1, state = model.step(x[:, 16:32], state, training=False)

np.testing.assert_array_almost_equal(
    y.values,
    jnp.concatenate([y0.values, y1.values], axis=1))

print('Layer and step produced identical results:')
print(y)

In [None]:
# @title `constants`: Side inputs to SequenceLayers. {vertical-output: true}
# @markdown `get_initial_state`, `step` and `layer` have an optional `constants: dict[str, jax.Array | sl.Sequence]` argument. This is used to provide side-inputs to layers, for example conditioning information, source sequences for cross attention, control parameters (CFG scale, temperature if sampling happens internally to a SequenceLayer), etc.

config = sl.DotProductAttention.Config(
    source_name='source', num_heads=8, units_per_head=3
)

x = random_sequence(2, 16, 3)
source = random_sequence(2, 7, 5)

constants = {'source': source}
print('We provide the source for cross attention in constants:')
sl_utils.pprint_tree_shapes_types(constants)

model = config.make()
model_vars = model.init(key, x, training=False, constants=constants)
model = model.bind(model_vars)

print()
print(
    'State contains pre-computed KV caches for the cross attention to source:'
)
state = model.get_initial_state(
    batch_size, x.channel_spec, training=False, constants=constants
)
sl_utils.pprint_tree_shapes_types(state)

y = model.layer(x, training=False, constants=constants)
y_step, _, _ = sl_utils.step_by_step_dynamic(
    model, x, training=False, constants=constants
)

np.testing.assert_array_almost_equal(y.values, y_step.values)

print('Layer and step produced identical results:')

# Input and Output Latency

`SequenceLayer`s have both an `input_latency` and `output_latency` property
describing the latency properties of their step-wise behavior in relation to
their layer-wise behavior.

*   **Input Latency**: An `int` denoting the number of input timesteps before
    the step-wise output of the layer matches its layer-wise output.

*   **Output Latency**: A `fractions.Fraction`, the number of output timesteps
    before the step-wise output of the layer matches its layer-wise output.

An invariant that all layers must maintain is that for the layer-wise output and
step-wise output, the step-wise output is equivalent to the layer-wise output
**after** appending an additional `input_latency` invalid timesteps at the end
of the step-wise input, and dropping the initial `output_latency` timesteps from
the step-wise output. The initial `output_latency` timesteps must be invalid
(`mask = False`) to avoid accidental use by consumers.

In [None]:
#@title Convolution with lookahead. {vertical-output: true}

x = random_sequence(2, 18, 1, random_lengths=False)
config = sl.Conv1D.Config(1, kernel_size=5, strides=1, padding='reverse_causal')

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

y_layer = model.layer(x, training=False).mask_invalid()
y_step, _, _ = sl_utils.step_by_step_dynamic(model, x, training=False)
y_step = y_step.mask_invalid()

#@markdown A `reverse_causal` padded kernel size 5 convolution has a lookahead of 4 timesteps.
print(f'Input latency: {model.input_latency}')
print(f'Output latency: {model.output_latency}')
print()
print('The step-wise output does not match the layer-wise output. There are 4 invalid timesteps at the start, and not all outputs are produced!')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

y_step, _, _ = sl_utils.step_by_step_dynamic(model, x.pad_time(0, model.input_latency, valid=False), training=False)
y_step = y_step[:, int(model.output_latency):]
y_step = y_step.mask_invalid()

print()
print('By padding the input and trimming the output, we achieve layer/step equivalence:')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

np.testing.assert_allclose(y_layer.values, y_step.values, atol=1e-6, rtol=1e-6)

In [None]:
#@title Self attention with lookahead. {vertical-output: true}

x = random_sequence(2, 18, 1, random_lengths=False)
config = sl.Serial.Config([
    sl.DotProductSelfAttention.Config(num_heads=1, units_per_head=1, max_past_horizon=5, max_future_horizon=5),
    sl.Flatten.Config()
])

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

y_layer = model.layer(x, training=False).mask_invalid()
y_step, _, _ = sl_utils.step_by_step_dynamic(model, x, training=False)
y_step = y_step.mask_invalid()

print(f'Input latency: {model.input_latency}')
print(f'Output latency: {model.output_latency}')
print()
print('The step-wise output does not match the layer-wise output. There are 5 invalid timesteps at the start, and not all outputs are produced!')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

y_step, _, _ = sl_utils.step_by_step_dynamic(model, x.pad_time(0, model.input_latency, valid=False), training=False)
y_step = y_step[:, int(model.output_latency):]
y_step = y_step.mask_invalid()

print()
print('By padding the input and trimming the output, we achieve layer/step equivalence:')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

np.testing.assert_allclose(y_layer.values, y_step.values, atol=1e-6, rtol=1e-6)

In [None]:
#@title `sl.Delay`, an example of `input_latency != output_latency`. {vertical-output: true}

x = random_sequence(2, 18, 1, random_lengths=False)
config = sl.Delay.Config(3)

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

y_layer = model.layer(x, training=False).mask_invalid()
y_step, _, _ = sl_utils.step_by_step_dynamic(model, x, training=False)
y_step = y_step.mask_invalid()

#@markdown `sl.Delay` delays both the layer-wise output and step-wise output by default. As a result, the output latency is zero (no trimming is required for step-wise out to match the layer-wise output), but the input latency is > 0 (flushing required to consume all input).
print(f'Input latency: {model.input_latency}')
print(f'Output latency: {model.output_latency}')
print()
print('The step-wise output does not match the layer-wise output. Not all outputs are produced without flushing.')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

y_step, _, _ = sl_utils.step_by_step_dynamic(model, x.pad_time(0, model.input_latency, valid=False), training=False)
y_step = y_step[:, int(model.output_latency):]
y_step = y_step.mask_invalid()

print()
print('By padding the input, we achieve layer/step equivalence:')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

np.testing.assert_allclose(y_layer.values, y_step.values, atol=1e-6, rtol=1e-6)

In [None]:
#@title `sl.Lookahead`, an example of `input_latency != output_latency`. {vertical-output: true}

x = random_sequence(2, 18, 1, random_lengths=False)
config = sl.Lookahead.Config(3)

model = config.make()
model_vars = model.init(key, x, training=False)
model = model.bind(model_vars)

y_layer = model.layer(x, training=False).mask_invalid()
y_step, _, _ = sl_utils.step_by_step_dynamic(model, x, training=False)
y_step = y_step.mask_invalid()

#@markdown `sl.Lookahead` truncates both the layer-wise output and step-wise output by default. As a result, the input latency is zero (no flushing is required for step-wise out to match the layer-wise output), but the output latency is > 0 (trimming of the output is required).
print(f'Input latency: {model.input_latency}')
print(f'Output latency: {model.output_latency}')
print()
print('The step-wise output does not match the layer-wise output. The step-wise output needs trimming at the front.')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

y_step, _, _ = sl_utils.step_by_step_dynamic(model, x.pad_time(0, model.input_latency, valid=False), training=False)
y_step = y_step[:, int(model.output_latency):]
y_step = y_step.mask_invalid()

print()
print('By trimming the output, we achieve layer/step equivalence:')
print(y_layer.values[0, :, 0])
print(y_step.values[0, :, 0])

np.testing.assert_allclose(y_layer.values, y_step.values, atol=1e-6, rtol=1e-6)

# Examples

Let's build some example `SequenceLayer` models.

In [None]:
#@title Example: Streamable Log mel Spectrogram Frontend

sample_rate = 24000.0
frame_length = 1200 #@param { type: "integer" }
frame_step =  300#@param { type: "integer" }
fft_length =  2048#@param { type: "integer" }
num_mel_bins = 128 #@param { type: "integer" }
lower_edge_hertz = 20.0  #@param { type: "number" }
upper_edge_hertz = 12000.0  #@param { type: "number" }
log_offset = 1e-6 #@param { type: "number" }
time_padding = 'causal' #@param ['reverse_causal', 'causal', 'valid']
fft_padding = 'center' #@param ['right', 'center']

frontend = sl.Serial.Config([
  sl.STFT.Config(frame_length,
          frame_step,
          fft_length,
          output_magnitude=True,
          fft_padding=fft_padding,
          time_padding=time_padding),
  sl.LinearToMelSpectrogram.Config(num_mel_bins, sample_rate, lower_edge_hertz, upper_edge_hertz),
  sl.Add.Config(1e-6),  # log offset to avoid blowup.
  sl.Log.Config(),
]).make()

In [None]:
#@title Example: Transformer Encoder

def SelfAttentionNetwork(model_dimension, dropout_rate, num_heads, units_per_head, max_past_horizon, max_future_horizon):
  return sl.Residual.Config([
    sl.RMSNormalization.Config(name='rms_norm'),
    sl.DotProductSelfAttention.Config(
      units_per_head=units_per_head,
      num_heads=num_heads,
      max_past_horizon=max_past_horizon,
      max_future_horizon=max_future_horizon,
      use_bias=False,
      # Use RoPE for the queries and keys.
      query_network=sl.ApplyRotaryPositionalEncoding.Config(max_wavelength=10000),
      key_network=sl.ApplyRotaryPositionalEncoding.Config(max_wavelength=10000),
      attention_probabilities_dropout_rate=dropout_rate,
      broadcast_dropout_across_queries=True,
      name='attention'),
    sl.DenseShaped.Config([model_dimension], use_bias=False, name='output_projection'),
    sl.Dropout.Config(dropout_rate)
  ], name='self_attention')

def FeedForwardNetwork(model_dimension: int, ffn_dim: int, dropout_rate: float):
  """Residual feed-forward module."""
  return sl.Residual.Config([
    sl.RMSNormalization.Config(name='rms_norm'),
    sl.Dense.Config(ffn_dim * 2, use_bias=False, name='dense1'),
    sl.GatedUnit.Config(jax.nn.gelu, None),  # GeGLU
    sl.Dropout.Config(dropout_rate),
    sl.Dense.Config(model_dimension, use_bias=False, name='dense2'),
    sl.Dropout.Config(dropout_rate),
  ], name='ffn')

def Transformer(model_dimension: int,
                num_layers: int,
                num_heads: int,
                units_per_head: int,
                ffn_dim: int,
                dropout_rate: float,
                max_past_horizon: int,
                max_future_horizon: int) -> sl.SequenceLayerConfig:
  return sl.Serial.Config([
      sl.AddTimingSignal.Config(),
      # Project to model_dimension for Repeat.
      sl.Dense.Config(model_dimension, use_bias=False, activation=jax.nn.relu, name='input_projection'),
      sl.Repeat.Config(
        sl.Serial.Config([
          SelfAttentionNetwork(model_dimension, dropout_rate, num_heads, units_per_head, max_past_horizon, max_future_horizon),
          FeedForwardNetwork(model_dimension, ffn_dim, dropout_rate),
        ], name='transformer_block'),
        num_repeats=num_layers,
        name='transformer_blocks'),
      sl.RMSNormalization.Config(name='output_rms_norm'),
    ], name='transformer')


layer = Transformer(
    model_dimension=1024,
    num_layers=2,
    num_heads=16,
    units_per_head=64,
    ffn_dim=4096,
    dropout_rate=0.1,
    # Unmasked self attention.
    max_past_horizon=-1,
    max_future_horizon=-1
).make()
key = jax.random.PRNGKey(42)
x = random_sequence(2, 32, 5)
layer_vars = layer.init(key, x, training=False)
layer_vars = unbox(layer_vars)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)

In [None]:
#@title Example: Causal Conformer

def MultiHeadedSelfAttention(
    hidden_size: int,
    num_heads: int,
    max_horizon: int,
    dropout_rate: float,
    name: str,
):
  """Multi-headed self attention module."""
  return sl.Residual.Config([
    sl.LayerNormalization.Config(),
    sl.LocalDotProductSelfAttention.Config(
      num_heads=num_heads,
      block_size=max_horizon,
      units_per_head=hidden_size // num_heads,
      max_past_horizon=max_horizon,
      max_future_horizon=0,
      attention_probabilities_dropout_rate=dropout_rate,
    ),
    sl.DenseShaped.Config([hidden_size]),
    sl.Dropout.Config(dropout_rate),
  ], name=name)


def FeedForwardModule(hidden_size: int, dropout_rate: float, name: str):
  return sl.Residual.Config([
    sl.LayerNormalization.Config(),
    sl.Dense.Config(4 * hidden_size, activation=jax.nn.swish),
    sl.Dropout.Config(dropout_rate),
    sl.Dense.Config(hidden_size),
    sl.Dropout.Config(dropout_rate),
    sl.Scale.Config(0.5)
  ], name=name)


def ConvolutionModule(hidden_size: int, dropout_rate: float, name: str):
  return sl.Residual.Config([
    sl.LayerNormalization.Config(),
    sl.Dense.Config(2 * hidden_size),
    sl.GatedLinearUnit.Config(),
    sl.DepthwiseConv1D.Config(kernel_size=32, padding='causal'),
    sl.BatchNormalization.Config(),
    sl.Swish.Config(),
    sl.Dense.Config(hidden_size),
    sl.Dropout.Config(dropout_rate),
  ], name=name)


def ConformerBlock(
    hidden_size: int,
    dropout_rate: float,
    max_horizon: int,
    name: str,
    num_heads: int = 8,
):
  return sl.Serial.Config([
    FeedForwardModule(hidden_size, dropout_rate, name='feedforward_start'),
    MultiHeadedSelfAttention(hidden_size, num_heads, max_horizon, dropout_rate, name='mhsa'),
    ConvolutionModule(hidden_size, dropout_rate, name='lconv'),
    FeedForwardModule(hidden_size, dropout_rate, name='feedforward_end'),
    sl.LayerNormalization.Config(),
  ], name=name)


def ConvolutionSubsampling(hidden_size: int):
  return sl.Serial.Config([
    sl.ExpandDims.Config(-1),
    # "Convolutional subsampling". Reduce rate by 4x.
    sl.Conv2D.Config(
        filters=hidden_size,
        kernel_size=3,
        strides=2,
        time_padding='causal',
    ),
    sl.GroupNormalization.Config(num_groups=1, cumulative=True),
    sl.Relu.Config(),
    sl.Conv2D.Config(
        filters=hidden_size,
        kernel_size=3,
        strides=2,
        time_padding='causal',
    ),
    sl.GroupNormalization.Config(num_groups=1, cumulative=True),
    sl.Relu.Config(),
  ], name='convolutional_subsampling')


def ConformerEncoder(
    hidden_size: int,
    num_blocks: int,
    max_horizon: int,
    dropout_rate: float = 0.1,
    name: str | None = None,
):
  return sl.Serial.Config([
    ConvolutionSubsampling(hidden_size),
    sl.DenseShaped.Config([hidden_size]),
    sl.AddTimingSignal.Config(),
    sl.Dropout.Config(dropout_rate),
    sl.Repeat.Config(ConformerBlock(
        hidden_size,
        dropout_rate,
        max_horizon,
        name='conformer_block'
    ), num_repeats=num_blocks),
  ], name=name or 'conformer')

layer = ConformerEncoder(1024, num_blocks=8, max_horizon=128).make()

x = random_sequence(2, 32, 128)
layer_vars = layer.init(key, x, training=False)
layer_vars = unbox(layer_vars)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)

y = layer.layer(x, training=False)

state = layer.get_initial_state(2, x.channel_spec, training=False)
ys = []
for i in range(8):
  yi, state = layer.step(x[:, i * 4 : (i+1) * 4], state, training=False)
  ys.append(yi)
y_step = sl.Sequence.concatenate_sequences(ys)

np.testing.assert_allclose(y.values, y_step.values, atol=1e-4, rtol=1e-4)


# Combinators

Combinators are `SequenceLayer`s that allow you to compose or combine other layers using common patterns (e.g. serial and parallel computation). These layers enable you to build complex models without having to resort to creating custom layers.

In [None]:
#@title Combinators: Serial. {vertical-output: true}

#@markdown A serial combinator applies the layers provided to it sequentially.

# A simple Conv-BN-Relu block. Layers are applied in order.
layer = sl.Serial.Config([
  sl.Conv2D.Config(8, kernel_size=[5, 3], strides=[2, 1], time_padding='causal', spatial_padding='same', name='conv'),
  sl.BatchNormalization.Config(name='bn'),
  sl.Relu.Config(),
], name='conv_bn_relu').make()

x = random_sequence(2, 32, 5, 1)
y, layer_vars = layer.init_with_output(key, x, training=False)
layer_vars = unbox(layer_vars)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)
print(f'Input: {x.shape}')
print(f'Output: {y.shape}')


In [None]:
#@title Combinators: Residual. {vertical-output: true}

#@markdown A residual combinator applies the layers provided to it sequentially, then adds the input back to the output.

# A simple residual Conv-BN-Relu block.
layer = sl.Residual.Config([
  sl.Conv2D.Config(8, kernel_size=[5, 3], strides=1, time_padding='causal', spatial_padding='same', name='conv'),
  sl.BatchNormalization.Config(name='bn'),
  sl.Relu.Config(),
], name='conv_bn_relu').make()

x = random_sequence(2, 32, 5, 8)
y, layer_vars = layer.init_with_output(key, x, training=False)
layer_vars = unbox(layer_vars)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)
print(f'Input: {x.shape}')
print(f'Output: {y.shape}')

In [None]:
#@title Combinators: Bidirectional. {vertical-output: true}

#@markdown The bidirectional combinator processes the input sequence forward with the forward network and
#@markdown in reverse with the backward network, then combines them according to a combination function.

layer = sl.Bidirectional.Config(
    forward=sl.Serial.Config([
        sl.Conv2D.Config(
            8,
            kernel_size=5,
            strides=2,
            time_padding='reverse_causal',
            spatial_padding='same',
        ),
        sl.BatchNormalization.Config(),
        sl.Relu.Config(),
    ]),
    backward=sl.Serial.Config([
        sl.Conv2D.Config(
            8,
            kernel_size=5,
            strides=2,
            time_padding='reverse_causal',
            spatial_padding='same',
        ),
        sl.BatchNormalization.Config(),
        sl.Relu.Config(),
    ]),
    combination=sl.CombinationMode.STACK,
    name='bidirectional_conv2d',
).make()

batch_size, time, height, channels = 2, 32, 5, 1
x = random_sequence(2, 32, 5, 1)
y, layer_vars = layer.init_with_output(key, x, training=False)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)
print(f'Input: {x.shape}')
print(f'Output: {y.shape}')

In [None]:
#@title Combinators: Repeat. {vertical-output: true}

#@markdown The repeat combinator repeats the provided network specification multiple times with different variables on each repeat.
#@markdown This allows compile time savings since the loop body only has to be compiled once.
#@markdown Can be combined with CheckpointGradient to save memory in addition to compile time.

layer = sl.Repeat.Config(
    sl.Serial.Config([
        sl.Conv2D.Config(8, kernel_size=5, strides=1, time_padding='reverse_causal', spatial_padding='same'),
        sl.BatchNormalization.Config(),
        sl.Relu.Config(),
    ]),
    num_repeats=3).make()

x = random_sequence(2, 32, 5, 8)
y, layer_vars = layer.init_with_output(key, x, training=False)
layer_vars = unbox(layer_vars)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)
print(f'Input: {x.shape}')
print(f'Output: {y.shape}')


In [None]:
#@title Combinators: CheckpointGradient. {vertical-output: true}

#@markdown The CheckpointGradient combinator wraps the provided body in a gradient checkpoint
#@markdown scope so the intermediate tensors are discarded according to the provided policy,
#@markdown and recomputed in the backward pass.

layer = sl.CheckpointGradient.Config(
    sl.Serial.Config([
        sl.Conv2D.Config(8, kernel_size=5, strides=1, time_padding='reverse_causal', spatial_padding='same'),
        sl.BatchNormalization.Config(),
        sl.Relu.Config(),
    ])).make()

x = random_sequence(2, 32, 5, 8)
y, layer_vars = layer.init_with_output(key, x, training=False)
layer_vars = unbox(layer_vars)
layer = layer.bind(layer_vars)
sl_utils.pprint_tree_shapes_types(layer_vars)
print(f'Input: {x.shape}')
print(f'Output: {y.shape}')

# What's a SequenceLayer useful for?

They are a useful building block anywhere in your system that you have a sequence-in / sequence-out block where it doesn't matter what the implementation is. Additionally, when combined with a global pooling operation they can be useful as a sequence-to-vector primitive.

SequenceLayers have been used as components of:
* Diffusion models
* Autoregressive models
* Flows
* GANs
* LLMs (Gemma 3n, DolphinGemma)


# Example: Model-independent Autoregressive Sampling

In [None]:
import abc
import dataclasses

import flax.linen as nn

In [None]:
#@title Distribution helper. Not part of SequenceLayers, but a useful abstraction over distributions.

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

class DistributionLayer(nn.Module, metaclass=abc.ABCMeta):

  @property
  @abc.abstractmethod
  def event_spec(self) -> jax.ShapeDtypeStruct:
    pass

  @abc.abstractmethod
  def get_distribution(hidden: sl.Sequence, training: bool) -> tfd.Distribution:
    pass

  def __call__(self, x: sl.Sequence) -> tfd.Distribution:
    """For Flax compatibility."""
    return self.get_distribution(x, training=False)


class DistributionLayerConfig(metaclass=abc.ABCMeta):

  @abc.abstractmethod
  def make(self) -> DistributionLayer:
    pass


class Categorical(DistributionLayer):

  @dataclasses.dataclass(frozen=True)
  class Config(DistributionLayerConfig):
    num_classes: int

    def make(self) -> 'Categorical':
      return Categorical(self)

  config: Config

  @property
  def event_spec(self) -> jax.ShapeDtypeStruct:
    return jax.ShapeDtypeStruct([], jnp.int32)

  @nn.compact
  def get_distribution(self, hidden: sl.Sequence, training: bool) -> tfd.Distribution:
    l = sl.Dense.Config(self.config.num_classes, use_bias=False, name='to_logits').make()
    logits = l.layer(hidden, training=training)
    return tfd.Categorical(logits=logits.values)

key = jax.random.PRNGKey(42)
x = sl.Sequence(jnp.ones((2, 3), jnp.int32), jnp.ones((2, 3), jnp.bool_))
h = sl.Sequence(jnp.ones((2, 3, 5)), jnp.ones((2, 3), jnp.bool_))
l = Categorical.Config(num_classes=10).make()
l = l.bind(l.init(key, h))
sl_utils.pprint_tree_shapes_types(unbox(l.variables))
dist = l.get_distribution(h, training=False)

assert dist.event_shape == []
assert dist.batch_shape == [2, 3]
assert dist.dtype == jnp.int32

log_probs = x.apply_values(dist.log_prob).mask_invalid()
print(log_probs)
samples = sl.Sequence(dist.sample(sample_shape=(), seed=key), h.mask).mask_invalid()
print(samples)

In [None]:
# @title Using a SequenceLayer as an autoregressive step function {vertical-output: true}
# @markdown `SequenceLayer`s are handy for designing the step function of an autoregressive model, since they can be efficiently executed in a parallel layer-wise fashion ("teacher forcing" for likelihood evaluation) in training and executed step-by-step at sampling time.
# @markdown
# @markdown The implementation of the step function (e.g. Transformer vs. RNN vs. convolution) **does not matter at all** for the autoregressive math, and the implementation should not be coupled to those details.


class AutoregressiveDecoder(nn.Module):
  """An autoregressive model. Note, *not* a SequenceLayer."""

  @dataclasses.dataclass(frozen=True)
  class Config:
    # The step function to map from x_{t-1} to h_t.
    body: sl.SequenceLayerConfig
    # The conditional distribution expressing p(x_t | h_t).
    distribution: DistributionLayerConfig
    # An optional name for the decoder.
    name: str | None = None

    def make(self) -> 'AutoregressiveDecoder':
      return AutoregressiveDecoder(self, name=self.name)

  config: Config

  def setup(self) -> None:
    self.body = self.config.body.make()
    self.distribution = self.config.distribution.make()

  def log_prob(self, data: sl.Sequence, training: bool) -> sl.Sequence:
    """Compute the log likelihood of the observed data."""
    # Shift input by one and slice one off the end.
    # You could use a custom or learned SOS value here.
    teacher_forcing_inputs = data.pad_time(1, 0, valid=True)[:, :-1]
    # Compute log likelihood in parallel.
    hidden = self.body.layer(teacher_forcing_inputs, training=training)
    distribution = self.distribution.get_distribution(hidden, training=training)
    return data.apply_values(lambda v: distribution.log_prob(v)).mask_invalid()

  def sample(
      self, batch_size: int, num_steps: int, training: bool, seed: jax.Array
  ) -> sl.Sequence:
    """Decode step-by-step for num_steps. Uses a static unroll for clarity."""
    event_spec = self.distribution.event_spec
    # The first input is all zeros to match training above.
    x = sl.Sequence.from_values(
        jnp.zeros((batch_size, 1) + event_spec.shape, dtype=event_spec.dtype)
    )

    # Get initial state for the step function.
    state = self.body.get_initial_state(
        batch_size, x.channel_spec, training=training
    )

    # Decode for a fixed number of steps.
    xs = []
    for _ in range(num_steps):
      hidden, state = self.body.step(x, state, training=training)
      distribution = self.distribution.get_distribution(
          hidden, training=training
      )
      x = sl.Sequence(
          distribution.sample(seed=seed), hidden.mask
      ).mask_invalid()
      xs.append(x)

    # Concatenate the samples over time.
    return sl.Sequence.concatenate_sequences(xs)

In [None]:
#@title Define an autoregressive categorical decoder with a transformer as the step function: {vertical-output: true}

model_dimension = 1024
num_layers = 12
num_heads = 16
units_per_head = 64
ffn_dim = 4 * model_dimension
vocab_size = 100
dropout_rate = 0.1
max_past_horizon = 128

decoder_config = AutoregressiveDecoder.Config(
    body=sl.Serial.Config([
      sl.Embedding.Config(num_embeddings=vocab_size, dimension=model_dimension, name='embedding'),
      sl.AddTimingSignal.Config(),
      sl.Repeat.Config(
        sl.Serial.Config([
          # Residual self-attention module:
          sl.Residual.Config([
            sl.RMSNormalization.Config(name='rms_norm'),
            sl.DotProductSelfAttention.Config(
              units_per_head=units_per_head,
              num_heads=num_heads,
              max_past_horizon=max_past_horizon,
              max_future_horizon=0,
              use_bias=False,
              # Use RoPE for the queries and keys.
              query_network=sl.ApplyRotaryPositionalEncoding.Config(max_wavelength=10000),
              key_network=sl.ApplyRotaryPositionalEncoding.Config(max_wavelength=10000),
              attention_probabilities_dropout_rate=dropout_rate,
              broadcast_dropout_across_queries=True,
              name='attention'),
            sl.DenseShaped.Config([model_dimension], use_bias=False, name='output_projection'),
            sl.Dropout.Config(dropout_rate)
          ], name='self_attention'),
          # Residual feed-forward module:
          sl.Residual.Config([
            sl.RMSNormalization.Config(name='rms_norm'),
            sl.Dense.Config(ffn_dim * 2, use_bias=False, name='dense1'),
            sl.GatedUnit.Config(jax.nn.gelu, None),
            sl.Dropout.Config(dropout_rate),
            sl.Dense.Config(model_dimension, use_bias=False, name='dense2'),
            sl.Dropout.Config(dropout_rate),
          ], name='ffn'),
        ], name='transformer_block'),
        num_repeats=num_layers, name='transformer_blocks'),
      sl.RMSNormalization.Config(name='output_rms_norm'),
    ], name='transformer'),
    distribution=Categorical.Config(num_classes=vocab_size)
)

decoder = decoder_config.make()

key = jax.random.PRNGKey(1234)
batch_size, time = 1, 512

lengths = jax.random.randint(
    key, [batch_size], minval=time // 2, maxval=time + 1
)
mask = np.arange(time)[jnp.newaxis, :] < lengths[:, jnp.newaxis]
data = sl.Sequence(
    jax.random.randint(key, [batch_size, time], minval=0, maxval=vocab_size),
    mask
)

decoder_vars = decoder.init(key, data, training=False, method='log_prob')
sl_utils.pprint_tree_shapes_types(decoder_vars)
decoder = decoder.bind(decoder_vars)

In [None]:
#@title Compute logprob of data.

log_prob = decoder.log_prob(data, training=False)
print(log_prob)

In [None]:
#@title Sample a batch.
samples = decoder.sample(batch_size=8, num_steps=5, training=False, seed=jax.random.PRNGKey(42))
print(samples)