# Attention sink and Logit lens in Penzai

In the first part, this colab shows to use Penzai to visualize/understand
attention sink and its related phenemenon, such as massive activations and
value-state drains. It involves the literature,
[[1](https://arxiv.org/abs/2309.17453), [2](https://arxiv.org/abs/2410.10781),
[3](https://arxiv.org/abs/2402.17762), [4](https://arxiv.org/abs/2410.13835),
[5](https://arxiv.org/abs/2504.02732)].

In the second part, this colab demonstrates how to use logit lens to analyze
model predictions using Penzai. We refer to the blog
[interpreting GPT: the logit lens](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens).
Instead of using GPT, we use Gemma models.

NOTE: we run this colab on a TPU **v5e-1** runtime. Please see our notebook
`./notebooks/gemma3_multimodal_penzai.ipynb` on how to build a local runtime.

# First part: Attention sink and its related phenemenon

## Import packages

Firstly, we install `jax[tpu]`, `gemma_penzai` package and its dependencies.

In [None]:
# Clone the gemma_penzai package
!git clone https://github.com/google-deepmind/gemma_penzai.git

# Upgrade your pip in case
!pip install --upgrade pip

# Installs JAX with TPU support
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install the package in editable mode (-e)
# This installs dependencies defined in your pyproject.toml
print("Installing gemma_penzai and dependencies...")
%cd gemma_penzai
!pip install -e .

Import miscellaneous packages.

In [None]:
import gc
import os
from typing import Any
from gemma import gm
from IPython.display import clear_output
import kagglehub

Import JAX related packages.

In [None]:
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
import orbax.checkpoint

Import Penzai related packages.

In [None]:
import penzai
from penzai import pz
from penzai.models import transformer
from penzai.toolshed import token_visualization
import treescope

treescope.basic_interactive_setup(autovisualize_arrays=True)

## Loading Gemma2 pre-trained models

You can download the Gemma checkpoints using a Kaggle account and an API key. If
you don't have an API key already, you can:

1.  Visit https://www.kaggle.com/ and create an account if needed.

2.  Go to your account settings, then the 'API' section.

3.  Click 'Create new token' to download your key.

Next, input your "KAGGLE_USERNAME" and "KAGGLE_KEY" below.

In [None]:
KAGGLE_USERNAME = "<KAGGLE_USERNAME>"
KAGGLE_KEY = "<KAGGLE_KEY>"
try:
  kagglehub.config.set_kaggle_credentials(KAGGLE_USERNAME, KAGGLE_KEY)
except ImportError:
  kagglehub.login()

Firstly, we load metadata and checkpoint of Gemma2 2B model.

In [None]:
weights_dir = kagglehub.model_download("google/gemma-2/flax/gemma2-2b")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma2-2b")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

Prepare the device and sharding. Here the sharding strategy splits the model
parameters into different TPUs according to the last dimension.

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(sharding_devices, ("data",))

In [None]:
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=NamedSharding(
            mesh, PartitionSpec(*(None,) * (len(m.shape) - 1), "data")
        ),
    ),
    metadata.item_metadata,  # change back to metadata if any running error
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

Now we prepare the Gemma2 language model definition and bind it with the
parameters.

In [None]:
model = transformer.variants.gemma.gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=False,
)

Directly visualizing the model definition with parameters will take a long time.
Therefore, we firstly use `unbind_params` function to extract the model
architecture. Then we only visualize the model architecture without parameters.

In [None]:
model_unbound, _ = pz.unbind_params(model)
model_unbound

Free some memory.

In [None]:
del flat_params
gc.collect()

## Prepare prompt and tokenizer

In [None]:
tokenizer = gm.text.Gemma2Tokenizer()
tokenizer

Check the vocubulary size

In [None]:
tokenizer.vocab_size

Check the special token set

In [None]:
tokenizer.special_tokens

Prepare a simple prompt.

In [None]:
prompt = "Would you be able to travel through time using a wormhole?"

Use tokenizer to encode the prompt, and then transform it into a named JAX
array. Please note that we need to enable `add_bos=True` to ensure Gemma2 models
work normally.

In [None]:
tokens = tokenizer.encode(prompt, add_bos=True)
tokens = jnp.asarray(tokens)[None, :]
tokens = pz.nx.wrap(tokens).tag("batch", "seq")
tokens

We can also use `token_visualization` in Penzai to visualize the input token
ids. Please note that `show_token_array` needs an argument of `SentencePiece`
object. To achieve this, we can pass `tokenizer._sp`.

In [None]:
token_visualization.show_token_array(tokens, tokenizer._sp)  # pylint: disable=protected-access

## Massive activations for the first token

We first use Penzai to visualize the
[massive activations](https://arxiv.org/abs/2402.17762), which refers to that
the first token has activation outliers compared to other tokens.

In Penzai, it is easy to insert/delete/change model layers and manipulate
activations. The general tutorial is in
[Penzai Tutorials](https://penzai.readthedocs.io/en/stable/index.html). Here we
only show how to display or save intermediate activations:

In [None]:
# Define a layer to visualize the middle activations
@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(
    pz.nn.Layer
):  # <- pz.nn.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value, **unused_side_inputs) -> Any:
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value


# Define a layer to extract the middle activations
@pz.pytree_dataclass
class SaveIntermediate(pz.nn.Layer):
  saved: pz.StateVariable[Any | None]

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    self.saved.value = value
    return value

In Penzai, model modifications are generally performed by using `pz.select` to
make a modified copy of the original model (but sharing the same parameters).
This involves “selecting” the part of the model you want to modify, then
applying a modification, similar to the `.at[...].set(...)` syntax for modifying
JAX arrays.

Here we insert an object of `SaveIntermediate` to the model after each
transformer block. Remember that when we build a new object of
`SaveIntermediate`, a new `StateVariable` needs to be passed.

In [None]:
block_num = 26
model_patched = None
all_activations = [pz.StateVariable(value=None) for _ in range(block_num)]
for block_index in range(block_num):
  if model_patched is None:
    model_patched = (
        pz.select(model)
        .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(all_activations[block_index]))
    )
  else:
    model_patched = (
        pz.select(model_patched)
        .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(all_activations[block_index]))
    )

Then we conduct model forward under patch.

In [None]:
logits = model_patched(tokens)

Now the residual stream after each transformer block has been saved into the
list of `all_activations`. We then visualize the $\ell_2$-norm for activations
and show that the first token has massive activations.

In [None]:
all_norm = []
for block_index in range(block_num):
  activations = all_activations[block_index].value
  norm = pz.nx.nmap(jnp.linalg.norm)(activations.untag("embedding"), ord=2)
  # norm = pz.nx.nmap(jnp.expand_dims)(norm, axis=0).tag("block")
  all_norm.append(norm)
all_norm = pz.nx.stack(all_norm, axis_name="block")
all_norm

Please put the cursor on each token across different blocks, we can check the
$\ell_2$-norm of each activation. It is clear that the first token is an outlier
compared to the other tokens.

## Cosine Similarity among different transformer blocks

We use the same saved activations to analyze how they changed w.r.t. blocks.

In [None]:
all_activations_stack = []
for block_index in range(block_num):
  activations = all_activations[block_index].value
  all_activations_stack.append(activations)
all_activations_stack = pz.nx.stack(all_activations_stack, axis_name="block")
all_activations_stack.named_shape

Here we measure the cosine-similarity between two consecutive blocks.

In [None]:
all_activations_stack_prev = all_activations_stack.untag("block")[
    : block_num - 1
].tag("block")
all_activations_stack_next = all_activations_stack.untag("block")[1:].tag(
    "block"
)

all_activations_stack_prev_normalized = pz.nx.nmap(
    lambda x: x / jnp.linalg.norm(x, ord=2)
)(all_activations_stack_prev.untag("embedding"))
all_activations_stack_next_normalized = pz.nx.nmap(
    lambda x: x / jnp.linalg.norm(x, ord=2)
)(all_activations_stack_next.untag("embedding"))

cosine_sim = pz.nx.nmap(jnp.dot)(
    all_activations_stack_prev_normalized, all_activations_stack_next_normalized
)
cosine_sim

As can be seen here, the cosine similarity between consecutive layers seem to be
large, especially for those middle layers.

In [None]:
cosine_sim > 0.9

## Attention sink for the first token

Similar to how to extract residual stream from the model, we can construct
objects of `SaveIntermediate` and insert them after each attention layer. As we
would like to visualize the attention map, we insert them after `pz.nn.Softmax`.

In [None]:
block_num = 26
model_patched = None
attention_maps = [pz.StateVariable(value=None) for _ in range(block_num)]
for block_index in range(block_num):
  if model_patched is None:
    model_patched = (
        pz.select(model)
        .at_instances_of(pz.nn.Softmax)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(attention_maps[block_index]))
    )
  else:
    model_patched = (
        pz.select(model_patched)
        .at_instances_of(pz.nn.Softmax)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(attention_maps[block_index]))
    )

