Copyright 2025 Google LLC.

SPDX-License-Identifier: Apache-2.0

In [None]:
# @title Speculative Cascades between Gemma models
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This colab is based on the <a href="https://gemma-llm.readthedocs.io/en/latest/colab_sampling.html">Sampling example</a> tutorial provided with Gemma and forks of the <a href="https://github.com/google-deepmind/gemma/blob/2925236fb0d1ff07e4d8abc96f7ff2fe3d9b1ee3/gemma/sampler.py#L81">Sampler</a> class in their open-sourced
<a href="https://github.com/google-deepmind/gemma">codebase</a>. The purpose of this colab is to illustrate cost-quality trade-offs using speculative cascades. We therefore provides a simple implementation, where the drafter generates only one draft token at a time, and the verifier is run for one step to either accept the token, or reject and replace it. In practice, one usually runs the drafter for multiple steps, and runs the verifier in parallel scoring mode to verify the draft tokens.

**Acknowledgement:** We thank Ananda Theertha Suresh for invaluable help in writing this colab.



In [None]:
!pip install -q gemma

In [None]:
from collections.abc import Callable, Sequence
import dataclasses

import os
import numpy as np
import chex
import jax
import jax.numpy as jnp

from gemma import gm
from gemma import modules
from gemma.deprecated import transformer as transformer_lib
from gemma.deprecated import params as params_lib

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

# Custom speculative sampler

## Generic function to sample next token

In [None]:
MIN_PROBS = 1e-10


