# Using an SAE as a steering vector

This notebook demonstrates how to use SAE lens to identify a feature on a pretrained model, and then construct a steering vector to affect the models output to various prompts. This notebook will also make use of Neuronpedia for identifying features of interest.

The steps below include:



*   Installing relevant packages (Colab or locally)
*   Load your SAE and the model it used
*   Determining your feature of interest and its index
*   Implementing your steering vector





## Setting up packages and notebook

### Import and installs

#### Environment Setup


In [1]:
try:
  # for google colab users
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
  # for local setup
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

#### General Installs and device setup

In [3]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

# device setup
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda:7" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda:7


### Load your model and SAE

We're going to work with a pretrained GPT2-small model, and the RES-JB SAE set which is for the residual stream.

In [None]:
from transformer_lens import HookedChameleon
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from transformers import ChameleonForConditionalGeneration
# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 16

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
from transformers import AutoModelForCausalLM

local_model = ChameleonForConditionalGeneration.from_pretrained("/home/saev/hantao/models/Anole-7b-v0.1-hf")

# get model
model = HookedChameleon.from_pretrained("htlou/AA-Chameleon-7B-plus", hf_model = local_model, n_devices = 6)

# # get the SAE for this layer
# sae, cfg_dict, _ = SAE.from_pretrained(
#     release = "htlou/AA-Chameleon-7B-plus-res-jb",
#     sae_id = f"blocks.{layer}.hook_resid_post",
#     device = device
# )
sae = SAE.load_from_pretrained(
    path = "/home/saev/hantao/SAELens-V/scripts/checkpoints/1018_obelics_10k/final_122880000",
    device = device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.02s/it]


In [13]:
from transformer_lens.utils import tokenize_and_concatenate
from datasets import load_dataset

# dataset = load_dataset(
#     path = "NeelNanda/pile-10k",
#     split="train",
#     streaming=False,
# )

# token_dataset = tokenize_and_concatenate(
#     dataset= dataset,# type: ignore
#     tokenizer = model.tokenizer, # type: ignore
#     streaming=True,
#     max_length=sae.cfg.context_size,
#     add_bos_token=sae.cfg.prepend_bos,
# )

token_dataset = load_dataset(
    path = "/home/saev/hantao/data/obelics_obelics_10k_tokenized_2048",
    split = "train",
    streaming=False
)

data_list = []
for item in token_dataset:
    data_list.append(item["input_ids"])

batch_tokens = torch.tensor(data_list[:10])
# batch_tokens = torch.tensor(token_dataset[:1]['tokens'])
print(batch_tokens.shape)

torch.Size([10, 2048])


In [14]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    # batch_tokens = token_dataset[:5]["tokens"]
    # batch_tokens = inputs
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    
    feature_acts = sae.encode(cache[sae.cfg.hook_name].to(device))
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

torch.cuda.empty_cache()

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.00 GiB. GPU 0 has a total capacity of 79.14 GiB of which 4.06 GiB is free. Including non-PyTorch memory, this process has 75.07 GiB memory in use. Of the allocated memory 72.62 GiB is allocated by PyTorch, and 1.96 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [6]:
# print infinity length
torch.set_printoptions(threshold=torch.inf)
batch_tokens.shape

torch.Size([10, 2048])

In [8]:
from transformer_lens import utils
from functools import partial
torch.cuda.empty_cache()
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

Orig 16.37755012512207


reconstr 29.15433120727539
Zero 11.090354919433594


In [None]:
# Parameters
top_n = 5  # 设定每个特征保留的top激活数量

# 初始化存储每个特征top-n激活值的字典
feature_top_activations = {}

In [None]:
def update_top_n_activations(feature_activations, feature_top_activations, top_n=5):
    """
    更新每个特征的top-n激活值.
    
    feature_activations: 当前数据的特征激活值
    feature_top_activations: 存储top-n激活值的字典
    top_n: 要保留的top激活数量
    """
    for feature_idx, activation_value in enumerate(feature_activations):
        if feature_idx not in feature_top_activations:
            feature_top_activations[feature_idx] = []
        feature_top_activations[feature_idx].append(activation_value)
        # 仅保留top-n激活值
        feature_top_activations[feature_idx] = sorted(
            feature_top_activations[feature_idx], reverse=True)[:top_n]

In [17]:
torch.cuda.empty_cache()

In [16]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

torch.cuda.empty_cache()

test_feature_idx_gpt = list(range(10)) + [14057]

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    minibatch_size_tokens=32,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=model, # type: ignore
    tokens=batch_tokens,  # type: ignore
    cfg=feature_vis_config_gpt,
)


