##### Copyright 2025 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gemma - Activation Hacking

Author: Sascha Rothe

This Colaboratory notebook provides a comprehensive tutorial on interacting with
a Gemma checkpoint. It will guide you through the process of loading the model,
generating samples, and examining its internal states, including the residual
stream, MLP activations, and attention mechanisms. Furthermore, an example
demonstrating how to modify the model's behavior will be presented. While the
notebook includes an illustrative example, we encourage you to adapt these
experiments and explore the model's behavior to uncover novel insights.

**Note: This colab was tested with a A100 GPU. You might see unexpected behaviour on a different hardware.**

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3]Activation_Hacking.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

In [None]:
#@title Install dependencies
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope

To interact with the Gemma model, you will use the Flax NNX gemma code from
google/flax examples on GitHub. Since it is not exposed as a package, you need
to use the following workaround to import from the Flax NNX examples/gemma on
GitHub.

In [None]:
#@title Python imports
import html
import os
import sys
import tempfile
from flax import nnx
from google.colab import userdata
from IPython.display import HTML
import jax
import kagglehub
from matplotlib import colors, pyplot
import numpy as np
import sentencepiece as spm

with tempfile.TemporaryDirectory() as tmp:
  # Create a temporary directory and clone the `flax` repo.
  # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
  ! git clone https://github.com/google/flax.git {tmp}/flax
  sys.path.append(f"{tmp}/flax/examples/gemma")
  import modules
  import params as params_lib
  import sampler as sampler_lib
  import sow_lib
  import transformer as transformer_lib

  sys.path.pop()

To use Gemma model, you’ll need a Kaggle account and API key:

1.  To create an account, visit Kaggle and click on ‘Register’.

2.  If/once you have an account, you need to sign in, go to your ‘Settings’, and
    under ‘API’ click on ‘Create New Token’ to generate and download your Kaggle
    API key.

3.  [Optional] In Google Colab, under ‘Secrets’ add your Kaggle username and API
    key, storing the username as KAGGLE_USERNAME and the key as KAGGLE_KEY.

4.  Request access to the model here:
    https://www.kaggle.com/models/google/gemma-3

In [None]:
#@title Select a Gemma 3 model from kaggle
# Surpress Noisy progress bar.
%%capture captured_output --no-stdout

VARIANT = 'gemma3-1b' # @param ['gemma3-1b', 'gemma3-1b-it', 'gemma3-4b', 'gemma3-4b-it', 'gemma3-12b', 'gemma3-12b-it', 'gemma3-27b', 'gemma3-27b-it'] {type:"string"}

try:
  os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
  os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
except userdata.SecretNotFoundError:
  kagglehub.login()

print("Downloading model ...")
weights_dir = kagglehub.model_download(f'google/gemma-3/flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
print("Done.")

## Start Setting up Your Model

Prepare the Sow Configuration and set the intermediates you want to surface. For the tutorial, enable `embeddings`, `rs_after_ffw` and `rs_after_attention` for the residual stream. Also set `mlp_hidden_topk=10` to surface the activations in the MLP layers (also called feedforward layer).

In [None]:
embeddings = False  # @param {"type":"boolean"}
rs_after_ffw = False  # @param {"type":"boolean"}
rs_after_attention = False  # @param {"type":"boolean"}
mlp_hidden_topk = 0  # @param {"type":"integer"}
attn_logits_topk = 0  # @param {"type":"integer"}

sow_config = sow_lib.SowConfig(
    embeddings=embeddings,
    rs_after_ffw=rs_after_ffw,
    rs_after_attention=rs_after_attention,
    mlp_hidden_topk=mlp_hidden_topk,
    attn_logits_topk=attn_logits_topk,
)

Now, build the model and required modules.

In [None]:
# Load parameters
params = params_lib.load_and_format_params(ckpt_path)

# Tokenizer
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

# Transformer Model
transformer = transformer_lib.Transformer.from_params(
    params, sow_config=sow_config
)

# Sampler
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
)

You're ready to start generating! As we want to visualize the intermediate
activations, we are using a batch size of 1. We limit the models output by setting `total_generation_steps = 1`, so we can investigate it more easily.

