# Gemma Scope 2 Tutorial with Gemma_Penzai

This colab shows how to use
[Gemma Scope 2](https://deepmind.google/blog/gemma-scope-2-helping-the-ai-safety-community-deepen-understanding-of-complex-language-model-behavior/)
in gemma_penzai. Gemma Scope 2 is Google DeepMind's open suite of
interpretability tools for all
[Gemma 3](https://deepmind.google/models/gemma/gemma-3/) model sizes, from 270 M
to 27B parameters. It combines sparse autoencoders (SAEs) and transcoders (TCs),
which allows for looking inside LLMs.

SAEs are an interpretability tool that act like a "microscope" on language model
activations to find individual concepts. TCs expand this by finding circuits
connecting concepts together.

We aim to reproduce main examples in the
[Gemma Scope 2 Tutorial](https://colab.sandbox.google.com/drive/1NhWjg7n0nhfW--CjtsOdw5A5J_-Bzn4r?usp=sharing).
Please also check our tutorial on Gemma Scope 1 with penzai in the same folder
if you are interested.

NOTE: we run this colab on a TPU **v5e-1** runtime. Please see our notebook
`./notebooks/gemma3_multimodal_penzai.ipynb` on how to build a local runtime.

## Import packages

Firstly, we install `jax[tpu]`, `gemma_penzai` package and its dependencies.

In [None]:
# Clone the gemma_penzai package
!git clone https://github.com/google-deepmind/gemma_penzai.git

# Upgrade your pip in case
!pip install --upgrade pip

# Installs JAX with TPU support
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install the package in editable mode (-e)
# This installs dependencies defined in your pyproject.toml
print("Installing gemma_penzai and dependencies...")
%cd gemma_penzai
!pip install -e .

Import miscellaneous packages.

In [None]:
import dataclasses
import os
from gemma import gm
from huggingface_hub import hf_hub_download
import kagglehub
import numpy as np
from safetensors.torch import load_file

Import plot and display utilities.

In [None]:
from IPython.display import clear_output
from IPython.display import display
from IPython.display import HTML
import pandas as pd
import plotly.express as px
import plotly.io as pio

Import JAX related packages.

In [None]:
import jax
import jax.numpy as jnp
import orbax.checkpoint

Import Penzai related packages.

In [None]:
import penzai
from penzai import pz
from penzai.toolshed import jit_wrapper
import treescope

treescope.basic_interactive_setup(autovisualize_arrays=True)

Import `gemma_penzai` package to use Gemma3 models.

In [None]:
from gemma_penzai import mllm

gemma_from_pretrained_checkpoint = (
    mllm.load_gemma.gemma_from_pretrained_checkpoint
)
sampling_mode = mllm.sampling_mode
simple_decoding_loop = mllm.simple_decoding_loop

## Loading Gemma 3 models

You can download the Gemma checkpoints using a Kaggle account and an API key. If
you don't have an API key already, you can:

1.  Visit https://www.kaggle.com/ and create an account if needed.

2.  Go to your account settings, then the 'API' section.

3.  Click 'Create new token' to download your key.

Next, input your "KAGGLE_USERNAME" and "KAGGLE_KEY" below.

In [None]:
KAGGLE_USERNAME = "<KAGGLE_USERNAME>"
KAGGLE_KEY = "<KAGGLE_KEY>"
try:
  kagglehub.config.set_kaggle_credentials(KAGGLE_USERNAME, KAGGLE_KEY)
except ImportError:
  kagglehub.login()

Here we load Gemma 3 1B pre-trained model, the second smallest model that Gemma
Scope 2 works for (you can also try Gemma 3 270m, but in a Colab the 1B-size
model should work fine).

In [None]:
weights_dir = kagglehub.model_download("google/gemma-3/flax/gemma3-1b")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma3-1b")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

Here, we don't split model parameters. Optionally, we can shard parameters into
different devices.

In [None]:
flat_params = checkpointer.restore(ckpt_path)

Now we prepare the Gemma3 language model definition and bind it with the
parameters. Note here we upcast the activation precision to float32 by passing
`upcast_activations_to_float32=True` (default is `False`).

In [None]:
model = gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=True,
)

Directly visualizing the model definition with parameters will take a long time.
Therefore, we firstly use `unbind_params` function to extract the model
architecture. Then we only visualize the model architecture without parameters.

In [None]:
model_unbound, _ = pz.unbind_params(model)
model_unbound

Now we've loaded the model, let's try running it! Before that, let's prepare our
tokenizer.

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

Let's give it a prompt and tokenize it.

In [None]:
# The input text
prompt_physics = (
    "The law of conservation of energy states that energy cannot be created or"
    " destroyed, only transformed."
)

token_ids = tokenizer.encode(prompt_physics, add_bos=True)
tokens = jnp.asarray(token_ids)[None, :]
tokens = pz.nx.wrap(tokens).tag("batch", "seq")
tokens

Before the inference, we first prepare an inference mode by adding KV cache. And
then we jit the model and sample the output from the loop.

In [None]:
inference_model = sampling_mode.KVCachingTransformerLM.from_uncached(
    model,
    cache_len=1024,
    batch_axes={"batch": 1},
)
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    temperature=1.0,
    top_p=0.95,
    # top_k=64,
    rng=jax.random.key(3),
    max_sampling_steps=50,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out = tokenizer.decode(sample_tokens)
penzai_out

This was the pretrained (PT) model, so it doesn't respond like a chatbot - it
just continues based on its priors for what is likely to follow the initial
prompt, given the dataset it was trained on.

We'll also be using the instruction-tuned (IT) model, which behaves more like a
standard chatbot.Let's also load that in and see how it works. Note that we have
to carefully format the input so that it's in the correct form for our IT model:

In [None]:
weights_dir_it = kagglehub.model_download("google/gemma-3/flax/gemma3-1b-it")
clear_output()

In [None]:
ckpt_path_it = os.path.join(weights_dir_it, "gemma3-1b-it")
checkpointer_it = orbax.checkpoint.PyTreeCheckpointer()
flat_params_it = checkpointer.restore(ckpt_path_it)
model_it = gemma_from_pretrained_checkpoint(
    flat_params_it,
    upcast_activations_to_float32=True,
)

In [None]:
def format_prompt(input_prompt: str) -> str:
  return f"""<start_of_turn>user
{input_prompt}<end_of_turn>
<start_of_turn>model
"""


# prepare prompt and tokenize it
user_prompt = "What is your name?"
it_inputs = tokenizer.encode(
    format_prompt(input_prompt=user_prompt), add_bos=True
)
it_tokens = jnp.asarray(it_inputs)[None, :]
it_tokens = pz.nx.wrap(it_tokens).tag("batch", "seq")


# prepare inference mode and generate output
inference_model_it = sampling_mode.KVCachingTransformerLM.from_uncached(
    model_it,
    cache_len=1024,
    batch_axes={"batch": 1},
)
samples_it = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model_it)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=it_tokens,
    temperature=1.0,
    top_p=0.95,
    rng=jax.random.key(3),
    max_sampling_steps=50,
)
sample_tokens_it = samples_it.untag("batch", "seq").unwrap()[0]
penzai_out_it = tokenizer.decode(sample_tokens_it)
penzai_out_it = penzai_out_it.split("<end_of_turn>")[0]
penzai_out_it

