# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [None]:
import torch
from transformer_lens import HookedTransformer
from sae_lens import ActivationsStore, SAE
from importlib import reload
import sae_dashboard

reload(sae_dashboard)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

model = HookedTransformer.from_pretrained(
    "mistral-7b", device=device, n_devices=4, dtype="bfloat16"
)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(
    release="mistral-7b-res-wg",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="blocks.8.hook_resid_pre",  # won't always be a hook point
    device="cuda:3",
)
# fold w_dec norm so feature activations are accurate
sae.fold_W_dec_norm()


activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=8,
    device="cpu",
    dataset=sae.cfg.metadata.dataset_path,
)

In [None]:
from tqdm import tqdm
from sae_dashboard.utils_fns import get_tokens

def get_tokens_mistral(
    activations_store: ActivationsStore,
    n_batches_to_sample_from: int = 4096 * 6,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens = get_tokens(activations_store, n_batches_to_sample_from)
    return all_tokens[:n_prompts_to_select]

# 1000 prompts is plenty for a demo.
# token_dataset = get_tokens_mistral(activations_store)

In [None]:
# torch.save(token_dataset, "token_dataset.pt")
token_dataset = torch.load("token_dataset.pt")

In [None]:
import os

os.rmdir("demo_activations_cache")

In [None]:
from pathlib import Path

test_feature_idx_gpt = list(range(256))

feature_vis_config_gpt = sae_dashboard.SaeVisConfig(
    hook_point=sae.cfg.metadata.hook_name,
    features=test_feature_idx_gpt,
    minibatch_size_features=16,
    minibatch_size_tokens=32,  # this is really prompt with the number of tokens determined by the sequence length
    verbose=True,
    device=device,  # Use the same device as the model
    cache_dir=Path(
        "demo_activations_cache"
    ),  # this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
)

runner = sae_dashboard.SaeVisRunner(feature_vis_config_gpt)

data = runner.run(
    encoder=sae,  # type: ignore
    model=model,
    tokens=token_dataset[:4096],
)

In [None]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis

filename = f"demo_feature_dashboards.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)

# Quick Profiling experiment

In [None]:
activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=32,
    n_batches_in_buffer=1,
    device="cpu",
    dataset=sae.cfg.metadata.dataset_path,
)

In [None]:
sae.cfg.d_in

In [None]:
from sae_lens.util import extract_layer_from_tlens_hook_name
import gc
import torch
from safetensors.torch import save_file
from torch.profiler import profile, record_function, ProfilerActivity

gc.collect()
torch.cuda.empty_cache()


@torch.no_grad()
def my_function():
    # Your PyTorch code here
    for _ in range(5):
        tokens = token_dataset[:32]
        _, cache = model.run_with_cache(
            tokens, stop_at_layer=extract_layer_from_tlens_hook_name(sae.cfg.metadata.hook_name) + 1, names_filter=sae.cfg.hook_name
        )
        sae_in = cache[sae.cfg.metadata.hook_name]
        # tensors = {"activations": sae_in}
        # save_file(tensors, "test.safetensors")
        # del tensors


with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True
) as prof:
    with record_function("my_function"):
        my_function()

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))