In [None]:
def run_model(prompt):
  prompt_length = len(vocab.EncodeAsIds(prompt)) + 1
  total_generation_steps = 1

  out_data = sampler(
      input_strings=[prompt],
      echo=True,  # This returns the prompt as well.
      total_generation_steps=total_generation_steps,
  )
  print(out_data.text[0])

  out_length = np.count_nonzero(out_data.tokens)

  return prompt_length, out_length, out_data

In [None]:
prompt_length, out_length, out_data = run_model("What is the capital of Switzerland? Answer:")

## Investigate the model

In order to investigate intermediate outputs we map them into the text space by performing premature decoding. To be precise we have the apply the final norm and the softmax projection. We create a helper function `premature_decode`.

In [None]:
def format_token(string):
  string = string.replace('▁', ' ')
  string = string.replace('<', ' <')
  string = html.escape(string)
  return string


def id_to_token(i):
  return format_token(vocab.IdToPiece(i))


def premature_decode(residual_stream):
  residual_stream = transformer.final_norm(residual_stream)
  logits = transformer.embedder.decode(residual_stream)
  _, token_ids = jax.lax.top_k(logits, 10)
  tokens = []
  for top10_token_ids in token_ids:
    tokens.append([id_to_token(int(id)) for id in top10_token_ids])
    if top10_token_ids[0] == vocab.eos_id():
      break
  return tokens

To visualize the residual stream we create:

*   a green row for the embedding layer
*   a red row for each attention layer
*   a blue row for each ffw layer

You can see how the output evolves from bottom to top. Darker the colors
indicate a bigger change during this layer in the residual stream. Note that in
the early layers the premature token is simply the previous token. This is
because the residual stream was initialized with the embedding of the previous
token. You might also find some language code switching between the layers.
For the example *What is the capital of Switzerland? Answer: Bern* you should see that the answer *Bern* was only decided on the the very last feedforward layer.

In [None]:
# @title Display residual stream


def get_style(prob, color):
  cmap = pyplot.get_cmap(color)
  rgb = cmap(prob)[:3]
  bg = colors.rgb2hex(rgb)
  fg = 'black' if sum(rgb) > 1.5 else 'white'
  return 'color: {};background-color: {};'.format(fg, bg)


def cosine_similarity(a, b):
  return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def get_premature_tokens(
    residual_stream, previous_rs, title, color, length=out_length
):
  batch_size, _, _ = residual_stream.shape
  assert batch_size == 1
  previous_rs = previous_rs[0, 1:length, :]
  premature_rs = residual_stream[0, 1:length, :]
  premature_tokens = premature_decode(premature_rs)
  tds = ''
  for i, top10_premature_tokens in enumerate(premature_tokens):
    value = 1 - cosine_similarity(premature_rs[i], previous_rs[i])
    top_premature_token, *others = top10_premature_tokens
    title_for_token = title + '\n' + '/'.join(others)
    tds += "<td title='{}' style='{}'>{}</td>".format(
        title_for_token, get_style(value * 70, color), top_premature_token
    )
  return tds


def print_premature_layers(intermediates):
  trs = []
  previous_rs = np.ones_like(intermediates.embeddings)
  tds = get_premature_tokens(
      intermediates.embeddings,
      previous_rs,
      'After Embedding:',
      color='Greens',
  )
  previous_rs = intermediates.embeddings
  trs.append(f'<tr>{tds}</tr>')
  for layer_id, layer in enumerate(intermediates.layers):
    tds = get_premature_tokens(
        layer.rs_after_attention,
        previous_rs,
        f'After Attention {layer_id}:',
        color='Reds',
    )
    trs.append(f'<tr>{tds}</tr>')
    previous_rs = layer.rs_after_attention
    tds = get_premature_tokens(
        layer.rs_after_ffw,
        previous_rs,
        f'After FFW {layer_id}:',
        color='Blues',
    )
    trs.append(f'<tr>{tds}</tr>')
    previous_rs = layer.rs_after_ffw
  tds = get_premature_tokens(
      intermediates.embeddings[:, 1:, :],
      previous_rs,
      'Forced:',
      color='Greens',
      length=prompt_length,
  )
  trs.append(f'<tr>{tds}</tr>')
  trs.reverse()
  html_string = f'<table>{"".join(trs)}</table>'
  return HTML(html_string)