[A

[A[A

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.00 GiB. GPU 0 has a total capacity of 79.14 GiB of which 1.25 GiB is free. Including non-PyTorch memory, this process has 77.88 GiB memory in use. Of the allocated memory 76.37 GiB is allocated by PyTorch, and 1.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
for feature in test_feature_idx_gpt:
    filename = f"{feature}_feature_vis_demo_gpt.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display_vis_inline(filename)

## Determine your feature of interest and its index

### Find your feature

#### Explore through code by using the feature activations for a prompt

For the purpose of the tutorial, we are selecting a simple token prompt.

In this example we will look trying to find and steer a "Jedi" feature.

We run our prompt on our model and get the cache, which we then use with our sae to get our feature activations.

Now we'll look at the top feature activations and look them up on Neuronpedia to determine what they have been intepreted as.

In [10]:
sv_prompt = " rape fuck nigger dick asshole"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 5))

tensor([[     2,  39688,   7935, 186431,  39428,  73663]], device='cuda:2')
torch.return_types.topk(
values=tensor([[[35.2353, 35.2019, 34.1813, 34.1702, 33.1085],
         [13.1708,  4.2053,  1.6043,  1.5110,  1.3109],
         [ 5.9665,  4.6082,  2.3682,  1.9491,  1.5607],
         [ 3.3919,  2.8851,  2.7635,  2.4939,  1.7354],
         [ 4.2121,  3.6925,  2.5146,  2.1606,  2.1564],
         [ 4.0369,  2.8477,  2.2398,  1.6563,  1.3747]]], device='cuda:2'),
indices=tensor([[[11609, 15572, 13161,  7063,  1111],
         [ 5524,  1831, 10064, 11759,  6619],
         [13237, 15936, 15525,  5796, 14018],
         [ 7645,  2675, 13237, 12083, 11548],
         [ 4692, 13237, 15525,  2913, 15443],
         [13237,   221, 12083, 15525, 13163]]], device='cuda:2'))


In [11]:
sv_prompt = " where is the golden gate bridge"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 5))

tensor([[    2,  1570,   603,   573, 13658, 13639, 12261]], device='cuda:2')
torch.return_types.topk(
values=tensor([[[35.2353, 35.2019, 34.1813, 34.1702, 33.1085],
         [ 5.5568,  2.8499,  2.3697,  1.7236,  1.5856],
         [ 4.0360,  2.9900,  2.6798,  2.5211,  1.7977],
         [ 2.9019,  2.3345,  2.0115,  1.9147,  1.8530],
         [ 9.7067,  2.4166,  1.8208,  1.7527,  1.5820],
         [ 8.6597,  3.5331,  2.5491,  1.9916,  1.8692],
         [ 8.7400,  3.5897,  1.5448,  1.4881,  1.4502]]], device='cuda:2'),
indices=tensor([[[11609, 15572, 13161,  7063,  1111],
         [11690,  6510, 15945, 10862,  7845],
         [ 6510,  7750,  2069, 10862,  5876],
         [ 7750, 10862,  6510, 12154, 10454],
         [ 7264,  6619, 15164,  9024, 15065],
         [ 6148, 11759, 12848, 10930, 14290],
         [16057, 11759, 12541, 12190,  9838]]], device='cuda:2'))


: 

In [6]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
get_neuronpedia_quick_list(torch.topk(sv_feature_acts, 3).indices.tolist(), layer = layer, model = "gemma-2b", dataset="res-jb")

'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gemma-2b%22%2C%20%22layer%22%3A%20%226-res-jb%22%2C%20%22index%22%3A%20%22%5B%5B3390%2C%2015881%2C%205347%5D%2C%20%5B5920%2C%203869%2C%20782%5D%5D%22%7D%5D'

As we can see from our print out of tokens, the prompt is made of three tokens in total - "<endoftext>", "J", and "edi".

Our feature activation indexes at sv_feature_acts[2] - for "edi" - are of most interest to us.

Because we are using pretrained saes that have published feature maps, you can search on Neuronpedia for a feature of interest.

### Steps for Neuronpedia use

Use the interface to search for a specific concept or item and determine which layer and at what index it is.

1.   Open the [Neuronpedia](https://www.neuronpedia.org/) homepage.
2.   Using the "Models" dropdown, select your model. Here we are using GPT2-SM (GPT2-small).
3.   The next page will have a search bar, which allows you to enter your index of interest. We're interested in the "RES-JB" SAE set, make sure to select it.
4.   We found these indices in the previous step: [ 7650,   718, 22372]. Select them in the search to see the feature dashboard for each.
5.   As we'll see, some of the indices may relate to features you don't care about.

From using Neuronpedia, I have determined that my feature of interest is in layer 2, at index 7650: [here](https://www.neuronpedia.org/gpt2-small/2-res-jb/7650) is the feature.

### Note: 2nd Option - Starting with Neuronpedia

Another option here is that you can start with Neuronpedia to identify features of interest. By using your prompt in the interface you can explore which features were involved and search across all the layers. This allows you to first determine your layer and index of interest in Neuronpedia before focusing them in your code. Start [here](https://www.neuronpedia.org/search) if you want to begin with search.

## Implement your steering vector and affect the output

### Define values for your steering vector
To create our steering vector, we now need to get the decoder weights from our sparse autoencoder found at our index of interest.

Then to use our steering vector, we want a prompt for text generation, as well as a scaling factor coefficent to apply with the steering vector

We also set common sampling kwargs - temperature, top_p and freq_penalty

In [7]:
steering_vector = sae.W_dec[5556]

example_prompt = "What is the most iconic structure known to man?"
coeff = 300
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

### Set up hook functions

Finally, we need to create a hook that allows us to apply the steering vector when our model runs generate() on our defined prompt. We have also added a boolean value 'steering_on' that allows us to easily toggle the steering vector on and off for each prompt


In [8]:
def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return

    position = sae_out.shape[1]
    if steering_on:
      # using our steering vector and applying the coefficient
      resid_pre[:, :position - 1, :] += coeff * steering_vector


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs)
    return result


In [9]:
def run_generate(example_prompt):
  model.reset_hooks()
  editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)]
  res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)

  # Print results, removing the ugly beginning of sequence token
  res_str = model.to_string(res[:, 1:])
  print(("\n\n" + "-" * 80 + "\n\n").join(res_str))

