# Gemma Scope Tutorial with Penzai

This colab shows how to use
[Gemma Scope](https://huggingface.co/google/gemma-scope) in Penzai. Gemma Scope
is Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and
sublayer of Gemma2 2B and 9B.

Sparse Autoencoders are an interpretability tool that act like a "microscope" on
language model activations.

We aim to reproduce the example in the
[Tutorial: Gemma Scope from Scratch](https://colab.sandbox.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp?usp=sharing).

NOTE: we run this colab on a TPU **v5e-1** runtime. Please see our notebook
`./notebooks/gemma3_multimodal_penzai.ipynb` on how to build a local runtime.

## Import packages

Firstly, we install `jax[tpu]`, `gemma_penzai` package and its dependencies.

In [None]:
# Clone the gemma_penzai package
!git clone https://github.com/google-deepmind/gemma_penzai.git

# Upgrade your pip in case
!pip install --upgrade pip

# Installs JAX with TPU support
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install the package in editable mode (-e)
# This installs dependencies defined in your pyproject.toml
print("Installing gemma_penzai and dependencies...")
%cd gemma_penzai
!pip install -e .

Import miscellaneous packages.

In [None]:
import dataclasses
import gc
import os
from gemma import gm
from huggingface_hub import hf_hub_download
from IPython.display import clear_output
import kagglehub

Import JAX related packages.

In [None]:
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
import numpy as np
import orbax.checkpoint

Import Penzai related packages.

In [None]:
import penzai
from penzai import pz
from penzai.models import transformer
from penzai.toolshed import jit_wrapper
from penzai.toolshed import token_visualization
import treescope

treescope.basic_interactive_setup(autovisualize_arrays=True)

## Loading Gemma2 pre-trained models

You can download the Gemma checkpoints using a Kaggle account and an API key. If
you don't have an API key already, you can:

1.  Visit https://www.kaggle.com/ and create an account if needed.

2.  Go to your account settings, then the 'API' section.

3.  Click 'Create new token' to download your key.

Next, input your "KAGGLE_USERNAME" and "KAGGLE_KEY" below.

In [None]:
KAGGLE_USERNAME = "<KAGGLE_USERNAME>"
KAGGLE_KEY = "<KAGGLE_KEY>"
try:
  kagglehub.config.set_kaggle_credentials(KAGGLE_USERNAME, KAGGLE_KEY)
except ImportError:
  kagglehub.login()

As Gemma Scope is trained on activations of Gemma2, so we first load Gemma2 2B
pre-trained models in Penzai.

In [None]:
weights_dir = kagglehub.model_download("google/gemma-2/flax/gemma2-2b")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma2-2b")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

Prepare the device and sharding. Here the sharding strategy splits the model
parameters into different TPUs according to the last dimension.

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(sharding_devices, ("data",))

In [None]:
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=NamedSharding(
            mesh, PartitionSpec(*(None,) * (len(m.shape) - 1), "data")
        ),
    ),
    metadata.item_metadata,  # change back to metadata if any running error
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

Now we prepare the Gemma2 language model definition and bind it with the
parameters.

In [None]:
model = transformer.variants.gemma.gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=False,
)

Directly visualizing the model definition with parameters will take a long time.
Therefore, we firstly use `unbind_params` function to extract the model
architecture. Then we only visualize the model architecture without parameters.

In [None]:
model.body.sublayers[-2]

In [None]:
model_unbound, _ = pz.unbind_params(model)
model_unbound

Free some memory.

In [None]:
del flat_params
gc.collect()

## Evaluate model inference of Gemma2 in Penzai

Load tokenizer for Gemma2 models

In [None]:
tokenizer = gm.text.Gemma2Tokenizer()
tokenizer

Show the vocabulary size.

In [None]:
tokenizer.vocab_size

Show the special tokens in the vocabulary.

In [None]:
tokenizer.special_tokens

Use tokenizer to encode the prompt, and then transform it into a named JAX
array. Please note that we need to enable `add_bos=True` to ensure Gemma2 models
work normally.

In [None]:
token_ids = tokenizer.encode(
    "Would you be able to travel through time using a wormhole?", add_bos=True
)
tokens = jnp.asarray(token_ids)[None, :]
tokens = pz.nx.wrap(tokens).tag("batch", "seq")
tokens

We can also use `token_visualization` in Penzai to visualize the input token
ids. Please note that `show_token_array` needs an argument of `SentencePiece`
object. To achieve this, we can pass `tokenizer._sp`.

In [None]:
token_visualization.show_token_array(tokens, tokenizer._sp)  # pylint: disable=protected-access

Before the inference, we first prepare an inference mode by adding KV cache.

