# Loading and Analysing Pre-Trained Sparse Autoencoders

## Imports & Installs

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

## Set Up

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda:6" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda:6


In [3]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

# Loading a pretrained Sparse Autoencoder

Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface.

In [4]:
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset,load_from_disk
from transformer_lens import HookedTransformer
from typing import Any, Generator, Iterator, Literal, cast
from sae_lens import SAE
from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)

from transformer_lens.HookedLlava import HookedLlava
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path="/mnt/data/changye/model/llava"

processor = LlavaNextProcessor.from_pretrained(model_path)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
        model_path, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True,
)

vision_tower = vision_model.vision_tower.to("cuda:5")
multi_modal_projector = vision_model.multi_modal_projector.to("cuda:5")
# 加载 HookedTransformer 语言模型
hook_language_model = HookedLlava.from_pretrained(
        MODEL_NAME,
        hf_model=vision_model.language_model,
        device="cuda:5", 
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
        tokenizer=None,
        dtype=torch.float32,
        vision_tower=vision_tower,
        multi_modal_projector=multi_modal_projector,
        n_devices=2,
    )
# del vision_model
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
#     device = device
# )

sae = SAE.load_from_pretrained(
    path = "/mnt/data/changye/checkpoints/checkpoints-V/kxpk98cr/final_122880000",
    device ="cuda:7"
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded pretrained model llava-hf/llava-v1.6-mistral-7b-hf into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:
from transformer_lens.utils import tokenize_and_concatenate
import transformer_lens.utils as utils
dataset_path="/mnt/data/changye/data/obelics3k-tokenized-llava4096"
try:
    dataset = (
        load_dataset(
                dataset_path,
                split="train",
                streaming=False,
                trust_remote_code=False,  # type: ignore
        )
        if isinstance(dataset_path, str)
            else dataset_path
        )
except Exception as e:
    dataset = (
        load_from_disk(
                dataset_path,
            )
            if isinstance(dataset_path, str)
            else dataset_path
        )
if isinstance(dataset, (Dataset, DatasetDict)):
        dataset = cast(Dataset | DatasetDict, dataset)
# dataset_sample = next(iter(dataset))
columns_to_read=["input_ids","pixel_values","attention_mask","image_sizes"]
ds_context_size = len(dataset["input_ids"])
if hasattr(dataset, "set_format"):
    dataset.set_format(type="torch", columns=columns_to_read)
    print("dataset set format")

Loading dataset from disk:   0%|          | 0/33 [00:00<?, ?it/s]

dataset set format


In [6]:
batch_size = 2
batch = dataset[:batch_size]
batch_tokens = {
    "input_ids": batch["input_ids"],
    "pixel_values": batch["pixel_values"],
    "attention_mask": batch["attention_mask"],
    "image_sizes": batch["image_sizes"],
}

## Basic Analysis

Let's check some basic stats on this SAE in order to see how some basic functionality in the codebase works.

We'll calculate:
- L0 (the number of features that fire per activation)
- The cross entropy loss when the output of the SAE is used in place of the activations

### L0 Test and Reconstruction Test

In [7]:
sae.cfg

SAEConfig(architecture='standard', d_in=4096, d_sae=65536, activation_fn_str='relu', apply_b_dec_to_input=False, finetuning_scaling_factor=False, context_size=4096, model_name='llava-hf/llava-v1.6-mistral-7b-hf', hook_name='blocks.16.hook_resid_post', hook_layer=16, hook_head_index=None, prepend_bos=True, dataset_path='/home/saev/changye/data/obelics100k-tokenized-llava4096_4image', dataset_trust_remote_code=True, normalize_activations='expected_average_only_in', dtype='float32', device='cuda:7', sae_lens_training_version='3.20.0', activation_fn_kwargs={}, neuronpedia_id=None, model_from_pretrained_kwargs={'n_devices': 3})

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    _, cache = hook_language_model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=lambda name: name == sae.cfg.hook_name)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()
    

Note that while the mean L0 is 64, it varies with the specific activation.

To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens.

In [None]:
from transformer_lens import utils
from functools import partial
torch.cuda.empty_cache()
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", hook_language_model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    hook_language_model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    hook_language_model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [None]:
example_prompt = "The fruit in the image is "
example_answer = "Apple"
torch.cuda.empty_cache()
# utils.test_prompt(example_prompt, example_answer, hook_language_model, prepend_bos=True)
from PIL import Image
image_path="/home/saev/changye/TransformerLens-V/Apple.jpg"
image = Image.open(image_path)
image=image.resize((336, 336))
conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": example_prompt},
                {"type": "image"},
            ],
        },
    ]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    
    # 处理图像和文本输入
inputs = processor(images=image, text=prompt, return_tensors="pt")
# print(inputs)
inputs=inputs.to("cuda:6")
logits, cache = hook_language_model.run_with_cache(input=inputs,model_inputs=inputs,vision=True, prepend_bos=True)
# inputs = hook_language_model.to_tokens(inputs)
sae_out = sae(cache[sae.cfg.hook_name])



def reconstr_hook(activations, hook, sae_out):
    return sae_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


hook_name = sae.cfg.hook_name

print("Orig", hook_language_model(inputs,model_inputs=inputs,vision=True, return_type="loss").item())
print(
    "reconstr",
    hook_language_model.run_with_hooks(
        inputs,
        fwd_hooks=[
            (
                hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    hook_language_model.run_with_hooks(
        inputs,
        return_type="loss",
        fwd_hooks=[(hook_name, zero_abl_hook)],
    ).item(),
)


# with hook_language_model.hooks(
#     fwd_hooks=[
#         (
#             hook_name,
#             partial(reconstr_hook, sae_out=sae_out),
#         )
#     ]
# ):
#     utils.test_prompt(example_prompt, example_answer, hook_language_model, prepend_bos=True)

In [6]:
example_prompt="The sky is blue today."
tokens=processor.tokenizer(example_prompt)
print(tokens)

{'input_ids': [1, 415, 7212, 349, 5045, 3154, 28723], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}


In [8]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData
torch.cuda.empty_cache()
sae.eval()
test_feature_idx_gpt = list(range(10)) + [14057]
hook_name = sae.cfg.hook_name
feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_name,
    features=test_feature_idx_gpt,
    # batch_size=2048,
    minibatch_size_tokens=128,
    verbose=True,
)
torch.cuda.empty_cache()
sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=hook_language_model, # type: ignore
    tokens=tokens['input_ids'],  # type: ignore
    cfg=feature_vis_config_gpt,
)

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/11 [00:00<?, ?it/s]

AttributeError: 'list' object has no attribute 'split'

In [None]:
for feature in test_feature_idx_gpt:
    filename = f"{feature}_feature_vis_demo_gpt.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display_vis_inline(filename)

Now, since generating feature dashboards can be done once per sparse autoencoder, for pre-trained SAEs in the public domain, everyone can use the same dashboards. Neuronpedia hosts dashboards which we can load via the intergration.

In [None]:
# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# # this function should open
# neuronpedia_quick_list = get_neuronpedia_quick_list(
#     sae=sae,
#     features=test_feature_idx_gpt,
#     # layer=sae.cfg.hook_layer,
#     # model="llava-hf/llava-v1.6-mistral-7b-hf",
#     # dataset="res-jb",

#     name="A quick list we made",
# )

# if COLAB:
#   # If you're on colab, click the link below
#   print(neuronpedia_quick_list)