In [41]:
import re
import pandas as pd


from huggingface_hub import HfApi
import os


def list_repo_files(repo_id):
    api = HfApi()
    repo_files = api.list_repo_files(repo_id)
    return repo_files


files = list_repo_files(repo_id)

# print(f"Files in the repository '{repo_id}':")
# for file in files:
#     print(file)


def get_details_from_file_path(file_path):
    """
    eg: layer_11/width_16k/average_l0_79

    layer = 11
    width = 16k
    l0_or_canonical = "79"

    or if layer_11/width_16k/canonical

    layer = 11
    width = 16k
    l0_or_canonical = "canonical"

    or if layer_11/width_1m/average_l0_79

    layer = 11
    width = 1m
    l0_or_canonical = "79"
    """

    layer = re.search(r"layer_(\d+)", file_path).group(1)
    width = re.search(r"width_(\d+[k|m])", file_path).group(1)
    l0 = re.search(r"average_l0_(\d+)", file_path)
    if l0:
        l0 = l0.group(1)
    else:
        l0 = re.search(r"(canonical)", file_path).group(1)

    return layer, width, l0


# # test it
# file_path = 'layer_11/width_16k/average_l0_79'
# layer, width, l0 = get_details_from_file_path(file_path)
# print(f"layer: {layer}, width: {width}, l0: {l0}")


# file_path = 'layer_11/width_16k/canonical'
# layer, width, l0 = get_details_from_file_path(file_path)
# print(f"layer: {layer}, width: {width}, l0: {l0}")


# file_path = 'layer_11/width_1m/canonical'
# layer, width, l0 = get_details_from_file_path(file_path)
# print(f"layer: {layer}, width: {width}, l0: {l0}")


def generate_entries(repo_id):
    entries = []
    files = list_repo_files(repo_id)
    for file in files:
        if "params.npz" in file:
            entry = {}
            # print(file)
            layer, width, l0 = get_details_from_file_path(file)
            folder_path = os.path.dirname(file)
            entry["repo_id"] = repo_id
            entry["id"] = folder_path
            entry["path"] = folder_path
            entry["l0"] = l0
            entry["layer"] = layer
            entry["width"] = width

            entries.append(entry)
    return entries


def df_to_yaml(df, file_path, canonical=False):
    """
    EXAMPLE STRUCTURE:

    gemma-scope-2b-pt-res:
    repo_id: google/gemma-scope-2b-pt-res
    model: gemma-2-2b
    conversion_func: gemma_2
    saes:
        - id: layer_11/width_16k/average_l0_79
        path: layer_11/width_16k/average_l0_79
        l0: 79.0

    """
    repo_id = df.iloc[0]["repo_id"]
    release_id = (
        repo_id.split("/")[1] + "-canonical" if canonical else repo_id.split("/")[1]
    )
    with open(file_path, "w") as f:

        f.write(f"{release_id}:\n")
        f.write(f"  repo_id: {repo_id}\n")
        f.write(f"  model: gemma-2-2b\n")
        f.write(f"  conversion_func: gemma_2\n")
        f.write(f"  saes:\n")
        for index, row in df.iterrows():
            f.write(f"    - id: {row['id']}\n")
            f.write(f"      path: {row['path']}\n")
            if row["l0"] != "canonical":
                f.write(f"      l0: {row['l0']}\n")
            # f.write(f"      l0: {row['l0']}\n")
            # f.write(f"      layer: {row['layer']}\n")
            # f.write(f"      width: {row['width']}\n")
            f.write("\n")


repo_ids = [
    "google/gemma-scope-2b-pt-res",
    "google/gemma-scope-2b-pt-mlp",
    "google/gemma-scope-2b-pt-att",
    "google/gemma-scope-9b-pt-res",
    "google/gemma-scope-9b-pt-mlp",
    "google/gemma-scope-9b-pt-att",
    "google/gemma-scope-27b-pt-res",
]

for repo_id in repo_ids:

    entries = generate_entries(repo_id)

    df = pd.DataFrame(entries)
    df["layer"] = pd.to_numeric(df["layer"])
    df.sort_values(by=["width", "layer", "l0"], inplace=True)
    df.head(30)

    canonical_only_df = df[df["l0"] == "canonical"]
    non_canonical_df = df[df["l0"] != "canonical"]

    df_to_yaml(
        non_canonical_df, f'{repo_id.split("/")[1]}_not_canonical.yaml', canonical=False
    )
    if canonical_only_df.shape[0] == 0:
        print(f"No canonical entries found in {repo_id.split('/')[1]}")
        continue
    else:
        df_to_yaml(
            canonical_only_df,
            f'{repo_id.split("/")[1]}_canonical_only.yaml',
            canonical=True,
        )

    # !cat canonical_only.yaml

No canonical entries found in gemma-scope-2b-pt-att
No canonical entries found in gemma-scope-9b-pt-res
No canonical entries found in gemma-scope-9b-pt-mlp
No canonical entries found in gemma-scope-9b-pt-att
No canonical entries found in gemma-scope-27b-pt-res


In [5]:
import os
from huggingface_hub import snapshot_download
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from pathlib import Path


local_dir = snapshot_download(repo_id, use_auth_token=True)

print(f"Repository downloaded to: {local_dir}")

# Function to generate entries for the YAML file


# Path to the YAML file
yaml_file = "pretrained_saes.yaml"

# Initialize yamel.yaml
yaml = YAML()
yaml.preserve_quotes = True
yaml.indent(mapping=2, sequence=4, offset=2)

# Read the existing YAML file
with open(yaml_file, "r") as file:
    data = yaml.load(file)

# Generate new entries
new_entries = generate_entries(local_dir)

# Create a CommentedMap for gemmascope-2b-pt-res
gemmascope_data = CommentedMap()
gemmascope_data["repo_id"] = "gg-hf/gemmascope-2b-pt-res"
gemmascope_data["model"] = "gemma-2-2b"
gemmascope_data["conversion_func"] = "gemma_2"
gemmascope_data["saes"] = new_entries

# Remove the existing gemmascope-2b-pt-res entry if it exists
if "SAE_LOOKUP" in data and "gemmascope-2b-pt-res" in data["SAE_LOOKUP"]:
    del data["SAE_LOOKUP"]["gemmascope-2b-pt-res"]

# Add gemmascope-2b-pt-res at the end
data["SAE_LOOKUP"]["gemmascope-2b-pt-res"] = gemmascope_data

# Write the updated YAML file
with open(yaml_file, "w") as file:
    yaml.dump(data, file)

print(f"YAML file updated: {yaml_file}")

Fetching 369 files:   2%|▏         | 9/369 [00:05<03:54,  1.53it/s]


KeyboardInterrupt: 

In [3]:
from sae_lens import HookedSAETransformer, SAE

device = "cuda"

model = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-9b-pt-mlp",  # <- Release name
    sae_id="layer_2/width_131k/average_l0_12",  # <- SAE id (not always a hook point!)
    device=device,
)

Downloading shards: 100%|██████████| 3/3 [00:33<00:00, 11.03s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.95it/s]


Loaded pretrained model gemma-2-2b into HookedTransformer


In [4]:
cache = model.run_with_cache("test")[1]
sae_in = cache[sae.cfg.hook_name]

In [9]:
sae.hook_z_reshaping_mode

True

In [6]:
sae_in.shape

torch.Size([1, 2, 8, 256])

In [8]:
sae(sae_in).shape

torch.Size([1, 2, 8, 256])