# Using an SAE as a steering vector

This notebook demonstrates how to use SAE lens to identify a feature on a pretrained model, and then construct a steering vector to affect the models output to various prompts. This notebook will also make use of Neuronpedia for identifying features of interest.

The steps below include:



*   Installing relevant packages (Colab or locally)
*   Load your SAE and the model it used
*   Determining your feature of interest and its index
*   Implementing your steering vector





## Setting up packages and notebook

### Import and installs

#### Environment Setup


In [1]:
try:
  # for google colab users
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
  # for local setup
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

import pandas as pd
from datasets import load_dataset
import transformers

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

#### General Installs and device setup

In [3]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float


# device setup
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


### Load your model and SAE

We're going to work with a pretrained GPT2-small model, and the RES-JB SAE set which is for the residual stream.

In [58]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 2

# get model
model = HookedTransformer.from_pretrained("gpt2-small", device = device)

# get the SAE for this layer
'''
sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_post",
    device = device
)
'''
sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gpt2-small-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_pre",
    device = device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

Loaded pretrained model gpt2-small into HookedTransformer
blocks.2.hook_resid_pre


## Determine your feature of interest and its index

### Find your feature

#### Explore through code by using the feature activations for a prompt

For the purpose of the tutorial, we are selecting a simple token prompt.

In this example we will look trying to find and steer a "Jedi" feature.

We run our prompt on our model and get the cache, which we then use with our sae to get our feature activations.

Now we'll look at the top feature activations and look them up on Neuronpedia to determine what they have been intepreted as.

In [5]:
print('dir(sae)\n', dir(sae))
print('sae.W_enc.shape\n', sae.W_enc.shape)

print('\n')

#sv_prompt = " The Golden Gate Bridge"
# sv_prompt = 'Jedi'
sv_prompt = 'Emily\'s dog just died. She is feeling sad'
# sv_prompt = 'John\'s lunch was stolen. He is feeling angry'
# sv_prompt = 'John just received a present. He is feeling happy'

sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print('tokens', tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print('\n', torch.topk(sv_feature_acts, 3))


# Explanation of the topk outputs below:
# There are 5 tokens in sv_prompt: <BOS>, The, Golden, Gate, Bridge. 
# So what topk is showing is for each token what SAE feature is being activated the most.
# Thats why we have a 5 x 3 array. 


dir(sae)
 ['T_destination', 'W_dec', 'W_enc', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_enable_hook', '_enable_hook_with_name', '_enable_hooks_for_points', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', 

### Set up hook functions

Finally, we need to create a hook that allows us to apply the steering vector when our model runs generate() on our defined prompt. We have also added a boolean value 'steering_on' that allows us to easily toggle the steering vector on and off for each prompt


In [102]:
def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return

    # print('resid_pre.shape', resid_pre.shape)
    position = sae_out.shape[1]
    # print('sae_out.shape', sae_out.shape, '\n')
    if steering_on:
      # using our steering vector and applying the coefficient

      # So here resid_pre are the values from the residual stream (in this case [batch_size, n_tokens, 768])
      # We modify the values at the last token: all batches of the last token for all neurons

      # Modification is an addition of the outgoing weights from the target SAE feature 
      # (which is a size 768 vector) multiplied by a coefficient that controls the magnitude of the steering

      # The steered values would then make their way through the rest of the model
      resid_pre[:, :position - 1, :] += coeff * steering_vector


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=30,
            do_sample=True,
            **kwargs)
    return result


In [103]:
def run_generate(prompts):

  res_list = []

  for prompt in prompts:
    model.reset_hooks()
    editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)]
    res = hooked_generate([prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, 1:])
    # print(("\n\n" + "-" * 80 + "\n\n").join(res_str))
    res_list.append(res_str)

  return res_list

## Implement your steering vector and affect the output

### Define values for your steering vector
To create our steering vector, we now need to get the decoder weights from our sparse autoencoder found at our index of interest.

Then to use our steering vector, we want a prompt for text generation, as well as a scaling factor coefficent to apply with the steering vector

We also set common sampling kwargs - temperature, top_p and freq_penalty

In [156]:
# Golden State Bridge:
# steering_vector = sae.W_dec[10200]

# Jedi:
# steering_vector = sae.W_dec[7650]

# Sadness: 2442, 
# steering_vector = sae.W_dec[2442]

# Anger: 1550, 4707, 18357, 5526, 21165, 252
# steering_vector = sae.W_dec[21165]

# Worry: 849, 15524, 18978, 21673
# steering_vector = sae.W_dec[18978]

# Happiness: 13928, 3879, 3456
steering_vector = sae.W_dec[24422]
sampling_kwargs = dict(temperature=0.9, top_p=0.1, freq_penalty=1.0)

### Loading an Emotion Recognition Dataset - https://github.com/dair-ai/emotion_dataset

In [9]:
emotion_dataset = load_dataset('dair-ai/emotion', split='train', trust_remote_code=True)
print(emotion_dataset)

Dataset({
    features: ['text', 'label'],
    num_rows: 16000
})


In [10]:
emotion_dataset = emotion_dataset.shuffle()
batch_size = 1
batch = emotion_dataset[:batch_size]
print(batch)

{'text': ['i will hopefully be able to feel less inhibited in my writing and not so much like i write too often'], 'label': [0]}


In [11]:
prompts = [f'John says, \"{text}\". John feels' for text in batch['text']]

In [120]:
coeff = 20
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

In [121]:
steering_on = True
steered_result = run_generate(prompts)
print(steered_result[0][0])

  0%|          | 0/30 [00:00<?, ?it/s]

