# Inspecting and Patching Gemma With Penzai - ICLR 2024


Penzai is a JAX research toolkit for building, editing, and visualizing neural networks. This demo shows how to use it to inspect and patch the Gemma open-weights models. (Want a more in-depth tutorial? Check out the [Penzai documentation](https://penzai.readthedocs.io/)!)

You can follow along yourself using your own copy of this notebook:

- To load Gemma 2B, you can use either a **"TPU v2"** or **"T4 GPU"** Colab runtime.
  - TPU v2 is recommended. If you use a T4 GPU runtime, it may run out of memory in Part D of the demo.
- To load Gemma 7B, you'll need to connect to a **"TPU v2"** Colab runtime.

You can change your runtime type using the "Runtime" menu at the top of Colab.

Before you start, you'll also need to:

- Sign up for a Kaggle account at  https://www.kaggle.com/ if you don't have one already
- Consent to the Gemma Terms of Use at https://www.kaggle.com/models/google/gemma/license/consent
- Generate a Kaggle API key:
  - Go to your account settings (https://www.kaggle.com/settings), then the ‘API’ section.
  - Click ‘Create new token’ to download your key.


## Setting up

### Connecting to Kaggle

Run the cell below, then enter your username and Kaggle API key (from https://www.kaggle.com/settings):

In [None]:
import kagglehub
kagglehub.login()

You should see "Kaggle credentials successfully validated."

Next select which model you want to use, based on the Colab runtime you are connected to:

In [None]:
model_choice = "Gemma 2B (any accelerator kernel)" # @param ["Gemma 7B (for TPU v2 kernel)", "Gemma 2B (any accelerator kernel)"]

### Imports and Configuration

Running these cells will install Penzai in your runtime and set it up as the default pretty-printer.

In [None]:
from __future__ import annotations

In [None]:
try:
  import penzai
except ImportError:
  !pip install "penzai[notebook]>=0.1.1,<0.2"

In [None]:
from typing import Any, Callable
import dataclasses
import os
import traceback
import gc
import collections
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils
import sentencepiece as spm

In [None]:
# Allow using ~all GPU memory if using a Colab GPU kernel.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".98"

In [None]:
import penzai
from penzai import pz
from penzai.example_models import gemma
from penzai.toolshed import basic_training
from penzai.toolshed import token_visualization
from penzai.toolshed import jit_wrapper
from penzai.toolshed import lora
from penzai.toolshed import auto_nmap
from penzai.toolshed import model_rewiring

In [None]:
nx_jax = auto_nmap.wrap_module(jax)
nx_jnp = auto_nmap.wrap_module(jnp)

In [None]:
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

In [None]:
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

### Loading Gemma

If you added a Kaggle API key and agreed to the consent form, you can then load it as follows:

In [None]:
if model_choice.startswith("Gemma 7B"):
  model_variant = '7b'
elif model_choice.startswith("Gemma 2B"):
  model_variant = '2b'
else:
  raise NotImplementedError()

weights_dir = kagglehub.model_download(f"google/gemma/Flax/{model_variant}")
ckpt_path = os.path.join(weights_dir, model_variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
checkpointer = orbax.checkpoint.PyTreeCheckpointer()

In [None]:
metadata = checkpointer.metadata(ckpt_path)
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

In [None]:
gemma_model = gemma.model_core.GemmaTransformer.from_pretrained(
    flat_params,
    upcast_activations_to_float32=True,
)

### Helper functions
These cells define some helper functions that will be useful for interacting with Gemma.

In [None]:
def tokenize_batch(examples, pad_length=32, include_eos=True):
  padded_tokens = []
  for example in examples:
    example_tokens = [vocab.bos_id()] + vocab.EncodeAsIds(example)
    if include_eos:
      example_tokens = example_tokens + [vocab.eos_id()]
    assert len(example_tokens) <= pad_length
    example_tokens = example_tokens + [vocab.pad_id()] * (pad_length - len(example_tokens))
    padded_tokens.append(example_tokens)
  return pz.nx.wrap(jnp.array(padded_tokens)).tag("batch", "seq")

In [None]:
def xent_loss_fn(model, rng, state, input_examples):
  del rng, state

  outputs = model(gemma.model_core.GemmaInputs.from_basic_segments(
      input_examples[{"seq": pz.slice[:-1]}]
  ))
  all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      outputs.untag("vocabulary")
  ).tag("vocabulary")

  correct_next_tokens = input_examples[{"seq": pz.slice[1:]}]
  correct_log_probs = pz.nx.nmap(jnp.where)(
      correct_next_tokens == vocab.pad_id(),
      0.0,
      all_log_probs[{"vocabulary": correct_next_tokens}],
  )

  loss = -correct_log_probs.untag("batch", "seq").unwrap().mean()
  return loss, None, {"loss": loss}

In [None]:
xent_loss_train_step = basic_training.build_train_step_fn(
    xent_loss_fn, donate_params_and_state=True
)

## Demo Part A: Visualizing and running the model

In [None]:
# Show the model:
gemma_model

In [None]:
# Tokenize some text:
tokens = [
    "Penzai includes a number of general-purpose tools for analyzing JAX neural networks.",
    "It also includes a declarative neural-network library designed to take advantage of those tools.",
]
tokenized_prompts = tokenize_batch(tokens, 32, include_eos=True)

In [None]:
# Show the tokens
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokenized_prompts

In [None]:
# Run the model on the tokens:
example_input = gemma.model_core.GemmaInputs.from_basic_segments(tokenized_prompts)
output = gemma_model(example_input)
output

In [None]:
# Compute probabilities:
nx_jax.nn.softmax(output.untag("vocabulary")).tag("vocabulary")

In [None]:
# To extract part of a model, click a "copy path" button...
accessor_fn = REPLACE_ME # <- ...then paste it here

layer = accessor_fn(gemma_model)
layer

In [None]:
# Call it on an appropriately-sized input:
# (Assuming you copied a query/key/value projection layer, you can use an
# embedding size of 2048 for the 2B model or 3072 for the 7B model.)
layer(pz.nx.ones({"embedding": 2048}))

In [None]:
# Try clicking the pretty-printed output, above, pressing "r" for roundtrip
# mode, then copying and pasting that output below:



## Demo Part B: Inspecting model intermediate values

In [None]:
# Select all softmax operations:
(
    pz.select(gemma_model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at_instances_of(pz.nn.Softmax)
)

In [None]:
# Insert new logic:

@pz.pytree_dataclass
class ShowValue(pz.Layer):
  def __call__(self, x):
    print("My intermediate value:", x)
    return x

verbose_model = (
    pz.select(gemma_model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at_instances_of(pz.nn.Softmax)
    .insert_after(ShowValue())
)

In [None]:
# Check what you've inserted by printing it out:
verbose_model

In [None]:
# Run it:
verbose_model(example_input)

In [None]:
# New input sequence (because repeating text has more interesting patterns)
example_text = (
    "Penzai: A JAX research toolkit for building, editing, and visualizing neural networks."
    + " " + "Penzai: A JAX research toolkit for building, editing, and visualizing neural networks."
)
tokens = jnp.array([vocab.bos_id()] + vocab.EncodeAsIds(example_text))
token_seq = pz.nx.wrap(tokens).tag("seq")
single_input = gemma.model_core.GemmaInputs.from_basic_segments(token_seq)

In [None]:
# Extract intermediate values:
side_output_model = pz.de.CollectingSideOutputs.handling(
    pz.select(gemma_model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at_instances_of(pz.nn.Softmax)
    .insert_after(pz.de.TellIntermediate())
)

_, side_outs = side_output_model(single_input)
all_attentions = pz.nx.stack(
    [out.value for out in side_outs],
    "blocks",
)

In [None]:
# Visualize them:
tok_strs = [repr(vocab.IdToPiece(int(t))) for t in tokens]
pz.ts.render_array(
    all_attentions,
    axis_item_labels={"seq": tok_strs, "kv_seq": tok_strs},
    rows=["seq", "blocks"], columns=["kv_seq", "heads"],
    valid_mask=single_input.attention_mask,
)

## Demo Part C: Modifying intermediate values

In [None]:
# Some input text: A repeated sequence of otherwise-unpredictable digits.
example_text = (
    "01976954310149754605"
    + "01976954310149754605"
)
tokens = jnp.array([vocab.bos_id()] + vocab.EncodeAsIds(example_text))
token_seq = pz.nx.wrap(tokens).tag("seq")
single_input = gemma.model_core.GemmaInputs.from_basic_segments(token_seq)

In [None]:
# Score them under the model:
logits = gemma_model(single_input)
log_probs = nx_jax.nn.log_softmax(logits.untag("vocabulary")).tag("vocabulary")
sliced_preds = log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = token_seq[{"seq": pz.slice[1:]}]
log_prob_of_correct_next = sliced_preds[{"vocabulary": correct_next_token}]
token_visualization.show_token_scores(
    correct_next_token, nx_jnp.exp(log_prob_of_correct_next), vocab, vmax=1
)

In [None]:
# Identify a specific subset of heads
# (see https://penzai.readthedocs.io/en/stable/notebooks/induction_heads.html)
if model_choice.startswith("Gemma 7B"):
  mask_shape = (28, 16)
  block_indices = jnp.array([5,14,20,21,21,21])
  head_indices = jnp.array([0,15,13,1,2,5])
elif model_choice.startswith("Gemma 2B"):
  mask_shape = (18, 8)
  block_indices = jnp.array([11,14,14])
  head_indices = jnp.array([3,0,4])
else:
  raise NotImplementedError()

top_heads_mask = pz.nx.wrap(
    jnp.ones(mask_shape).at[block_indices, head_indices].set(0.0)
).tag("blocks", "heads")
top_heads_mask

In [None]:
# Knock them out:
def knock_out_heads(model, head_mask_per_block):
  parts = list(head_mask_per_block.untag("blocks"))
  return (
      pz.select(model)
      .at_instances_of(gemma.model_core.GemmaAttention)
      .at_instances_of(pz.nn.Softmax)
      .insert_after("<placeholder>", and_select=True)
      .set_sequence(
          model_rewiring.KnockOutAttentionHeads(part) for part in parts
      )
  )

knockout_model = knock_out_heads(gemma_model, top_heads_mask)

In [None]:
# Run the knocked-out model:
logits = knockout_model(single_input)

# Result: Much less confident and less accurate predictions!
log_probs = pz.nx.nmap(jax.nn.log_softmax)(logits.untag("vocabulary")).tag("vocabulary")
sliced_preds = log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = token_seq[{"seq": pz.slice[1:]}]
log_prob_of_correct_next = sliced_preds[{"vocabulary": correct_next_token}]
token_visualization.show_token_scores(correct_next_token, pz.nx.nmap(jnp.exp)(log_prob_of_correct_next), vocab, vmax=1)

## Demo Part D: Low-rank Finetuning and Sampling

In [None]:
# Freeze the existing weights.
frozen_gemma_model = (
    pz.select(gemma_model)
    .at_instances_of(pz.nn.Parameter)
    .apply(
        lambda param: pz.nn.FrozenParameter(param.value, param.name)
    )
)

In [None]:
# Replace linear layers with low-rank adapter layers:
lora_model_def = (
    pz.select(frozen_gemma_model)
    .at_instances_of(gemma.model_core.GemmaAttention)
    .at_instances_of(pz.nn.Linear)
    .apply(
        lambda k, layer: lora.LowRankAdapter.from_linear(
            layer, rank=16, name=jax.tree_util.keystr(k)
        ),
        with_keypath=True
    )
)

In [None]:
# Initialize the new parameters:
lora_model = pz.nn.initialize_parameters(lora_model_def, jax.random.key(10))

In [None]:
# Look at it:
lora_model

In [None]:
# Train it on a synthetic task:
def generate_example(np_rng):
  a, b = np_rng.choice(1000, size=(2,))
  return f">>> mystery_function({a}, {b})\n{a + b}"

In [None]:
train_state = basic_training.TrainState.initial_state(
    model=lora_model,
    optimizer_def=optax.adamw(5e-5, weight_decay=0.01),
    root_rng=jax.random.key(42),
)
np_rng = np.random.default_rng(123)

In [None]:
while train_state.step < 200:
  input_examples = tokenize_batch([
      generate_example(np_rng) for _ in range(16)
  ])
  train_state, out = xent_loss_train_step(train_state, input_examples=input_examples)
  if train_state.step % 10 == 0:
    print(train_state.step, out)

In [None]:
# Convert it to sampling mode:
finetuned_inference_model, initial_inference_state = (
  gemma.sampling_mode.GemmaKVCachingTransformer.from_uncached(
      train_state.model,
      cache_len=64,
      batch_axes={"batch": 4},
  )
)

In [None]:
# Make some prompts:
prompts = [
    ">>> mystery_function(123, 123)",
    ">>> mystery_function(101, 15)",
    ">>> mystery_function(999, 876)",
    ">>>", # Let the model write and solve its own problem
]

In [None]:
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokenized_prompts = tokenize_batch(prompts, 16, include_eos=False)
tokenized_prompts

In [None]:
# Draw samples:
samples = gemma.simple_decoding_loop.temperature_sample_pyloop(
    jit_wrapper.Jitted(finetuned_inference_model),
    initial_inference_state,
    prompt=tokenized_prompts,
    rng=jax.random.key(3),
    pad_id=vocab.pad_id(),
    max_sampling_steps=20,
)

In [None]:
# And visualize them:
token_visualization.show_token_array(samples, vocab)