In [20]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from tabulate import tabulate
from functools import partial

import sae_lens
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from sae_lens import SAE,HookedSAETransformer,ActivationsStore
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from torch import Tensor, nn
import einops
from rich import print as rprint
from rich.table import Table
from tqdm.auto import tqdm
import pandas as pd
import requests
from typing import Any, Callable, Literal, TypeAlias
from openai import OpenAI
from huggingface_hub import interpreter_login
import os
import sys


## Feature Steering 

In [4]:
interpreter_login()


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|



Enter your token (input will not be visible):  ········
Add token as git credential? (Y/n)  n


In [6]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))

###  Load the model and SAE

In [36]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

### Gemma-2-2b | sae_id -> layer_20 | width_16k

In [None]:
# Load the LLM
gemma_2_2b = HookedSAETransformer.from_pretrained_no_processing(
    "gemma-2-2b",
    device = device,
    torch_dtype = torch.float16,
    device_map = "auto"
)

# Load the corresponding SAE
release="gemma-scope-2b-pt-res-canonical"
sae_id="layer_20/width_16k/canonical"
sae, cfg_dict, _ = sae_lens.SAE.from_pretrained(
    release=release, 
    sae_id=sae_id,
    device=device,
)

### Activation Steering

In [23]:
def steering_hook(
    activations: Float[Tensor, "batch pos d_in"],
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
    sequence positions.
    """
    return activations + steering_coefficient * sae.W_dec[latent_idx]


# if USING_GEMMA:
    # part32_tests.test_steering_hook(steering_hook, gemma_2_2b_sae)

In [31]:
GENERATE_KWARGS = dict(temperature=0.7, freq_penalty=2.0, verbose=False)


def generate_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    prompt: str,
    latent_idx: int,
    steering_coefficient: float = 1.0,
    max_new_tokens: int = 50,
):
    """
    Generates text with steering. A multiple of the steering vector (the decoder weight for this latent) is added to
    the last sequence position before every forward pass.
    """
    _steering_hook = partial(
        steering_hook,
        sae=sae,
        latent_idx=latent_idx,
        steering_coefficient=steering_coefficient,
    )

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, _steering_hook)]):
        output = model.generate(prompt, max_new_tokens=max_new_tokens, **GENERATE_KWARGS)

    return output

In [22]:
# latent_idx = 12082
latent_idx = 4442

display_dashboard(sae_release=release, sae_id=sae_id, latent_idx=latent_idx)

https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/4442?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [35]:
prompt = "When I look at myself in the mirror, I see "

no_steering_output = model.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal", no_steering_output)
for i in tqdm(range(3), "Generating steered examples..."):
    table.add_row(
        f"Steered #{i}",
        generate_with_steering(
            model,
            sae,
            prompt,
            latent_idx,
            steering_coefficient=350.0,  # roughly 1.5-2x the latent's max activation
            max_new_tokens = 50
        ).replace("\n", "↵"),
    )
rprint(table)

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]