## Setting up
General installs, device setup and load models (LLM and SAE)

In [3]:
try:
  # for google colab users
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
  # for local setup
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

In [4]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float
import torch

# device setup
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [5]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

### Load model and pretrained SAE

In [12]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 2

# get model
model = HookedTransformer.from_pretrained("gpt2-small", device = device)

# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gpt2-small-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_pre",
    device = device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)



Loaded pretrained model gpt2-small into HookedTransformer
blocks.2.hook_resid_pre


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


## Determine the feature of interest and the index
Here set "Jedi" as the simple token prompt and try to steer a "Jedi" feature.

### Find the feature

In [None]:
sv_prompt = "Jedi"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
# print(cache[hook_point].shape) # [1, 3, 768] bs, seq_len, d_model

tokens = model.to_tokens(sv_prompt)

# print(tokens) # ['<|endoftext|>', 'J', 'edi']
# print(model.to_str_tokens(sv_prompt))

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])
# print(sv_feature_acts.shape)  # [1, 3, 24576] bs, seq_len, d_sae
# print(sv_feature_acts) # as we can see the most elements are 0s

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 3))

torch.Size([1, 3, 768])
torch.return_types.topk(
values=tensor([[[107.2709, 105.1811,  96.7976],
         [ 27.9996,   8.2975,   5.5113],
         [ 16.7904,  10.1490,   9.2543]]], device='cuda:0'),
indices=tensor([[[ 1151, 10488,  3344],
         [17972,  9293, 23888],
         [ 7650,   718, 22372]]], device='cuda:0'))


In [31]:
print(sae.W_enc.shape)
print(sae.W_dec.shape)

torch.Size([768, 24576])
torch.Size([24576, 768])


### Implement steering vector and affect the output

In [None]:
steering_vector = sae.W_dec[7650] # 

example_prompt = "What do we find in space?"
coeff = 100
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0) 
# only select the top_p generated tokenes
# Reduce redundant phrases

Set up hook functions 

In [None]:
def steering_hook(resid_pre, hook):
    """
    resid_pre: Residual stream activation of one layer, with the shape [bs, seq_len, d_model]

    """
    # when the seq_len is 1, there is no previous context to lead the generation
    if resid_pre.shape[1] == 1:  
        return

    # the current pos
    position = sae_out.shape[1]
    if steering_on:
      # using our steering vector and applying the coefficient
      # conduct on the activations of the tokens that are already generated
      resid_pre[:, :position - 1, :] += coeff * steering_vector


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs)
    return result


In [None]:
def run_generate(example_prompt):
  model.reset_hooks()
  editing_hooks = [(f"blocks.{layer}.hook_resid_pre", steering_hook)]
  res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)

  # Print results, removing the ugly beginning of sequence token
  res_str = model.to_string(res[:, 1:])
  print(("\n\n" + "-" * 80 + "\n\n").join(res_str))

In [36]:
steering_on = True
run_generate(example_prompt)

100%|██████████| 50/50 [00:00<00:00, 77.15it/s]

What do we find in space?

The Jedi are the most common species of Jedi, and they are the only ones who have been known to be able to survive on a planet. They are not considered sentient by most people, but they can be found in many different forms.

--------------------------------------------------------------------------------

What do we find in space?

We've been told that the Jedi are not only good at combat, but also have a great sense of humor. We've seen them play a lot of games and even read books about them. But what if they were just as good at

--------------------------------------------------------------------------------

What do we find in space?

We've been told that the Jedi are not as powerful as they appear. They're just more powerful than the Jedi, and they're also less intelligent than their Jedi counterparts. But what about the other races? The answer is simple: there





In [37]:
steering_on = False
run_generate(example_prompt)

100%|██████████| 50/50 [00:00<00:00, 91.49it/s]

What do we find in space?

The world is full of surprises. The first thing that comes to mind is the mysterious and mysterious nature of our universe. We are constantly looking for new things to explore, but there are so many unknowns that we can't even begin to

--------------------------------------------------------------------------------

What do we find in space?

The first thing we find is the density of matter. The density of matter is the ratio between the mass and mass of a substance. This means that if you have a liquid, it will be denser than water, but if you have

--------------------------------------------------------------------------------

What do we find in space?

The most common questions we get are: "How does the Earth work?" and "How does it work on Mars?" The answer is that the Earth is a very complex system. It has many different types of planets, moons, satellites,





In [38]:
question_prompt = "What is on your mind?"
coeff = 100
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

In [39]:
steering_on = True
run_generate(question_prompt)

100%|██████████| 50/50 [00:00<00:00, 84.92it/s]

What is on your mind?

The Jedi are a group of Jedi who have been trained by the Sith Lord Darth Vader. They are known as the "Dark Jedi" and are considered to be among the most powerful in the galaxy. The Dark Jedi were created by Darth Vader

--------------------------------------------------------------------------------

What is on your mind?

I'm not sure if you're aware of the Jedi Order, but I've been a Jedi since I was a child. My parents were both Jedi and my father was an Imperial Knight. I remember being raised by my father as a boy

--------------------------------------------------------------------------------

What is on your mind?

If you're looking for a way to get started with the Jedi Order, then this is the place to start. The Jedi Order has been around since at least the Clone Wars and it's been a long time since we've seen any of





In [40]:
steering_on = False
run_generate(question_prompt)

100%|██████████| 50/50 [00:00<00:00, 90.66it/s]

What is on your mind?

I'm not sure what I'm thinking. I've been reading a lot of books about the Bible and it's hard to find anything that's really relevant to me. So, I'm just going to try and get some information out there

--------------------------------------------------------------------------------

What is on your mind?

I'm not sure what I'm thinking. I've been thinking about this for a while now, and it's something that I think about a lot. It's something that I think about when you're in the middle of an argument with

--------------------------------------------------------------------------------

What is on your mind?

I'm not sure what I'm thinking. I've been thinking about this for a while now, and it's something that I've been trying to figure out for a while now. It's kind of like the "what if" question



