In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);


if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda:6" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1
        
        
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset,load_from_disk
from transformer_lens import HookedTransformer
from typing import Any, Generator, Iterator, Literal, cast
from sae_lens import SAE
from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)

from transformer_lens.HookedLlava import HookedLlava
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path="/home/saev/changye/model/llava"
processor = LlavaNextProcessor.from_pretrained(model_path)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
        model_path, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True,
)

vision_tower = vision_model.vision_tower.to("cuda:6")
multi_modal_projector = vision_model.multi_modal_projector.to("cuda:6")
# 加载 HookedTransformer 语言模型
hook_language_model = HookedLlava.from_pretrained(
        MODEL_NAME,
        hf_model=vision_model.language_model,
        device="cuda:6", 
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
        tokenizer=None,
        dtype=torch.float32,
        vision_tower=vision_tower,
        multi_modal_projector=multi_modal_projector,
        n_devices=2,
    )
# del vision_model
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
#     sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
#     device = device
# )

sae = SAE.load_from_pretrained(
    path = "/home/saev/changye/checkpoints-V/kxpk98cr/final_122880000",
    device ="cuda:7"
)

In [None]:
from transformer_lens.utils import tokenize_and_concatenate
import transformer_lens.utils as utils
dataset_path="/home/saev/changye/data/obelics100k-tokenized-llava4096_w/batch_1"
try:
    dataset = (
        load_dataset(
                dataset_path,
                split="train",
                streaming=False,
                trust_remote_code=False,  # type: ignore
        )
        if isinstance(dataset_path, str)
            else dataset_path
        )
except Exception as e:
    dataset = (
        load_from_disk(
                dataset_path,
            )
            if isinstance(dataset_path, str)
            else dataset_path
        )
if isinstance(dataset, (Dataset, DatasetDict)):
        dataset = cast(Dataset | DatasetDict, dataset)
# dataset_sample = next(iter(dataset))
columns_to_read=["input_ids","pixel_values","attention_mask","image_sizes"]
ds_context_size = len(dataset["input_ids"])
if hasattr(dataset, "set_format"):
    dataset.set_format(type="torch", columns=columns_to_read)
    print("dataset set format")
    
batch_size = 2
batch = dataset[:batch_size]
batch_tokens = {
    "input_ids": batch["input_ids"],
    "pixel_values": batch["pixel_values"],
    "attention_mask": batch["attention_mask"],
    "image_sizes": batch["image_sizes"],
}

sae.cfg

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    _, cache = hook_language_model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=lambda name: name == sae.cfg.hook_name)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()