# Demo Notebook

Steps:
1. Download SAE with SAE Lens.
2. Create a dataset consistent with that SAE. 
3. Fold the SAE decoder norm weights so that feature activations are "correct".
4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.
5. Run the SAE generator for the features you want.

# Set Up

In [None]:
# Download Gemma-2-9b weights

import wandb

run = wandb.init()
artifact = run.use_artifact(
    "jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7",
    type="model",
)
artifact_dir = artifact.download()

In [None]:
import wandb

run = wandb.init()
artifact = run.use_artifact(
    "jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688_log_feature_sparsity:v7",
    type="log_feature_sparsity",
)
artifact_dir = artifact.download()

In [None]:
import matplotlib.pyplot as plt
from safetensors.torch import load_file

# Assume we have a PyTorch tensor
feature_sparsity = load_file(
    "artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7/sparsity.safetensors"
)["sparsity"]

# Convert the tensor to a numpy array
data = feature_sparsity.numpy()

# Create the histogram
plt.hist(data, bins=30, edgecolor="black")

# Add labels and title
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.title("Histogram of PyTorch Tensor")

# Show the plot
plt.show()

In [None]:
from transformer_lens import HookedTransformer
from sae_lens import ActivationsStore, SAE
from importlib import reload
import sae_dashboard

torch.set_grad_enabled(False)

reload(sae_dashboard)

In [None]:
MODEL = "gemma-2-9b"

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

model = HookedTransformer.from_pretrained(MODEL, device=device, dtype="bfloat16")

In [None]:
sae = SAE.load_from_disk(
    "artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7"
)
sae.fold_W_dec_norm()

In [None]:
# _, cache = model.run_with_cache("Wasssssup", names_filter = sae.cfg.metadata.hook_name)
# sae_in = cache[sae.cfg.metadata.hook_name]
# print(sae_in.shape)
sae_in = torch.rand((1, 4, 3584)).to(sae.device)
sae_out = sae(sae_in)

In [None]:
# # the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# # Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# # We also return the feature sparsities which are stored in HF for convenience.
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "mistral-7b-res-wg", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
#     device = "cuda:3",
# )
# # fold w_dec norm so feature activations are accurate
#
activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=8,
    device="cpu",
    dataset=sae.cfg.metadata.dataset_path,
)

In [None]:
sae.encode_fn

In [None]:
from sae_lens import run_evals
from sae_lens.evals import EvalConfig
from sae_lens.training.activation_scaler import ActivationScaler

eval_config = EvalConfig(
    batch_size_prompts=8,
    n_eval_reconstruction_batches=3,
    compute_ce_loss=True,
)

scalar_metrics, feature_metrics = run_evals(
    sae=sae,
    activation_store=activations_store,
    model=model,
    activation_scaler=ActivationScaler(),
    eval_config=eval_config,
)

# CE Loss score should be high for residual stream SAEs
print(scalar_metrics["model_performance_preservation"]["ce_loss_score"])

# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly
print(scalar_metrics["model_performance_preservation"]["ce_loss_without_sae"])

# ce loss with SAE shouldn't be massively higher
print(scalar_metrics["model_performance_preservation"]["ce_loss_with_sae"])

In [None]:

from sae_dashboard.utils_fns import get_tokens

# 1000 prompts is plenty for a demo.
token_dataset = get_tokens(activations_store, 4096)

In [None]:
# torch.save(token_dataset, "to")

In [None]:
# torch.save(token_dataset, "token_dataset.pt")
token_dataset = torch.load("token_dataset.pt")

In [None]:
import os

os.rmdir("demo_activations_cache")

In [None]:
import torch


def select_indices_in_range(tensor, min_val, max_val, num_samples=None):
    """
    Select indices of a tensor where values fall within a specified range.

    Args:
    tensor (torch.Tensor): Input tensor with values between -10 and 0.
    min_val (float): Minimum value of the range (inclusive).
    max_val (float): Maximum value of the range (inclusive).
    num_samples (int, optional): Number of indices to randomly select. If None, return all indices.

    Returns:
    torch.Tensor: Tensor of selected indices.
    """
    # Ensure the input range is valid
    if not (-10 <= min_val <= max_val <= 0):
        raise ValueError(
            "Range must be within -10 to 0, and min_val must be <= max_val"
        )

    # Find indices where values are within the specified range
    mask = (tensor >= min_val) & (tensor <= max_val)
    indices = mask.nonzero().squeeze()

    # If num_samples is specified and less than the total number of valid indices,
    # randomly select that many indices
    if num_samples is not None and num_samples < indices.numel():
        perm = torch.randperm(indices.numel())
        indices = indices[perm[:num_samples]]

    return indices


n_features = 4096
feature_idxs = select_indices_in_range(feature_sparsity, -4, -2, 4096)
feature_sparsity[feature_idxs.tolist()]

In [None]:
from importlib import reload
import sys


def reload_user_modules(module_names):
    """Reload specified user modules."""
    for name in module_names:
        if name in sys.modules:
            reload(sys.modules[name])


