## Setup

In [160]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt
import pickle
from ipywidgets import interact, IntSlider, SelectionSlider

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder, generate_with_encoder
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line
from utils.circuit_discovery_utils import *

%reload_ext autoreload
%autoreload 2

In [2]:
model_name = "tiny-stories-33M-2L"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

(…)TinyStories-33M/resolve/main/config.json:   0%|          | 0.00/968 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/291M [00:00<?, ?B/s]

(…)s-33M/resolve/main/tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

(…)/TinyStories-33M/resolve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

(…)/TinyStories-33M/resolve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

(…)yStories-33M/resolve/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

(…)33M/resolve/main/special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-33M into HookedTransformer


In [38]:
# L0 '2_silvery_smoke',
# l1 '2_soft_monkey',  
# L3 '2_driven_planet'
run_names = ['1_skilled_universe', '2_driven_planet']
encoders = []
for run_name in run_names:
    encoder, cfg = load_encoder(run_name, model_name, model)
    cfg.run_name = run_name
    print(cfg.run_name, cfg.layer, cfg.l1_coeff)
    encoders.append((encoder, cfg))

1_skilled_universe 2 0.0001
2_driven_planet 3 0.0001


In [8]:
prompts = load_tinystories_validation_prompts()

(INFO) 08:11:44: Loaded 21990 TinyStories validation prompts


In [16]:
def get_activations(encoder, cfg, encoder_name, save_path="/workspace"):
    path = f"{save_path}/data/{encoder_name}_activations.pkl"
    if os.path.exists(path):
        with open(path, "rb") as f:
            data = pickle.load(f)
            max_activations = data["max_activations"]
            max_activation_token_indices = data["max_activation_token_indices"]
    else:
        max_activations, max_activation_token_indices = get_max_activations(prompts, model, encoder, cfg)
        with open(path, "wb") as f:
            pickle.dump({"max_activations": max_activations, "max_activation_token_indices": max_activation_token_indices}, f)
    return max_activations, max_activation_token_indices

In [39]:
max_activation_data = {}
for encoder, cfg in encoders:
    run_name = cfg.run_name
    max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name)
    max_activation_data[run_name] = {
        "max_activations": max_activations.cpu(),
        "max_activation_token_indices": max_activation_token_indices.cpu()
    }

  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 24576 out of 24576


## Pairwise cosine circuit discovery

In [41]:
first_encoder, first_encoder_cfg = encoders[0]
second_encoder, second_encoder_cfg = encoders[1]

first_encoder_max_activations = max_activation_data[first_encoder_cfg.run_name]["max_activations"]
first_encoder_max_activation_token_indices = max_activation_data[first_encoder_cfg.run_name]["max_activation_token_indices"]
second_encoder_max_activations = max_activation_data[second_encoder_cfg.run_name]["max_activations"]
second_encoder_max_activation_token_indices = max_activation_data[second_encoder_cfg.run_name]["max_activation_token_indices"]

In [44]:
W_out = model.W_out[first_encoder_cfg.layer]
W_in = model.W_in[second_encoder_cfg.layer]

cosine_sims = torch.nn.functional.normalize(first_encoder.W_dec @ W_out, dim=-1) @ torch.nn.functional.normalize(W_in @ second_encoder.W_enc, dim=0)
cosine_sims = torch.tril(cosine_sims)

def i_to_row_col(i: int, n_cols: int = first_encoder.d_hidden):
    row = i // n_cols
    col = i % n_cols
    return row, col

all_sims = cosine_sims.flatten().cpu()
top_cosine_similarities, top_cosine_sim_indices = torch.topk(all_sims, 10)

