## Kiran Notes:
Do/keep these in mind before running all the cells
1. Login to huggingface account, get secret key, add it to this notebook (needed for some model weights like gemma-2b). Might also need to accept terms and conditions on that HuggingFace page.
2. Full Walkthrough:
    1. Go to Neuronpedia.com/search
        1. Put model, which sae to use, layer number
    2. Find indices for the sae feature vectors you want
        1. Do this by prompting say "I like cats," then click on the cats part of it there
        2. Should show different vectors corresponding
        3. Pick one and note its index
    3. Go into notebook, load pretrained SAE
        1. The vector should just be sae.W_dec[vector_index]
    4. To sanity check the vector is correct
        1. Go to https://www.neuronpedia.org/gemma-2-9b-it/steer (for the correct model)
        2. Do "add vector" and add the vector you found at the correct layer, it should be able to steer


## Setting up packages and notebook


### Import and installs


#### Environment Setup


In [1]:
try:
    # for google colab users
    import google.colab  # type: ignore
    from google.colab import output

    COLAB = True
    %pip install sae-lens transformer-lens pandas
except:
    # for local setup
    COLAB = False
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading

PORT = 8000

# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);



In [2]:
def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    if not (COLAB):
        webbrowser.open(filename)

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(
            PORT, path=f"/{filename}", height=height, cache_in_notebook=True
        )

        PORT += 1

#### General Installs and device setup


In [3]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

# device setup
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


### Load your model and SAE

We're going to work with a pretrained GPT2-small model, and the RES-JB SAE set which is for the residual stream.


In [None]:
from transformer_lens import HookedTransformer
from sae_lens import SAE

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 6

# get model
model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)

# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release="gemma-2-2b-res-matryoshka-dc", sae_id=f"blocks.{layer}.hook_resid_post", device=device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Get sae labeled vector


In [None]:
index = 69 # hahahshdasuhdoaushdasid
vector = sae.W_dec[index]