## Sparse Autoencoders

Now, let's load one of our SAEs. GemmaScope actually contains over four hundred
SAEs, but for now we'll just load one on the residual stream at the end of layer
22 (of 26, note that layers start at 0 so this is the 23rd layer. This is a
fairly late layer, so the model should have time to find more abstract
concepts!).

The specific filename can be found at
[google/gemma-scope-2](https://huggingface.co/collections/google/gemma-scope-2)
in `huggingface`.

In [None]:
layer = 22  # @param [7, 13, 17, 22]
width = "65k"  # @param ["16k", "65k", "262k", "1m"]
l0 = "medium"  # @param ["small", "medium", "big"]

path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2-1b-pt",
    filename=(
        f"resid_post/layer_{layer}_width_{width}_l0_{l0}/params.safetensors"
    ),
)
params = load_file(path_to_params)
params = {k: v.numpy() for k, v in params.items()}
params

Check the dimensions for SAE parameters.

In [None]:
sae_params = {k: v.shape for k, v in params.items()}
sae_params

### Implementing the SAE

We now define the forward pass of the SAE for pedagogical purposes using Penzai.

Gemma Scope 2 is a collection of
[JumpReLU SAEs](https://arxiv.org/abs/2407.14435), which is like an auto-encoder
with both encoder and decoder. The encoder is defined to map the activations
into a sparse, non-negative vector of feature magnitude:

$$\boldsymbol{f}(\boldsymbol{x})=\sigma(\boldsymbol{W}_{\text{enc}}\boldsymbol{x}+\boldsymbol{b}_{\text{enc}})$$

Here $\sigma$ is **JumpReLU** activation defined as ($H$ is the Heaviside step
function and $\theta$ is the threshold.)

$$\sigma(z)=zH(z-\theta)$$

Then the decoder reconstructs the input activations by:

$$\hat{\boldsymbol{x}}=\boldsymbol{W}_{\text{dec}}\boldsymbol{f}+\boldsymbol{b}_{\text{dec}}$$

As Penzai has not implemented such a JumpReLU auto-encoder, we first implement a
class of `AutoEncoder` with properties of `encoder` and `decoder`. The model
forward also includes `encode()` and `decode()`. Then we implement a class of
`JumpReLU` with learnable parameters.

In [None]:
from typing import Any

from penzai.core import named_axes
from penzai.core import struct
from penzai.nn import layer as layer_base
from penzai.nn import parameters
from penzai.nn.linear_and_affine import LinearOperatorWeightInitializer
from penzai.nn.linear_and_affine import zero_initializer


NamedArray = named_axes.NamedArray


@struct.pytree_dataclass
class AutoEncoder(pz.nn.Layer):
  """Top-level auto-encoder wrapper.

  Attributes:
    encoder: The encoder to transform inputs to latents.
    decoder: The decoder to reconstruct inputs from latents.
  """

  encoder: pz.nn.Layer
  decoder: pz.nn.Layer

  def __call__(
      self, x: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the forward pass of the auto-encoder."""
    acts = self.encode(x, **side_inputs)
    recons = self.decode(acts, **side_inputs)
    return recons

  def encode(
      self, x: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the encoder sublayer."""
    return self.encoder(x, **side_inputs)

  def decode(
      self, acts: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the decoder sublayer."""
    return self.decoder(acts, **side_inputs)


@struct.pytree_dataclass
class JumpReLU(pz.nn.Layer):
  """JumpReLU activation."""

  threshold: parameters.ParameterLike[NamedArray]
  new_axis_names: tuple[str, ...] = dataclasses.field(
      metadata={"pytree_node": False}
  )
  act_fn: layer_base.Layer = pz.nn.Elementwise(jax.nn.relu)

  def __call__(self, value: NamedArray, **_unused_side_inputs) -> NamedArray:
    """Return whether the value is above the threshold."""
    # Elementwise functions broadcast automatically
    return (value > self.threshold.value) * self.act_fn(value)

  @classmethod
  def from_config(
      cls,
      name: str,
      init_base_rng: jax.Array | None,
      threshold_axes: dict[str, int],
      new_output_axes: dict[str, int] | None = None,
      initializer: LinearOperatorWeightInitializer = zero_initializer,
      dtype: jax.typing.DTypeLike = jnp.float32,
  ):
    """Constructs an ``JumpReLU`` layer from a configuration.

    Args:
      name: The name of the layer.
      init_base_rng: The base RNG to use for initializing model parameters.
      threshold_axes: Names and lengths for the axes in the input that the
        threshold should act over. Other axes will be broadcast over.
      new_output_axes: Names and lengths of new axes that should be introduced
        into the input.
      initializer: Function to use to initialize the weight. Only the output
        axes will be set.
      dtype: Dtype for the threshold.

    Returns:
      A new ``AddThreshold`` layer with an uninitialized threshold parameter.
    """
    if new_output_axes is None:
      new_output_axes = {}

    return cls(
        threshold=parameters.make_parameter(
            f"{name}/threshold",
            init_base_rng,
            initializer,
            input_axes={},
            output_axes={**threshold_axes, **new_output_axes},
            parallel_axes={},
            convolution_spatial_axes={},
            dtype=dtype,
        ),
        new_axis_names=tuple(new_output_axes.keys()),
    )

  def treescope_color(self) -> str:
    return "#65cfbc"

After the definition of the above model layers, we implement the model
definition of the whole SAE and bind it with parameters.

In [None]:
def sae_from_gemma_scope2(
    params_sae: dict[str, Any],
) -> AutoEncoder:
  """Constructs an SAE model from Gemma scope 2 parameters.

  Args:
    params_sae: The parameters of the Gemma scope 2.

  Returns:
    A new SAE model.
  """
  embedding_dim, latents_dim = params_sae["w_enc"].shape

  # Encoder
  encoder = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="sae/w_enc",
          init_base_rng=None,
          input_axes={"embedding": embedding_dim},
          output_axes={"latents": latents_dim},
      ),
      pz.nn.AddBias.from_config(
          name="sae/b_enc",
          init_base_rng=None,
          biased_axes={"latents": latents_dim},
      ),
      JumpReLU.from_config(
          name="sae",
          init_base_rng=None,
          threshold_axes={"latents": latents_dim},
      ),
  ])
  # Decoder
  decoder = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="sae/w_dec",
          init_base_rng=None,
          input_axes={"latents": latents_dim},
          output_axes={"embedding": embedding_dim},
      ),
      pz.nn.AddBias.from_config(
          name="sae/b_dec",
          init_base_rng=None,
          biased_axes={"embedding": embedding_dim},
      ),
  ])

  # Create the model definition.
  model_def = AutoEncoder(
      encoder=encoder,
      decoder=decoder,
  )

  # Create parameter objects for each parameter.
  model_params = [
      pz.Parameter(
          value=pz.nx.wrap(params_sae["w_enc"]).tag("embedding", "latents"),
          label="sae/w_enc.weights",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_sae["b_enc"]).tag("latents"),
          label="sae/b_enc.bias",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_sae["w_dec"]).tag("latents", "embedding"),
          label="sae/w_dec.weights",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_sae["b_dec"]).tag("embedding"),
          label="sae/b_dec.bias",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_sae["threshold"]).tag("latents"),
          label="sae/threshold",
      ),
  ]

  model_sae = pz.bind_variables(
      model_def,
      model_params,
  )
  return model_sae

By passing the params loaded from huggingface, we now get our SAE model. We can
easily visualize the model structure in Penzai.

In [None]:
sae_model = sae_from_gemma_scope2(params)
sae_model

### Running the SAE on model activations

Let's first get out some activations from the model at the SAE target site.
We'll demonstrate how to easily do this 'manually' with gemma_penzai by patching
the model forward.

In penzai and gemma_penzai, it is easy to insert/delete/change model layers and
manipulate activations. The general tutorial is in
[Penzai Tutorials](https://penzai.readthedocs.io/en/stable/index.html). Here we
only show how to display or save intermediate activations:

In [None]:
# Define a layer to visualize the middle activations
@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(
    pz.nn.Layer
):  # <- pz.nn.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value, **unused_side_inputs) -> Any:
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value


# Define a layer to extract the middle activations
@pz.pytree_dataclass
class SaveIntermediate(pz.nn.Layer):
  saved: pz.StateVariable[Any | None]

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    self.saved.value = value
    return value

Define a `StateVariable` to save model activations.

In [None]:
target_act = pz.StateVariable(value=None)
target_act

In Penzai, model modifications are generally performed by using `pz.select` to
make a modified copy of the original model (but sharing the same parameters).
This involves “selecting” the part of the model you want to modify, then
applying a modification, similar to the `.at[...].set(...)` syntax for modifying
JAX arrays. Here we insert `SaveIntermediate` layer after `layer` st
`TransformerBlock`.

In [None]:
model_patched = (
    pz.select(model)
    .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
    .pick_nth_selected(layer)
    .insert_after(SaveIntermediate(target_act))
)
logits = model_patched(tokens)

Now, we can run our SAE on the saved activations.

In [None]:
sae_acts = sae_model.encode(
    jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), target_act.value)
)
recon = sae_model.decode(sae_acts)

Let's transform these penzai arrays back to JAX arrays.

In [None]:
recon_np = recon.untag("batch", "seq", "embedding").unwrap()
target_act_np = target_act.value.untag("batch", "seq", "embedding").unwrap()
sae_acts_np = sae_acts.untag("batch", "seq", "latents").unwrap()

Let's just double check that the model looks sensible by checking that we
explain a decent chunk of the variance:

In [None]:
reconstruction_mse = jnp.mean((recon_np[:, 1:] - target_act_np[:, 1:]) ** 2)
target_variance = target_act_np[:, 1:].var()

fvu = reconstruction_mse / target_variance
print(f"Fraction of variance unexplained: {fvu:.2%}")

This looks pretty good!

This SAE is supposed to have an L0 of ~60 (size "medium"), so let's check that
too:

In [None]:
l0_per_token = (sae_acts_np > 1).sum(-1)[0]
print(l0_per_token.tolist())

print(f"Average L0: {l0_per_token[1:].mean():.2f}")

Note that the SAEs are *NOT* trained on the BOS token because of so called
''attention sink'' and ''massive activations'' phenemenon. The first token's
activations are outliers and will mess up SAE training. So they tend to give
nonsense when we apply to them to it, and we need to be careful not to do this
accidentally! We can see this above: the BOS token is a total outlier in terms
of L0!

Another way we can evaluate our SAE is by looking at the **delta loss**, i.e.
how much the model's prediction loss increases when we patch in the SAE's
output. To do this we'll set up a new penzai layer:

In [None]:
@pz.pytree_dataclass
class SAEIntervention(pz.nn.Layer):
  """Define a layer to intervene the model forward."""

  sae: pz.nn.Layer

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    # first we get the SAE reconstruction
    recons = self.sae(
        jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), value)
    )
    # second we only patch the activations except for the BOS token
    value = pz.nx.nmap(jnp.concatenate)(
        [value.untag("seq")[:1], recons.untag("seq")[1:]], axis=0
    ).tag("seq")
    return value

We first get the clean logits by using the normal model forward.

In [None]:
logits_clean = model(tokens)

Then we patch model with SAE intervention.

In [None]:
model_intervened = (
    pz.select(model)
    .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
    .pick_nth_selected(layer)
    .insert_after(SAEIntervention(sae=sae_model))
)
logits_sae = model_intervened(tokens)

Afterwards, we compute the cross entropy loss for both clean and SAE logits, and
then obtain the delta loss.

In [None]:
def cross_entropy_loss(model_logits, input_token_seq):
  """Measures avg cross entropy loss."""
  log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      model_logits.untag("vocabulary")
  ).tag("vocabulary")
  sliced_preds = log_probs[{"seq": pz.slice[:-1]}]
  correct_next_token = input_token_seq[{"seq": pz.slice[1:]}]

  log_prob_of_correct_next = sliced_preds[{"vocabulary": correct_next_token}]
  return -log_prob_of_correct_next

In [None]:
loss_clean = cross_entropy_loss(logits_clean, tokens)
loss_clean = pz.nx.nmap(jnp.mean)(loss_clean.untag("batch", "seq")).unwrap()

loss_sae = cross_entropy_loss(logits_sae, tokens)
loss_sae = pz.nx.nmap(jnp.mean)(loss_sae.untag("batch", "seq")).unwrap()

print(f"Loss (clean): {loss_clean.item():.4f}")
print(f"Loss (corrupted): {loss_sae.item():.4f}")
print(f"Delta loss: {(loss_sae - loss_clean).item():.4f}")

Let's look at the highest activating features on this input text, on each token
position:

In [None]:
top_activations = sae_acts_np.max(-1)
top_features = sae_acts_np.argmax(-1)
print(top_features)

Note that a lot of these indices are quite small, relative to the number of
features in the SAE (over 200 thousand). This is because the SAE was trained
with
[**Matryoshka loss**](https://www.lesswrong.com/posts/zbebxYCqsryPALh8C/matryoshka-sparse-autoencoders),
which imposes a feature hierarchy: the smaller-indexed features are incentivised
to be good at reconstructing the input even when all other features are switched
off. This helps avoid problems like **feature absorption**.

Let's find the feature which activates the strongest when averaged over all
tokens in the sequence:

In [None]:
top_acts, top_latents = jax.lax.top_k(sae_acts_np.squeeze().mean(0), 5)
for act, idx in zip(top_acts, top_latents):
  print(f"{act:>6.1f} | {idx}")

Latent 6524 seems to fire strongest. Let's inspect it:

In [None]:
feature_idx = 6524

activations = sae_acts_np[0, 1:, feature_idx].tolist()
str_toks = tokenizer.split(prompt_physics)


def html_activations(toks: list[str], acts: list[float]):
  return "".join(
      f"""<span style="background-color: rgba(255,0,0,{v}); padding: 4px"""
      f""" 0px;">{t}</span>"""
      for t, v in zip(
          toks,
          np.array(acts) / (1e-6 + np.max(acts)),
          strict=True,
      )
  )


display(HTML(html_activations(str_toks, activations)))

One guess we might have is that this latent fires on concepts related to science
or scientific laws. Let's test this out with a few examples:

In [None]:
for prompt in [
    "Gemma Scope 2 is a model release from Google DeepMind",
    "Lorem ipsum dolor sit amet, consectetur adipiscing elit",
    "Gravity describes how massive objects attract one another",
    "A charge accelerating through an electric field experiences a force",
    "Chemical fuel stores energy in molecular bonds, which is released",
]:

  inputs = tokenizer.encode(prompt, add_bos=True)
  inputs = jnp.asarray(inputs)[None, :]
  inputs = pz.nx.wrap(inputs).tag("batch", "seq")
  logits = model_patched(inputs)
  sae_acts = sae_model.encode(
      jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), target_act.value)
  )

  target_act_np = target_act.value.untag("batch", "seq", "embedding").unwrap()
  sae_acts_np = sae_acts.untag("batch", "seq", "latents").unwrap()

  str_toks = tokenizer.split(prompt)

  display(
      HTML(html_activations(str_toks, sae_acts_np[0, 1:, feature_idx].tolist()))
  )
  print()