# List of your module names
user_modules = [
    "sae_dashboard",
    "sae_dashboard.sae_vis_runner",
    "sae_dashboard.data_parsing_fns",
    "sae_dashboard.feature_data_generator",
]

# Reload modules
reload_user_modules(user_modules)

# Re-import after reload
from sae_dashboard.feature_data_generator import FeatureDataGenerator

In [None]:
from pathlib import Path

test_feature_idx_gpt = feature_idxs.tolist()

feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(
    hook_point=sae.cfg.metadata.hook_name,
    features=test_feature_idx_gpt,
    minibatch_size_features=16,
    minibatch_size_tokens=4096,  # this is really prompt with the number of tokens determined by the sequence length
    verbose=True,
    device="cuda",
    cache_dir=Path(
        "demo_activations_cache"
    ),  # this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
)

runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)

data = runner.run(
    encoder=sae,  # type: ignore
    model=model,
    tokens=token_dataset[:1024],
)

In [None]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis

filename = f"demo_feature_dashboards.html"
save_feature_centric_vis(sae_vis_data=data, filename=filename)

# Quick Profiling experiment

In [None]:
def mock_feature_acts_subset_for_now(sae: SAE):

    @torch.no_grad()
    def sae_lens_get_feature_acts_subset(x: torch.Tensor, feature_idx):  # type: ignore
        """
        Get a subset of the feature activations for a dataset.
        """
        original_device = x.device
        feature_activations = sae.encode_fn(x.to(device=sae.device, dtype=sae.dtype))
        return feature_activations[..., feature_idx].to(original_device)

    sae.get_feature_acts_subset = sae_lens_get_feature_acts_subset  # type: ignore

    return sae


sae = mock_feature_acts_subset_for_now(sae)
feature_idxs = list(range(128))
sae_in = torch.rand((1, 4, 3584)).to(sae.device)
sae.get_feature_acts_subset(sae_in, feature_idxs)

In [None]:
for k, v in sae.named_parameters():
    print(k, v.shape)

In [None]:
from torch import nn
from typing import List


class FeatureMaskingContext:
    def __init__(self, sae: SAE, feature_idxs: List):
        self.sae = sae
        self.feature_idxs = feature_idxs
        self.original_weight = {}

    def __enter__(self):

        ## W_dec
        self.original_weight["W_dec"] = getattr(self.sae, "W_dec").data.clone()
        # mask the weight
        masked_weight = sae.W_dec[self.feature_idxs]
        # set the weight
        setattr(self.sae, "W_dec", nn.Parameter(masked_weight))

        ## W_enc
        # clone the weight.
        self.original_weight["W_enc"] = getattr(self.sae, "W_enc").data.clone()
        # mask the weight
        masked_weight = sae.W_enc[:, self.feature_idxs]
        # set the weight
        setattr(self.sae, "W_enc", nn.Parameter(masked_weight))

        if self.sae.cfg.architecture() == "standard":

            ## b_enc
            self.original_weight["b_enc"] = getattr(self.sae, "b_enc").data.clone()
            # mask the weight
            masked_weight = sae.b_enc[self.feature_idxs]
            # set the weight
            setattr(self.sae, "b_enc", nn.Parameter(masked_weight))

        elif self.sae.cfg.architecture() == "gated":

            ## b_gate
            self.original_weight["b_gate"] = getattr(self.sae, "b_gate").data.clone()
            # mask the weight
            masked_weight = sae.b_gate[self.feature_idxs]
            # set the weight
            setattr(self.sae, "b_gate", nn.Parameter(masked_weight))

            ## r_mag
            self.original_weight["r_mag"] = getattr(self.sae, "r_mag").data.clone()
            # mask the weight
            masked_weight = sae.r_mag[self.feature_idxs]
            # set the weight
            setattr(self.sae, "r_mag", nn.Parameter(masked_weight))

            ## b_mag
            self.original_weight["b_mag"] = getattr(self.sae, "b_mag").data.clone()
            # mask the weight
            masked_weight = sae.b_mag[self.feature_idxs]
            # set the weight
            setattr(self.sae, "b_mag", nn.Parameter(masked_weight))
        else:
            raise (ValueError("Invalid architecture"))

        return self

    def __exit__(self, exc_type, exc_value, traceback):

        # set everything back to normal
        for key, value in self.original_weight.items():
            setattr(self.sae, key, nn.Parameter(value))

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()
torch.set_grad_enabled(False)


def my_function(sae_in):
    # Your PyTorch code here
    feature_idxs = list(range(2048))
    with FeatureMaskingContext(sae, feature_idxs):
        features = sae(sae_in)
        print(features.mean())


tokens = token_dataset[:64]
_, cache = model.run_with_cache(
    tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=sae.cfg.metadata.hook_name
)
sae_in = cache[sae.cfg.metadata.hook_name]

In [None]:
tokens.shape

In [None]:
sae.W_dec.shape

In [None]:
%load_ext memray

In [None]:
%%memray_flamegraph --trace-python-allocators --leaks
my_function(sae_in)