In [1]:
%load_ext autoreload
%autoreload 2

# needed for set_determinism
%set_env CUBLAS_WORKSPACE_CONFIG=:16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


In [2]:
! huggingface-cli login --token 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

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /home/paperspace/.cache/huggingface/token
Login successful


In [3]:
import os
import gc
import torch
import pandas as pd
from tqdm import tqdm
import requests
import plotly.express as px
from datasets import Dataset, load_dataset
from typing import cast
import torch.nn.functional as F
import numpy as np
import random

from sae_lens import SAE
from transformer_lens import HookedTransformer

In [4]:
def clean_cache():
    torch.cuda.empty_cache()
    gc.collect()


def load_pretokenized_dataset(
    path: str,
    split: str,
) -> Dataset:
    dataset = load_dataset(path, split=split)
    dataset = cast(Dataset, dataset)
    return dataset.with_format("torch")


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def get_device_str() -> str:
    if torch.backends.mps.is_available():
        return "mps"
    else:
        return "cuda" if torch.cuda.is_available() else "cpu"


def download_sae_feature_explanations():

    url = "https://www.neuronpedia.org/api/explanation/export"

    # payload = {
    #     "modelId": "gpt2-small",
    #     "saeId": "7-res-jb",
    # }
    payload = {"modelId": "gemma-2-2b", "saeId": "20-gemmascope-res-16k"}
    headers = {"Content-Type": "application/json"}

    response = requests.post(url, json=payload, headers=headers)

    # convert to pandas
    explanations_df = pd.DataFrame(response.json()["explanations"])
    # rename index to "feature"
    explanations_df.rename(columns={"index": "feature"}, inplace=True)
    # explanations_df["feature"] = explanations_df["feature"].astype(int)
    explanations_df["description"] = explanations_df["description"].apply(
        lambda x: x.lower()
    )
    return explanations_df

In [5]:
set_seed(2462462)

device = get_device_str()
print(device)

batch_size = 16

cuda


In [6]:
dataset = load_pretokenized_dataset(
    path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train"
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/75 [00:00<?, ?it/s]

In [7]:
# model = HookedTransformer.from_pretrained("gemma-2-2b", device = device)
# logits, activations = model.run_with_cache("Hello World")

In [11]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",  # <- Release name
    sae_id="layer_1/width_16k/average_l0_102",  # <- SAE id (not always a hook point!)
    device=device,
)

---- {'layer_11/width_16k/average_l0_79': 'layer_11/width_16k/average_l0_79', 'layer_1/width_16k/average_l0_102': 'layer_1/width_16k/average_l0_102', 'layer_12/width_65k/average_l0_72': 'layer_12/width_65k/average_l0_72', 'layer_12/width_16k/average_l0_82': 'layer_12/width_16k/average_l0_82', 'layer_10/width_16k/average_l0_77': 'layer_10/width_16k/average_l0_77', 'layer_0/width_16k/average_l0_106': 'layer_0/width_16k/average_l0_106', 'layer_13/width_16k/average_l0_83': 'layer_13/width_16k/average_l0_83', 'layer_13/width_65k/average_l0_74': 'layer_13/width_65k/average_l0_74', 'layer_18/width_16k/average_l0_74': 'layer_18/width_16k/average_l0_74', 'layer_18/width_65k/average_l0_117': 'layer_18/width_65k/average_l0_117', 'layer_19/width_16k/average_l0_73': 'layer_19/width_16k/average_l0_73', 'layer_19/width_65k/average_l0_115': 'layer_19/width_65k/average_l0_115', 'layer_14/width_16k/average_l0_83': 'layer_14/width_16k/average_l0_83', 'layer_2/width_16k/average_l0_141': 'layer_2/width_16k

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]