In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-vis==0.2.14
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import torch
from datasets import load_dataset
import webbrowser
import os
from transformer_lens import utils, HookedTransformer
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
import time

# Library imports
from sae_vis.utils_fns import get_device
from sae_vis.model_fns import AutoEncoder
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig
# from sae_lens.training.sparse_autoencoder import SparseAutoencoder

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

device = get_device()
torch.set_grad_enabled(False);

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

In [3]:
encoder = AutoEncoder.load_from_hf(version="run1").to(device)
encoder_B = AutoEncoder.load_from_hf(version="run2").to(device)

In [4]:
model = HookedTransformer.from_pretrained("gelu-1l")
model.to(device)

Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  cpu


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [5]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

torch.Size([215402, 128])


In [14]:
all_tokens[:2049].shape

torch.Size([2049, 128])

In [30]:
feature_centric_layout

SaeVisLayoutConfig(columns={0: <sae_vis.data_config_classes.Column object at 0x7ca617dfbf10>, 1: <sae_vis.data_config_classes.Column object at 0x7ca60e1df350>, 2: <sae_vis.data_config_classes.Column object at 0x7ca60e1df810>}, height=750, seq_cfg=SequencesConfig(buffer=(5, 5), compute_buffer=True, n_quantiles=10, top_acts_group_size=20, quantile_group_size=5, top_logits_hoverdata=5, stack_mode='stack-none', hover_below=True), act_hist_cfg=ActsHistogramConfig(n_bins=50), logits_hist_cfg=LogitsHistogramConfig(n_bins=50), logits_table_cfg=LogitsTableConfig(n_rows=10), feature_tables_cfg=FeatureTablesConfig(n_rows=3, neuron_alignment_table=True, correlated_neurons_table=True, correlated_features_table=True, correlated_b_features_table=False), prompt_cfg=None)

In [7]:
from sae_vis.data_config_classes import SaeVisLayoutConfig, SequencesConfig
feature_centric_layout = SaeVisLayoutConfig.default_feature_centric_layout()
feature_centric_layout.seq_cfg = SequencesConfig(buffer=None)

In [8]:

# Specify the hook point you're using, and the features you're analyzing
sae_vis_config = SaeVisConfig(
    hook_point = utils.get_act_name("post", 0),
    features = range(64),
    verbose = True,
    feature_centric_layout = feature_centric_layout
)

# Gather the feature data
sae_vis_data = SaeVisData.create(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens[:2048],
    cfg = sae_vis_config,
)

# Save as HTML file & display vis
filename = "_feature_vis_demo.html"
sae_vis_data.save_feature_centric_vis(filename, feature_idx=8)

display_vis_inline(filename)

Forward passes to cache data for vis:   0%|          | 0/32 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/64 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/64 [00:00<?, ?it/s]

In [26]:
sae_vis_config.feature_centric_layout.seq_cfg

SequencesConfig(buffer=(5, 5), compute_buffer=True, n_quantiles=10, top_acts_group_size=20, quantile_group_size=5, top_logits_hoverdata=5, stack_mode='stack-none', hover_below=True)

In [38]:
sae_vis_data.feature_data_dict[0].sequence_data.seq_group_data[0].seq_data[1].token_ids.__len__()

127

In [26]:
sae_vis_data = SaeVisData.create(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens[:2048],
    cfg = sae_vis_config,
)


prompt = "'first_name': ('django.db.models.fields"

seq_pos = model.tokenizer.tokenize(prompt).index("Ġ('") # type: ignore
metric = 'act-quantiles'

filename = "_prompt_vis_demo.html"




sae_vis_data.save_prompt_centric_vis(
    prompt = prompt,
    filename = filename,
    seq_pos = seq_pos, # optional argument, to determine the default option when the page loads
    metric = metric, # optional argument, to determine the default option when the page loads
)

Forward passes to cache data for vis:   0%|          | 0/32 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/64 [00:00<?, ?it/s]

ValueError: invalid literal for int() with base 10: '"\'" (0)'