# Goodfire Llama series SAEs

Before getting started making sure you've added your HF_TOKEN and GOODFIRE_API_KEY to your Colab secrets and granted this notebook access.

Learn more here: https://www.goodfire.ai/blog/sae-open-source-announcement/

## Install nnsight, huggingface_hub, and the Goodfire SDK

nnsight is a package for mechanistic interpretability work by our good friends at NDIF: https://nnsight.net

In [None]:
!pip install nnsight==0.3.0



Use huggingface_hub to download the SAE

In [None]:
!pip install huggingface_hub



Use the Goodfire SDK to search features.

In [None]:
!pip install goodfire



## Import dependencies

In [None]:
import torch
from typing import Optional, Callable

import nnsight


## Specify which language model, which SAE to use, and which layer

In [None]:
MODEL_NAME = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
SAE_NAME = 'Llama-3.1-8B-Instruct-SAE-l19'
SAE_LAYER = 'model.layers.19'
EXPANSION_FACTOR = 16 if SAE_NAME == 'Llama-3.1-8B-Instruct-SAE-l19' else 8

## Define SAE class

In [None]:
class SparseAutoEncoder(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        d_hidden: int,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.device = device
        self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
        self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
        self.dtype = dtype
        self.to(self.device, self.dtype)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a batch of data using a linear, followed by a ReLU."""
        return torch.nn.functional.relu(self.encoder_linear(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a batch of data using a linear."""
        return self.decoder_linear(x)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """SAE forward pass. Returns the reconstruction and the encoded features."""
        f = self.encode(x)
        return self.decode(f), f


def load_sae(
    path: str,
    d_model: int,
    expansion_factor: int,
    device: torch.device = torch.device("cpu"),
):
    sae = SparseAutoEncoder(
        d_model,
        d_model * expansion_factor,
        device,
    )
    sae_dict = torch.load(
        path, weights_only=True, map_location=device
    )
    sae.load_state_dict(sae_dict)

    return sae

## Define language model wrapper

In [None]:


class ObservableLanguageModel:
    def __init__(
        self,
        model: str,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.dtype = dtype
        self.device = device
        self._original_model = model

        self._model = nnsight.LanguageModel(
            self._original_model,
            device_map=device,
            torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
        )

        # Quickly run a trace to force model to download due to nnsight lazy download
        input_tokens = self._model.tokenizer.apply_chat_template([{"role": "user", "content": "hello"}])
        with self._model.trace(input_tokens):
          pass

        self.tokenizer = self._model.tokenizer

        self.d_model = self._attempt_to_infer_hidden_layer_dimensions()

        self.safe_mode = False  # Nnsight validation is disabled by default, slows down inference a lot. Turn on to debug.

    def _attempt_to_infer_hidden_layer_dimensions(self):
        config = self._model.config
        if hasattr(config, "hidden_size"):
            return int(config.hidden_size)

        raise Exception(
            "Could not infer hidden number of layer dimensions from model config"
        )

    def _find_module(self, hook_point: str):
        submodules = hook_point.split(".")
        module = self._model
        while submodules:
            module = getattr(module, submodules.pop(0))
        return module

    def forward(
        self,
        inputs: torch.Tensor,
        cache_activations_at: Optional[list[str]] = None,
        interventions = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
        cache: dict[str, torch.Tensor] = {}
        with self._model.trace(
            inputs,
            scan=self.safe_mode,
            validate=self.safe_mode,
        ):
            # If we input an intervention
            if interventions:
                for hook_site in interventions.keys():
                    if interventions[hook_site] is None:
                        continue

                    module = self._find_module(hook_site)

                    intervened_acts = interventions[
                        hook_site
                    ](module.output[0])
                    # We only modify module.output[0]

                    module.output = (intervened_acts,)

            if cache_activations_at is not None:
                for hook_point in cache_activations_at:
                    module = self._find_module(hook_point)
                    cache[hook_point] = module.output.save()

            logits = self._model.output[0].squeeze(1).save()

            kv_cache = self._model.output.past_key_values.save()

        return (
            logits.detach(),
            kv_cache,
            {k: v[0].detach() for k, v in cache.items()},
        )


## Download and instantiate the Llama model

**This will take a while to download Llama from HuggingFace.**

In [None]:
model = ObservableLanguageModel(
    MODEL_NAME,
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Let's read some activations out from the model.

In [None]:
input_tokens = model.tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    add_generation_prompt=True,
    return_tensors="pt",
)
logits, kv_cache, feature_cache = model.forward(
    input_tokens,
    cache_activations_at=[SAE_LAYER],
)

print(feature_cache[SAE_LAYER].shape)

torch.Size([1, 41, 4096])


## Download and instantiate the SAE

Download from HuggingFace

In [None]:
from huggingface_hub import hf_hub_download

file_path = hf_hub_download(
    repo_id=f"Goodfire/{SAE_NAME}",
    filename=f"{SAE_NAME}.pth",
    repo_type="model"
)

In [None]:
file_path

'/root/.cache/huggingface/hub/models--Goodfire--Llama-3.1-8B-Instruct-SAE-l19/snapshots/f6775a221e47b44233af4bac2c7b65189265519a/Llama-3.1-8B-Instruct-SAE-l19.pth'

Load the SAE

In [None]:
sae = load_sae(
    file_path,
    d_model=model.d_model,
    expansion_factor=EXPANSION_FACTOR,
    device=model.device,
)

You can use the SAE on its own

In [None]:
features = sae.encode(feature_cache[SAE_LAYER])
features.shape

torch.Size([1, 41, 65536])

## Use the Goodfire API to search for features

In [None]:
!pip install goodfire

Collecting goodfire
  Downloading goodfire-0.3.5-py3-none-any.whl.metadata (24 kB)
Collecting httpx<0.28.0,>=0.27.2 (from goodfire)
  Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)
Collecting ipywidgets<9.0.0,>=8.1.5 (from goodfire)
  Downloading ipywidgets-8.1.8-py3-none-any.whl.metadata (2.4 kB)
Collecting comm>=0.1.3 (from ipywidgets<9.0.0,>=8.1.5->goodfire)
  Downloading comm-0.2.3-py3-none-any.whl.metadata (3.7 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets<9.0.0,>=8.1.5->goodfire)
  Downloading widgetsnbextension-4.0.15-py3-none-any.whl.metadata (1.6 kB)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading goodfire-0.3.5-py3-none-any.whl (36 kB)
Downloading httpx-0.27.2-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ipywidgets-8.1.8-py3-none-

In [None]:
import goodfire
from google.colab import userdata

client = goodfire.Client(userdata.get('GOODFIRE_API_KEY'))

In [None]:
MODEL_NAME = 'meta-llama/Meta-Llama-3.1-8B-Instruct'


In [None]:
from tqdm import tqdm

client.features.lookup([1], MODEL_NAME)

output = "index,label\n"

BATCH_SIZE = 512

for i in tqdm(range(0, 4096 * 16, BATCH_SIZE)):
  lookup = client.features.lookup(list(range(i, i + BATCH_SIZE)), MODEL_NAME)

  for j in lookup.keys():
    output += f"{j},{lookup[j].label}\n"

with open("features.csv", "w") as f:
  f.write(output)

100%|██████████| 128/128 [01:53<00:00,  1.12it/s]


## Intervene on the model to change it's outputs

In [None]:
pirate_feature_index = pirate_features[0].index_in_sae
pirate_feature_index

58644

In [None]:
def example_intervention(activations):
    features = sae.encode(activations).detach()
    reconstructed_acts = sae.decode(features).detach()
    error = activations - reconstructed_acts

    # Modify feature at index 0 across all batch positions and token positions
    features[:, :, [pirate_feature_index]] += 12

    # Very important to add the error term back in!
    return sae.decode(features) + error

input_tokens = model.tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    add_generation_prompt=True,
    return_tensors="pt",
)

for i in range(10):
  logits, kv_cache, feature_cache = model.forward(
      input_tokens,
      interventions={SAE_LAYER: example_intervention},
      # use_cache=False,
  )

  new_token = logits[-1].argmax(-1)
  input_tokens = torch.cat([input_tokens[0], new_token.unsqueeze(0).cpu()]).unsqueeze(0)

  print(model.tokenizer.decode(new_token), end="")

RuntimeError: Tensors must have same number of dimensions: got 1 and 2