In [1]:
try:
  # for google colab users
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
  # for local setup
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

# general imports
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "15"
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

In [3]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [4]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 16

# get model
model = HookedTransformer.from_pretrained("mistral-7b", device = device)

# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release="mistral-7b-res-wg",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id=f"blocks.{layer}.hook_resid_pre",  # won't always be a hook point
    device=device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.54s/it]


Loaded pretrained model mistral-7b into HookedTransformer
blocks.16.hook_resid_pre


In [5]:

sv_prompt = " The Golden Gate Bridge"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 3))

tensor([[    1, 28705,   415, 14739, 19986, 15050]], device='cuda:0')
torch.return_types.topk(
values=tensor([[[6.0901e+01, 2.6774e+01, 5.6409e-03],
         [4.0106e+01, 2.8247e+01, 1.2325e+01],
         [1.7201e+01, 1.3086e+01, 1.0442e+01],
         [1.9446e+01, 1.3755e+01, 8.6492e+00],
         [1.1657e+01, 9.6264e+00, 7.7555e+00],
         [1.4528e+01, 1.1941e+01, 1.1077e+01]]], device='cuda:0'),
indices=tensor([[[ 9725, 59488, 53699],
         [10855, 61479,  9725],
         [45837, 51897, 53690],
         [54117,  2786, 61771],
         [53015, 49200, 39893],
         [17640, 20150, 39893]]], device='cuda:0'))


In [26]:
steering_vector = sae.W_dec[12590]

example_prompt = "Who actually said Let them eat cake?"
coeff = -300
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

In [27]:
def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return

    position = sae_out.shape[1]
    if steering_on:
      # using our steering vector and applying the coefficient
      resid_pre[:, :position - 1, :] += coeff * steering_vector


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=False,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs)
    return result

In [28]:
def run_generate(example_prompt):
  model.reset_hooks()
  editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)]
  res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)

  # Print results, removing the ugly beginning of sequence token
  res_str = model.to_string(res[:, 1:])
  print(("\n\n" + "-" * 80 + "\n\n").join(res_str))

In [30]:
steering_on = True
run_generate(example_prompt)

100%|██████████| 50/50 [00:03<00:00, 13.02it/s]

Who actually said Let them eat cake?

The phrase “Let them eat cake” is often attributed to Marie Antoinette, the wife of King Louis XVI of France. The story goes that when she was told that the peasants had no bread to eat, she responded with this

--------------------------------------------------------------------------------

Who actually said Let them eat cake?

Marie Antoinette is often credited with the quote, “Let them eat cake,” but it’s not true. The phrase was first used by Jean-Jacques Rousseau in his 1755 novel

--------------------------------------------------------------------------------

Who actually said Let them eat cake?

The phrase “Let them eat cake” is often attributed to Marie Antoinette, the wife of King Louis XVI of France. The story goes that when she was told that the peasants had no bread to eat, she responded with this



