In [3]:
from nnsight import LanguageModel

import gc
import itertools
import math
import os
import random
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

device = "cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu"

In [19]:
## check memory usage

if t.cuda.is_available():
    gpu_id = 0  # Set to your target GPU ID
    total_memory = t.cuda.get_device_properties(gpu_id).total_memory
    allocated_memory = t.cuda.memory_allocated(gpu_id)
    cached_memory = t.cuda.memory_reserved(gpu_id)

    print(f"Total GPU Memory: {total_memory / 1024**2:.2f} MB")
    print(f"Allocated GPU Memory: {allocated_memory / 1024**2:.2f} MB")
    print(f"Cached GPU Memory: {cached_memory / 1024**2:.2f} MB")
elif t.backends.mps.is_available():
    # MPS (Metal Performance Shaders) for Mac
    print("MPS is available.")
    # Note: As of now, PyTorch doesn't provide direct memory management functions for MPS
    print("Memory information is not available for MPS.")
else:
    print("Neither CUDA nor MPS is available.")

Total GPU Memory: 45541.31 MB
Allocated GPU Memory: 15740.26 MB
Cached GPU Memory: 15930.00 MB


In [23]:
# del gemma2
# del gemma2_sae

t.cuda.empty_cache()


# Output tensor of SAE activations for advbench

In [None]:
import json

# Read from advbench.json file
with open('../dataset/processed/advbench.json', 'r') as file:
    advbench_data = json.load(file)

len(advbench_data)

In [19]:
del stacked_sae_acts_post
del gemma2_sae
t.cuda.empty_cache()


In [24]:
t.set_grad_enabled(False)

sae_name = "gemma-scope-2b-pt-res-canonical"
sae_ids = [f"layer_{layer}/width_16k/canonical" for layer in range(11)]  # 0 through 10


gemma2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gemma-2-2b-it", device=device)

import torch
import os

for sae_id in sae_ids:
    gemma2_sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_name,
        sae_id=sae_id,
        device=str(device),
    )

    all_sae_acts_post = []

    for item in advbench_data:
        prompt = item['instruction']
        
        # Get top activations on final token
        _, cache = gemma2.run_with_cache_with_saes(
            prompt,
            saes=[gemma2_sae],
            stop_at_layer=gemma2_sae.cfg.hook_layer + 1,
        )
        sae_acts_post = cache[f"{gemma2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :]
        all_sae_acts_post.append(sae_acts_post)

    # Stack all sae_acts_post tensors
    stacked_sae_acts_post = torch.stack(all_sae_acts_post)

    print(f"SAE ID: {sae_id}")
    print(f"Shape of stacked tensor: {stacked_sae_acts_post.shape}")
    print(f"Total number of non-zero activations: {(stacked_sae_acts_post != 0).sum().item()}")

    # Create directory if it doesn't exist
    os.makedirs(f'../data/sae_acts/{sae_name}/{sae_id}', exist_ok=True)

    # Save the stacked_sae_acts_post tensor
    torch.save(stacked_sae_acts_post, f'../data/sae_acts/{sae_name}/{sae_id}_advbench.pt')

    # Print confirmation message
    print(f"Stacked SAE activations saved to '../data/sae_acts/{sae_name}/{sae_id}_advbench.pt'")
    print("---")

    del stacked_sae_acts_post
    del gemma2_sae
    t.cuda.empty_cache()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# get average activations for each sae

In [4]:
t.set_grad_enabled(False)

gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    device=str(device),
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:
gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Example of how you can use this:
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)

