In [1]:
%load_ext autoreload
%autoreload 2

# needed for set_determinism
%set_env CUBLAS_WORKSPACE_CONFIG=:16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


## Setup

In [2]:
! huggingface-cli login --token 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

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /home/paperspace/.cache/huggingface/token
Login successful


In [5]:
# ! pip install torch plotly-express numpy sae-lens transformer-lens pandas

In [3]:
import os
import gc
import torch
import pandas as pd
from tqdm import tqdm
import requests
import plotly.express as px
from datasets import Dataset, load_dataset
from typing import cast
import torch.nn.functional as F
import numpy as np
import random
from huggingface_hub import hf_hub_download

from sae_lens import SAE
from transformer_lens import HookedTransformer

In [4]:
def clean_cache():
    torch.cuda.empty_cache()
    gc.collect()


def load_pretokenized_dataset(
    path: str,
    split: str,
) -> Dataset:
    dataset = load_dataset(path, split=split)
    dataset = cast(Dataset, dataset)
    return dataset.with_format("torch")


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def get_device_str() -> str:
    if torch.backends.mps.is_available():
        return "mps"
    else:
        return "cuda" if torch.cuda.is_available() else "cpu"


def download_sae_feature_explanations():
    url = "https://www.neuronpedia.org/api/explanation/export"

    # payload = {
    #     "modelId": "gpt2-small",
    #     "saeId": "7-res-jb",
    # }
    payload = {"modelId": "gemma-2-2b", "saeId": "20-gemmascope-res-16k"}
    headers = {"Content-Type": "application/json"}

    response = requests.post(url, json=payload, headers=headers)

    # convert to pandas
    explanations_df = pd.DataFrame(response.json()["explanations"])
    # rename index to "feature"
    explanations_df.rename(columns={"index": "feature"}, inplace=True)
    # explanations_df["feature"] = explanations_df["feature"].astype(int)
    explanations_df["description"] = explanations_df["description"].apply(lambda x: x.lower())
    return explanations_df


def get_most_likely_tokens(logits, tokenizer, top_k=5):
    logits = logits[0, -1, :]
    probs = torch.softmax(logits, dim=-1)
    top_k_probs, top_k_indices = torch.topk(probs, k=top_k)

    for prob, index in zip(top_k_probs, top_k_indices):
        token = tokenizer.decode([index])
        print(f"Token: '{token}', Probability: {prob.item():.4f}")

    return tokenizer.decode([top_k_indices[0]])

In [5]:
set_seed(5325)

device = get_device_str()
print(device)

batch_size = 16

cuda


## Load Gemma 2 2b model

In [6]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = HookedTransformer.from_pretrained("gemma-2-2b", device = device)
logits, activations = model.run_with_cache("Hello World")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [7]:
get_most_likely_tokens(logits, model.tokenizer, top_k=5)

Token: '!', Probability: 0.5286
Token: ',', Probability: 0.1997
Token: '.', Probability: 0.0678
Token: '

', Probability: 0.0312
Token: '/', Probability: 0.0229


'!'

In [93]:
# The input text
prompt = "1 + 1 = "

# Pass it in to the model and generate text
logits, activations = model.run_with_cache(prompt)
# print(model.tokenizer.decode(outputs[0]))

In [94]:
get_most_likely_tokens(logits, model.tokenizer, top_k=5)

Token: '2', Probability: 0.6732
Token: '3', Probability: 0.1358
Token: '1', Probability: 0.0824
Token: '5', Probability: 0.0287
Token: '4', Probability: 0.0249


'2'

## Load a Gemma 2 SAE

In [22]:
# TODO Update Link to be standard download format
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res", # TODO on release s/gg-hf/google/
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [23]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}

In [24]:
import torch.nn as nn

class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon


In [25]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae.to(device)

JumpReLUSAE()

## Loading the data

In [11]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    max_length=128,
    add_bos_token=True,
)

## Running experiments

In [16]:
feature_explanations_df = download_sae_feature_explanations()
feature_explanations_df['feature'] = feature_explanations_df.feature.astype(int)
feature_explanations_df.head()