In [45]:
data = []
for top_cosine_index in tqdm(top_cosine_sim_indices):
    first_encoder_dir, second_encoder_dir = i_to_row_col(top_cosine_index)
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, first_encoder_dir, first_encoder_max_activations, first_encoder_max_activation_token_indices, k=100)
    
    original_losses = []
    first_encoder_losses = []
    second_encoder_losses = []
    acts = []
    ablated_acts = []
    for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
        # Direction losses
        original_loss, first_encoder_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, first_encoder, model, first_encoder_dir, first_encoder_cfg, pos=pos)
        _, second_encoder_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, second_encoder, model, second_encoder_dir, second_encoder_cfg, pos=pos)
        original_losses.append(original_loss)
        first_encoder_losses.append(first_encoder_ablated_loss)
        second_encoder_losses.append(second_encoder_ablated_loss)

        # Second encoder direction activation with and without ablation
        act = get_acts(prompt, model, second_encoder, second_encoder_cfg)[pos, second_encoder_dir].item()
        encoder_hook_point = f"blocks.{first_encoder_cfg.layer}.{first_encoder_cfg.act_name}"
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(first_encoder, first_encoder_dir, pos))]):
            ablated_act = get_acts(prompt, model, second_encoder, second_encoder_cfg)[pos, second_encoder_dir].item()
        acts.append(act)
        ablated_acts.append(ablated_act)

    data.append([first_encoder_dir.item(), second_encoder_dir.item(), np.mean(original_losses), np.mean(first_encoder_losses), np.mean(second_encoder_losses), np.mean(acts), np.mean(ablated_acts)])
df = pd.DataFrame(data, columns=["Encoder 1 direction", "Encoder 2 direction", "Original loss", "Encoder 1 direction ablation loss", "Encoder 2 direction ablation loss", "Second encoder activation", "Second encoder activation after ablation"])
df["Cosine similarity"] = top_cosine_similarities.tolist()

  0%|          | 0/10 [00:00<?, ?it/s]

## Prompt co occurrence analysis

In [48]:
# Pick interesting looking prompt
# Save activations of all directions for that prompt
# Save last layer active directions for each earlier direction ablated individually
# Compute AND measure for all active directions in last layer based on previous layers

In [73]:
prompt = "This moral story teaches children that"
second_encoder_acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-1]
second_encoder_top_acts, second_encoder_top_dirs = torch.topk(second_encoder_acts, 10)
second_encoder_direction = second_encoder_top_dirs[0].item()
first_encoder_acts = get_acts(prompt, model, first_encoder, first_encoder_cfg)[-1]
active_first_encoder_directions = torch.argwhere(first_encoder_acts > 1).flatten().tolist()
original_second_encoder_act = second_encoder_top_acts[0].item()
#px.histogram(acts.cpu().numpy(), width=700)

In [None]:
data = []
hook_point = first_encoder_cfg.encoder_hook_point
for first_encoder_direction in active_first_encoder_directions:
    ablation_hook = get_direction_ablation_hook(first_encoder, first_encoder_direction, -1)
    with model.hooks([(hook_point, ablation_hook)]):
        ablated_acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-1, second_encoder_direction].item()
    data.append([first_encoder_direction, ablated_acts])
and_df = pd.DataFrame(data, columns=["First encoder direction", "Second encoder activation after ablation"])
and_df["Activation difference"] = and_df["Second encoder activation after ablation"] - original_second_encoder_act
and_df

In [75]:
data = []
hook_point = first_encoder_cfg.encoder_hook_point
for prompt_index, prompt in enumerate(prompts[:2]):
    second_encoder_acts_all_pos = get_acts(prompt, model, second_encoder, second_encoder_cfg)
    first_encoder_acts_all_pos = get_acts(prompt, model, first_encoder, first_encoder_cfg)
    num_tokens = second_encoder_acts_all_pos.shape[0]
    for position in range(10, num_tokens):
        first_encoder_acts = first_encoder_acts_all_pos[position]
        second_encoder_acts = second_encoder_acts_all_pos[position]

        second_encoder_top_acts, second_encoder_top_dirs = torch.topk(second_encoder_acts, 10)
        second_encoder_direction = second_encoder_top_dirs[0].item()
        active_first_encoder_directions = torch.argwhere(first_encoder_acts > 1).flatten().tolist()
        original_second_encoder_act = second_encoder_top_acts[0].item()

        for first_encoder_direction in active_first_encoder_directions:
            ablation_hook = get_direction_ablation_hook(first_encoder, first_encoder_direction, position)
            with model.hooks([(hook_point, ablation_hook)]):
                ablated_acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[position, second_encoder_direction].item()
            data.append([prompt_index, position, first_encoder_direction, second_encoder_direction, ablated_acts, original_second_encoder_act])
