# Dev Code for U-Maps for any given SAE

In [1]:
from umap import UMAP
import hdbscan
import pandas as pd 
import torch 
from sae_lens import SAE
from transformer_lens import HookedTransformer
torch.set_grad_enabled(False)
import plotly.express as px
import os 
from tqdm import tqdm 
import gc
import json


def get_neuronpedia_umap_and_clusters(
    release_id: str = "res-jb",
    sae_id: str = "blocks.11.hook_resid_post",
    n_neighbors_visual: int = 15,
    min_dist_visual: float = 0.05,
    n_neighbors_cluster: float = 15,
    min_dist_cluster: float = 0.1,
    min_cluster_size: int = 3,
    plot: bool = True
):
    '''
    This function will generate UMAP and cluster plots for the SAE specified by the release_id and sae_id.
    The plots will be saved in the output folder.
    
    Args:
    release_id: str
        The release id of the SAE to be visualized.
    sae_id: str
        The sae_id of the SAE to be visualized.
    n_neighbors_visual: int
        The number of neighbors to consider for the UMAP embedding for the visual plot.
    min_dist_visual: float
        The minimum distance between points in the UMAP embedding for the visual plot.
    n_neighbors_cluster: int    
        The number of neighbors to consider for the UMAP embedding for the cluster plot.
    min_dist_cluster: float
        The minimum distance between points in the UMAP embedding for the cluster plot.
    min_cluster_size: int
        The minimum number of points in a cluster.
    plot: bool
        Whether to plot the UMAP embeddings or not.
    '''
    # make output folder:
    os.makedirs(f"output/{release_id}/{sae_id}", exist_ok=True)
    
    sae, _, sparsity = SAE.from_pretrained(
        release = release_id,
        sae_id = sae_id,
        device = "mps",
    )
    sae.fold_W_dec_norm()
    embedding = sae.W_dec.cpu()

    if sparsity is None: 
        sparsity = torch.zeros_like(embedding[:,0])
        
    feature_df = pd.DataFrame(sparsity.cpu(), index = [f"feature_{i}" for i in range(embedding.shape[0])], columns=["sparsity"])
    
    if plot:
        # Assume in t-lens for now
        model = HookedTransformer.from_pretrained_no_processing(sae.cfg.model_name, fold_ln=True)
        W_U = model.W_U
        tokenizer = model.tokenizer
        del model 
        gc.collect()

        if not os.path.exists(f"output/{release_id}/{sae_id}/feature_df.csv"):

            # Get total number of rows in W_dec
            total_rows = sae.W_dec.shape[0]

            results = []

            batch_size = sae.W_dec.shape[0] // 4
            # Process in batches
            for i in range(0, total_rows, batch_size):
                # Calculate end index for current batch
                end_idx = min(i + batch_size, total_rows)
                
                # Process a batch
                batch_result = sae.W_dec[i:end_idx] @ W_U
                
                # Get top k values and indices for the batch
                batch_vals, batch_indices = torch.topk(batch_result, 10)
                
                # Store results
                results.append((batch_vals, batch_indices))

            # Combine results
            token_factors_inds = torch.cat([r[1] for r in results])

            feature_df["tok_token_ids"] = token_factors_inds.tolist()
            feature_df["top_token_strs"] = feature_df["tok_token_ids"].apply(lambda x: [tokenizer.decode([i]) for i in x]) # type: ignore
            feature_df["top_token_strs_formatted"] = feature_df["top_token_strs"].apply(lambda x: ",".join(x))
        else:
            feature_df = pd.read_csv(f"output/{release_id}/{sae_id}/feature_df.csv", index_col=0)
            feature_df["tok_token_ids"] = feature_df["tok_token_ids"].apply(lambda x: eval(x))
            feature_df["top_token_strs"] = feature_df["top_token_strs"].apply(lambda x: eval(x))
            feature_df["top_token_strs_formatted"] = feature_df["top_token_strs"].apply(lambda x: ",".join(x))


    # 2. Visual UMAP
    print("Calculating 2D UMAP")
    visual_umap = UMAP(n_components=2, n_neighbors=n_neighbors_visual, min_dist=min_dist_visual, metric='cosine')
    embedding = sae.W_dec.cpu()
    visual_umap_embedding = visual_umap.fit_transform(embedding)

    feature_df["umap_x"] = visual_umap_embedding[:,0] # type: ignore
    feature_df["umap_y"] = visual_umap_embedding[:,1] # type: ignore

    # 3: Cluster UMAP
    print("Calculating 10D UMAP")
    clustering_umap = UMAP(n_components=10, n_neighbors=n_neighbors_cluster, min_dist=min_dist_cluster, metric='cosine')
    clustering_umap_embedding = clustering_umap.fit_transform(embedding)
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size)
    clusterer.fit(clustering_umap_embedding)

    feature_df["cluster"] = clusterer.labels_
    feature_df.sort_values("cluster", inplace=True)
    feature_df["cluster"] = feature_df["cluster"].astype(str)

    if plot:
        print("Plotting")
        fig = px.scatter(
            feature_df,
            x="umap_x",
            y="umap_y",
            color="cluster",
            height=1200,
            width =1600,
            hover_data= ["top_token_strs_formatted"] if "top_token_strs_formatted" in feature_df.columns else None,
        )
            
        # reduce point size 
        fig.update_traces(marker=dict(size=2))
        fig.write_html(f"output/{release_id}/{sae_id}/umap_.html")

    print("Saving")
    feature_df.to_csv(f"output/{release_id}/{sae_id}/feature_df.csv")
    
    # save config as well (in json)
    cfg = {
        "release_id": release_id,
        "sae_id": sae_id,
        "n_neighbors_visual": n_neighbors_visual,
        "min_dist_visual": min_dist_visual,
        "n_neighbors_cluster": n_neighbors_cluster,
        "min_dist_cluster": min_dist_cluster,
        "min_cluster_size": min_cluster_size,
    }
    with open(f"output/{release_id}/{sae_id}/umap_cfg.json", "w") as f:
        json.dump(cfg, f)
    


