In [1]:
from sae_lens import SAE
import torch
from datasets import load_dataset
from transformer_lens import HookedTransformer
import os
import plotly.express as px
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.7.hook_resid_pre",
    device=DEVICE,
)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
decoder_matrix = sae.W_dec
print(decoder_matrix.shape[0], "dictionary atoms")
print(decoder_matrix.shape[1], "mlp neurons")

24576 dictionary atoms
768 mlp neurons


In [4]:
#delete model and sae to save memory
del sae

In [5]:
import torch.nn.functional as F
import pickle
from tqdm import tqdm

def graph_cluster_sims(
        all_sims, top_k, sim_cutoff=0.5, prune_clusters=False
):
    """
    Create a graph from similarity scores, keeping only the top_k neighbors
    """
    near_neighbors = torch.topk(all_sims, k=top_k, dim=1)

    graph = [[] for _ in range(all_sims.shape[0])]

    for i in range(all_sims.shape[0]):
        top_indices = near_neighbors.indices[i]
        top_sims = near_neighbors.values[i]
        top_indices = top_indices[top_sims > sim_cutoff]
        graph[i] = top_indices.tolist()

    for i in tqdm(range(all_sims.shape[0])):
        for j in graph[i]:
            if i not in graph[j]:
                graph[j].append(i)

    visited = [False] * all_sims.shape[0]
    components = []
    for i in range(all_sims.shape[0]):
        if visited[i]:
            continue
        component = []
        stack = [i]
        while stack:
            node = stack.pop()
            if visited[node]:
                continue
            visited[node] = True
            component.append(node)
            stack.extend(graph[node])
        components.append(component)

    if prune_clusters:
        threshold = 1000
        components = [c for c in components if len(c) < threshold and len(c) > 1]

    with open(
        f"clusters_{top_k}_sim_cutoff_{sim_cutoff}.pkl", "wb"
    ) as f:
        pickle.dump(components, f)
    print(f"Saved clusters to clusters_{top_k}_sim_cutoff_{sim_cutoff}.pkl")

similarity_matrix = decoder_matrix @ decoder_matrix.T
similarity_matrix.fill_diagonal_(0)

model_name = "gpt2-small-res-jb"
layer = 7

graph_cluster_sims(
    all_sims=torch.tensor(similarity_matrix),
    top_k=2,
    sim_cutoff=0.5,
    prune_clusters=True
)

  all_sims=torch.tensor(similarity_matrix),
100%|██████████| 24576/24576 [00:00<00:00, 2657365.69it/s]

Saved clusters to clusters_2_sim_cutoff_0.5.pkl