Unnamed: 0,modelId,layer,feature,description,scoreV1,scoreV2,autoInterpModel
0,gemma-2-2b,20-gemmascope-res-16k,1,phrases emphasizing continuity or progression ...,0,,gpt-4o-mini
1,gemma-2-2b,20-gemmascope-res-16k,2,elements related to interviews and interrogati...,0,,gpt-4o-mini
2,gemma-2-2b,20-gemmascope-res-16k,3,"instances of the verb ""to be"" in various forms",0,,gpt-4o-mini
3,gemma-2-2b,20-gemmascope-res-16k,4,"temporal references such as days, weeks, mont...",0,,gpt-4o-mini
4,gemma-2-2b,20-gemmascope-res-16k,5,words related to trends and patterns in variou...,0,,gpt-4o-mini


In [115]:
# The input text
prompt = "Would you be able to travel through time using a wormhole?"

# Pass it in to the model and generate text
logits, cache = model.run_with_cache(prompt)

In [123]:
l_prompt = len(prompt)
outputs = model.generate(prompt, max_new_tokens=50)[l_prompt + 1:]

  0%|          | 0/50 [00:00<?, ?it/s]

In [124]:
print(outputs)

We have information that this is possible and we can figure out the technology that would need to be put in place. And here's something: It's possible to build the most advanced spaceships that ever existed and it won't even be


In [83]:
# for k in cache.keys():
#     if "blocks.5." in k:
#         print(k)

In [125]:
layer = "blocks.1.hook_resid_pre"
pos = -1
activations = cache[layer][:, pos, :]

In [126]:
feature_acts = sae.encode(activations).squeeze().cpu()
active_feature_ids = feature_acts.nonzero().squeeze().cpu()
active_feature_acts = feature_acts[active_feature_ids]

In [127]:
sorted_feature_ids = active_feature_ids[torch.argsort(active_feature_acts, descending=True)].cpu().numpy().tolist()

In [129]:
for active_feature_id in sorted_feature_ids[:10]:
    feature_info = dict(feature_explanations_df[feature_explanations_df.feature == active_feature_id].iloc[0])
    feature_desc = feature_info['description']
    feature_act = feature_acts[active_feature_id]
    print(f"F={active_feature_id} [{feature_act:.2f}] {feature_desc}")

F=12054 [44.46]  phrases expressing uncertainty or opinion
F=1508 [32.36] features related to product quality and durability
F=6483 [24.83]  expressions of quantity or frequency
F=13412 [24.63] checks and comparisons related to identifying the smallest or largest values
F=14339 [15.18] mathematical notation and structures related to statistics or algorithms
F=12842 [10.08] references to personal pronouns and possessive nouns, particularly related to individuals within a narrative context
F=2153 [7.23] references to significant political events or figures
F=10934 [5.12] curly braces and other related syntactical elements in structured data formats
F=2514 [5.06] statements related to decisions and their implications
F=6308 [5.04] references to specific legislative measures or governmental actions


## Visualise features

In [26]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [27]:
def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [28]:
# The input text
prompt = "Would you be able to travel through time using a wormhole?"

# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda")
print(inputs)

# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

tensor([[     2,  18925,    692,    614,   3326,    577,   5056,   1593,   1069,
           2177,    476,  47420,  18216, 235336]], device='cuda:0')
<bos>Would you be able to travel through time using a wormhole?

[Answer 1]

Yes, you can travel through time using a wormhole.

A wormhole is a theoretical object that connects two points in space-time. It is a tunnel through space-time that allows objects to travel from


In [29]:
target_act = gather_residual_activations(model, 20, inputs)

In [30]:
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [31]:
# Let's just double check that the model looks sensible by checking that we explain a decent chunk of the variance:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())

tensor(0.8887, device='cuda:0')

In [32]:
# This probably looks OK! This SAE is supposed to have an L0 of around 70, so let's just check that too:
(sae_acts > 1).sum(-1)

tensor([[7017,   47,   65,   70,   55,   72,   65,   75,   80,   72,   68,   93,
           86,   89]], device='cuda:0')

In [33]:
# Let's look at the highest activating features on this input text, on each token position:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631,  5482, 10376,  1670, 11023,  7562,  9407,  8399, 12935, 10004,
         10004, 10004, 12935,  3442]], device='cuda:0')

In [34]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)