print_premature_layers(out_data.intermediates)

The next cell will visualize the top activated neurons in a given layer. You can hover over the colored blocks to get the neuron id and value. For the last feedfordward layer _25_ you should see that the top activated neurons are *1937* and *4422*.

In [None]:
# @title Activation in Feedforward Layers
layer = 0  # @param {"type":"integer"}


def get_activations_line(step, token_id, intermediates):
  tr = '<td>{}</td>'.format(id_to_token(token_id))
  values = intermediates.layers[layer].mlp_hidden_topk_values[0, step, :]
  indices = intermediates.layers[layer].mlp_hidden_topk_indices[0, step, :]
  mouseover_texts = [f'Layer: {layer}'] * 70
  colors = [0.0] * 70
  for value, neuron in zip(values, indices):
    neuron = int(neuron)
    mouseover_texts[neuron // 100] += f'\nNeuron: {neuron}, Value: {value:3.2f}'
    colors[neuron // 100] += value / values[0]
  for mouseover_text, color in zip(mouseover_texts, colors):
    style = get_style(color, 'Blues')
    tr += f"<td title='{mouseover_text}' style='{style}'>&nbsp;&nbsp;</td>"

  return '<tr>{}</tr>'.format(tr)


def print_activations(tokens, intermediates):
  html_string = ''
  for step, token in enumerate(tokens):
    html_string += get_activations_line(step, token, intermediates)
    if token == vocab.eos_id():
      break
  html_string = f'<table>{html_string}</table>'
  return HTML(html_string)


print_activations(out_data.tokens[0], out_data.intermediates)

After identifiying a neuron of interest we can deactive or boost it. You can play around with the bias values to achieve different behaviours.
1.  **Deactivate** neuron 1937 in layer 25 and issue the same prompt again.
2.  **Boost** neuron 1937 in layer 25 and your model should repsonse with *Switzerland* significantly more often.


In [None]:
#@title Mask/Boost ar single neuron
layer = 0  # @param {"type":"integer"}
neuron = 0  # @param {"type":"integer"}
operation = "none"  # @param ["none","deactivate","boost"]

gate_bias = np.zeros(6912)
if operation == "none":
  gate_bias[neuron] = 0.0
elif operation == "boost":
  gate_bias[neuron] = 1.0
elif operation == "deactivate":
  gate_bias[neuron] = -1000.0
transformer.layers[layer].mlp.gate_proj.use_bias = True
transformer.layers[layer].mlp.gate_proj.bias = nnx.Variable(gate_bias)
print(f"Gate bias of neuron {neuron} in layer {layer} set to {gate_bias[neuron]}")

if operation == "boost":
  up_bias = np.zeros(6912)
  up_bias[neuron] = 20.0
  transformer.layers[layer].mlp.up_proj.use_bias = True
  transformer.layers[layer].mlp.up_proj.bias = nnx.Variable(up_bias)
  print(f"Up proj bias of neuron {neuron} in layer {layer} set to {up_bias[neuron]}")

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
)


We can also apply the `premature_decode` functions to the value of a neuron to investigate the effect of a neuron. Check neuron *1937* of layer *25* to verify why the model behaviour changed by deactivating or boosting the neuron.

In [None]:
#@title Print top k associated tokens.
layer = 0  # @param {"type":"integer"}
neuron = 0  # @param {"type":"integer"}
variable = "down_proj"  # @param ["up_proj","gate_proj","down_proj"]

if variable == "up_proj":
  embedding = transformer.layers[layer].mlp.up_proj.kernel[:, neuron]
elif variable == "gate_proj":
  embedding = transformer.layers[layer].mlp.gate_proj.kernel[:, neuron]
else:
  embedding = transformer.layers[layer].mlp.down_proj.kernel[neuron, :]

normalized_stream = transformer.final_norm(embedding)
logits = transformer.embedder.decode(normalized_stream)
_, token_ids = jax.lax.top_k(logits, 20)
for i, token_id in enumerate(token_ids):
  print(f"{i}: {id_to_token(int(token_id))}")

We can also take a look at the attention mechanism. For the purpose of this
tutorial we take the last layer and average over all heads.

In [None]:
#@title Attention visualization
html_header = """
<style>
  span:hover {
    color:white !important;
    background-color:#00670d !important;
  }
</style>
<script>
  const componentToHex = (c) => {
    const hex = c.toString(16);
    return hex.length == 1 ? "0" + hex : hex;
  }
  function hueToRgb(t1, t2, hue) {
    if (hue < 0) hue += 6;
    if (hue >= 6) hue -= 6;
    if (hue < 1) return (t2 - t1) * hue + t1;
    else if(hue < 3) return t2;
    else if(hue < 4) return (t2 - t1) * (4 - hue) + t1;
    else return t1;
  }
  function hslToRgb(hue, sat, light) {
    var t1, t2, r, g, b;
    hue = hue / 60;
    if ( light <= 0.5 ) {
      t2 = light * (sat + 1);
    } else {
      t2 = light + sat - (light * sat);
    }
    t1 = light * 2 - t2;
    r = hueToRgb(t1, t2, hue + 2) * 255;
    g = hueToRgb(t1, t2, hue) * 255;
    b = hueToRgb(t1, t2, hue - 2) * 255;
    r = componentToHex(Math.floor(r));
    g = componentToHex(Math.floor(g));
    b = componentToHex(Math.floor(b));
    console.log(r, g, b);
    return `#${r}${g}${b}`;
  }
  function getStyle(value, h=0) {
    value = Math.min(value, 1.0)
    fg = value < 0.7 ? "black" : "white";
    bg = hslToRgb(h, 1.0, 1-(value*0.8));
    return `color: ${fg}; background-color: ${bg}`;
  }
  function showAttention(self, atten_probs) {
    for (i = 0; i < atten_probs.length; i++) {
      atten_prob = atten_probs[i]
      document.getElementById('token_' + i).title = `Value: ${atten_prob}`;
      document.getElementById('token_' + i).style.cssText = getStyle(atten_prob);
    }
    self.style.cssText = 'color: white; background-color: #00670d';
  }
</script>"""


def print_attention(token_ids, intermediates):
  attention_html = '<p>'
  output_html = '<p>'
  for i, token_id in enumerate(token_ids):
    # Get topk attention values of current token and unsparsify.
    all_atten_probs = []
    # Last layer only.
    for layer in intermediates.layers[-1:]:
      # Average over all heads.
      _, _, num_heads, _ = layer.attn_logits_topk_values.shape
      for head in range(num_heads):
        atten_logits = modules.K_MASK * np.ones(out_length)
        for value, index in zip(
            layer.attn_logits_topk_values[0, i, head, :],
            layer.attn_logits_topk_indices[0, i, head, :],
        ):
          atten_logits[index] = value
        # The models tends to attend a lot towards the BOS token. Mask this
        # to have a more meaningful visualization.
        atten_logits[0] = modules.K_MASK
        # Note that this softmax is an approximation as we have topks only.
        atten_probs = np.exp(atten_logits) / (
            sum(np.exp(atten_logits)) + 0.000001
        )
        all_atten_probs.append(atten_probs)
    avg_atten_probs = np.sum(all_atten_probs, axis=0)
    avg_atten_probs /= np.max(avg_atten_probs) + 0.000001

    token = id_to_token(token_id)
    atten_probs_string = (
        np.array2string(avg_atten_probs, separator=',')
        .replace('\n', '')
        .replace(' ', '')
    )
    onclick = f'showAttention(this, {atten_probs_string})'
    attention_html += "<span id='token_{}' onclick='{}'>{}</span>".format(
        i, onclick, token
    )

    if token_id == vocab.eos_id():
      break

  attention_html += '</p>'
  output_html += '</p>'

  return HTML(html_header + attention_html)


print_attention(out_data.tokens[0], out_data.intermediates)