and_df = pd.DataFrame(data, columns=["Prompt", "Position", "First encoder direction", "Second encoder direction", "Second encoder activation after ablation", "Second encoder activation"])
and_df["Activation difference"] = and_df["Second encoder activation after ablation"] - and_df["Second encoder activation"]
and_df = and_df.sort_values("Activation difference", ascending=False)

In [76]:
and_df.head(10)

Unnamed: 0,Prompt,Position,First encoder direction,Second encoder direction,Second encoder activation after ablation,Second encoder activation,Activation difference
1671,1,274,23259,5990,7.092909,6.608549,0.48436
474,1,43,23474,18399,5.086508,4.807766,0.278742
629,1,81,17995,3915,3.132586,2.92153,0.211057
1786,1,293,9644,20364,6.042625,5.857118,0.185507
1737,1,286,1926,17551,7.392098,7.215293,0.176805
1687,1,277,1926,20909,3.525586,3.359591,0.165995
1742,1,287,1188,9700,2.273056,2.107394,0.165661
1694,1,278,9763,21514,2.460677,2.296569,0.164109
884,1,132,6288,10140,4.225554,4.063017,0.162536
1735,1,285,9674,5003,4.961776,4.802139,0.159637


In [66]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)


print_top_examples(prompts, second_encoder_max_activations, second_encoder_direction, second_encoder, second_encoder_cfg, n=1)

## Get top token occurrences per direction

In [171]:
def get_direction_token_df(max_activations, prompts, model, encoder, encoder_cfg, percentage_threshold=0.5, save_path="/workspace/data/top_token_occurrences"):
    os.makedirs(save_path, exist_ok=True)
    file_name = f"{save_path}/{encoder_cfg.run_name}_direction_token_occurrences.csv"
    if os.path.exists(file_name):
        direction_df = pd.read_csv(file_name)
    else:

        token_wise_activations = eval_direction_tokens_global(max_activations, prompts, model, encoder, encoder_cfg, percentage_threshold=0.5)
        total_occurrences = token_wise_activations.sum(1)
        max_occurrences = token_wise_activations.max(1)[0]
        max_occurring_token = token_wise_activations.argmax(1)
        str_tokens = model.to_str_tokens(torch.LongTensor(list(range(model.cfg.d_vocab))))

        direction_data = []
        for direction in tqdm(range(encoder.d_hidden)):
            total_occurrence = total_occurrences[direction].item()
            top_occurrence = max_occurrences[direction].item()
            top_token = model.to_single_str_token(max_occurring_token[direction].item())
            direction_data.append([direction, total_occurrence, top_token, top_occurrence])

        direction_df = pd.DataFrame(direction_data, columns=["Direction", "Total occurrences", "Top token", "Top token occurrences"])
        direction_df["Top token percent"] = direction_df["Top token occurrences"] / direction_df["Total occurrences"]
        direction_df = direction_df.dropna()

    print(len(direction_df))
    return direction_df

direction_df = get_direction_token_df(first_encoder_max_activations, prompts, model, first_encoder, first_encoder_cfg, percentage_threshold=0.5)

  0%|          | 0/21990 [00:00<?, ?it/s]

  0%|          | 0/24576 [00:00<?, ?it/s]

24439


In [172]:
fig = px.histogram(direction_df, x="Top token percent", width=700, title="Per direction percentage of activations on top token")
fig.update_layout({
    "xaxis_title": "Top token activation percentage",
})
fig.show()

In [175]:
good_directions = direction_df[(direction_df["Top token percent"] > 0.2) & (direction_df["Top token percent"] < 0.7)]["Direction"].tolist()
#good_directions = direction_df[(direction_df["Top token percent"] < 0.1)]["Direction"].tolist()

print(len(good_directions))

6032


In [176]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    max_direction_act = activations[:, direction].max().item()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        if max(direction_act) > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

def print_direction_example(direction, n=10):
    print_top_examples(prompts, first_encoder_max_activations, direction, first_encoder, first_encoder_cfg, n)

# Max activations
_ = interact(print_direction_example, 
         direction=SelectionSlider(options=good_directions, value=good_directions[0], description='Direction'),
         #direction=IntSlider(min=0, max=l0_encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))

interactive(children=(SelectionSlider(description='Direction', options=(0, 3, 6, 15, 24, 29, 37, 53, 58, 60, 6…