In [1]:
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from tabulate import tabulate
import numpy
import random
from pprint import pprint
from einops import einsum

from typing import Any, Callable, Literal, TypeAlias
from jaxtyping import Float, Int

from IPython.display import HTML, IFrame, display
from datasets import load_dataset
from huggingface_hub import hf_hub_download

import circuitsvis as cv
import sae_lens
from sae_lens import SAE, ActivationsStore, HookedSAETransformer, LanguageModelSAERunnerConfig
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory

from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

from test_support import show_token_scores, show_topk_preds

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x766ddb796e90>

In [3]:
# # SHOW AVAILABLE MODELS


# metadata_rows = [
#     [data.model, len(data.saes_map)]           # [data.model, data.release, data.repo_id, len(data.saes_map)]
#     for data in get_pretrained_saes_directory().values()
# ]

# print(
#     tabulate(
#         sorted(metadata_rows, key=lambda x: x[0]),
#         headers=["model", "n_saes"],
#         tablefmt="simple_outline",
#     )
# )

In [4]:
# # SHOW SAE MODEL INFO

# def format_value(value):
#     return (
#         "{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items())))
#         if isinstance(value, dict)
#         else repr(value)
#     )


# release = get_pretrained_saes_directory()["gpt2-small-res-jb"]

# print(
#     tabulate(
#         [[k, format_value(v)] for k, v in release.__dict__.items()],
#         headers=["Field", "Value"],
#         tablefmt="simple_outline",
#     )
# )

In [5]:
gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small")

gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(
                                                    release="gpt2-small-res-jb",
                                                    sae_id="blocks.7.hook_resid_pre",
                                                )

# pprint(cfg_dict)
print(cfg_dict["d_sae"])

Loaded pretrained model gpt2-small into HookedTransformer
24576


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [6]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))


# latent_idx = 9    # not every new is the same, the new which interests a neuron could be different from a new in your mind
# latent_idx = 24111    # CNN neuron. highly sparse latent, so is interpretable af
# latent_idx = 13    # a mild concpet level feature[activations_density < 0.05]. activates for words like victory, win, winning, and similar kind of shit
# latent_idx = 67    # activates on country's decisions, government policies and like shit. high concept feature[high activations density].

latent_idx = random.randint(0, gpt2_sae.cfg.d_sae)
# display_dashboard(latent_idx=0)

In [7]:
# RUNNING SAEs

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# test_prompt(prompt, answer, gpt2)
show_token_scores(gpt2, prompt, 101)
show_topk_preds(gpt2, prompt, 10)

TOKEN: |ï¿½| RANK: 44527, PROB: 1.0305602725357854e-11
PROB: 69.55%  TOKEN: | priority|
PROB: 9.21%  TOKEN: | effort|
PROB: 5.60%  TOKEN: | issue|
PROB: 4.13%  TOKEN: | challenge|
PROB: 3.17%  TOKEN: | goal|
PROB: 2.33%  TOKEN: | concern|
PROB: 1.93%  TOKEN: | focus|
PROB: 1.48%  TOKEN: | approach|
PROB: 1.37%  TOKEN: | policy|
PROB: 1.22%  TOKEN: | initiative|