Okay, so it doesn't fire on the gravity sentence, but it does fire on both the
other physics-related sentences as soon as they start talking about forces,
energies or fields. This gives us a more specific idea of the concepts this
latent might represent.

We can investigate this further by looking at the latent's unembedding, in other
words **what words get predicted strongest when this latent fires.**

In [None]:
w_u = model.body.sublayers[
    -1
].table.embeddings.value  # shape (d_vocab, d_model)
norm_weight = model.body.sublayers[-2].sublayers[1].weights.value - 1.0
w_u_eff = w_u * norm_weight

We apply a `-1.0` adjustment to the LN weights to account for implementation
differences between the Gemma RMS norm and Penzai. This ensures our results
align with the
[Gemma Scope 2 Tutorial](https://colab.sandbox.google.com/drive/1NhWjg7n0nhfW--CjtsOdw5A5J_-Bzn4r?usp=sharing).
Note that while this operation is arguably optional, it has a negligible impact
on the final results.

In [None]:
decoder_vector = sae_model.decoder.sublayers[0].weights.value[
    {"latents": pz.slice[feature_idx]}
]

fire_logits = pz.nx.nmap(jnp.matmul)(
    w_u_eff.untag("embedding"), decoder_vector.untag("embedding")
)
top_activations, top_tokens = pz.nx.nmap(jax.lax.top_k)(
    fire_logits.untag("vocabulary"), k=10
)

for act, tok in zip(top_activations.unwrap(), top_tokens.unwrap()):
  print(f"{act:.4f} | {tokenizer.decode(tok)}")

### Steering Model Output

Lastly, we can try **steering with this feature**. This means intervening in the
residual stream of the model to add some multiple of this feature's decoder
vector, so that we can change the behaviour of the model during generation.

You should see that when we steer the model on this "physical force feature", it
starts talking more about physics (specifically forces like electromagnetism or
gravity). Note that steering can often be fragile; it's difficult to choose the
intervention layer and steering coefficient in a way that gives the expected
behavioural change without also breaking the model's coherence. If you're
curious, you can try increasing the `coeff` parameter below and seeing what
happens!

In [None]:
@pz.pytree_dataclass
class SteerIntermediate(pz.nn.Layer):
  """Define a layer to steer the middle activations."""

  steer_vector: pz.StateVariable
  coeff: float

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    avg_norm = pz.nx.nmap(jnp.linalg.norm)(
        value.untag("embedding"), ord=2, axis=-1
    )  # , keepdims=True).tag("embedding")
    steer_value = value + self.coeff * avg_norm * self.steer_vector
    if (
        value.named_shape["seq"] != 1
    ):  # in prefilling mode, we don't want to change the first token
      steer_value = pz.nx.nmap(jnp.concatenate)(
          [value.untag("seq")[:1], steer_value.untag("seq")[1:]], axis=0
      ).tag("seq")
    return steer_value

We can first check the model output without model steering.

In [None]:
# prepare model inputs
user_prompt = "Tell me a fun fact."
it_inputs = tokenizer.encode(format_prompt(user_prompt), add_bos=True)
it_tokens = jnp.asarray(it_inputs)[None, :]
it_tokens = pz.nx.wrap(it_tokens).tag("batch", "seq")

# prepare inference model and then sample the outputs
inference_model_it = sampling_mode.KVCachingTransformerLM.from_uncached(
    model_it,
    cache_len=1024,
    batch_axes={"batch": 1},
)
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model_it)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=it_tokens,
    temperature=0.0,  # greedy generation
    rng=jax.random.key(1),
    max_sampling_steps=80,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out = tokenizer.decode(sample_tokens).split("<end_of_turn>")[0]