Then we use the patched model for model forward.

In [None]:
logits = model_patched(tokens)

As there are many transformer blocks, here we would like to visualize only one
block, e.g., the 23rd block (index=22).

In [None]:
block_index = 22  # @param {type:"integer"}
attention_maps[block_index].value

Here we can clearly see that the first token has extremely large attention
scores (close to 0.9 out of 1.0) compared to other tokens.

## Value-state drains for the first token

In the literature of [[2](https://arxiv.org/abs/2410.10781),
[4](https://arxiv.org/abs/2410.13835), [5](https://arxiv.org/abs/2504.02732)],
the authors mentioned that although the first token has large attention scores,
it has small value-states. And the attention operation is the weighted average
on the value-states, the first token may make contributions limitedly in
semantics, more from the functionality.

Now we use the same way to extract value states. Penzai provides convenient way
to select specific layers. Here we only would like to extract the linear layers
related to value states, so we can set a condition by finding the labels
including `value.weights`.

In [None]:
block_num = 26
model_patched = None
all_values = [pz.StateVariable(value=None) for _ in range(block_num)]
for block_index in range(block_num):
  if model_patched is None:
    model_patched = (
        pz.select(model)
        .at_instances_of(pz.nn.Linear)
        .where(lambda x: "attention/value.weights" in x.weights.label)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(all_values[block_index]))
    )
  else:
    model_patched = (
        pz.select(model_patched)
        .at_instances_of(pz.nn.Linear)
        .where(lambda x: "attention/value.weights" in x.weights.label)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(all_values[block_index]))
    )

In [None]:
logits = model_patched(tokens)

As there are many transformer blocks, here we would like to visualize only one
block, e.g., the 15th block (index=14).

In [None]:
block_index = 1  # @param {type:"integer"}
pz.nx.nmap(jnp.linalg.norm)(
    all_values[block_index].value.untag("projection"), ord=2
)

It is observed that the first token has much smaller $\ell_2$-norm (about 2)
compared to that of other tokens (more than 10).

We can also visualize the value states across different heads and transformer
blocks.

In [None]:
all_values_norm = []
for block_index in range(block_num):
  values = all_values[block_index].value
  norm = pz.nx.nmap(jnp.linalg.norm)(values.untag("projection"), ord=2)
  # norm = pz.nx.nmap(jnp.expand_dims)(norm, axis=0).tag("block")
  all_values_norm.append(norm)
all_values_norm = pz.nx.stack(all_values_norm, axis_name="block")
all_values_norm

We can clearly observe that the $\ell_2$-norm on the first token is
significantly smaller than other tokens.

This further shows that attention sink tokens contribute almost no semantic
meanings to follow-up tokens. Their existence is due to that the attention
scores are normalized to sum up to one.

## Disappearance of attention sink without \<bos\> token in Gemma models

Most of LLMs have attention sink in the first token regardless of \<bos\> token
according to [2](https://arxiv.org/abs/2410.10781). Gemma models are exceptions,
as shown in [5](https://arxiv.org/abs/2504.02732). Without \<bos\> token,
attention sink disappears. As hypothesized in
[5](https://arxiv.org/abs/2504.02732), this is due to the reason that Gemma
models always put \<bos\> token in the first position during the pre-training.

Now we encode the same prompt without \<bos\> token.

In [None]:
tokens_no_bos = tokenizer.encode(prompt, add_bos=False)
tokens_no_bos = jnp.asarray(tokens_no_bos)[None, :]
tokens_no_bos = pz.nx.wrap(tokens_no_bos).tag("batch", "seq")
tokens_no_bos

In [None]:
token_visualization.show_token_array(tokens_no_bos, tokenizer._sp)  # pylint: disable=protected-access

Now we construct the patched model again and conduct the model forward.

In [None]:
block_num = 26
model_patched = None
attention_maps = [pz.StateVariable(value=None) for _ in range(block_num)]
for block_index in range(block_num):
  if model_patched is None:
    model_patched = (
        pz.select(model)
        .at_instances_of(pz.nn.Softmax)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(attention_maps[block_index]))
    )
  else:
    model_patched = (
        pz.select(model_patched)
        .at_instances_of(pz.nn.Softmax)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(attention_maps[block_index]))
    )

In [None]:
logits = model_patched(tokens_no_bos)

In [None]:
block_index = 22  # @param {type:"integer"}
attention_maps[block_index].value

As shown here, the attention scores are more distributed in the diagonals
instead of the first token. Attention sink phenemenon on the first token
disappears.

Before move to the second part, we can delete the model we use and release some
memory.

In [None]:
del model_patched
del model

gc.collect()

# Second part: Logit lens

As the first part already imports the packages, we skip this part. To ensure the
diversity of models, we use Gemma3 1B pretrained model instead.

## Load Gemma3 model

We load the Gemma3-1B model similarly as above.

In [None]:
# checkpoint path for Gemma3 model
weights_dir = kagglehub.model_download("google/gemma-3/flax/gemma3-1b")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma3-1b")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(sharding_devices, ("data",))

In [None]:
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=NamedSharding(
            mesh, PartitionSpec(*(None,) * (len(m.shape) - 1), "data")
        ),
    ),
    metadata.item_metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

In [None]:
model = transformer.variants.gemma.gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=False,
)