In [None]:
inference_model = (
    transformer.sampling_mode.KVCachingTransformerLM.from_uncached(
        model,
        cache_len=1024,
        batch_axes={"batch": 1},
    )
)

Then we jit the model and sample the output from the loop.

In [None]:
samples = transformer.simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    rng=jax.random.key(3),
    max_sampling_steps=256,
)

Transform the sampled output from named JAX array back to JAX array, and then
decode it to text.

In [None]:
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
tokenizer.decode(sample_tokens)

## Loading a Sparse Autoencoder

After loading Gemma2 2B and showing the model could output reasonable text. Now,
we load a sparse autoencoder (SAE).

GemmaScope actually contains over four hundred SAEs, but for now we'll just load
one on the residual stream at the end of layer 20 (of 26, note that layers start
at 0 so this is the 21st layer. This is a fairly late layer, so the model should
have time to find more abstract concepts!).

The specific filename can be found at
[google/gemma-scope-release](https://huggingface.co/collections/google/gemma-scope-release)  in `huggingface`.

In [None]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
)
params = np.load(path_to_params, allow_pickle=True)
params

Check the dimensions for SAE parameters.

In [None]:
sae_params = {k: v.shape for k, v in params.items()}
sae_params

In [None]:
np.linalg.norm(params["W_enc"], axis=0)

## Implementing the SAE

We now define the forward pass of the SAE for pedagogical purposes using Penzai.

