# Generating Dashboards

We use Callum McDougall's `sae_viz` library for generating feature dashboards. 

We've written a runner that will wrap Callum's code and log artefacts to wandb / pick-up where it left off if needed.

## Set Up

In [None]:
import torch
import webbrowser
import os
import sys
from huggingface_hub import hf_hub_download
from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig
from sae_vis.utils_fns import get_device
from sae_analysis.dashboard_runner import DashboardRunner

device = get_device()
print(device)
torch.set_grad_enabled(False)

## Use Runner

In [None]:
layer = 8
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

obj = torch.load(path, map_location=device)
state_dict = obj["state_dict"]
assert set(state_dict.keys()) == {"W_enc", "b_enc", "W_dec", "b_dec"}


# Since Callum's library has it's own autoencoder class, it's important to check
# that we don't diverge from it in the future. For now, it should be fine
# with the SAE above.
cfg = AutoEncoderConfig(
    d_in=obj["cfg"].d_in,
    dict_mult=obj["cfg"].expansion_factor,
    device=device,
)
gpt2_sae = AutoEncoder(cfg)
gpt2_sae.load_state_dict(state_dict)


runner = DashboardRunner(
    sae_path=path,  # this will handle a local path.
    dashboard_parent_folder="../feature_dashboards",
    init_session=True,
    n_batches_to_sample_from=2
    ** 12,  # sampling more batches helps us get a  more diverse text sample.
    n_prompts_to_select=4096 * 6,  # more prompts are important for sparser features.
    n_features_at_a_time=128,
    max_batch_size=256,
    buffer_tokens=8,
    use_wandb=False,
    continue_existing_dashboard=True,
)
runner.run()

## Visualize Dashboards

In [None]:
feature_files = os.listdir(runner.dashboard_folder)
# pick 3 random feature files and open them in the web browser
for i in range(3):
    feature_file = feature_files[i]
    url = f"file://{os.path.abspath(runner.dashboard_folder)}/{feature_file}"
    webbrowser.open(url)
    print(url)