Here we visualize the model definition:

In [None]:
model_unbound, _ = pz.unbind_params(model)
model_unbound

Free the memory.

In [None]:
del flat_params
gc.collect()

## Prepare the tokenizer and prompt

For Gemma3 series, we use `Gemma3Tokenizer` instead.

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()
tokenizer

In [None]:
tokenizer.vocab_size

In [None]:
tokenizer.special_tokens

As we use a pre-trained model, we prepare a prefix text as the prompt.

In [None]:
prompt = (
    "Specifically, we train GPT-3, an autoregressive language model with 175"
    " billion parameters,"
)

In [None]:
tokens = tokenizer.encode(prompt, add_bos=True)
tokens = jnp.asarray(tokens)
tokens = pz.nx.wrap(tokens).tag("seq")
tokens

In [None]:
token_visualization.show_token_array(tokens, tokenizer._sp)  # pylint: disable=protected-access

We first check the model forward.

In [None]:
logits = model(tokens)

In [None]:
pred_tokens = pz.nx.nmap(jnp.argmax)(logits.untag("vocabulary"))
token_visualization.show_token_array(pred_tokens, tokenizer._sp)  # pylint: disable=protected-access

Note the final output is "_to". The previous output may not be sensible as the
next token prediction may not be accurate for the first few tokens.