penzai_out

Now let's steer the model output by using our defined `SteerIntermediate` layer.

In [None]:
model_it_steered = (
    pz.select(model_it)
    .at_instances_of(penzai.models.transformer.model_parts.TransformerBlock)
    .pick_nth_selected(layer - 8)
    .insert_after(
        SteerIntermediate(
            steer_vector=jax.tree_util.tree_map(
                lambda x: x.astype(jnp.float32), decoder_vector
            ),
            coeff=0.14,
        )
    )
)

inference_model_it_steered = sampling_mode.KVCachingTransformerLM.from_uncached(
    model_it_steered,
    cache_len=1024,
    batch_axes={"batch": 1},
)

samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model_it_steered)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=it_tokens,
    temperature=0.0,  # greedy generation
    rng=jax.random.key(1),
    max_sampling_steps=80,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out = tokenizer.decode(sample_tokens).split("<end_of_turn>")[0]
penzai_out

Note that steering is expected to be pretty brittle with smaller models.
Generally, larger models (up to a certain point) can better express more complex
concepts and are easier to steer without breaking coherence.

As an exercise, try finding more latents to steer with. Can you come up with any
other interesting prompts and latents?

## Transcoders

A **transcoder** is very similar to an SAE, except rather than reconstructing an
activation vector, it reconstructs the mapping from input vector to some output
(commonly the input and output of an MLP layer). In this way, rather than
decomposing a model's **representations**, it decomposes a model's
**computations**.

