## 1. Identify all SAE directions active before quotation marks are predicted.
## 2. Visualize them at various levels of activation.

### Setup

In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm import tqdm
from transformer_lens import HookedTransformer
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt
import pickle
from ipywidgets import interact, IntSlider, SelectionSlider
import plotly.graph_objects as go

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import (
    AutoEncoderConfig, 
    get_all_activating_test_prompts, 
    eval_direction_tokens_global, 
    get_encode_activations_hook, 
    get_activations, 
    get_acts, 
    load_encoder, 
    eval_ablation_token_rank, 
    get_direction_ablation_hook, 
    get_top_activating_examples_for_direction, 
    evaluate_direction_ablation_single_prompt,
    eval_encoder_reconstruction_single_position,
    get_top_direction_ablation_df,
    get_mean_component_wise_mlp,
    get_custom_forward_hook
)
from utils.plotting_utils import line, multiple_line
%reload_ext autoreload
%autoreload 2

In [2]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

run_names = ["54_serene_plasma", "189_giddy_water"] 
encoders = []
for run_name in run_names:
    encoder, cfg = load_encoder(run_name, model_name, model)
    cfg.run_name = run_name
    print(cfg.run_name, cfg.layer, cfg.l1_coeff)
    encoders.append((encoder, cfg))
    
prompts = load_tinystories_validation_prompts()

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer
{'cfg_file': None, 'data_path': '/workspace/data/tinystories', 'save_path': '/workspace', 'use_wandb': True, 'num_eval_tokens': 800000, 'num_training_tokens': 500000000.0, 'batch_size': 4096, 'buffer_mult': 128, 'seq_len': 128, 'model': 'tiny-stories-2L-33M', 'layer': 0, 'act': 'mlp.hook_post', 'expansion_factor': 4, 'seed': 47, 'lr': 0.0001, 'l1_coeff': 0.0003, 'l1_target': None, 'wd': 0.01, 'beta1': 0.9, 'beta2': 0.99, 'num_eval_prompts': 200, 'save_checkpoint_models': False, 'reg': 'l1', 'finetune_encoder': None, 'dead_direction_frequency': 0.0005, 'model_batch_size': 32, 'buffer_size': 524288, 'buffer_batches': 4096, 'num_eval_batches': 195, 'd_in': 4096, 'wandb_name': 'serene-plasma-54', 'save_name': '54_serene_plasma'}
54_serene_plasma 0 0.0003
{'cfg_file': None, 'data_path': '/workspace/data/tinystories', 'save_path': '/workspace', 'use_wandb': True, 'num_eval_tokens': 800000, 'num_training_tokens': 500000000.0, 

(INFO) 03:55:13: Loaded 21990 TinyStories validation prompts


189_giddy_water 1 [0.0001, 0.00015]


### SAE direction activation examples

In [3]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            print(f"Prompt: {prompt_index}")
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act, pretty_print=True)


def get_quotation_test_prompts(model, prompts):
    # Filter test prompts following 'said, " [...] ."' pattern
    # '."' '?"' and '!"' are single tokens
    test_prompts = []
    for prompt in tqdm(prompts):
        if "said, \"" in prompt:
            start_index = prompt.index("said, \"") + 7
            end_index = prompt.find(".\"", start_index)
            # Exclude long prompts - model performance degrades
            if (end_index != -1) and (end_index < 1600):
                subprompt = prompt[:end_index + 2]
                tokens = model.to_tokens(subprompt)
                last_token = model.to_single_str_token(tokens[0, -1].item())
                if (subprompt[-2:] == ".\"") and (last_token == ".\""):
                    test_prompts.append(subprompt)
    print(len(test_prompts), "test prompts")
    return test_prompts

In [4]:
from typing import Callable, Any
def get_cached_or_build(path: str, build: Callable[[str], Any]):
    if os.path.isfile(path):
        return pd.read_csv(path)
    df = build()
    df.to_csv(path)
    return df

In [15]:
def find_and_visualize_quotation_directions():
    max_activation_data = {}
    for encoder, cfg in encoders:
        run_name = cfg.run_name
        max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name, prompts, model)
        max_activation_data[run_name] = {
            "max_activations": max_activations.cpu(),
            "max_activation_token_indices": max_activation_token_indices.cpu()
        }

    test_prompts = get_quotation_test_prompts(model, prompts)

    dfs = []
    for encoder, cfg in encoders:
        def build():
            activating_test_prompts_l1 = get_all_activating_test_prompts(test_prompts, encoder, model, cfg, active_threshold=0.1)
            df, loss_increases_l1 = get_top_direction_ablation_df(activating_test_prompts_l1, test_prompts, model, encoder, cfg, max_activation_data[cfg.run_name]['max_activations'])
            df = df.sort_values("Loss increase", ascending=False)
            return df
        
        df = get_cached_or_build(f"/workspace/data/{cfg.run_name}-direction-ablation-df.csv", build)
        top_directions = df["Direction"].tolist()
        for direction in top_directions[:5]:
            print(f"Direction {direction} max activating examples")
            print_top_examples(prompts, max_activation_data[cfg.run_name]['max_activations'], direction, encoder, cfg, 5)
        dfs.append(df)
    
    return max_activation_data, test_prompts, dfs


max_activation_data, test_prompts, dfs = find_and_visualize_quotation_directions()


mae_interpretations = {
    1154: "I can't see a strong relation to quotation marks. \
        There are dry-prefix words, outfit-related words, and toy-related words.",
    4776: ""
}