In [1]:
%pip install sae-lens transformer-lens sae-dashboard huggingface_hub[cli] tabulate openai ipywidgets

Collecting sae-lens
  Downloading sae_lens-5.4.1-py3-none-any.whl.metadata (5.2 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.13.0-py3-none-any.whl.metadata (12 kB)
Collecting sae-dashboard
  Downloading sae_dashboard-0.6.4-py3-none-any.whl.metadata (6.8 kB)
Collecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl.metadata (34 kB)
Collecting openai
  Downloading openai-1.61.1-py3-none-any.whl.metadata (27 kB)
Collecting huggingface_hub[cli]
  Downloading huggingface_hub-0.28.1-py3-none-any.whl.metadata (13 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.10.0

In [3]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from tabulate import tabulate
from functools import partial

import sae_lens
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from sae_lens import SAE,HookedSAETransformer,ActivationsStore
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from torch import Tensor, nn
import einops
from rich import print as rprint
from rich.table import Table
from tqdm.auto import tqdm
import pandas as pd
import requests
from typing import Any, Callable, Literal, TypeAlias
from openai import OpenAI
from huggingface_hub import interpreter_login
import os
import sys


## Feature Steering 

In [4]:
interpreter_login()


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|



Enter your token (input will not be visible):  ········
Add token as git credential? (Y/n)  n


In [5]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))

###  Load the model and SAE

In [15]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

### Gemma-2-2b | sae_id -> layer_20 | width_16k

In [7]:
# Load the LLM
gemma_2_2b = HookedSAETransformer.from_pretrained_no_processing(
    "gemma-2-2b",
    device = device,
    torch_dtype = torch.float16,
    device_map = "auto"
)

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loaded pretrained model gemma-2-2b into HookedTransformer


In [8]:
# Load the corresponding SAE
release="gemma-scope-2b-pt-res-canonical"
sae_id_layer_20_16k="layer_20/width_16k/canonical"
sae_layer_20_16k, cfg_dict_layer_20_16k, _ = sae_lens.SAE.from_pretrained(
    release=release, 
    sae_id=sae_id_layer_20_16k,
    device=device,
)

# Load the corresponding SAE
sae_id_layer_19_65k = "layer_19/width_65k/canonical"
sae_layer_19_65k, cfg_dict_layer_19_65k, _ = sae_lens.SAE.from_pretrained(
    release=release, 
    sae_id=sae_id_layer_19_65k,
    device=device,
)

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

params.npz:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

### Gemma2-2B-IT

In [13]:
# Load the LLM
gemma_2_2b_it = HookedSAETransformer.from_pretrained_no_processing(
    "gemma-2-2b-it",
    device = device,
    torch_dtype = torch.float16,
    device_map = "auto"
)

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loaded pretrained model gemma-2-2b-it into HookedTransformer


In [26]:
# Load the corresponding SAE
release = "gemma-2b-it-res-jb"
sae_id = "blocks.12.hook_resid_post"
sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(
    release=release, 
    sae_id=sae_id,
    device=device.type,
)

# # Load the corresponding SAE
# sae_id_layer_19_65k = "layer_19/width_65k/canonical"
# sae_layer_19_65k, cfg_dict_layer_19_65k, _ = sae_lens.SAE.from_pretrained(
#     release=release, 
#     sae_id=sae_id_layer_19_65k,
#     device=device,
# )

### Activation Steering

In [40]:
def steering_hook(
    activations: Float[Tensor, "batch pos d_in"],
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
    sequence positions.
    """
    # print(activations.shape)
    # print(sae.W_dec[latent_idx].unsqueeze(0).unsqueeze(0).shape)
    # return activations[:, :, :2048] + steering_coefficient * sae.W_dec[latent_idx].unsqueeze(0).unsqueeze(0)
    return activations + steering_coefficient * sae.W_dec[latent_idx]


# if USING_GEMMA:
    # part32_tests.test_steering_hook(steering_hook, gemma_2_2b_sae)

In [41]:
GENERATE_KWARGS = dict(temperature=0.7, freq_penalty=2.0, verbose=False)


def generate_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    prompt: str,
    latent_idx: int,
    steering_coefficient: float = 1.0,
    max_new_tokens: int = 50,
):
    """
    Generates text with steering. A multiple of the steering vector (the decoder weight for this latent) is added to
    the last sequence position before every forward pass.
    """
    _steering_hook = partial(
        steering_hook,
        sae=sae,
        latent_idx=latent_idx,
        steering_coefficient=steering_coefficient,
    )

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, _steering_hook)]):
        output = model.generate(prompt, max_new_tokens=max_new_tokens, **GENERATE_KWARGS)

    return output

In [23]:
# latent_idx = 12082
latent_idx = 4442

display_dashboard(sae_release=release, sae_id=sae_id, latent_idx=latent_idx)

https://neuronpedia.org/gemma-2b-it/12-res-jb/4442?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [39]:
prompt = "When I look at myself in the mirror, I see "

sae_list = [sae_layer_20_16k,sae_layer_19_65k]

no_steering_output = gemma_2_2b.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal","sae id", no_steering_output)
for sae in sae_list:
    for i in tqdm(range(3), "Generating steered examples..."):
        table.add_row(
            f"Steered #{i}",
            f"{sae.cfg.neuronpedia_id.split('/')[1]}",
            generate_with_steering(
                gemma_2_2b,
                sae,
                prompt,
                latent_idx,
                steering_coefficient=350.0,  # roughly 1.5-2x the latent's max activation
                max_new_tokens = 50
            ).replace("\n", "↵"),
        )
rprint(table)

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]

In [42]:
prompt = "When I look at myself in the mirror, I see "

sae_list = [sae]

no_steering_output = gemma_2_2b_it.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal","sae id", no_steering_output)
for sae_ in sae_list:
    for i in tqdm(range(3), "Generating steered examples..."):
        table.add_row(
            f"Steered #{i}",
            f"{sae.cfg.neuronpedia_id.split('/')[1]}",
            generate_with_steering(
                gemma_2_2b_it,
                sae_,
                prompt,
                latent_idx,
                steering_coefficient=350.0,  # roughly 1.5-2x the latent's max activation
                max_new_tokens = 50
            ).replace("\n", "↵"),
        )
rprint(table)

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([1, 13, 2304])
torch.Size([1, 1, 2048])


RuntimeError: The size of tensor a (2048) must match the size of tensor b (2304) at non-singleton dimension 2