Downloading builder script: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.73k/2.73k [00:00<00:00, 87.8kB/s]
Downloading readme: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.35k/7.35k [00:00<00:00, 180kB/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


In [7]:
def get_frac_active(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    # latent_idx: int,
    total_batches: int = 400,
):
    """
    Displays the activation histogram for a particular latent, computed across `total_batches` batches from `act_store`.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    all_positive_acts = []

    all_positive_acts = t.zeros(sae.cfg.d_sae, device=device)
    total_acts = 0

    for i in tqdm(range(total_batches)):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name]
        all_positive_acts += (acts > 0).sum(dim=(0, 1))
        total_acts += acts.shape[0] * acts.shape[1]

    frac_active = all_positive_acts / total_acts

    return frac_active


frac_active = get_frac_active(gpt2, gpt2_sae, gpt2_act_store)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:25<00:00, 15.64it/s]


In [9]:
frac_active[3731]

tensor(0.1123, device='cuda:0')

In [None]:
t.set_grad_enabled(False)

sae_name = "gemma-scope-2b-pt-res-canonical"
sae_ids = [f"layer_{layer}/width_16k/canonical" for layer in range(11)]  # 0 through 10

gemma2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gemma-2-2b-it", device=device)

In [None]:
for sae_id in sae_ids:
    gemma2_sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_name,
        sae_id=sae_id,
        device=str(device),
    )

    gemma2_act_store = ActivationsStore.from_sae(
        model=gemma2,
        sae=gemma2_sae,
        streaming=True,
        store_batch_size_prompts=16,
        n_batches_in_buffer=32,
        device=str(device),
    )

    frac_active = get_frac_active(gemma2, gemma2_sae, gemma2_act_store)

    # Create directory if it doesn't exist
    os.makedirs(f'../data/sae_acts/{sae_name}/{sae_id}', exist_ok=True)

    # Save the stacked_sae_acts_post tensor
    t.save(frac_active, f'../data/sae_acts/{sae_name}/{sae_id}/frac_active.pt')
    print(f"SAE ID: {sae_id}")
    print(f"Fraction of activations: {frac_active}")

    del gemma2_sae
    t.cuda.empty_cache()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [03:40<00:00,  1.82it/s]


SAE ID: layer_0/width_16k/canonical
Fraction of activations: tensor([0.0323, 0.0013, 0.0100,  ..., 0.0024, 0.0030, 0.0010], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [05:25<00:00,  1.23it/s]


SAE ID: layer_1/width_16k/canonical
Fraction of activations: tensor([0.0011, 0.0016, 0.0276,  ..., 0.0004, 0.0009, 0.0008], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [07:08<00:00,  1.07s/it]


SAE ID: layer_2/width_16k/canonical
Fraction of activations: tensor([0.0012, 0.0075, 0.0009,  ..., 0.0008, 0.0064, 0.0012], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [08:50<00:00,  1.33s/it]


SAE ID: layer_3/width_16k/canonical
Fraction of activations: tensor([0.0032, 0.0028, 0.0156,  ..., 0.0007, 0.0013, 0.0059], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [10:35<00:00,  1.59s/it]


SAE ID: layer_4/width_16k/canonical
Fraction of activations: tensor([0.0007, 0.0138, 0.0029,  ..., 0.0040, 0.0014, 0.0067], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [12:33<00:00,  1.88s/it]


SAE ID: layer_5/width_16k/canonical
Fraction of activations: tensor([1.2868e-02, 1.6076e-02, 1.8198e-03,  ..., 8.3618e-05, 1.6998e-04,
        4.5691e-03], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [14:04<00:00,  2.11s/it]


SAE ID: layer_6/width_16k/canonical
Fraction of activations: tensor([1.7288e-04, 2.4805e-03, 3.6662e-03,  ..., 9.7809e-05, 3.8969e-03,
        3.2462e-03], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [15:48<00:00,  2.37s/it]


SAE ID: layer_7/width_16k/canonical
Fraction of activations: tensor([0.0008, 0.0101, 0.0005,  ..., 0.0010, 0.0026, 0.0043], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [17:32<00:00,  2.63s/it]


SAE ID: layer_8/width_16k/canonical
Fraction of activations: tensor([0.0034, 0.0097, 0.0066,  ..., 0.0054, 0.0058, 0.0058], device='cuda:0')


 24%|███████████████████████████████████▋                                                                                                               | 97/400 [04:40<14:34,  2.89s/it]

In [18]:
del gemma2_sae
t.cuda.empty_cache()