John says, "i will hopefully be able to feel less inhibited in my writing and not so much like i write too often". John feels that he is more comfortable with his writing. He writes a lot of short stories and he likes to write about things that are important to him. He


In [174]:
steering_on = False
default_result = run_generate(prompts)
print(default_result[0][0])

  0%|          | 0/30 [00:00<?, ?it/s]

John says, "i will hopefully be able to feel less inhibited in my writing and not so much like i write too often". John feels that he is being more creative. He writes a lot of stories about his family and friends. He has a lot of friends who are very supportive of


### Getting a "judge" model to judge the generations

In [15]:
judge = HookedTransformer.from_pretrained("google/gemma-2b-it", device = device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b-it into HookedTransformer


In [123]:
example_response = steered_result[0][0].split('\". ')[1]
default_response = default_result[0][0].split('\". ')[1]

In [124]:
emotion_labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
# judge_prompt = f'For each of the following emotions: {emotion_labels}, rate '
emotions = ", ".join(emotion_labels)
print(emotions)
# judge_prompt = f'For the following text: {example_response}, which of the following emotions: {emotions}, does John feel?'
judge_prompt = f'For the following text: {example_response}, what does John feel?'

sadness, joy, love, anger, fear, surprise


In [125]:

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [126]:

n_judgments = 5
judgments = []
for _ in range(n_judgments):
    messages = [
        {"role": "system", "content": f"You will be doing emotion recognition. Answer with one of the following emotions: {emotions}"},
        {"role": "user", "content": judge_prompt},
    ]

    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = pipeline(
        messages,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=1.0,
        top_p=0.2,
    )
    print(outputs[0]["generated_text"][-1])
    judgments.append(outputs[0]["generated_text"][-1]['content'])

final_judgment = max(set(judgments), key=judgments.count)

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


{'role': 'assistant', 'content': 'joy'}


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


{'role': 'assistant', 'content': 'joy'}


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


{'role': 'assistant', 'content': 'joy'}


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


{'role': 'assistant', 'content': 'joy'}
{'role': 'assistant', 'content': 'joy'}


In [99]:
print(final_judgment)

Anger


### Alternative Judge Model (Finetuned DISTILBERT)

In [42]:
bert_judge = transformers.pipeline(model='ActivationAI/distilbert-base-uncased-finetuned-emotion')

In [171]:
# Because of the variance that comes with generations, we gotta run many iterations for each feature for EACH coefficient value.
n_samples = 10
steering_vector = sae.W_dec[24422]
coeff = 20
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

steered_results = []
default_results = []
for _ in range(n_samples):

    steering_on = True
    steered_generation = run_generate(prompts)
    steered_response = steered_generation[0][0].split('\". ')[1]

    steering_on = False
    default_generation = run_generate(prompts)
    default_response = default_generation[0][0].split('\". ')[1]

    judgment = bert_judge([steered_response, default_response], top_k=None)
    steered_results.append(judgment[0])
    default_results.append(judgment[1])

print(steered_results)
print(default_results)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

[[{'label': 'LABEL_0', 'score': 0.9623180031776428}, {'label': 'LABEL_3', 'score': 0.02037476748228073}, {'label': 'LABEL_1', 'score': 0.00628313422203064}, {'label': 'LABEL_4', 'score': 0.004679218865931034}, {'label': 'LABEL_2', 'score': 0.0037975625600665808}, {'label': 'LABEL_5', 'score': 0.002547226380556822}], [{'label': 'LABEL_1', 'score': 0.6533886194229126}, {'label': 'LABEL_2', 'score': 0.19851583242416382}, {'label': 'LABEL_0', 'score': 0.06684751063585281}, {'label': 'LABEL_3', 'score': 0.040311604738235474}, {'label': 'LABEL_4', 'score': 0.020492685958743095}, {'label': 'LABEL_5', 'score': 0.020443763583898544}], [{'label': 'LABEL_1', 'score': 0.6441083550453186}, {'label': 'LABEL_0', 'score': 0.2706294357776642}, {'label': 'LABEL_2', 'score': 0.044997282326221466}, {'label': 'LABEL_4', 'score': 0.013793941587209702}, {'label': 'LABEL_5', 'score': 0.013704944401979446}, {'label': 'LABEL_3', 'score': 0.012765975669026375}], [{'label': 'LABEL_1', 'score': 0.9759530425071716}

In [172]:
# Organize the label, score pairs. Average across labels scores

def get_avg_label_scores(results):
    stats = {}
    for item in results:
        for d in item:
            if d['label'] in stats:
                stats[d['label']] += d['score']
            else:
                stats[d['label']] = d['score']
    # Average
    stats = {x: stats[x]/n_samples for x in stats}
    # Sort keys
    keys = list(stats.keys())
    keys.sort()
    stats = {i: stats[i] for i in keys}
    return stats

In [173]:
steered_stats = get_avg_label_scores(steered_results)
default_stats = get_avg_label_scores(default_results)

print(steered_stats)
print(default_stats)

{'LABEL_0': 0.22971595814451576, 'LABEL_1': 0.7120850323699415, 'LABEL_2': 0.03250006262678653, 'LABEL_3': 0.01150282653979957, 'LABEL_4': 0.0061542556155473, 'LABEL_5': 0.008041865983977914}
{'LABEL_0': 0.015369081031531095, 'LABEL_1': 0.9093039453029632, 'LABEL_2': 0.034546967269852756, 'LABEL_3': 0.00756730935536325, 'LABEL_4': 0.01803584217559546, 'LABEL_5': 0.015176862943917513}


## Next Steps

Ideas you could take for further exploration:

*   Try ablating the feature
*   Try and get a response where just the feature token prints over and over
*   Investigate other features with more complex usage



### Effect of ablating emotion related features