Note - this is where a new weight is introduced, the **affine skip connection**.
This is a learned linear transformation from the input to output activations of
the transcoder. You can view it as the learned linear component of the MLP
layer, and the latents represent the nonlinear components.

In [None]:
layer = 17  # @param [7, 13, 17, 22]
width = "65k"  # @param ["16k", "65k", "262k", "1m"]
l0 = "medium"  # @param ["small", "medium", "big"]

path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2-1b-pt",
    filename=f"transcoder/layer_{layer}_width_{width}_l0_{l0}_affine/params.safetensors",
)
params = load_file(path_to_params)
params = {k: v.numpy() for k, v in params.items()}

In [None]:
transcoder_params = {k: v.shape for k, v in params.items()}
transcoder_params

### Implementing the Transcoder

Transcoder uses affine residual connection for the encoder input. The encoder is
the same:

$$\boldsymbol{f}(\boldsymbol{x})=\sigma(\boldsymbol{W}_{\text{enc}}\boldsymbol{x}+\boldsymbol{b}_{\text{enc}})$$

Here $\sigma$ is **JumpReLU** activation defined as ($H$ is the Heaviside step
function and $\theta$ is the threshold.)

Then the decoder reconstructs the input activations by:

$$\hat{\boldsymbol{x}}=\boldsymbol{W}_{\text{dec}}\boldsymbol{f}+\boldsymbol{b}_{\text{dec}} + \boldsymbol{W}_{\text{skip}}\boldsymbol{x}$$

