In [17]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json
from collections import Counter
from datasets import load_dataset


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction
import utils.haystack_utils as haystack_utils

%reload_ext autoreload
%autoreload 2

In [4]:
data = load_dataset("MechInterpResearch/tinystories_tokenized", split="train")
data.save_to_disk(f"data/tinystories/data.hf")
del data

In [5]:
val_ds = load_dataset("roneneldan/TinyStories", split="validation")
prompts = [x["text"] for x in val_ds]
del val_ds

Repo card metadata block was not found. Setting CardData to empty.


In [10]:
model_name = "tiny-stories-1M"

model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device,
    )

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-1M into HookedTransformer


In [11]:
save_name = "4_scary_possession"  # "25_gallant_monkey"
path = model_name

with open(f"{model_name}/{save_name}.json", "r") as f:
    cfg = json.load(f)

cfg = AutoEncoderConfig(
    cfg["layer"], cfg["act"], cfg["expansion_factor"], cfg["l1_coeff"]
)

if cfg.act_name == "hook_mlp_out":
    d_in = model.cfg.d_model  # d_mlp
else:
    d_in = model.cfg.d_mlp
d_hidden = d_in * cfg.expansion_factor

encoder = AutoEncoder(d_hidden, cfg.l1_coeff, d_in)
encoder.load_state_dict(torch.load(os.path.join(path, save_name + ".pt")))
encoder.to(device)

cfg

AutoEncoderConfig(layer=4, act_name='mlp.hook_post', expansion_factor=8, l1_coeff=2.5e-05)

In [12]:
evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:200], model)

100%|██████████| 200/200 [00:06<00:00, 32.32it/s]


(1.8857718288898468, 1.9599793273210526)

In [20]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts

In [16]:
activations = []
for prompt in tqdm(prompts):
    acts = get_acts(prompt, model, encoder, cfg)
    max_prompt_activation = acts.max(0)[0]
    activations.append(max_prompt_activation)

max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc

total_activations = max_activation_per_prompt.sum(0)
print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")

  0%|          | 0/21990 [00:00<?, ?it/s]

torch.Size([21990, 2048])


In [38]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

print_top_examples(prompts, max_activation_per_prompt, 202)