Gemma Scope is a collection of
[JumpReLU SAEs](https://arxiv.org/abs/2407.14435), which is like an auto-encoder
with both encoder and decoder. The encoder is defined to map the activations
into a sparse, non-negative vector of feature magnitude:

$$\boldsymbol{f}(\boldsymbol{x})=\sigma(\boldsymbol{W}_{\text{enc}}\boldsymbol{x}+\boldsymbol{b}_{\text{enc}})$$

Here $\sigma$ is **JumpReLU** activation defined as ($H$ is the Heaviside step
function and $\theta$ is the threshold.)

$$\sigma(z)=zH(z-\theta)$$

Then the decoder reconstructs the input activations by:

$$\hat{\boldsymbol{x}}=\boldsymbol{W}_{\text{dec}}\boldsymbol{f}+\boldsymbol{b}_{\text{dec}}$$

As Penzai has not implemented such a JumpReLU auto-encoder, we first implement a
class of `AutoEncoder` with properties of `encoder` and `decoder`. The model
forward also includes `encode()` and `decode()`. Then we implement a class of
`JumpReLU` with learnable parameters.

In [None]:
from typing import Any

from penzai.core import named_axes
from penzai.core import struct
from penzai.nn import layer as layer_base
from penzai.nn import parameters
from penzai.nn.linear_and_affine import LinearOperatorWeightInitializer
from penzai.nn.linear_and_affine import zero_initializer


NamedArray = named_axes.NamedArray


@struct.pytree_dataclass
class AutoEncoder(pz.nn.Layer):
  """Top-level auto-encoder wrapper.

  Attributes:
    encoder: The encoder to transform inputs to latents.
    decoder: The decoder to reconstruct inputs from latents.
  """

  encoder: pz.nn.Layer
  decoder: pz.nn.Layer

  def __call__(
      self, x: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the forward pass of the auto-encoder."""
    acts = self.encode(x, **side_inputs)
    recon = self.decode(acts, **side_inputs)
    return recon

  def encode(
      self, x: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the encoder sublayer."""
    return self.encoder(x, **side_inputs)

  def decode(
      self, acts: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the decoder sublayer."""
    return self.decoder(acts, **side_inputs)


@struct.pytree_dataclass
class JumpReLU(pz.nn.Layer):
  """JumpReLU activation."""

  threshold: parameters.ParameterLike[NamedArray]
  new_axis_names: tuple[str, ...] = dataclasses.field(
      metadata={"pytree_node": False}
  )
  act_fn: layer_base.Layer = pz.nn.Elementwise(jax.nn.relu)

  def __call__(self, value: NamedArray, **_unused_side_inputs) -> NamedArray:
    """Return whether the value is above the threshold."""
    # Elementwise functions broadcast automatically
    return (value > self.threshold.value) * self.act_fn(value)

  @classmethod
  def from_config(
      cls,
      name: str,
      init_base_rng: jax.Array | None,
      threshold_axes: dict[str, int],
      new_output_axes: dict[str, int] | None = None,
      initializer: LinearOperatorWeightInitializer = zero_initializer,
      dtype: jax.typing.DTypeLike = jnp.float32,
  ):
    """Constructs an ``JumpReLU`` layer from a configuration.

    Args:
      name: The name of the layer.
      init_base_rng: The base RNG to use for initializing model parameters.
      threshold_axes: Names and lengths for the axes in the input that the
        threshold should act over. Other axes will be broadcast over.
      new_output_axes: Names and lengths of new axes that should be introduced
        into the input.
      initializer: Function to use to initialize the weight. Only the output
        axes will be set.
      dtype: Dtype for the threshold.

    Returns:
      A new ``AddThreshold`` layer with an uninitialized threshold parameter.
    """
    if new_output_axes is None:
      new_output_axes = {}

    return cls(
        threshold=parameters.make_parameter(
            f"{name}/threshold",
            init_base_rng,
            initializer,
            input_axes={},
            output_axes={**threshold_axes, **new_output_axes},
            parallel_axes={},
            convolution_spatial_axes={},
            dtype=dtype,
        ),
        new_axis_names=tuple(new_output_axes.keys()),
    )

  def treescope_color(self) -> str:
    return "#65cfbc"

After the definition of the above model layers, we implement the model
definition of the whole SAE and bind it with parameters.

In [None]:
def sae_from_gemma_scope(
    params_sae: dict[str, Any],
) -> AutoEncoder:
  """Constructs an SAE model from Gemma scope parameters.

  Args:
    params_sae: The parameters of the Gemma scope.

  Returns:
    A new SAE model.
  """
  embedding_dim, latents_dim = params_sae["W_enc"].shape

  # Encoder
  encoder = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="sae/W_enc",
          init_base_rng=None,
          input_axes={"embedding": embedding_dim},
          output_axes={"latents": latents_dim},
      ),
      pz.nn.AddBias.from_config(
          name="sae/b_enc",
          init_base_rng=None,
          biased_axes={"latents": latents_dim},
      ),
      JumpReLU.from_config(
          name="sae",
          init_base_rng=None,
          threshold_axes={"latents": latents_dim},
      ),
  ])
  # Decoder
  decoder = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="sae/W_dec",
          init_base_rng=None,
          input_axes={"latents": latents_dim},
          output_axes={"embedding": embedding_dim},
      ),
      pz.nn.AddBias.from_config(
          name="sae/b_dec",
          init_base_rng=None,
          biased_axes={"embedding": embedding_dim},
      ),
  ])

  # Create the model definition.
  model_def = AutoEncoder(
      encoder=encoder,
      decoder=decoder,
  )

  # Create parameter objects for each parameter.
  model_sae = pz.bind_variables(
      model_def,
      [
          pz.Parameter(
              value=pz.nx.wrap(params_sae["W_enc"]).tag("embedding", "latents"),
              label="sae/W_enc.weights",
          ),
          pz.Parameter(
              value=pz.nx.wrap(params_sae["b_enc"]).tag("latents"),
              label="sae/b_enc.bias",
          ),
          pz.Parameter(
              value=pz.nx.wrap(params_sae["W_dec"]).tag("latents", "embedding"),
              label="sae/W_dec.weights",
          ),
          pz.Parameter(
              value=pz.nx.wrap(params_sae["b_dec"]).tag("embedding"),
              label="sae/b_dec.bias",
          ),
          pz.Parameter(
              value=pz.nx.wrap(params_sae["threshold"]).tag("latents"),
              label="sae/threshold",
          ),
      ],
  )
  return model_sae

By passing the params loaded from huggingface, we now get our SAE model. We can
easily visualize the model structure in Penzai.

In [None]:
sae_model = sae_from_gemma_scope(params)
sae_model

## Running the SAE on model activations

In Penzai, it is easy to insert/delete/change model layers and manipulate
activations. The general tutorial is in
[Penzai Tutorials](https://penzai.readthedocs.io/en/stable/index.html). Here we
only show how to display or save intermediate activations:

In [None]:
# Define a layer to visualize the middle activations
@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(
    pz.nn.Layer
):  # <- pz.nn.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value, **unused_side_inputs) -> Any:
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value


# Define a layer to extract the middle activations
@pz.pytree_dataclass
class SaveIntermediate(pz.nn.Layer):
  saved: pz.StateVariable[Any | None]

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    self.saved.value = value
    return value

Define a `StateVariable` to save model activations.

In [None]:
destination = pz.StateVariable(value=None)
destination

In Penzai, model modifications are generally performed by using `pz.select` to
make a modified copy of the original model (but sharing the same parameters).
This involves “selecting” the part of the model you want to modify, then
applying a modification, similar to the `.at[...].set(...)` syntax for modifying
JAX arrays. Here we insert `SaveIntermediate` layer after 21 st
`TransformerBlock`.

In [None]:
model_patched = (
    pz.select(model)
    .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
    .pick_nth_selected(20)
    .insert_after(SaveIntermediate(destination))
)
logits = model_patched(tokens)

Now we can visualize the activations

In [None]:
destination.value

We can get some statistics from the activations. For example, we can visualize
the $\ell_2$-norm of residual streams.

In [None]:
# Check massive activations
pz.nx.nmap(jnp.linalg.norm)(
    destination.value.untag("embedding"), ord=2, axis=-1
)

It is clear that the first token has large activations, which is called
[massive activations](https://arxiv.org/abs/2402.17762).

Now we can run SAE on the extracted activations. We first encode the features.

In [None]:
sae_acts = sae_model.encode(destination.value)
sae_acts

Here we can observe that except for the first token, other tokens have very
sparse features in the latents. As the first token has outliers in activations
(massive activations), it is not used for training SAEs.

Check the L0 of SAE, should be around 70.

In [None]:
sparsity = pz.nx.nmap(jnp.sum)((sae_acts > 1).untag("latents"), axis=-1)
print(sparsity.untag("batch", "seq").unwrap())

Check the highest activating features on this input, on each token position.

In [None]:
indices = pz.nx.nmap(jnp.argmax)(sae_acts.untag("latents"), axis=-1)
print(indices.untag("batch", "seq").unwrap())

So we see that one of the max activating examples on this question is
[SAE feature 10004](https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/10004),
which fires on concepts related to time travel!

## Steering Model Behaviors using SAEs

SAEs can be used to steer model behaviors. Here we reproduce one example in
[steering](https://www.neuronpedia.org/api-doc#tag/steering/POST/api/steer)
using Penzai.

Firstly, we prepare the prompt as a named JAX array.

In [None]:
prompt = tokenizer.encode("The most iconic structure on Earth is", add_bos=True)
prompt = jnp.asarray(prompt)[None, :]
tokens = pz.nx.wrap(prompt).tag("batch", "seq")

Then we run the baseline model forward:

In [None]:
inference_model = (
    transformer.sampling_mode.KVCachingTransformerLM.from_uncached(
        model,
        cache_len=1024,
        batch_axes={"batch": 1},
    )
)

samples = transformer.simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    rng=jax.random.key(3),
    temperature=0.5,
    max_sampling_steps=64,
)

In [None]:
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
tokenizer.decode(sample_tokens)

As shown in the decoding results, the model outputs "the Great Pyramid of Giza"
and related description.

Here we would like the model to output references to SF. According to
[neuronpedia](ttps://www.neuronpedia.org), we could identify that the index of
latents which corresponding to this behavior. Then to amplify such model
behavior, we can add a steering vector to the model activations. This steering
vector is located in the same index of decoder matrix. The steer scale is an
empirical value, one may need to obtain from experimental results.

In [None]:
# steer references to SF
steer_index = 3124  # reproduce this: https://www.neuronpedia.org/api-doc#tag/steering/POST/api/steer
steer_scale = 38.5 * 4

Remember that it is easy to modify model in Penzai. We can add a new model layer
named as `SteerIntermediate` which adds steer vector to activations. We firstly
take the parameters from SAE decoder, and then create an object of
`SteerIntermediate` layer. We can use `.at[...].set(...)` syntax to insert the
steering vector.

In [None]:
# Define a layer to steer the middle activations
@pz.pytree_dataclass
class SteerIntermediate(pz.nn.Layer):
  steer_vector: pz.StateVariable
  steer_scale: float

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    steer_value = value + self.steer_vector * self.steer_scale
    return steer_value


steer_vector = (
    pz.nx.wrap(params["W_dec"][steer_index, :])
    .tag("embedding")
    .astype(jnp.bfloat16)
)

model_patched = (
    pz.select(model)
    .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
    .pick_nth_selected(20)
    .insert_after(SteerIntermediate(steer_vector, steer_scale))
)
steer_vector

In [None]:
inference_model_patched = (
    transformer.sampling_mode.KVCachingTransformerLM.from_uncached(
        model_patched,
        cache_len=1024,
        batch_axes={"batch": 1},
    )
)

In [None]:
samples_patched = transformer.simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model_patched)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    rng=jax.random.key(3),
    temperature=0.5,
    max_sampling_steps=64,
)

In [None]:
sample_tokens = samples_patched.untag("batch", "seq").unwrap()[0]
tokenizer.decode(sample_tokens)

It is observed that with model steering, the model outputs "the Golden Gate
Bridge", "San Francisco", "Bay Area", etc.

## Visualizing SAE features using Neuropedia

[neuronpedia](https://neuronpedia.org) provides nice visualization for SAE
features. These visualization can be also loaded in the colab, which better
interacts with Penzai. We only show an example as below.

In [None]:
from IPython.display import IFrame

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(
    sae_release="gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0
):
  return html_template.format(sae_release, sae_id, feature_idx)


html = get_dashboard_html(
    sae_release="gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004
)
IFrame(html, width=1200, height=600)