In [None]:
@struct.pytree_dataclass
class AutoEncoderSkip(pz.nn.Layer):
  """Top-level auto-encoder with skip connection wrapper.

  Attributes:
    encoder: The encoder to transform inputs to latents.
    decoder: The decoder to reconstruct inputs from latents.
  """

  encoder: pz.nn.Layer
  decoder: pz.nn.Layer

  def __call__(
      self, x: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the forward pass of the auto-encoder."""
    acts = self.encode(x, **side_inputs)
    recons = self.decode(acts, x, **side_inputs)
    return recons

  def encode(
      self, x: named_axes.NamedArray, **side_inputs: Any
  ) -> named_axes.NamedArray:
    """Applies the encoder sublayer."""
    return self.encoder(x, **side_inputs)

  def decode(
      self,
      acts: named_axes.NamedArray,
      x: named_axes.NamedArray,
      **side_inputs: Any,
  ) -> named_axes.NamedArray:
    """Applies the decoder sublayer."""
    return self.decoder(acts, x, **side_inputs)


@struct.pytree_dataclass
class JoinTwoBranch(pz.nn.Layer):
  """A joiner to combine multiple branches with individual inputs.

  Attributes:
    branch1: The branch network to handle input 1.
    branch2: The branch network to handle input 2.
  """

  branch1: pz.nn.Layer
  branch2: pz.nn.Layer

  def __call__(
      self,
      x: named_axes.NamedArray,
      y: named_axes.NamedArray,
      **side_inputs: Any,
  ) -> named_axes.NamedArray:
    """Applies the forward pass of the joiner."""
    x = self.branch1(x, **side_inputs)
    y = self.branch2(y, **side_inputs)
    return x + y

In [None]:
def transcoder_from_gemma_scope2(
    params_transcoder: dict[str, Any],
) -> AutoEncoder:
  """Constructs an Transcoder model from Gemma scope 2 parameters.

  Args:
    params_transcoder: The parameters of the Gemma scope 2.

  Returns:
    A new SAE model.
  """
  embedding_dim, latents_dim = params_transcoder["w_enc"].shape

  # Encoder
  encoder = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="transcoder/w_enc",
          init_base_rng=None,
          input_axes={"embedding": embedding_dim},
          output_axes={"latents": latents_dim},
      ),
      pz.nn.AddBias.from_config(
          name="transcoder/b_enc",
          init_base_rng=None,
          biased_axes={"latents": latents_dim},
      ),
      JumpReLU.from_config(
          name="transcoder",
          init_base_rng=None,
          threshold_axes={"latents": latents_dim},
      ),
  ])
  # Decoder

  decoder_branch1_layers = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="transcoder/w_dec",
          init_base_rng=None,
          input_axes={"latents": latents_dim},
          output_axes={"embedding": embedding_dim},
      ),
      pz.nn.AddBias.from_config(
          name="transcoder/b_dec",
          init_base_rng=None,
          biased_axes={"embedding": embedding_dim},
      ),
  ])
  decoder_branch2_layers = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="transcoder/affine_skip_connection",
          init_base_rng=None,
          input_axes={"embedding": embedding_dim},
          output_axes={"new_embedding": embedding_dim},
      ),
      pz.nn.RenameAxes(old="new_embedding", new="embedding"),
  ])

  decoder = JoinTwoBranch(
      branch1=decoder_branch1_layers,
      branch2=decoder_branch2_layers,
  )

  # Create the model definition.
  model_def = AutoEncoderSkip(
      encoder=encoder,
      decoder=decoder,
  )

  # Create parameter objects for each parameter.
  model_params = [
      pz.Parameter(
          value=pz.nx.wrap(params_transcoder["w_enc"]).tag(
              "embedding", "latents"
          ),
          label="transcoder/w_enc.weights",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_transcoder["b_enc"]).tag("latents"),
          label="transcoder/b_enc.bias",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_transcoder["w_dec"]).tag(
              "latents", "embedding"
          ),
          label="transcoder/w_dec.weights",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_transcoder["b_dec"]).tag("embedding"),
          label="transcoder/b_dec.bias",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_transcoder["threshold"]).tag("latents"),
          label="transcoder/threshold",
      ),
      pz.Parameter(
          value=pz.nx.wrap(params_transcoder["affine_skip_connection"]).tag(
              "embedding", "new_embedding"
          ),
          label="transcoder/affine_skip_connection.weights",
      ),
  ]

  model_transcoder = pz.bind_variables(
      model_def,
      model_params,
  )
  return model_transcoder

Now we bind the transcoder parameters with the model architecture.

In [None]:
transcoder = transcoder_from_gemma_scope2(params)
transcoder

### Running the Transcoder on model activations

Once nice property about transcoders is that you can use them to find
**circuits**. This is because (if you freeze attention patterns) we can model
the relationship between two transcoder latents in different layers as being
totally **linear**.

The input to transcoder is the input to FFN (after layer-norm) and the target of
transcoder is the output of FFN (also after layer-norm). Now we patch the model
to find these two latents by identifying the location of layer-norms (each
transformer block has 6 layer-norm, 4 in attention, 2 in FNN).

In [None]:
pre_ffn = pz.StateVariable(value=None)
post_ffn = pz.StateVariable(value=None)
model_patched = (
    pz.select(model)
    .at_instances_of(pz.nn.RMSLayerNorm)
    .pick_nth_selected(
        layer * 6 + 4
    )  # each transformer block has 6 layer norm, 4 in attention, 2 in ffn
    .insert_after(SaveIntermediate(pre_ffn))
)
model_patched = (
    pz.select(model_patched)
    .at_instances_of(pz.nn.RMSLayerNorm)
    .pick_nth_selected(layer * 6 + 5)
    .insert_after(SaveIntermediate(post_ffn))
)

In [None]:
prompt = "The quick brown fox jumped over the lazy dog"

token_ids = tokenizer.encode(prompt, add_bos=True)
tokens = jnp.asarray(token_ids)[None, :]
tokens = pz.nx.wrap(tokens).tag("batch", "seq")
logits = model_patched(tokens)

In [None]:
sae_acts = transcoder.encode(pre_ffn.value)
recon = transcoder.decode(sae_acts, pre_ffn.value)

In [None]:
recon_np = recon.untag("batch", "seq", "embedding").unwrap()
sae_target_np = post_ffn.value.untag("batch", "seq", "embedding").unwrap()
sae_acts_np = sae_acts.untag("batch", "seq", "latents").unwrap()

In [None]:
mse = jnp.mean((recon_np[:, 1:] - sae_target_np[:, 1:]) ** 2)
var = sae_target_np[:, 1:].var()
fvu = mse / var
l0 = (sae_acts_np[:, 1:] > 0).sum(-1).mean()

print(f"L0: {l0:.2f}")
print(f"Fraction of variance unexplained: {mse / var:.2%}")

This is a higher FVU than our sparse autoencoder. But since the output of a
single MLP is a less important causal node than the residual stream (which
contains **all accumulated information** up to that layer in the model),
transcoders usually have a smaller delta loss when we patch in their output.

Let's test this, with a slight modification of our previous patching function:

In [None]:
# Define a layer to intervene the model forward
@pz.pytree_dataclass
class TranscoderIntervention(pz.nn.Layer):
  transcoder_recon: named_axes.NamedArray

  def __call__(self, value: Any, /, **_unused_side_inputs) -> Any:
    # we only patch the activations except for the BOS token
    value = pz.nx.nmap(jnp.concatenate)(
        [value.untag("seq")[:1], self.transcoder_recon.untag("seq")[1:]], axis=0
    ).tag("seq")
    return value

In [None]:
prompt_physics = (
    "The law of conservation of energy states that energy cannot be created or"
    " destroyed, only transformed."
)
token_ids = tokenizer.encode(prompt_physics, add_bos=True)
tokens = jnp.asarray(token_ids)[None, :]
tokens = pz.nx.wrap(tokens).tag("batch", "seq")

logits_clean = model_patched(tokens)
sae_acts = transcoder.encode(pre_ffn.value)
recon = transcoder.decode(sae_acts, pre_ffn.value)

In [None]:
model_intervened = (
    pz.select(model)
    .at_instances_of(pz.nn.RMSLayerNorm)
    .pick_nth_selected(layer * 6 + 5)
    .insert_after(TranscoderIntervention(recon))
)
logits_sae = model_intervened(tokens)

In [None]:
loss_clean = (
    cross_entropy_loss(logits_clean, tokens).untag("batch", "seq").unwrap()[0]
)

loss_sae = (
    cross_entropy_loss(logits_sae, tokens).untag("batch", "seq").unwrap()[0]
)

print(f"Loss (clean): {loss_clean.mean().item():.4f}")
print(f"Loss (corrupted): {loss_sae.mean().item():.4f}")
print(f"Delta loss: {(loss_sae.mean() - loss_clean.mean()).item():.4f}")

In [None]:
pio.renderers.default = "colab"  # Force the renderer to Colab

data = {
    "token": list(range(len(str_toks))),
    "Clean": loss_clean.tolist(),
    "SAE": loss_sae.tolist(),
}
df = pd.DataFrame(data)

px.line(
    df, x="token", y=["Clean", "SAE"], labels={"value": "Loss"}
).update_layout(
    xaxis=dict(tickvals=df["token"], ticktext=str_toks, tickangle=45),
    title="Cross-entropy loss with SAE intervention",
    width=1000,
    height=600,
).show()