In [11]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

all_loadable_saes = []
saes_directory = get_pretrained_saes_directory()
for release, lookup in tqdm(saes_directory.items()):
    if release not in ["gpt2-small-res-jb", "sae-llama-3-8b-eai","gpt2-small-resid-post-v5-128k"]:
        for sae_name in lookup.saes_map.keys():
            print(f"Running {release} {sae_name}")
            get_neuronpedia_umap_and_clusters(release, sae_name, plot=True)


  0%|          | 0/10 [00:00<?, ?it/s]

Running gpt2-small-hook-z-kk blocks.0.hook_z
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-hook-z-kk blocks.1.hook_z
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-hook-z-kk blocks.2.hook_z
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-hook-z-kk blocks.3.hook_z
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-hook-z-kk blocks.4.hook_z
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-hook-z-kk blocks.5.hook_z
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-hook-z-kk blocks.6.ho

 20%|██        | 2/10 [04:42<18:50, 141.35s/it]

Saving
Running gpt2-small-mlp-tm blocks.0.hook_mlp_out
Loaded pretrained model gpt2 into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-mlp-tm blocks.1.hook_mlp_out
Loaded pretrained model gpt2 into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-mlp-tm blocks.2.hook_mlp_out
Loaded pretrained model gpt2 into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-mlp-tm blocks.3.hook_mlp_out
Loaded pretrained model gpt2 into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-mlp-tm blocks.4.hook_mlp_out
Loaded pretrained model gpt2 into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-mlp-tm blocks.5.hook_mlp_out
Loaded pretrained model gpt2 into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-mlp-tm blocks.6.hook_mlp_out
Loa

 30%|███       | 3/10 [09:05<22:24, 192.04s/it]

Running gpt2-small-res-jb-feature-splitting blocks.8.hook_resid_pre_768
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-res-jb-feature-splitting blocks.8.hook_resid_pre_1536
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-res-jb-feature-splitting blocks.8.hook_resid_pre_3072
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-res-jb-feature-splitting blocks.8.hook_resid_pre_6144
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-res-jb-feature-splitting blocks.8.hook_resid_pre_12288
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-res-jb-feature-splitting blocks.8.hook_re

 40%|████      | 4/10 [12:39<20:00, 200.12s/it]

Running gpt2-small-resid-post-v5-32k blocks.0.hook_resid_post
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-resid-post-v5-32k blocks.1.hook_resid_post
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-resid-post-v5-32k blocks.2.hook_resid_post
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-resid-post-v5-32k blocks.3.hook_resid_post
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-resid-post-v5-32k blocks.4.hook_resid_post
Loaded pretrained model gpt2-small into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gpt2-small-resid-post-v5-32k blocks.5.hook_resid_post
Loaded pretrained model gpt2-small into HookedTransfo

 50%|█████     | 5/10 [18:14<20:35, 247.15s/it]

Running gemma-2b-res-jb blocks.0.hook_resid_post


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gemma-2b-res-jb blocks.6.hook_resid_post


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting
Saving
Running gemma-2b-res-jb blocks.12.hook_resid_post


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting


 70%|███████   | 7/10 [23:52<10:24, 208.02s/it]

Saving
Running gemma-2b-it-res-jb blocks.12.hook_resid_post


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b-it into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP
Plotting


 80%|████████  | 8/10 [25:56<06:12, 186.17s/it]

Saving
Running mistral-7b-res-wg blocks.8.hook_resid_pre


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

Loaded pretrained model mistral-7b into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Plotting
Saving
Running mistral-7b-res-wg blocks.16.hook_resid_pre


sae_weights.safetensors:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model mistral-7b into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Plotting
Saving
Running mistral-7b-res-wg blocks.24.hook_resid_pre


sae_weights.safetensors:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model mistral-7b into HookedTransformer
Calculating 2D UMAP
Calculating 10D UMAP


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Plotting
Saving


100%|██████████| 10/10 [48:12<00:00, 289.28s/it]
100%|██████████| 10/10 [48:12<00:00, 289.28s/it]
