# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [1]:
import torch
from sae_lens import SAE 
from transformer_lens import HookedTransformer
from sae_vis.sae_vis_data import SaeVisConfig
from sae_vis.sae_vis_runner import SaeVisRunner

## Step 1. Download / Initialize SAE

In [2]:


# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

model = HookedTransformer.from_pretrained("gemma-2b", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.0.hook_resid_post", # won't always be a hook point
    device = device
)
# fold w_dec norm so feature activations are accurate
sae.fold_W_dec_norm()

Device: mps


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


sparsity.safetensors:   0%|          | 0.00/65.6k [00:00<?, ?B/s]

# 2. Get token dataset

In [3]:
from sae_lens import ActivationsStore

activations_store = ActivationsStore.from_sae(
    model = model,
    sae = sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=16,
)

Resolving data files:   0%|          | 0/23032 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23032 [00:00<?, ?it/s]

In [4]:
from tqdm import tqdm 

def get_tokens(
    activations_store: ActivationsStore,
    n_batches_to_sample_from: int = 4096 * 6,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]

# 1000 prompts is plenty for a demo.
token_dataset = get_tokens(activations_store, 128, 128)

100%|██████████| 128/128 [00:07<00:00, 16.33it/s]


# 4. Generate Feature Dashboards

In [5]:
# SAE Dashboard currently expects a different SAE class but we're fine if we mock this method:

def mock_feature_acts_subset_for_now(sae):
    def sae_lens_get_feature_acts_subset(x, feature_idx):
        """
        Get a subset of the feature activations for a dataset. 
        """
        return sae.encode(x)[...,feature_idx]

    sae.get_feature_acts_subset = sae_lens_get_feature_acts_subset
    
    return sae 

sae = mock_feature_acts_subset_for_now(sae)

In [6]:
from pathlib import Path
test_feature_idx_gpt = list(range(256))

feature_vis_config_gpt = SaeVisConfig(
    hook_point=sae.cfg.hook_name,
    features=test_feature_idx_gpt,
    minibatch_size_features=16,
    minibatch_size_tokens=16, # this is really prompt with the number of tokens determined by the sequence length
    verbose=True,
    device="mps",
    cache_dir=Path("demo_activations_cache"), # this will enable us to skip running the model for subsequent features.
)

data = SaeVisRunner(feature_vis_config_gpt).run(
    encoder=sae, # type: ignore
    model=model,
    tokens=token_dataset,
)

Changing model dtype to torch.float32
Moving model to device:  mps


Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/256 [00:00<?, ?it/s]

In [7]:
from sae_vis.data_writing_fns import save_feature_centric_vis
filename = f"demo_feature_dashboards.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)

Saving feature-centric vis:   0%|          | 0/256 [00:00<?, ?it/s]

# Repeat 4. with different features, but using the cached activations

In [8]:

test_feature_idx_gpt = list(range(256,512))

feature_vis_config_gpt = SaeVisConfig(
    hook_point=sae.cfg.hook_name,
    features=test_feature_idx_gpt,
    minibatch_size_features=16,
    minibatch_size_tokens=16, # this is really prompt with the number of tokens determined by the sequence length
    verbose=True,
    device="mps",
    cache_dir=Path("demo_activations_cache"), # this will enable us to skip running the model for subsequent features.
)

data = SaeVisRunner(feature_vis_config_gpt).run(
    encoder=sae, # type: ignore
    model=model,
    tokens=token_dataset,
)

filename = f"demo_feature_dashboards_2.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)

Changing model dtype to torch.float32
Moving model to device:  mps


Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/256 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/256 [00:00<?, ?it/s]