## Logit lens analysis

First, we visualize the model definition again to find the model embedding and
unembedding.

In [None]:
model_unbound

It is observed that in the transformer architecture, the token ids will first be
fed into a layer of `EmbeddingLookup` and then a layer of `ConstantRescale`
before input to a series of `TransformerBlock`. To decode the model predictions,
there are a final layer of `RMSLayerNorm` and then the LM head
`EmbeddingDecode`.

The logit lens refers to how logits change after each step of processing. Here
the processing mean each transformer block. We can first construct our
prediction head, including both the final RMS norm layer and a linear
unembedding.

In [None]:
pred_head = pz.nn.Sequential([
    model.body.sublayers[-2],
    model.body.sublayers[-1],
])
pred_head

It is easy to extract model layers, we can directly take the final two layers to
construct a new model named as `pred_head`.

Then we extract activations after each transformer block. Similar to how we
visualize massive activations, we can follow the procedure:

In [None]:
block_num = 26
model_patched = None
all_activations = [pz.StateVariable(value=None) for _ in range(block_num)]
for block_index in range(block_num):
  if model_patched is None:
    model_patched = (
        pz.select(model)
        .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(all_activations[block_index]))
    )
  else:
    model_patched = (
        pz.select(model_patched)
        .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
        .pick_nth_selected(block_index)
        .insert_after(SaveIntermediate(all_activations[block_index]))
    )

In [None]:
logits = model_patched(tokens)

Now we proceed to probe activations from transformer blocks and observe the
dynamics of logits. We use the output of each transformer block as the input to
our prediction head, resulting in logits. We then could infer the rank-1 model
predictions.

In [None]:
all_top_token = []
all_top_logit = []
for block_index in range(block_num):
  activations = all_activations[block_index].value
  block_logits = pred_head(activations)
  top_token = pz.nx.nmap(jnp.argmax)(block_logits.untag("vocabulary"))
  top_logit = pz.nx.nmap(jnp.max)(block_logits.untag("vocabulary"))
  all_top_token.append(top_token)
  all_top_logit.append(top_logit)
all_top_token = pz.nx.stack(all_top_token, axis_name="block")
all_top_logit = pz.nx.stack(all_top_logit, axis_name="block")

Then we visualize the model predictions in the middle.

In [None]:
# all_top_token
token_visualization.show_token_array(all_top_token, tokenizer._sp)  # pylint: disable=protected-access

We can observe that the model output in the middle already looks like the final
output.

Compute KL-divergence from output distributions $KL(final||layer)$.

In [None]:
all_kl_divergence = []
log_probs = pz.nx.nmap(jax.nn.log_softmax)(logits.untag("vocabulary"))
probs = pz.nx.nmap(jnp.exp)(log_probs)
for block_index in range(block_num):
  activations = all_activations[block_index].value
  block_logits = pred_head(activations)
  block_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      block_logits.untag("vocabulary")
  )
  kl_divergence = pz.nx.nmap(jnp.sum)(
      probs * (log_probs - block_log_probs), axis=-1
  )
  all_kl_divergence.append(kl_divergence)
all_kl_divergence = pz.nx.stack(all_kl_divergence, axis_name="block")
all_kl_divergence

*   After first few layers, the input has been transformed into something that
    looks like the final output.

*   After this one discontinuous jump, the distribution progresses in a much
    more smooth way to the final output distribution.