def sample_next_token(
    logits_small: jnp.ndarray,
    logits_large: jnp.ndarray,
    acceptance_prob_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray,
                                  jnp.ndarray, jnp.ndarray], jnp.ndarray],
    residual_distribution_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray,
                                        jnp.ndarray, jnp.ndarray], jnp.ndarray],
    rng: jnp.ndarray,
    temperature: float = 1.0) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Generic function for sampling the next token from small and large model logits."""
  # Random generator keys.
  rng, rng_small, rng_acceptance, rng_residual = jax.random.split(rng, 4)

  # Normalize logits to avoid overflows.
  logits_small = jax.nn.log_softmax(logits_small)
  logits_large = jax.nn.log_softmax(logits_large)

  # Probs without temperature scaling.
  probs_small_unscaled = jax.nn.softmax(logits_small, axis=-1)
  probs_large_unscaled = jax.nn.softmax(logits_large, axis=-1)

  if temperature == 1.0:
    probs_small = probs_small_unscaled
    probs_large = probs_large_unscaled
  elif temperature > 0.0:
    probs_small = jax.nn.softmax(logits_small / temperature, axis=-1)
    probs_large = jax.nn.softmax(logits_large / temperature, axis=-1)
  else:
    # For temperature = 0, we compute a one-hot encoding for the argmax token.
    probs_small = jax.nn.one_hot(
        jnp.argmax(logits_small, axis=-1),
        logits_small.shape[-1],
        axis=-1)
    probs_large = jax.nn.one_hot(
        jnp.argmax(logits_large, axis=-1),
        logits_large.shape[-1],
        axis=-1)

  # Sample from small model.
  token_small = jax.random.categorical(
      rng_small, jnp.log(probs_small), axis=-1)  # [B, 1]

  # Should we accept the token?
  acceptance_prob = acceptance_prob_fn(
      probs_small,
      probs_large,
      probs_small_unscaled,
      probs_large_unscaled,
      token_small)   # [B, 1]
  is_token_accepted = jax.random.bernoulli(
      rng_acceptance, acceptance_prob)  # [B, 1]

  # In the event of a rejection, sample from a residual distribution.
  probs_residual = residual_distribution_fn(
      probs_small,
      probs_large,
      probs_small_unscaled,
      probs_large_unscaled,
      token_small)  # [B, 1]
  logits_residual = jnp.log(probs_residual)  # [B, 1]
  token_residual = jax.random.categorical(
      rng_residual, logits_residual, axis=-1)  # [B, 1]

  # Return the next token.
  next_token = jnp.where(is_token_accepted, token_small, token_residual)
  return next_token, rng, is_token_accepted

## Generic sampler class

In [None]:
def _compute_attention_masks(
    time_step: jax.Array, seq_len: int, input_mask: jax.Array
) -> jax.Array:
  """Computes causal attention mask."""
  bsz = input_mask.shape[0]
  batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32)
  causal_padding = jnp.greater(
      jnp.expand_dims(jnp.arange(seq_len), 0), batch_time_step
  )
  max_seq_len = min(input_mask.shape[-1], seq_len)
  input_mask = jax.lax.dynamic_slice(
      input_mask,
      (0, jnp.maximum(time_step - seq_len + 1, 0)),
      (bsz, max_seq_len),
  )
  input_mask = (
      jnp.zeros((bsz, seq_len), dtype=jnp.bool_)
      .at[:, :max_seq_len]
      .set(input_mask)
  )

  causal_padding = jnp.logical_or(causal_padding, input_mask)
  attention_mask = causal_padding[:, jnp.newaxis, :].astype(jnp.bool_)

  return ~attention_mask

In [None]:
@chex.dataclass
class _SamplingState:
  """Internal sampling state."""

  # Decoding step.
  decoding_step: jnp.int32

  # Number of tokens in the prompt.
  num_input_tokens: jnp.ndarray  # [B]

  # Fixed-size buffer for accumulating the output tokens.
  token_buffer: jnp.ndarray  # [B, L]

  # Position indices, based on ignoring pad tokens.
  positions: jnp.ndarray  # [B, L]

  # Model state for conditioning the model on autoregressively.
  small_cache: dict[str, modules.LayerCache]
  large_cache: dict[str, modules.LayerCache]

  # Is decoding done on the given sequence?
  done: jnp.ndarray  # [B]

  # Total sampling steps (including the prompt).
  total_sampling_steps: int

  # rng key for sampling
  rng: jnp.ndarray

  # Booleans indicating if each token was accepted.
  tokens_accepted: jnp.ndarray

  # Fixed-size buffer for accumulating the output logits.
  logits_buffer: jnp.ndarray | None = None  # [B, L, V]

  # List of tokens that are forbidden to be generated.
  forbidden_token_ids: Sequence[int] | None = None

In [None]:
@dataclasses.dataclass
class SamplerOutput:

  # Decoded samples from the model.
  text: list[str]

  # Per-step logits used during sampling.
  logits: list[list[float]]

  # Tokens corresponding to the generated samples.
  tokens: list[list[int]]

  # Total tokens accepted.
  tokens_accepted: list[int]

In [None]:
class Sampler:
  """Sampler for gemma transformer."""

  def __init__(
      self,
      small_transformer: transformer_lib.Transformer,
      large_transformer: transformer_lib.Transformer,
      tokenizer: gm.text.Tokenizer,
      small_params: params_lib.Params,
      large_params: params_lib.Params,
      acceptance_prob_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
      residual_distribution_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
      temperature: float = 1.0,
      cache_length: int = 1024,
  ):
    """Initializes a sampler for a Gemma model.

    Args:
      small_transformer: an instance of the Gemma transformer.
      large_transformer: an instance of the Gemma transformer.
      tokenizer: tokenizer of the given model.
      small_params: weights of the small model.
      large_params: weights of the large model.
      acceptance_prob_fn: Acceptance criterion function.
      residual_distribution_fn: Residual distribution function.
      temperature: Temperature for sampling.
      cache_length: Length of the cache.
    """
    self.small_transformer = small_transformer
    self.large_transformer = large_transformer
    self.tokenizer = tokenizer
    self.small_params = small_params
    self.large_params = large_params
    self._compiled_sample_fn = jax.jit(self._sample_fn)
    self.acceptance_prob_fn = acceptance_prob_fn
    self.residual_distribution_fn = residual_distribution_fn
    self.temperature = temperature
    self.cache_length = cache_length

  @property
  def dtype(self) -> jnp.dtype:
    # assumes, dtype of both small and large models are same.
    return jax.tree_util.tree_leaves(self.small_params)[0].dtype

  def _sample_step(
      self, small_params, large_params, sampler_state: _SamplingState
  ) -> _SamplingState:
    """Performs a single sampling step."""
    batch_size = sampler_state.token_buffer.shape[0]
    decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32)
    last_token = sampler_state.token_buffer[:, decoding_step]
    input_mask = sampler_state.token_buffer == self.tokenizer.special_tokens.PAD
    # assumes that the cache length is same for both small and large models.
    attention_mask = _compute_attention_masks(
        decoding_step, self.cache_length, input_mask
    )
    step_positions = jnp.expand_dims(
        sampler_state.positions[:, decoding_step], -1
    )
    last_token = last_token.reshape((batch_size, 1))

    small_out = self.small_transformer.apply(
        {'params': small_params},
        last_token,
        positions=step_positions,
        cache=sampler_state.small_cache,
        attention_mask=attention_mask,
    )
    small_cache = small_out.cache
    small_logits = small_out.logits

    large_out = self.large_transformer.apply(
        {'params': large_params},
        last_token,
        positions=step_positions,
        cache=sampler_state.large_cache,
        attention_mask=attention_mask,
    )
    large_cache = large_out.cache
    large_logits = large_out.logits

    if sampler_state.forbidden_token_ids:
      small_logits = small_logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)
      large_logits = large_logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)

    next_token_candidate, rng, is_token_accepted = sample_next_token(
        small_logits,
        large_logits,
        self.acceptance_prob_fn,
        self.residual_distribution_fn,
        sampler_state.rng,
        self.temperature)
    next_token_candidate = next_token_candidate[:, 0]  # [B,]

    next_token_candidate = jnp.where(
        decoding_step < sampler_state.num_input_tokens - 1,
        sampler_state.token_buffer[:, decoding_step + 1],
        next_token_candidate,
    )

    token_buffer = sampler_state.token_buffer.at[:, decoding_step + 1].set(
        next_token_candidate
    )

    if sampler_state.logits_buffer is not None:
      next_logits = jnp.squeeze(large_logits, 1)
      logits_buffer = sampler_state.logits_buffer.at[:, decoding_step + 1].set(
          next_logits
      )
    else:
      logits_buffer = sampler_state.logits_buffer

    done = sampler_state.done | jnp.equal(
        token_buffer[:, decoding_step + 1], self.tokenizer.special_tokens.EOS
    )

    tokens_accepted = sampler_state.tokens_accepted.at[:, decoding_step + 1].set(
        jnp.squeeze(is_token_accepted, axis=-1)
    )
    state = _SamplingState(
        decoding_step=sampler_state.decoding_step + 1,
        num_input_tokens=sampler_state.num_input_tokens,
        token_buffer=token_buffer,
        positions=sampler_state.positions,
        logits_buffer=logits_buffer,
        small_cache=small_cache,
        large_cache=large_cache,
        done=done,
        total_sampling_steps=sampler_state.total_sampling_steps,
        tokens_accepted=tokens_accepted,
        forbidden_token_ids=sampler_state.forbidden_token_ids,
        rng=rng,
    )
    return state

  def init_small_cache(self, bsz) -> dict[str, modules.LayerCache]:
    """Initializes the attention cache for each layer."""
    return self.small_transformer.config.init_cache(
        batch_size=bsz,
        dtype=self.dtype,
        cache_length=self.cache_length,
    )

  def init_large_cache(self, bsz) -> dict[str, modules.LayerCache]:
    """Initializes the attention cache for each layer."""
    return self.large_transformer.config.init_cache(
        batch_size=bsz,
        dtype=self.dtype,
        cache_length=self.cache_length,
    )

  def init_sample_state(
      self,
      all_input_ids: list[jax.Array],
      total_sampling_steps: int,
      include_logits: bool = False,
      forbidden_token_ids: Sequence[int] | None = None,
      rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
  ) -> _SamplingState:
    """Initializes the sampling state given input prompts."""
    bsz = len(all_input_ids)
    num_input_tokens = [len(input_ids) for input_ids in all_input_ids]
    buffer_size = total_sampling_steps + 1

    token_buffer = jnp.full(
        (
            bsz,
            buffer_size,
        ),
        self.tokenizer.special_tokens.PAD,
        dtype=jnp.int32,
    )
    input_mask = jnp.ones_like(token_buffer, dtype=jnp.bool_)
    for i, (input_ids, num_tokens) in enumerate(
        zip(all_input_ids, num_input_tokens)
    ):
      token_buffer = token_buffer.at[i, :num_tokens].set(input_ids)
      input_mask = input_mask.at[i, :num_tokens].set(
          input_ids != self.tokenizer.special_tokens.PAD
      )
    positions = transformer_lib.build_positions_from_mask(input_mask)

    done = jnp.zeros((bsz,), dtype=jnp.bool_)

    tokens_accepted = jnp.zeros_like(token_buffer, dtype=jnp.bool_)

    if include_logits:
      logits_buffer = jnp.zeros(
          (bsz, buffer_size, self.small_transformer.config.num_embed),
          dtype=jnp.float32,
      )
    else:
      logits_buffer = None

    return _SamplingState(
        decoding_step=0,
        num_input_tokens=jnp.array(num_input_tokens, dtype=jnp.int32),
        token_buffer=token_buffer,
        positions=positions,
        logits_buffer=logits_buffer,
        small_cache=self.init_small_cache(bsz),
        large_cache=self.init_large_cache(bsz),
        done=done,
        total_sampling_steps=total_sampling_steps,
        tokens_accepted=tokens_accepted,
        forbidden_token_ids=forbidden_token_ids,
        rng=rng,
    )

  def tokenize(self, input_string: str) -> jax.Array:
    """Tokenizes the input string."""
    input_ids = self.tokenizer.encode(input_string)
    input_ids = jnp.array(
        [self.tokenizer.special_tokens.BOS] + jnp.array(input_ids).tolist(), dtype=jnp.int32
    )
    return input_ids

  def mask_tokens_after_eos_ids(self, token_buffer):
    """Mask token IDs after the EOS token with the padding ID."""
    eos_id = self.tokenizer.special_tokens.EOS
    eos_exists = jnp.any(jnp.equal(token_buffer, eos_id), axis=-1)
    eos_indices = jnp.where(
        eos_exists,
        jnp.argmax(jnp.equal(token_buffer, eos_id), axis=-1),
        token_buffer.shape[-1],
    )
    mask = jnp.less_equal(
        jnp.arange(token_buffer.shape[-1]), eos_indices[:, None]
    )
    masked_token_buffer = token_buffer * mask + self.tokenizer.special_tokens.PAD * (1 - mask)

    return masked_token_buffer

  def _sample_fn(
      self,
      small_params: params_lib.Params,
      large_params: params_lib.Params,
      initial_sampling_state: _SamplingState,
  ) -> _SamplingState:
    """Internal sampling function (to be jitted)."""

    def sample_with_params(sampler_state: _SamplingState):
      return self._sample_step(small_params, large_params, sampler_state)

    def cond_fn(sampler_state: _SamplingState):
      return (
          sampler_state.decoding_step < sampler_state.total_sampling_steps
      ) & jnp.any(jnp.logical_not(sampler_state.done))

    return jax.lax.while_loop(
        cond_fn, sample_with_params, initial_sampling_state
    )

  def __call__(
      self,
      input_strings: Sequence[str],
      total_generation_steps: int,
      echo: bool = False,
      return_logits: bool = False,
      forbidden_tokens: Sequence[str] | None = None,
      seed: int | None = 0,
  ) -> SamplerOutput:
    """Samples a completion of the input string.

    Args:
      input_strings: input prompts to feed to the model for sampling.
      total_generation_steps: number of generation steps. will correspond to the
        longest prompt in the batch.
      echo: whether to return the prompt as part of the output sample.
      return_logits: whether to return per-step logits used during generation.
      forbidden_tokens: list of tokens that are forbidden to be generated. Each
        token must map to a single token id in the vocab.
      seed: random seed.

    Returns:
      sampler_output: A SamplerOutput object containing the generated samples.
    """
    forbidden_token_ids = None
    if forbidden_tokens is not None:
      forbidden_token_ids = []
      for token in forbidden_tokens:
        token_id = self.tokenizer.encode(token)
        if len(token_id) != 1:
          raise ValueError(
              'Forbidden tokens must map to single token ids in the vocab.'
          )
        forbidden_token_ids.extend(token_id)
      forbidden_token_ids = tuple(forbidden_token_ids)
    all_input_ids = [self.tokenize(x) for x in input_strings]
    max_input_length = max(len(input_ids) for input_ids in all_input_ids)
    total_sampling_steps = max_input_length + total_generation_steps
    rng = jax.random.PRNGKey(seed)
    initial_sampling_state = self.init_sample_state(
        all_input_ids,
        include_logits=return_logits,
        total_sampling_steps=total_sampling_steps,
        forbidden_token_ids=forbidden_token_ids,
        rng=rng,
    )

    sampling_state = self._compiled_sample_fn(
        self.small_params, self.large_params, initial_sampling_state
    )

    masked_token_buffer = self.mask_tokens_after_eos_ids(
        sampling_state.token_buffer
    )

    out_tokens = []
    out_logits = []
    out_accepted = []
    for i, (token_buffer, num_tokens) in enumerate(
        zip(
            masked_token_buffer,
            sampling_state.num_input_tokens,
        )
    ):
      start_idx = 0 if echo else num_tokens
      out_tokens.append(token_buffer[start_idx:total_sampling_steps].tolist())
      token_accepted = sampling_state.tokens_accepted[i]
      out_accepted.append(token_accepted[start_idx:total_sampling_steps].tolist())
      if return_logits:
        logits_buffer = sampling_state.logits_buffer[i]
        out_logits.append(
            logits_buffer[start_idx:total_sampling_steps].tolist()
        )

    decoded_outputs = [self.tokenizer.decode(tokens) for tokens in out_tokens]

    result = SamplerOutput(
        text=decoded_outputs,
        logits=out_logits,
        tokens=out_tokens,
        tokens_accepted=out_accepted,
    )
    return result

# Draft and verify functions


## Elementary draft and verify functions

In [None]:
def accept_all_prob_fn(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray) -> jnp.ndarray:
  del probs_small, probs_large, probs_small_unscaled, probs_large_unscaled
  return jnp.ones_like(token_small, dtype=jnp.float32)


def reject_all_prob_fn(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray) -> jnp.ndarray:
  del probs_small, probs_large, probs_small_unscaled, probs_large_unscaled
  return jnp.zeros_like(token_small, dtype=jnp.float32)


def small_distribution_fn(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray | None = None) -> jnp.ndarray:
  del probs_large, probs_small_unscaled, probs_large_unscaled, token_small
  return probs_small


def large_distribution_fn(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray | None = None) -> jnp.ndarray:
  del probs_small, probs_small_unscaled, probs_large_unscaled, token_small
  return probs_large

## Lossy speculative decoding: draft & verify functions

In [None]:
def speed_sampling_acceptance_prob_fn(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = 0.0) -> jnp.ndarray:
  """Acceptance function for lossy speculative sampling."""
  del probs_small_unscaled, probs_large_unscaled
  # Small model's probability on token_small.
  token_prob_small = jnp.take_along_axis(
      probs_small, jnp.expand_dims(token_small, axis=1), axis=-1)  # [B, 1, 1]
  token_prob_small = jnp.squeeze(token_prob_small, axis=-1)  # [B, 1]
  # Large model's probability on token_small.
  token_prob_large = jnp.take_along_axis(
      probs_large, jnp.expand_dims(token_small, axis=1), axis=-1)  # [B, 1, 1]
  token_prob_large = jnp.squeeze(token_prob_large, axis=-1)  # [B, 1]
  # Acceptance probability: min{1, p_large(v) / ((1 - lenience) * p_small(v))}.
  # See Leviathan et al., 2023, A.5, Page 12.
  denominator = jnp.maximum((1 - lenience) * token_prob_small, MIN_PROBS)
  return jnp.minimum(1, token_prob_large / denominator)


def speed_sampling_residual_distribution_fn(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray | None = None
    ) -> jnp.ndarray:
  """Residual distribution for lossy speculative sampling."""
  del probs_small_unscaled, probs_large_unscaled, token_small
  # Residual distribution is max{0, p_large(.) - p_small(.)}.
  return jnp.maximum(0.0, probs_large - probs_small)

## Speculative cascade: draft & verify functions with generic target distribution


In [None]:
def create_speculative_cascade_sampling_acceptance_prob_fn(
    target_distribution_fn, lenience=0.0
):
  """Return a function that computes acceptance criteria for a sampling speculative cascade with target_distribution_fn."""
  def speculative_cascade_sampling_acceptance_prob_fn(
      probs_small: jnp.ndarray,
      probs_large: jnp.ndarray,
      probs_small_unscaled: jnp.ndarray,
      probs_large_unscaled: jnp.ndarray,
      token_small: jnp.ndarray,
  ) -> jnp.ndarray:
    probs_target = target_distribution_fn(
        probs_small,
        probs_large,
        probs_small_unscaled,
        probs_large_unscaled,
        token_small,
        lenience
    )
    # Apply loss-less SPEED (lenience = 0) with the target distribution.
    return speed_sampling_acceptance_prob_fn(
        probs_small,
        probs_target,
        probs_small_unscaled,
        probs_large_unscaled,
        token_small,
        lenience=0
    )
  return speculative_cascade_sampling_acceptance_prob_fn


def create_speculative_cascade_sampling_residual_distribution_fn(
    target_distribution_fn, lenience=0.0
):
  """Return a function that computes residual distribution for a sampling speculative cascade with target_distribution_fn."""
  def speculative_cascade_residual_distribution_fn(
      probs_small: jnp.ndarray,
      probs_large: jnp.ndarray,
      probs_small_unscaled: jnp.ndarray,
      probs_large_unscaled: jnp.ndarray,
      token_small: jnp.ndarray,
  ) -> jnp.ndarray:
    probs_target = target_distribution_fn(
        probs_small,
        probs_large,
        probs_small_unscaled,
        probs_large_unscaled,
        token_small,
        lenience
    )
    # Apply loss-less SPEED with the target distribution.
    return speed_sampling_residual_distribution_fn(
        probs_small,
        probs_target,
        probs_small_unscaled,
        probs_large_unscaled,
        token_small
    )
  return speculative_cascade_residual_distribution_fn

## Speculative cascade: target distributions for different deferral rules

In [None]:
def target_distribution_chow(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = -1.0
) -> jnp.ndarray:
  """Target distribution for Chow deferral rule."""
  del probs_large_unscaled, token_small
  max_prob_small = jnp.max(probs_small_unscaled, axis=-1, keepdims=True)
  # Chow criterion: equation 2 in Narasimhan et al. (2025):
  #   max_v p_small(v) >= 1 - lenience.
  pick_small = jnp.greater_equal(max_prob_small, 1.0 - lenience)
  return pick_small * probs_small + (1 - pick_small) * probs_large


def target_distribution_diff(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = -1.0
) -> jnp.ndarray:
  """Target distribution for Diff deferral rule."""
  del token_small
  max_prob_small = jnp.max(probs_small_unscaled, axis=-1, keepdims=True)
  max_prob_large = jnp.max(probs_large_unscaled, axis=-1, keepdims=True)
  # Diff criterion: equation 5 in Narasimhan et al. (2025):
  #   max_v p_small(v) >= max_v p_large(v) - lenience.
  pick_small = jnp.greater_equal(max_prob_small, max_prob_large - lenience)
  return pick_small * probs_small + (1 - pick_small) * probs_large


def target_distribution_opt(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = -1.0) -> jnp.ndarray:
  """Target distribution for OPT deferral rule."""
  del token_small
  max_prob_small = jnp.max(probs_small_unscaled, axis=-1, keepdims=True)
  max_prob_large = jnp.max(probs_large_unscaled, axis=-1, keepdims=True)
  # OPT criterion: equation 10 in Narasimhan et al. (2025):
  #   max_v p_small(v) >= max_v p_large(v) - lenience * TVD(p_small, p_large),
  #      where TV-distance between p_small and p_large is defined as:
  #           \sum_v max{0, p_small(v) - p_large(v)}.
  tvd = jnp.sum(
      jnp.maximum(0.0, probs_small - probs_large), axis=-1, keepdims=True
  )
  pick_small = jnp.greater_equal(
      max_prob_small, max_prob_large - lenience * tvd)
  return pick_small * probs_small + (1 - pick_small) * probs_large


def target_distribution_token_v1(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = -1.0) -> jnp.ndarray:
  """Target distribution for Token-v1 deferral rule."""
  del token_small
  max_prob_large = jnp.max(
      probs_large_unscaled, axis=-1, keepdims=True)  # [B, 1, 1]
  # Token-v1 criterion: equation 13 in Narasimhan et al. (2025):
  #   p_small_unscaled(v) >= max_u p_large_unscaled(u) - lenience.
  tokens_accepted = jnp.greater_equal(
      probs_small_unscaled, max_prob_large - lenience)  # [B, 1, V]
  # p_res(v) = p_small(v) * 1(v accepted) +
  #               (1 - \sum_u p_small(u) * 1(u rejected)) * p_large(v).
  probs_small_accepted = probs_small * tokens_accepted
  probs_small_accepted_sum = 1.0 - jnp.sum(
      probs_small_accepted, axis=-1, keepdims=True)
  return probs_small_accepted + probs_small_accepted_sum * probs_large


def target_distribution_token_v2(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = -1.0) -> jnp.ndarray:
  """Target distribution for Token-v2 deferral rule."""
  del probs_small_unscaled, token_small
  token_prob_large = jnp.max(
      probs_large_unscaled, axis=-1, keepdims=True)  # [B, 1, 1]
  # Token-v2 criterion: equation 14 in Narasimhan et al. (2025):
  #   probs_large_unscaled(v) >= max_u probs_large_unscaled(u) - lenience.
  tokens_accepted = jnp.greater_equal(
      probs_large_unscaled, token_prob_large - lenience)  # [B, 1, V]
  # p_res(v) = p_small(v) * 1(v accepted) +
  #               (1 - \sum_u p_small(u) * 1(u rejected)) * p_large(v).
  probs_small_accepted = probs_small * tokens_accepted
  probs_small_accepted_sum = 1.0 - jnp.sum(
      probs_small_accepted, axis=-1, keepdims=True)
  return probs_small_accepted + probs_small_accepted_sum * probs_large


def target_distribution_token_v3(
    probs_small: jnp.ndarray,
    probs_large: jnp.ndarray,
    probs_small_unscaled: jnp.ndarray,
    probs_large_unscaled: jnp.ndarray,
    token_small: jnp.ndarray,
    lenience: float = 0.0) -> jnp.ndarray:
  """Target distribution for Token-v3 deferral rule."""
  del probs_small_unscaled, token_small
  token_prob_large = jnp.max(
      probs_large_unscaled, axis=-1, keepdims=True)  # [B, 1, 1]
  # Token-v3 criterion: equation 15 in Narasimhan et al. (2025):
  #   p_large_unscaled(v) >= (1 - lenience) max_u p_large_unscaled(u).
  tokens_accepted = jnp.greater_equal(
      probs_large_unscaled, token_prob_large * (1 - lenience))  # [B, 1, V]
  # p_res(v) = p_small(v) * 1(v accepted) +
  #               (1 - \sum_u p_small(u) * 1(u rejected)) * p_large(v).
  probs_small_accepted = probs_small * tokens_accepted
  probs_small_rejected_sum = 1.0 - jnp.sum(
      probs_small_accepted, axis=-1, keepdims=True)
  return probs_small_accepted + probs_small_rejected_sum * probs_large

## Get acceptance and residual functions for different methods

In [None]:
def get_acceptance_residual_fns(method: str, lenience: float = 0.0):
  """Returns (acceptance_fn, residual_fn) for the selected method.

  Args:
    method: can be 'drafter_only', 'verifier_only', 'speed',
      'cascade_chow', 'cascade_diff', 'cascade_opt', 'cascade_tokenV1',
      'cascade_tokenV2', or 'cascade_tokenV3'.
    lenience: Lenience parameter alpha for deferral rule.
  """
  if method == 'drafter_only':
    return accept_all_prob_fn, small_distribution_fn
  elif method == 'verifier_only':
    return reject_all_prob_fn, large_distribution_fn
  elif method == 'speed':
    speed_acceptance_prob_fn = (
        lambda x, y, u, v, w: speed_sampling_acceptance_prob_fn(
            x, y, u, v, w, lenience=lenience
        )
    )
    return speed_acceptance_prob_fn, speed_sampling_residual_distribution_fn
  elif method.startswith('cascade'):
    # Syntax: cascade_<deferral_rule>
    method_splits = method.split('_')
    if len(method_splits) != 2:
      raise ValueError(f'Invalid method syntax: {method}')
    deferral_rule = method_splits[1]
    if deferral_rule == 'chow':
      target_distribution_fn = target_distribution_chow
    elif deferral_rule == 'diff':
      target_distribution_fn = target_distribution_diff
    elif deferral_rule == 'opt':
      target_distribution_fn = target_distribution_opt
    elif deferral_rule == 'tokenV1':
      target_distribution_fn = target_distribution_token_v1
    elif deferral_rule == 'tokenV2':
      target_distribution_fn = target_distribution_token_v2
    elif deferral_rule == 'tokenV3':
      target_distribution_fn = target_distribution_token_v3
    else:
      raise ValueError(f'Unknown deferral rule: {deferral_rule}')
    spec_cascade_acceptance_prob_fn = (
        create_speculative_cascade_sampling_acceptance_prob_fn(
            target_distribution_fn, lenience
        )
    )
    spec_cascade_residual_distribution_fn = (
        create_speculative_cascade_sampling_residual_distribution_fn(
            target_distribution_fn, lenience
        )
    )
    return (
        spec_cascade_acceptance_prob_fn,
        spec_cascade_residual_distribution_fn,
    )
  else:
    raise ValueError(f'Unknown method: {method}')

# Load a Gemma 2B and 9B model

In [None]:
# Load the tokenizer
tokenizer = gm.text.Gemma2Tokenizer()

# Load the Transformer model
small_transformer = gm.nn.Gemma2_2B()
large_transformer = gm.nn.Gemma2_9B()

# Load the params.
small_params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA2_2B_IT)
large_params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA2_9B_IT)

# Test samplers with example prompt



In [None]:
# Build a sampler on top of your model and your tokenizer, and test it on an

def acceptance_rate(out_data, idx=0):
  all_tokens = out_data.tokens[idx]
  all_accepted = out_data.tokens_accepted[idx]
  eos_idx = np.where(np.array(all_tokens) == 1)[0]
  if len(eos_idx) == 0:
    # No EOS. So num_gen_steps not enough.
    return np.sum(all_accepted) / len(all_tokens)
  else:
    eos_idx = eos_idx[0]
    if eos_idx == 0:
      # EOS in first position.
      return 1.0
    return np.sum(all_accepted[:eos_idx]) / eos_idx


def test_sample(
    acceptance_prob_fn,
    residual_distribution_fn,
    prompts,
    temperature=1.0,
    total_generation_steps=200,
    verbose=True,
    seed=0):
  if not isinstance(prompts, list):
    prompts = [prompts]
  # Build sampler.
  sampler = Sampler(
    small_transformer=small_transformer,
    large_transformer=large_transformer,
    tokenizer=tokenizer,
    small_params=small_params,
    large_params=large_params,
    acceptance_prob_fn=acceptance_prob_fn,
    residual_distribution_fn=residual_distribution_fn,
    temperature=temperature
  )
  # Response for example prompt.
  out_data = sampler(
      input_strings=prompts,
      total_generation_steps=total_generation_steps,
      seed=seed)
  if verbose:
    print(f"Response:\n{out_data.text[0]}\n")
    print(f"Acceptance rate:\n{acceptance_rate(out_data)}")
  return out_data

Unless otherwise specified, we sample with temperature 1.0

In [None]:
prompt = "what's the purpose of life?"

You may pick one of the following methods:

- **drafter_only**: Always call the drafter
- **verifier_only**: Always call the verifier
- **speed**: Lossy SPEED, where lenience parameter can vary from 0 (standard SPEED) to 1 (only drafter)
- **cascade_chow:** Speculative cascade with Chow's rule, where lenience can vary from 0 (standard SPEED) to 1 (only drafter)
- **cascade_diff:** Speculative cascade with the Diff rule, where lenience can vary from -1 (standard SPEED) to 1 (only drafter)
- **cascade_opt:** Speculative cascade with the OPT rule, where lenience can vary from -Inf (standard SPEED) to Inf (only drafter)
- **cascade_tokenV1:** Speculative cascade with the TokenV1 rule, where lenience can vary from -1 (standard SPEED) to 1 (only drafter)
- **cascade_tokenV2:** Speculative cascade with the TokenV2 rule, where lenience can vary from -1 (standard SPEED) to 1 (only drafter)
- **cascade_tokenV3:** Speculative cascade with the TokenV3 rule, where lenience can vary from 0 (standard SPEED) to 1 (only drafter)

In [None]:
# Sample response for prompt.

# Pick one among the different speculative cascade deferral rules and the baselines.
method = 'cascade_tokenV3'  # @param ['drafter_only', 'verifier_only', 'speed', 'cascade_chow', 'cascade_diff', 'cascade_opt', 'cascade_tokenV1', 'cascade_tokenV2', 'cascade_tokenV3']

# Temperature for sampling.
temperature = 1.0  # @param {type:'number'}

# The lenience parameter is same as the trade-off parameter `alpha` in the paper.
lenience = 0.5  # @param {type:'number'}

# The number of output tokens is strictly bounded by this parameter.
total_generation_steps = 100  # @param {type:'integer'}


acceptance_fn, residual_fn = get_acceptance_residual_fns(
    method=method,
    lenience=lenience
)

out_data = test_sample(
    acceptance_fn,
    residual_fn,
    prompts=prompt,
    temperature=temperature,
    total_generation_steps=total_generation_steps
)