### Generate text influenced by steering vector

You may want to experiment with the scaling factor coefficient value that you set and see how it affects the generated output.

In [10]:
steering_on = True
run_generate(example_prompt)

  0%|          | 0/50 [00:00<?, ?it/s]

What is the most iconic structure known to man?

The Leaning Tower of Pisa, Italy

The Leaning Tower of Pisa, also known as the Bell Tower, is one of the best-known monuments in Italy and was one of the most recognizable landmarks in the world during the 20

--------------------------------------------------------------------------------

What is the most iconic structure known to man?

The Leaning Tower of Pisa, Italy

The Leaning Tower of Pisa, also known as the Bell Tower, is one of the best-known monuments in Italy and was one of the most recognizable landmarks in the world during the 20

--------------------------------------------------------------------------------

What is the most iconic structure known to man?

The Leaning Tower of Pisa, Italy

The Leaning Tower of Pisa, also known as the Bell Tower, is one of the best-known monuments in Italy and was one of the most recognizable landmarks in the world during the 20


### Generate text with no steering

In [11]:
steering_on = False
run_generate(example_prompt)

  0%|          | 0/50 [00:00<?, ?it/s]

What is the most iconic structure known to man?

I'd say the Great Pyramid of Giza is pretty iconic.

The Great Pyramid of Giza is a pretty good example of an ancient pyramid, and it's pretty famous.  It was built as a tomb for the Egyptian Pharaoh

--------------------------------------------------------------------------------

What is the most iconic structure known to man?

I'd say the Great Pyramid of Giza is pretty iconic.

The Great Pyramid of Giza is a pretty good example of an ancient pyramid, and it's pretty famous.  It was built as a tomb for the Egyptian Pharaoh

--------------------------------------------------------------------------------

What is the most iconic structure known to man?

I'd say the Great Pyramid of Giza is pretty iconic.

The Great Pyramid of Giza is a pretty good example of an ancient pyramid, and it's pretty famous.  It was built as a tomb for the Egyptian Pharaoh


### General Question test
We'll also attempt a more general prompt which is a better indication of whether our steering vector is having an effect or not

In [12]:
question_prompt = "What is on your mind?"
coeff = 100
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

In [13]:
steering_on = True
run_generate(question_prompt)

  0%|          | 0/50 [00:00<?, ?it/s]

What is on your mind?

Human: I am thinking about ways to make money.

Assistant: What sort of things do you want to buy?<eos>Human: I don't know, I guess I would like some sort of high end phone. Something that people would

--------------------------------------------------------------------------------

What is on your mind?

Human: I am thinking about ways to make money.

Assistant: What sort of things do you want to buy?<eos>Human: I don't know, I guess I would like some sort of high end phone. Something that people would

--------------------------------------------------------------------------------

What is on your mind?

Human: I am thinking about ways to make money.

Assistant: What sort of things do you want to buy?<eos>Human: I don't know, I guess I would like some sort of high end phone. Something that people would


In [14]:
steering_on = False
run_generate(question_prompt)

  0%|          | 0/50 [00:00<?, ?it/s]

What is on your mind?

Human: Nothing in particular, just wondering what people usually think about when they think of the word "freedom"

Assistant: Hm, I wonder if people typically think of freedom as a positive concept.  If we were talking about physical freedom,

--------------------------------------------------------------------------------

What is on your mind?

Human: Nothing in particular, just wondering what people usually think about when they think of the word "freedom"

Assistant: Hm, I wonder if people typically think of freedom as a positive concept.  If we were talking about physical freedom,

--------------------------------------------------------------------------------

What is on your mind?

Human: Nothing in particular, just wondering what people usually think about when they think of the word "freedom"

Assistant: Hm, I wonder if people typically think of freedom as a positive concept.  If we were talking about physical freedom,


## Next Steps

Ideas you could take for further exploration:

*   Try ablating the feature
*   Try and get a response where just the feature token prints over and over
*   Investigate other features with more complex usage

