# Gemma Scope Tutorial

This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a "microscope" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!

**Learn more:**
* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).
* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).
* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details



For illustrative purposes, we begin with a lightweight tutorial that uses as few libraries as possible to outline how Gemma Scope works, and what Sparse Autoencoders are doing. This is deliberately a fairly minimalist tutorial, designed to make clear what is actually going on, but does not model research best practices.

For any serious research with Gemma Scope, **we recommend using the [SAELens](https://jbloomaus.github.io/SAELens/) and [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) libraries**, see [this tutorial](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb) on how to use [SAELens](https://jbloomaus.github.io/SAELens/) in practice.


## Loading the Model

First, let's load the model:

For simplicity we do this straight from [HuggingFace transformers](https://huggingface.co/docs/transformers/en/index), rather than using an interpretability focused library like [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) or [nnsight](https://nnsight.net/)

In [1]:
import os
os.environ['HF_HOME'] = '/mnt/datadisk/Mehdi'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch



We load Gemma 2 2B, the smallest model that Gemma Scope works for. We load the base model, not the chat model, since that's where our SAEs are trained. Though the SAEs seem to transfer OK to these models. First, you'll need to authenticate with huggingface in order to download the model weights

In [2]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

Now we've loaded the model, let's try running it! We give it the prompt "Would you be able to travel through time using a wormhole?" and print the generated output

In [4]:
# The input text
prompt = "Would you be able to travel through time using a wormhole?"

# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda")
print(inputs)

# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

tensor([[     2,  18925,    692,    614,   3326,    577,   5056,   1593,   1069,
           2177,    476,  47420,  18216, 235336]], device='cuda:0')
<bos>Would you be able to travel through time using a wormhole?

[Answer 1]

Yes, you can travel through time using a wormhole.

A wormhole is a theoretical object that connects two points in space-time. It is a tunnel through space-time that allows objects to travel from


## Loading a Sparse Autoencoder

OK, so we have got Gemma 2 loaded and can sample from it to get sensible stuff. Now, let's load one of our SAEs.

GemmaScope actually contains over four hundred SAEs, but for now we'll just load one on the residual stream at the end of layer 20 (of 26, note that layers start at 0 so this is the 21st layer. This is a fairly late layer, so the model should have time to find more abstract concepts!).

See [the final section](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp?authuser=2#scrollTo=E7zjkVseLSPp) for more information on how to load all the other SAEs in Gemma Scope

<details><summary>What is the residual stream?</summary>

Transformers have skip connections, which means that the output of each block is the output of each sublayer *plus* the input to the block. This means that each sublayer (attention or MLP) actually only has a fairly small effect on the output of the block, since most of it comes from all the earlier layers. We call the output of a block (including skip connections) the **residual stream**.

Everything communicated from earlier layers to later layers must go via the residual stream, so it acts as a "bottleneck" in the transformer, essentially capturing everything the model has "thought" so far. This means it is often a natural thing to study, since it will contain everything important going on in the model.
</details>


In [5]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)


In [6]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}


In [7]:
{k:v.shape for k, v in pt_params.items()}

{'W_dec': torch.Size([16384, 2304]),
 'W_enc': torch.Size([2304, 16384]),
 'b_dec': torch.Size([2304]),
 'b_enc': torch.Size([16384]),
 'threshold': torch.Size([16384])}

In [8]:
pt_params["W_enc"].norm(dim=0)

tensor([1.2101, 1.1695, 0.9836,  ..., 1.0630, 0.9997, 1.1070], device='cuda:0')

### Implementing the SAE


We now define the forward pass of the SAE for pedagogical purposes (in practice, we recommend using the implementation in SAELens)

Gemma Scope is a collection of [JumpReLU SAEs](https://arxiv.org/abs/2407.14435), which is like a standard two layer (one hidden layer) neural network, but where the activation function is a **JumpReLU**: a ReLU with a discontinuous jump.

In [9]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon


In [10]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

### Running the SAE on model activatinos


Let's first get out some activations from the model at the SAE target site. We'll demonstrate how to do this 'manually' first, by using Pytorch hooks. Note that this is not particularly good practice, and it's probably more practical to use a library like TransformerLens to handle hooking the SAE into a model forward pass. But for illustrative purposes, it's useful to see how it's done.

We can gather activations at a site by registering a hook. To keep this local, we can wrap this in a function that registers a hook, runs the model, saving the intermediate activation, then removes the hook. (This is basically what TransformerLens is doing under the hood)

In [11]:

def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [12]:

target_act = gather_residual_activations(model, 20, inputs)

Now, we can run our SAE on the saved activations.

In [17]:
sae = sae.to('cuda:1')

In [18]:
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

Let's just double check that the model looks sensible by checking that we explain a decent chunk of the variance:

In [19]:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())

tensor(0.8887, device='cuda:1')

This probably looks OK! This SAE is supposed to have an L0 of around 70, so let's just check that too:

In [20]:
(sae_acts > 1).sum(-1)

tensor([[7017,   47,   65,   70,   55,   72,   65,   75,   80,   72,   68,   93,
           86,   89]], device='cuda:1')

It's always worth checking this sort of thing when you do this by hand to check that you haven't got the wrong site, or are missing a scaling factor or something like this. But here, our results all look like they are supposed to .

Note that there's a bit of a gotcha here; our SAEs are *NOT* trained on the BOS token, because we found that this tended to be a large outlier and to mess up training. So they tend to give nonsense when we apply to them to it, and we need to be careful not to do this accidentally! We can see this above : the BOS token is a total outlier in terms of L0!

Let's look at the highest activating features on this input text, on each token position:

In [21]:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631,  5482, 10376,  1670, 11023,  7562,  9407,  8399, 12935, 10004,
         10004, 10004, 12935,  3442]], device='cuda:1')

So we see that one of the max activating examples on this question is [SAE feature 10004](https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/10004), which fires on concepts related to time travel! We can visualise this below in the notebook, embedding the neuronpedia dashboard in the colab cell:


In [22]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)

### SAELens

We recommend using SAELens: https://github.com/jbloomAus/SAELens for research on SAEs. To load this SAE in SAELens, run the following:

In [23]:
!pip install sae-lens

from sae_lens import SAE  # pip install sae-lens

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = "layer_20/width_16k/canonical",
)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting torch>=2.6 (from transformer-lens<3.0.0,>=2.16.1->sae-lens)
  Using cached torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting sympy>=1.13.3 (from torch>=2.6->transformer-lens<3.0.0,>=2.16.1->sae-lens)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch>=2.6->transformer-lens<3.0.0,>=2.16.1->sae-lens)
  Using cached nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch>=2.6->transformer-lens<3.0.0,>=2.16.1->sae-lens)
  Using cached nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch>=2.6->transformer-lens<3.0.0,>=2.16.1->sae-lens)
  Using cached nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cudnn

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
  sae, cfg_dict, sparsity = SAE.from_pretrained(


Full SAELens documentation, tutorials, etc. can be found at https://github.com/jbloomAus/SAELens