### Imports

In [1]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

### Setting up configs

In [2]:
print(get_pretrained_saes_directory())

{'deepseek-r1-distill-llama-8b-qresearch': PretrainedSAELookup(release='deepseek-r1-distill-llama-8b-qresearch', repo_id='qresearch/DeepSeek-R1-Distill-Llama-8B-SAE-l19', model='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', conversion_func='deepseek_r1', saes_map={'blocks.19.hook_resid_post': 'DeepSeek-R1-Distill-Llama-8B-SAE-l19.pt'}, expected_var_explained={'blocks.19.hook_resid_post': 1.0}, expected_l0={'blocks.19.hook_resid_post': 0.0}, neuronpedia_id={'blocks.19.hook_resid_post': 'deepseek-r1-distill-llama-8b/19-qresearch-res-65k'}, config_overrides=None), 'gemma-2b-it-res-jb': PretrainedSAELookup(release='gemma-2b-it-res-jb', repo_id='jbloom/Gemma-2b-IT-Residual-Stream-SAEs', model='gemma-2b-it', conversion_func=None, saes_map={'blocks.12.hook_resid_post': 'gemma_2b_it_blocks.12.hook_resid_post_16384'}, expected_var_explained={'blocks.12.hook_resid_post': 0.57}, expected_l0={'blocks.12.hook_resid_post': 61.0}, neuronpedia_id={'blocks.12.hook_resid_post': 'gemma-2b-it/12-res-jb'}, co

In [3]:
metadata_rows = [
    [data.model, data.release, data.repo_id, len(data.saes_map)] for data in get_pretrained_saes_directory().values()
]

# Print all SAE releases, sorted by base model
print(
    tabulate(
        sorted(metadata_rows, key=lambda x: x[0]),
        headers=["model", "release", "repo_id", "n_saes"],
        tablefmt="simple_outline",
    )
)

┌──────────────────────────────────────────┬─────────────────────────────────────────────────────┬────────────────────────────────────────────────────────┬──────────┐
│ model                                    │ release                                             │ repo_id                                                │   n_saes │
├──────────────────────────────────────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────┼──────────┤
│ deepseek-ai/DeepSeek-R1-Distill-Llama-8B │ deepseek-r1-distill-llama-8b-qresearch              │ qresearch/DeepSeek-R1-Distill-Llama-8B-SAE-l19         │        1 │
│ deepseek-ai/DeepSeek-R1-Distill-Llama-8B │ llama_scope_r1_distill                              │ fnlp/Llama-Scope-R1-Distill                            │       96 │
│ gemma-2-27b                              │ gemma-scope-27b-pt-res                              │ google/gemma-scope-27b-pt-res                          │       18 

In [25]:
def format_value(value):
    return "{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items()))) if isinstance(value, dict) else repr(value)


#get_pretrained_saes_directory()["llama_scope_r1_distill"]
release = get_pretrained_saes_directory()["deepseek-r1-distill-llama-8b-qresearch"]

print(
    tabulate(
        [[k, format_value(v)] for k, v in release.__dict__.items()],
        headers=["Field", "Value"],
        tablefmt="simple_outline",
    )
)

┌────────────────────────┬─────────────────────────────────────────────────────────────────────────────────────────┐
│ Field                  │ Value                                                                                   │
├────────────────────────┼─────────────────────────────────────────────────────────────────────────────────────────┤
│ release                │ 'deepseek-r1-distill-llama-8b-qresearch'                                                │
│ repo_id                │ 'qresearch/DeepSeek-R1-Distill-Llama-8B-SAE-l19'                                        │
│ model                  │ 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'                                              │
│ conversion_func        │ 'deepseek_r1'                                                                           │
│ saes_map               │ {'blocks.19.hook_resid_post': 'DeepSeek-R1-Distill-Llama-8B-SAE-l19.pt', ...}           │
│ expected_var_explained │ {'blocks.19.hook_resid_post': 1.0, ..

In [26]:
data = [[id, path, release.neuronpedia_id[id]] for id, path in release.saes_map.items()]

print(
    tabulate(
        data,
        headers=["SAE id", "SAE path (HuggingFace)", "Neuronpedia ID"],
        tablefmt="simple_outline",
    )
)

┌───────────────────────────┬─────────────────────────────────────────┬───────────────────────────────────────────────────┐
│ SAE id                    │ SAE path (HuggingFace)                  │ Neuronpedia ID                                    │
├───────────────────────────┼─────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ blocks.19.hook_resid_post │ DeepSeek-R1-Distill-Llama-8B-SAE-l19.pt │ deepseek-r1-distill-llama-8b/19-qresearch-res-65k │
└───────────────────────────┴─────────────────────────────────────────┴───────────────────────────────────────────────────┘


#### Load a single SAE

In [27]:
t.set_grad_enabled(False)

R1Llama8B: HookedSAETransformer = HookedSAETransformer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B", device=device)

R1Llama8B_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="deepseek-r1-distill-llama-8b-qresearch",
    sae_id="blocks.19.hook_resid_post",
    device=str(device),
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model deepseek-ai/DeepSeek-R1-Distill-Llama-8B into HookedTransformer


  state_dict_loaded = torch.load(sae_path, map_location=device)


In [28]:
print(tabulate(R1Llama8B_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

┌──────────────────────────────┬──────────────────────────────────────────┐
│ name                         │ value                                    │
├──────────────────────────────┼──────────────────────────────────────────┤
│ architecture                 │ standard                                 │
│ d_in                         │ 4096                                     │
│ d_sae                        │ 65536                                    │
│ activation_fn_str            │ relu                                     │
│ apply_b_dec_to_input         │ False                                    │
│ finetuning_scaling_factor    │ False                                    │
│ context_size                 │ 1024                                     │
│ model_name                   │ deepseek-ai/DeepSeek-R1-Distill-Llama-8B │
│ hook_name                    │ blocks.19.hook_resid_post                │
│ hook_layer                   │ 19                                       │
│ hook_head_

In [8]:
# def display_dashboard(
#     sae_release="llama_scope_r1_distill",
#     sae_id="l15r_800m_slimpajama",
#     latent_idx=0,
#     width=800,
#     height=600,
# ):
#     release = get_pretrained_saes_directory()[sae_release]
#     neuronpedia_id = release.neuronpedia_id[sae_id]

#     url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

#     print(url)
#     display(IFrame(url, width=width, height=height))


# latent_idx = 4527 #random.randint(0, gpt2_sae.cfg.d_sae)
# display_dashboard(latent_idx=latent_idx)

https://neuronpedia.org/deepseek-r1-distill-llama-8b/15-llamascope-slimpj-res-32k/4527?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [29]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
test_prompt(prompt, answer, R1Llama8B)

# Test our prompt, to see what the model says
with R1Llama8B.saes(saes=[R1Llama8B_sae]):
    test_prompt(prompt, answer, R1Llama8B)

# Same thing, done in a different way
R1Llama8B.add_sae(R1Llama8B_sae)
test_prompt(prompt, answer, R1Llama8B)
R1Llama8B.reset_saes()  # Remember to always do this!

# Using `run_with_saes` method in place of standard forward pass
logits = R1Llama8B(prompt, return_type="logits")
logits_sae = R1Llama8B.run_with_saes(prompt, saes=[R1Llama8B_sae], return_type="logits")
answer_token_id = R1Llama8B.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {R1Llama8B.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {R1Llama8B.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

Tokenized prompt: ['<｜begin▁of▁sentence｜>', 'Mit', 'ig', 'ating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 20.89 Prob: 84.45% Token: | priority|
Top 1th token. Logit: 17.45 Prob:  2.70% Token: | concern|
Top 2th token. Logit: 17.34 Prob:  2.42% Token: | effort|
Top 3th token. Logit: 17.23 Prob:  2.18% Token: | responsibility|
Top 4th token. Logit: 17.18 Prob:  2.08% Token: | collaborative|
Top 5th token. Logit: 16.53 Prob:  1.08% Token: |,|
Top 6th token. Logit: 16.45 Prob:  1.00% Token: | challenge|
Top 7th token. Logit: 15.74 Prob:  0.49% Token: | cooperative|
Top 8th token. Logit: 15.72 Prob:  0.48% Token: | policy|
Top 9th token. Logit: 15.47 Prob:  0.37% Token: | collaboration|


Tokenized prompt: ['<｜begin▁of▁sentence｜>', 'Mit', 'ig', 'ating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 19.71 Prob: 31.63% Token: | challenge|
Top 1th token. Logit: 19.47 Prob: 24.92% Token: | priority|
Top 2th token. Logit: 19.15 Prob: 18.08% Token: | responsibility|
Top 3th token. Logit: 18.38 Prob:  8.37% Token: | issue|
Top 4th token. Logit: 17.59 Prob:  3.82% Token: | effort|
Top 5th token. Logit: 17.53 Prob:  3.57% Token: | concern|
Top 6th token. Logit: 16.78 Prob:  1.69% Token: | collaborative|
Top 7th token. Logit: 16.63 Prob:  1.46% Token: | ethical|
Top 8th token. Logit: 16.44 Prob:  1.20% Token: | problem|
Top 9th token. Logit: 15.40 Prob:  0.43% Token: | collective|


Tokenized prompt: ['<｜begin▁of▁sentence｜>', 'Mit', 'ig', 'ating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 19.71 Prob: 31.63% Token: | challenge|
Top 1th token. Logit: 19.47 Prob: 24.92% Token: | priority|
Top 2th token. Logit: 19.15 Prob: 18.08% Token: | responsibility|
Top 3th token. Logit: 18.38 Prob:  8.37% Token: | issue|
Top 4th token. Logit: 17.59 Prob:  3.82% Token: | effort|
Top 5th token. Logit: 17.53 Prob:  3.57% Token: | concern|
Top 6th token. Logit: 16.78 Prob:  1.69% Token: | collaborative|
Top 7th token. Logit: 16.63 Prob:  1.46% Token: | ethical|
Top 8th token. Logit: 16.44 Prob:  1.20% Token: | problem|
Top 9th token. Logit: 15.40 Prob:  0.43% Token: | collective|


Standard model: top prediction = ' priority', prob = 84.45%
SAE reconstruction: top prediction = ' challenge', prob = 31.63%



### Examine SAE features

#### Setup

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)

    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "mps") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [33]:
from huggingface_hub import hf_hub_download
sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"

file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

In [34]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

expansion_factor = 16
sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=model.config.hidden_size,
    expansion_factor=expansion_factor,
    device="mps"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  state_dict = torch.load(path, map_location=device)


In [35]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs

    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

In [38]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Roleplay as a pirate"},
        {"role": "assistant", "content": "Yarr, I'll be speakin' like a true seafarer from here on out! Got me sea legs ready and me vocabulary set to proper pirate speak. What can I help ye with, me hearty?"},
    ],
    return_tensors="pt",
).to("mps")

In [39]:
layer_id = 19
target_act = gather_residual_activations(model, layer_id, inputs)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [40]:
def ensure_same_device(sae, target_act):
    """Ensure SAE and activations are on the same device"""
    model_device = target_act.device
    sae = sae.to(model_device)
    return sae, target_act.to(model_device)

sae, target_act = ensure_same_device(sae, target_act)
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [41]:
var_explained = 1 - torch.mean((recon - target_act.to(torch.float32)) ** 2) / torch.var(target_act.to(torch.float32))
print(f"Variance explained: {var_explained:.3f}")

Variance explained: 0.988


In [42]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Roleplay as a pirate"},
        {"role": "assistant", "content": "Yarr, I'll be speakin' like a true seafarer from here on out! Got me sea legs ready and me vocabulary set to proper pirate speak. What can I help ye with, me hearty?"},
    ],
    return_tensors="pt",
).to("mps")

In [44]:
tokenizer.decode(inputs[0])

"<｜begin▁of▁sentence｜><｜User｜>Roleplay as a pirate<｜Assistant｜>Yarr, I'll be speakin' like a true seafarer from here on out! Got me sea legs ready and me vocabulary set to proper pirate speak. What can I help ye with, me hearty?<｜end▁of▁sentence｜>"

In [45]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the assistant's response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

assistant_start = special_positions[-2] + 1
assistant_tokens = slice(assistant_start, None)

# Get activation statistics for assistant's response
assistant_activations = sae_acts[0, assistant_tokens]
mean_activations = assistant_activations.mean(dim=0)

# Find top activated features during pirate speech
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during pirate speech:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[assistant_tokens], assistant_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")


Top activated features during pirate speech:
Feature 7560: 1.677
Feature 34357: 0.742
Feature 14322: 0.559
Feature 19958: 0.311
Feature 37063: 0.299
Feature 53431: 0.247
Feature 34347: 0.234
Feature 1262: 0.221
Feature 43713: 0.217
Feature 55236: 0.190
Feature 54314: 0.190
Feature 60475: 0.190
Feature 53323: 0.186
Feature 36282: 0.177
Feature 18706: 0.170
Feature 60476: 0.149
Feature 61938: 0.149
Feature 56551: 0.142
Feature 266: 0.141
Feature 29635: 0.137

Activation patterns across tokens:

Token: Y
  Feature 7560: 2.326
  Feature 34357: 0.570
  Feature 34347: 0.578
  Feature 1262: 0.288
  Feature 53323: 0.347

Token: arr
  Feature 7560: 2.790
  Feature 34357: 0.685
  Feature 60475: 0.461
  Feature 53323: 0.472
  Feature 266: 0.268

Token: ,
  Feature 7560: 3.592
  Feature 34357: 2.499
  Feature 19958: 0.335
  Feature 1262: 0.468
  Feature 60475: 0.212
  Feature 53323: 0.604
  Feature 60476: 0.571

Token: ĠI
  Feature 7560: 1.948
  Feature 34357: 1.654
  Feature 14322: 0.321
  Featur

In [47]:
def generate_with_intervention(
    model,
    tokenizer,
    sae,
    messages: list[dict],
    feature_idx: int,
    intervention: float = 3.0,
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    modified_activations = None

    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]

        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed

        features[:, :, feature_idx] += intervention

        modified = sae.decode(features) + error
        modified_activations = modified
        modified_activations = modified.to(torch.bfloat16)

        return outputs

    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs

    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]

    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )

        generated_text = tokenizer.decode(outputs[0])

    finally:
        for handle in handles:
            handle.remove()

    return generated_text

messages = [
    {"role": "user", "content": "How are you doing?"}
]
feature_to_modify = 7560

print("Original generation:")
input_tokens = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)
outputs = model.generate(input_tokens, max_new_tokens=800, do_sample=False)
print(tokenizer.decode(outputs[0]))

print("\nGeneration with modified feature:")
modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=feature_to_modify,
    intervention=10,
    target_layer=layer_id,
    max_new_tokens=800
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Original generation:


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<｜begin▁of▁sentence｜><｜User｜>How are you doing?<｜Assistant｜><think>
Okay, so I'm trying to figure out how to respond to the question "How are you doing?" I know that this is a common question people ask when they want to start a conversation or show they care. But I'm a bit confused about the best way to answer it, especially since I'm an AI and not a human. Let me think through this step by step.

First, I should consider the context in which the question is being asked. If someone is asking me how I'm doing, they might be looking for a friendly exchange. Since I don't have feelings or emotions, I shouldn't pretend to have them. Instead, I should acknowledge that I'm an AI and explain how I function.

I remember that in previous interactions, I've responded by saying I'm just a program without feelings, but I want to make sure that's the right approach. Maybe I can also mention that I'm here to help them with whatever they need. That way, the conversation can shift towards assisting t

### Import and reprocess the penguin dataset