## Kiran Notes:
Do/keep these in mind before running all the cells
1. Login to huggingface account, get secret key, add it to this notebook (needed for some model weights like gemma-2b). Might also need to accept terms and conditions on that HuggingFace page.
2. Full Walkthrough:
    1. Go to Neuronpedia.com/search
        1. Put model, which sae to use, layer number
    2. Find indices for the sae feature vectors you want
        1. Do this by prompting say "I like cats," then click on the cats part of it there
        2. Should show different vectors corresponding
        3. Pick one and note its index
    3. Go into notebook, load pretrained SAE
        1. The vector should just be sae.W_dec[vector_index]
    4. To sanity check the vector is correct
        1. Go to https://www.neuronpedia.org/gemma-2-9b-it/steer (for the correct model)
        2. Do "add vector" and add the vector you found at the correct layer, it should be able to steer


## Setting up packages and notebook


### Import and installs


#### Environment Setup


In [1]:
try:
    # for google colab users
    import google.colab  # type: ignore
    from google.colab import output

    COLAB = True
    %pip install sae-lens transformer-lens pandas
except:
    # for local setup
    COLAB = False
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading

PORT = 8000

# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    """
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    """
    if not (COLAB):
        webbrowser.open(filename)

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(
            PORT, path=f"/{filename}", height=height, cache_in_notebook=True
        )

        PORT += 1

#### General Installs and device setup


In [2]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

# device setup
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


### Load your model and SAE

We're going to work with a pretrained GPT2-small model, and the RES-JB SAE set which is for the residual stream.


In [3]:
from transformer_lens import HookedTransformer
from sae_lens import SAE

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 12

# get model
model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [4]:
# get the SAE for this layer
sae, cfg_dict, _ = SAE.from_pretrained(
    release="gemma-2-2b-res-matryoshka-dc", sae_id=f"blocks.{layer}.hook_resid_post", device=device
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

cfg.json:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/604M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/131k [00:00<?, ?B/s]

blocks.12.hook_resid_post


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


### Get sae labeled vector


In [None]:
index = 5612 # hahahshdasuhdoaushdasid
vector = sae.W_dec[index]

TypeError: only integer tensors of a single element can be converted to an index

In [15]:
torch.set_printoptions(threshold=10_000)
l = [str(x.item()) for x in vector]
s = ', '.join(l)
s = f"[{s}]"
print(s)

[0.007212216965854168, -0.004830700345337391, -0.007925352081656456, -0.0040094805881381035, 0.0017129123443737626, -0.010327393189072609, -0.0035030951257795095, -0.008199452422559261, -0.0016925865784287453, 0.011705713346600533, 0.0024807523004710674, -0.0005987276090309024, 0.0015520796878263354, 0.010408670641481876, 0.016110964119434357, 0.005073183216154575, -0.014809477142989635, 0.004627746529877186, 0.012781382538378239, 0.0009600270423106849, 0.004976531956344843, 0.00437548290938139, 0.0076545001938939095, 0.005622220225632191, -0.007180842570960522, 0.012978986836969852, 0.0091465525329113, -0.011791808530688286, 0.002551150741055608, 0.004394683055579662, -0.0021134724374860525, -0.003127015894278884, 0.002304886933416128, 0.10745819658041, -0.0028806994669139385, 0.01077902503311634, 0.0019503601361066103, 0.007336829788982868, 0.0006057784776203334, -0.0012253146851435304, 0.004230687394738197, 0.0037181172519922256, 0.003319409443065524, -0.005152077879756689, -0.01124

In [7]:
# Fit PCA to features then save/try loading
from sklearn.decomposition import PCA
import pickle

# Run PCA
pca = PCA(n_components=2)
data = sae.W_dec # Might have to take data off GPU/tensor
pca.fit(data.to("cpu"))

# Save PCA model
with open('pca_model.pkl', 'wb') as f:
    pickle.dump(pca, f)

# Later, to load the model:
with open('pca_model.pkl', 'rb') as f:
    loaded_pca = pickle.load(f)

In [8]:
loaded_pca

In [32]:
from transformers import AutoTokenizer

guess = "apple"

feature_index = 1644 # water feature
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
prompt = f"Repeat exactly: {guess}"

tokens = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = tokens["input_ids"]

# Do inference with cache
logits, cache = model.run_with_cache(input_ids)

# Access the layer activations
layer_12_activations = cache[hook_point][0, -1, :]

# MIGHT NEED TO RESHAPE layer_12_activations
sae_activations = sae.encode(layer_12_activations)
feature_activation = sae_activations[feature_index]

pre_pca_activations = layer_12_activations.reshape(1, -1).to("cpu")

# Project the layer output vector into 2D
projection = pca.transform(pre_pca_activations)[0].tolist()


In [10]:
sae_activations
sae_activations.shape

torch.Size([32768])

In [26]:
feature_activation

tensor(9.0465, device='cuda:0')

In [27]:
projection

[10.271209428924237, 17.199737734641086]

In [29]:
projection

[6.6524617584048915, 16.493429598970916]

In [31]:
projection

[16.855471362340545, 9.650907503185982]

In [34]:
projection

[14.213539485490838, 5.578465499653612]