## Setup

In [158]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt
import pickle
from ipywidgets import interact, IntSlider, SelectionSlider
from transformer_lens.utils import test_prompt

import plotly.graph_objects as go

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import AutoEncoderConfig, eval_direction_tokens_global, get_acts, load_encoder, eval_ablation_token_rank, get_direction_ablation_hook, get_top_activating_examples_for_direction, evaluate_direction_ablation_single_prompt
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line, multiple_line
%reload_ext autoreload
%autoreload 2

In [150]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [151]:
# L0 '2_silvery_smoke',
# l1 '2_soft_monkey',  
# L3 '2_driven_planet'
run_names = ["18_morning_sun", "8_deep_brook"]
#run_names = ['1_skilled_universe', '47_winter_sun']
encoders = []
for run_name in run_names:
    encoder, cfg = load_encoder(run_name, model_name, model)
    cfg.run_name = run_name
    print(cfg.run_name, cfg.layer, cfg.l1_coeff)
    encoders.append((encoder, cfg))

18_morning_sun 0 0.0001
8_deep_brook 1 0.0003


In [152]:
prompts = load_tinystories_validation_prompts()

In [153]:
def get_activations(encoder, cfg, encoder_name, save_path="/workspace"):
    path = f"{save_path}/data/{encoder_name}_activations.pkl"
    if os.path.exists(path):
        with open(path, "rb") as f:
            data = pickle.load(f)
            max_activations = data["max_activations"]
            max_activation_token_indices = data["max_activation_token_indices"]
    else:
        max_activations, max_activation_token_indices = get_max_activations(prompts, model, encoder, cfg)
        with open(path, "wb") as f:
            pickle.dump({"max_activations": max_activations, "max_activation_token_indices": max_activation_token_indices}, f)
    return max_activations, max_activation_token_indices

In [154]:
max_activation_data = {}
for encoder, cfg in encoders:
    run_name = cfg.run_name
    max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name)
    max_activation_data[run_name] = {
        "max_activations": max_activations.cpu(),
        "max_activation_token_indices": max_activation_token_indices.cpu()
    }

In [155]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [156]:
first_encoder, first_encoder_cfg = encoders[0]
second_encoder, second_encoder_cfg = encoders[1]

first_encoder_max_activations = max_activation_data[first_encoder_cfg.run_name]["max_activations"]
first_encoder_max_activation_token_indices = max_activation_data[first_encoder_cfg.run_name]["max_activation_token_indices"]
second_encoder_max_activations = max_activation_data[second_encoder_cfg.run_name]["max_activations"]
second_encoder_max_activation_token_indices = max_activation_data[second_encoder_cfg.run_name]["max_activation_token_indices"]

In [157]:
second_encoder, second_encoder_cfg = encoders[0]
second_encoder_max_activations = max_activation_data[second_encoder_cfg.run_name]["max_activations"]
second_encoder_max_activation_token_indices = max_activation_data[second_encoder_cfg.run_name]["max_activation_token_indices"]

## Pairwise cosine circuit discovery

In [None]:
W_out = model.W_out[first_encoder_cfg.layer]
W_in = model.W_in[second_encoder_cfg.layer]

cosine_sims = torch.nn.functional.normalize(first_encoder.W_dec @ W_out, dim=-1) @ torch.nn.functional.normalize(W_in @ second_encoder.W_enc, dim=0)
cosine_sims = torch.tril(cosine_sims)

def i_to_row_col(i: int, n_cols: int = first_encoder.d_hidden):
    row = i // n_cols
    col = i % n_cols
    return row, col

all_sims = cosine_sims.flatten().cpu()
top_cosine_similarities, top_cosine_sim_indices = torch.topk(all_sims, 10)

In [None]:
data = []
for top_cosine_index in tqdm(top_cosine_sim_indices):
    first_encoder_dir, second_encoder_dir = i_to_row_col(top_cosine_index)
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, first_encoder_dir, first_encoder_max_activations, first_encoder_max_activation_token_indices, k=100)
    
    original_losses = []
    first_encoder_losses = []
    second_encoder_losses = []
    acts = []
    ablated_acts = []
    for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
        # Direction losses
        original_loss, first_encoder_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, first_encoder, model, first_encoder_dir, first_encoder_cfg, pos=pos)
        _, second_encoder_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, second_encoder, model, second_encoder_dir, second_encoder_cfg, pos=pos)
        original_losses.append(original_loss)
        first_encoder_losses.append(first_encoder_ablated_loss)
        second_encoder_losses.append(second_encoder_ablated_loss)

        # Second encoder direction activation with and without ablation
        act = get_acts(prompt, model, second_encoder, second_encoder_cfg)[pos, second_encoder_dir].item()
        encoder_hook_point = f"blocks.{first_encoder_cfg.layer}.{first_encoder_cfg.act_name}"
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(first_encoder, first_encoder_dir, pos))]):
            ablated_act = get_acts(prompt, model, second_encoder, second_encoder_cfg)[pos, second_encoder_dir].item()
        acts.append(act)
        ablated_acts.append(ablated_act)

    data.append([first_encoder_dir.item(), second_encoder_dir.item(), np.mean(original_losses), np.mean(first_encoder_losses), np.mean(second_encoder_losses), np.mean(acts), np.mean(ablated_acts)])
df = pd.DataFrame(data, columns=["Encoder 1 direction", "Encoder 2 direction", "Original loss", "Encoder 1 direction ablation loss", "Encoder 2 direction ablation loss", "Second encoder activation", "Second encoder activation after ablation"])
df["Cosine similarity"] = top_cosine_similarities.tolist()

## Prompt co occurrence analysis

In [None]:
# Pick interesting looking prompt
# Save activations of all directions for that prompt
# Save last layer active directions for each earlier direction ablated individually
# Compute AND measure for all active directions in last layer based on previous layers

In [None]:
prompt = "This moral story teaches children that"
second_encoder_acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-1]
second_encoder_top_acts, second_encoder_top_dirs = torch.topk(second_encoder_acts, 10)
second_encoder_direction = second_encoder_top_dirs[0].item()
first_encoder_acts = get_acts(prompt, model, first_encoder, first_encoder_cfg)[-1]
active_first_encoder_directions = torch.argwhere(first_encoder_acts > 1).flatten().tolist()
original_second_encoder_act = second_encoder_top_acts[0].item()
#px.histogram(acts.cpu().numpy(), width=700)

In [None]:
data = []
hook_point = first_encoder_cfg.encoder_hook_point
for first_encoder_direction in active_first_encoder_directions:
    ablation_hook = get_direction_ablation_hook(first_encoder, first_encoder_direction, -1)
    with model.hooks([(hook_point, ablation_hook)]):
        ablated_acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-1, second_encoder_direction].item()
    data.append([first_encoder_direction, ablated_acts])
and_df = pd.DataFrame(data, columns=["First encoder direction", "Second encoder activation after ablation"])
and_df["Activation difference"] = and_df["Second encoder activation after ablation"] - original_second_encoder_act
and_df

In [None]:
data = []
hook_point = first_encoder_cfg.encoder_hook_point
for prompt_index, prompt in enumerate(prompts[:2]):
    second_encoder_acts_all_pos = get_acts(prompt, model, second_encoder, second_encoder_cfg)
    first_encoder_acts_all_pos = get_acts(prompt, model, first_encoder, first_encoder_cfg)
    num_tokens = second_encoder_acts_all_pos.shape[0]
    for position in range(10, num_tokens):
        first_encoder_acts = first_encoder_acts_all_pos[position]
        second_encoder_acts = second_encoder_acts_all_pos[position]

        second_encoder_top_acts, second_encoder_top_dirs = torch.topk(second_encoder_acts, 10)
        second_encoder_direction = second_encoder_top_dirs[0].item()
        active_first_encoder_directions = torch.argwhere(first_encoder_acts > 1).flatten().tolist()
        original_second_encoder_act = second_encoder_top_acts[0].item()

        for first_encoder_direction in active_first_encoder_directions:
            ablation_hook = get_direction_ablation_hook(first_encoder, first_encoder_direction, position)
            with model.hooks([(hook_point, ablation_hook)]):
                ablated_acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[position, second_encoder_direction].item()
            data.append([prompt_index, position, first_encoder_direction, second_encoder_direction, ablated_acts, original_second_encoder_act])
and_df = pd.DataFrame(data, columns=["Prompt", "Position", "First encoder direction", "Second encoder direction", "Second encoder activation after ablation", "Second encoder activation"])
and_df["Activation difference"] = and_df["Second encoder activation after ablation"] - and_df["Second encoder activation"]
and_df = and_df.sort_values("Activation difference", ascending=False)

In [None]:
and_df.head(10)

In [None]:
print_top_examples(prompts, second_encoder_max_activations, second_encoder_direction, second_encoder, second_encoder_cfg, n=1)

## Get top token occurrences per direction

In [172]:
def get_direction_token_df(max_activations, prompts, model, encoder, encoder_cfg, percentage_threshold=0.5, save_path="/workspace/data/top_token_occurrences"):
    os.makedirs(save_path, exist_ok=True)
    file_name = f"{save_path}/{encoder_cfg.run_name}_direction_token_occurrences.csv"
    if os.path.exists(file_name):
        direction_df = pd.read_csv(file_name)
    else:

        token_wise_activations = eval_direction_tokens_global(max_activations, prompts, model, encoder, encoder_cfg, percentage_threshold=0.5)
        total_occurrences = token_wise_activations.sum(1)
        max_occurrences = token_wise_activations.max(1)[0]
        max_occurring_token = token_wise_activations.argmax(1)
        str_tokens = model.to_str_tokens(torch.LongTensor(list(range(model.cfg.d_vocab))))

        direction_data = []
        for direction in tqdm(range(encoder.d_hidden)):
            total_occurrence = total_occurrences[direction].item()
            top_occurrence = max_occurrences[direction].item()
            top_token = model.to_single_str_token(max_occurring_token[direction].item())
            direction_data.append([direction, total_occurrence, top_token, top_occurrence])

        direction_df = pd.DataFrame(direction_data, columns=["Direction", "Total occurrences", "Top token", "Top token occurrences"])
        direction_df["Top token percent"] = direction_df["Top token occurrences"] / direction_df["Total occurrences"]
        direction_df = direction_df.dropna()

    print(len(direction_df))
    return direction_df

direction_df = get_direction_token_df(first_encoder_max_activations, prompts, model, first_encoder, first_encoder_cfg, percentage_threshold=0.5)

100%|██████████| 21990/21990 [08:52<00:00, 41.33it/s]


  0%|          | 0/16384 [00:00<?, ?it/s]

16291


In [173]:
fig = px.histogram(direction_df, x="Top token percent", width=700, title="Per direction percentage of activations on top token")
fig.update_layout({
    "xaxis_title": "Top token activation percentage",
})
fig.show()

In [177]:
#good_directions = direction_df[(direction_df["Top token percent"] > 0.2) & (direction_df["Top token percent"] < 0.7)]["Direction"].tolist()
good_directions = direction_df[(direction_df["Top token percent"] < 0.2)]["Direction"].tolist()

print(len(good_directions))

2629


In [178]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().unique().tolist()
    max_direction_act = activations[:, direction].max().item()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        if max(direction_act) > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

def print_direction_example(direction, n=10):
    print_top_examples(prompts, first_encoder_max_activations, direction, first_encoder, first_encoder_cfg, n)

# Max activations
interact(print_direction_example, 
         direction=SelectionSlider(options=good_directions, value=good_directions[0], description='Direction'),
         #direction=IntSlider(min=0, max=l0_encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))

interactive(children=(SelectionSlider(description='Direction', options=(5, 11, 16, 20, 24, 29, 30, 33, 38, 46,…

<function __main__.print_direction_example(direction, n=10)>

## Quotation mark

In [9]:
# Capabilities
# Start quotation after "said" or ":" or other obvious tokens (check tokenization)
# End quotation after "." if started

# Model can definitely predict 'said, "' trigram
#test_prompt(test_prompts[1][:-2], " \"", model)

In [88]:
# '."' '?"' and '!"' are single tokens
# Save prompts where next token is '."'

answer_token = model.to_single_token(".\"")
test_prompts = []
for prompt in tqdm(prompts):
    if "said, \"" in prompt:
        start_index = prompt.index("said, \"") + 7
        end_index = prompt.find(".\"", start_index)
        if end_index != -1:
            subprompt = prompt[:end_index+2]
            tokens = model.to_tokens(subprompt)
            last_token = model.to_single_str_token(tokens[0, -1].item())
            if (subprompt[-2:] == ".\"") and (last_token == ".\""):
                test_prompts.append(subprompt)
print(len(test_prompts))


  0%|          | 0/21990 [00:00<?, ?it/s]

4554


In [89]:
print(model.to_str_tokens(model.to_tokens(test_prompts[0])))

['<|endoftext|>', 'Spot', '.', ' Spot', ' saw', ' the', ' shiny', ' car', ' and', ' said', ',', ' "', 'Wow', ',', ' Kitty', ',', ' your', ' car', ' is', ' so', ' bright', ' and', ' clean', '!"', ' Kitty', ' smiled', ' and', ' replied', ',', ' "', 'Thank', ' you', ',', ' Spot', '.', ' I', ' polish', ' it', ' every', ' day', '."']


In [90]:
#test_prompt(test_prompts[0][:-2], '."', model, prepend_space_to_answer=False)

In [91]:
def DLA(prompts: list[str], model: HookedTransformer, pos=-1) -> tuple[Float[Tensor, "component"], list[str]]:
    logit_attributions = []
    for prompt in tqdm(prompts):
        tokens = model.to_tokens(prompt)
        answers = tokens[:, 1:]
        tokens = tokens[:, :-1]
        answer_residual_directions = model.tokens_to_residual_directions(answers)[0, pos]  # [batch pos d_model]
        _, cache = model.run_with_cache(tokens)
        accumulated_residual, labels = cache.get_full_resid_decomposition(layer=-1, pos_slice=pos, return_labels=True, expand_neurons=False)
        scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer = -1, pos_slice=pos).squeeze(1)
        logit_attribution = einops.einsum(scaled_residual_stack, answer_residual_directions, "component d_model, d_model -> component") # / answers.shape[0]
        logit_attributions.append(logit_attribution)
    
    logit_attributions = torch.stack(logit_attributions).mean(0)
    return logit_attributions, labels

In [92]:
# DLA of '."', find relevant MLPs if they exist, check encoders of those MLPs for relevant directions
# DLA looks pretty different for different prompts
# Relevant boosting heads:
# Relevant boosting MLP: 
# Relevant deboosting heads: 
# Relevant deboosting MLP:
dlas, labels = DLA(test_prompts, model, pos=-1)
print(dlas.shape)

  0%|          | 0/4554 [00:00<?, ?it/s]

torch.Size([37])


In [93]:
line(dlas.cpu().numpy(), xticks=labels, width=1200, title="Closing quotation DLA", show_legend=False)

In [94]:
second_encoder_cfg

AutoEncoderConfig(layer=1, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.0003, d_in=4096, run_name='8_deep_brook')

In [131]:
all_acts = []
for prompt in test_prompts[:1000]:
    acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-2]
    all_acts.append(acts)

all_acts = torch.stack(all_acts)

# Max per direction
max_val, _ = second_encoder_max_activations.max(0)
threshold_per_direction = (max_val * 0.17).cuda()

# Mean activation on all prompts is misleading, prompts could be important on subset of test prompts
num_active_acts = (all_acts > threshold_per_direction).sum(0) + 1e-9
all_acts_tmp = all_acts.clone()
all_acts_tmp[all_acts_tmp <= threshold_per_direction] = 0
# Direction wise mean activation on active prompts
mean_active_acts = all_acts_tmp.sum(0) / num_active_acts
# Filter directions that are active on less than x% of quotation prompts
mean_active_acts[num_active_acts < 0.05*all_acts.shape[0]] = 0
n_non_zero_directions = (mean_active_acts > 0).sum().item()
top_acts, top_dirs = torch.topk(mean_active_acts, min(100, n_non_zero_directions))
print(len(top_acts), top_acts, top_dirs)

24 tensor([1.8794, 1.8002, 1.7505, 1.6188, 1.4893, 1.4747, 1.3057, 1.2482, 1.1837,
        1.0795, 0.9071, 0.7425, 0.7034, 0.6819, 0.6467, 0.5938, 0.5857, 0.5672,
        0.5669, 0.5662, 0.5614, 0.5149, 0.4747, 0.4295], device='cuda:0') tensor([ 2594,  3373,  8842, 14277, 11856, 14011, 15330,  1303,  8676,  7447,
         7758, 10040,  4013, 13506, 12512,  1790, 10241, 11571, 15192,  5063,
        16065,  8032, 13081, 14374], device='cuda:0')


In [115]:
fig = px.histogram(all_acts.mean(0).cpu().numpy(), width=700)
fig.update_layout({
    "showlegend": False,
    "xaxis_title": "Encoder activations",
    "title": "Encoder activations before closing quotation prompts"
})
fig.show()

In [97]:
# Filter test prompts by which prompts the direction activates on
activating_test_prompts_all_dir = torch.zeros((len(test_prompts), second_encoder.d_hidden), dtype=torch.bool)
for i, prompt in tqdm(enumerate(test_prompts), total=len(test_prompts)):
    tokens = model.to_tokens(prompt)
    act_token_index = tokens.shape[1] - 2
    act = get_acts(prompt, model, second_encoder, second_encoder_cfg)[act_token_index]
    act_active = act > 0.1
    activating_test_prompts_all_dir[i] = act_active

  0%|          | 0/4554 [00:00<?, ?it/s]

In [132]:
# Run ablation for activating directions
data = []
for direction in tqdm(top_dirs):
    loss_increases = []
    active_test_prompt_indices = torch.argwhere(activating_test_prompts_all_dir[:, direction]).flatten().tolist()
    active_test_prompts = [test_prompts[i] for i in active_test_prompt_indices]
    num_prompts = min(len(active_test_prompts), 200)
    for prompt in active_test_prompts[:num_prompts]:
        tokens = model.to_tokens(prompt)
        pos = tokens.shape[1]-2
        original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, second_encoder, model, direction.item(), second_encoder_cfg, pos=pos)
        loss_increase = ablated_loss - original_loss
        loss_increases.append(loss_increase)
    loss_increase = np.mean(loss_increases)
    mean_activation =  mean_active_acts[direction].item()
    percentage_activation = num_active_acts[direction].item() / all_acts.shape[0]
    data.append([direction.item(), loss_increase, mean_activation, percentage_activation])
df = pd.DataFrame(data, columns=["Direction", "Loss increase", "Mean activation", "Percentage activation"])
df.sort_values("Loss increase", ascending=False).head(10)

  0%|          | 0/24 [00:00<?, ?it/s]

Unnamed: 0,Direction,Loss increase,Mean activation,Percentage activation
12,4013,-0.050905,0.703426,0.068
11,10040,-0.049067,0.742509,0.14
15,1790,-0.03682,0.593831,0.179
22,13081,-0.017396,0.474668,0.07
13,13506,-0.013377,0.681941,0.063
10,7758,-0.00961,0.907128,0.226
20,16065,-0.001208,0.561357,0.061
18,15192,0.005952,0.566868,0.072
21,8032,0.007661,0.514914,0.082
16,10241,0.00985,0.585729,0.091


In [133]:
df = df.sort_values("Loss increase", ascending=False)
top_directions = df["Direction"].tolist()
print(top_directions)
df.head(10)

[15330, 2594, 3373, 7447, 11856, 8676, 8842, 14277, 14011, 12512, 14374, 11571, 5063, 1303, 10241, 8032, 15192, 16065, 7758, 13506, 13081, 1790, 10040, 4013]


Unnamed: 0,Direction,Loss increase,Mean activation,Percentage activation
6,15330,0.137017,1.305696,0.434
0,2594,0.130624,1.879432,0.064
1,3373,0.121711,1.800152,0.393
9,7447,0.111302,1.079514,0.193
4,11856,0.100708,1.489349,0.052
8,8676,0.090075,1.183704,0.158
2,8842,0.079358,1.750508,0.087
3,14277,0.059188,1.618774,0.098
5,14011,0.05851,1.474702,0.068
14,12512,0.027182,0.646699,0.063


In [102]:
# Compute MLP remainder
# DLA of MLP compared to DLA of reconstructed MLP
# Mean activations over a bunch of test prompts
# Mean reconstructions over a bunch of test prompts
# Apply LN final layer to both
# Compute DLA of both

dlas = []
hook_name = f"blocks.{second_encoder_cfg.layer}.{second_encoder_cfg.act_name}"
for prompt in test_prompts[:200]:
    _, cache = model.run_with_cache(prompt)
    mlp_acts = cache[hook_name][0, -2]
    
    _, x_reconstruct, _, _, _ = second_encoder(mlp_acts)
    
    mlp_res = mlp_acts @ model.W_out[second_encoder_cfg.layer]
    reconstruct_res = x_reconstruct @ model.W_out[second_encoder_cfg.layer]
    remainder = (mlp_res - reconstruct_res) 

    stacked_residual = torch.stack([mlp_res, reconstruct_res, remainder])
    stacked_residual = cache.apply_ln_to_stack(stacked_residual, pos_slice=-2)

    dla = stacked_residual @ model.W_U[:, answer_token]

    dlas.append(dla)

dla = torch.stack(dlas).mean(0)
print(dla.shape)


torch.Size([3])


In [103]:
line(dla.tolist(), xticks=["MLP", "Reconstruct", "Remainder"], show_legend=False, title="Last MLP layer DLA for closing quotation prompts")

In [104]:
# Check neuron basis for directions
# Cosine sims of direction encoder weights
# Check if directions activate together or on separate examples
# Compare set of top 100 contributing neurons for each direction for overlap
# DFA for directions

In [105]:
def direction_dla(direction, max_activations, max_activation_token_indices, encoder, encoder_cfg, n=100, mean_mlp_decomp= None):
    num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    
    direction_weight = encoder.W_enc[:, direction]
    dlas = []
    for i in range(n):
        prompt = top_prompts[i]
        pos = top_prompt_token_indices[i]
        _, cache = model.run_with_cache(prompt)

        decomposition, labels = cache.get_full_resid_decomposition(encoder_cfg.layer, mlp_input=True, apply_ln=True, return_labels=True, expand_neurons=False, pos_slice=pos)
        decomposition = decomposition.squeeze(1)

        # Account for GELU in DLA by setting neuron contributions to 0 if they are not activated
        mlp_wise_decomposition = einops.einsum(decomposition, model.W_in[encoder_cfg.layer], "component d_res, d_res d_mlp -> component d_mlp")
        mlp_activations = cache[f"blocks.{encoder_cfg.layer}.mlp.hook_post"][0, pos, :]
        zeroed_neurons = torch.argwhere(mlp_activations <= 0).flatten()
        if mean_mlp_decomp is not None:
            mlp_wise_decomposition[:, zeroed_neurons] = mean_mlp_decomp[:, zeroed_neurons]
        else:
            mlp_wise_decomposition[:, zeroed_neurons] = 0

        dla = einops.einsum(mlp_wise_decomposition, direction_weight, "component d_mlp, d_mlp -> component")
        dlas.append(dla)
    dla = torch.stack(dlas).mean(0).tolist()
    return dla, labels

def get_mean_component_wise_mlp(prompts, encoder_cfg):
    mlp_wise_decompositions = []
    for prompt in prompts:
        _, cache = model.run_with_cache(prompt)

        decomposition = cache.get_full_resid_decomposition(encoder_cfg.layer, mlp_input=True, apply_ln=True, return_labels=False, expand_neurons=False, pos_slice=None)
        decomposition = decomposition.squeeze(1) # Batch
        # Account for GELU in DLA by setting neuron contributions to 0 if they are not activated
        mlp_wise_decomposition = einops.einsum(decomposition, model.W_in[encoder_cfg.layer], "component pos d_res, d_res d_mlp -> component pos d_mlp")
        mlp_wise_decomposition = mlp_wise_decomposition.mean(1)
        mlp_wise_decompositions.append(mlp_wise_decomposition)
    mlp_wise_decompositions = torch.stack(mlp_wise_decompositions).mean(0)
    return mlp_wise_decompositions

directions = [3373, 15330, 2594, 8842, 7447]
dlas = []
mean_mlp_decomp = get_mean_component_wise_mlp(prompts[:100], second_encoder_cfg)
for direction in directions:
    dla, labels = direction_dla(direction, second_encoder_max_activations, second_encoder_max_activation_token_indices, second_encoder, second_encoder_cfg, mean_mlp_decomp=mean_mlp_decomp)
    dlas.append(dla)


In [106]:
multiple_line(dlas, directions, xticks=labels)

In [107]:
prompt = active_test_prompts[0]
_, cache = model.run_with_cache(prompt)
pattern = cache["pattern", 1][:, 6]
print(pattern.shape)

torch.Size([1, 182, 182])


In [108]:
# import circuitsvis as cv
# display(cv.attention.attention_patterns(
#         attention = pattern.cpu(),
#         tokens = model.to_str_tokens(prompt),
#         attention_head_names = ["L1H6"],
#     ))

## Ablate set of important L1 directions

In [134]:
#directions = [3373, 15330, 2594, 8842, 7447]
directions = top_directions
print(len(directions), directions)

24 [15330, 2594, 3373, 7447, 11856, 8676, 8842, 14277, 14011, 12512, 14374, 11571, 5063, 1303, 10241, 8032, 15192, 16065, 7758, 13506, 13081, 1790, 10040, 4013]


In [135]:
# On positions which close quotation
loss_increases = []
for prompt in test_prompts[:200]:
    tokens = model.to_tokens(prompt)
    pos = tokens.shape[1]-2
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, second_encoder, model, directions, second_encoder_cfg, pos=pos)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

0.34759918388910593 0.37310643415246686


In [136]:
answer_token = ".\""
n = 200
data = []
for test_prompt_index, prompt in tqdm(enumerate(test_prompts[:n]), total=n):
    answer_logprob, ablated_answer_logprob, answer_logit, ablated_answer_logit, answer_rank, ablated_answer_rank = eval_ablation_token_rank(prompt, second_encoder, model, directions, second_encoder_cfg, answer_token, pos=-2)
    data.append([test_prompt_index, answer_logprob, ablated_answer_logprob, answer_logit, ablated_answer_logit, answer_rank, ablated_answer_rank])
ablation_df = pd.DataFrame(data, columns=["Prompt index", "Answer logprob", "Ablated answer logprob", "Answer logit", "Ablated answer logit", "Answer rank", "Ablated answer rank"])
# Calculate differences between original and ablated measures
ablation_df['Logprob Difference'] = ablation_df['Answer logprob'] - ablation_df['Ablated answer logprob']
ablation_df['Logit Difference'] = ablation_df['Answer logit'] - ablation_df['Ablated answer logit']
ablation_df['Rank Difference'] = ablation_df['Answer rank'] - ablation_df['Ablated answer rank']

  0%|          | 0/200 [00:00<?, ?it/s]

In [143]:
def plot_ablation_losses(ablation_df):
    # Calculate means, standard deviations, and standard errors for differences
    means_diff = ablation_df[['Logprob Difference', 'Logit Difference', 'Rank Difference']].mean()
    stds_diff = ablation_df[['Logprob Difference', 'Logit Difference', 'Rank Difference']].std()

    # Standard Errors
    stderrs_diff = stds_diff / np.sqrt(n)

    # 95% Confidence Intervals for Differences
    confidence_interval_diff = 1.96 * stderrs_diff

    # Lower and upper bounds for differences
    lower_bounds_diff = means_diff - confidence_interval_diff
    upper_bounds_diff = means_diff + confidence_interval_diff

    # Create a bar plot for differences
    fig = go.Figure()
    metrics = ['Logprob Difference', 'Logit Difference', 'Rank Difference']

    for metric in metrics:
        fig.add_trace(go.Bar(
            x=[metric],
            y=[means_diff[metric]],
            error_y=dict(
                type='data',  # Represents actual data points
                array=[upper_bounds_diff[metric] - means_diff[metric]],
                arrayminus=[means_diff[metric] - lower_bounds_diff[metric]]
            ),
            name=metric
        ))

    fig.update_layout(
        title="Original - Ablated Measures",
        xaxis_title="Metric",
        yaxis_title="Difference",
        barmode='group',
        width=700
    )

    fig.show()


In [145]:
plot_ablation_losses(ablation_df)

In [138]:
# Baseline check: mean ablate MLP1
hook_name = f"blocks.{second_encoder_cfg.layer}.{second_encoder_cfg.act_name}"
mean_activation = []
for prompt in prompts[:10]:
    _, cache = model.run_with_cache(prompt)
    activations = cache[hook_name][0].mean(0)
    mean_activation.append(activations)
mean_activation = torch.stack(mean_activation).mean(0)

torch.Size([4096])


In [139]:
def mean_ablation_hook(value, hook):
    value[0, -2] = mean_activation 
    return value

In [148]:
loss_increases = []
answer_token_index = model.to_single_token(answer_token)
pos = -2
data = []
for prompt in test_prompts[:100]:
    logits = model(prompt, return_type="logits")[0, pos]
    answer_rank = (logits > logits[answer_token_index]).sum().item()
    answer_logprob = logits.log_softmax(dim=-1)[answer_token_index].item()
    answer_logit = logits[answer_token_index].item()
    with model.hooks(fwd_hooks=[(hook_name, mean_ablation_hook)]):
        ablated_logits = model(prompt, return_type="logits")[0, pos]
        ablated_answer_rank = (ablated_logits > ablated_logits[answer_token_index]).sum().item()
        ablated_answer_logprob = ablated_logits.log_softmax(dim=-1)[answer_token_index].item()
        ablated_answer_logit = ablated_logits[answer_token_index].item()
    data.append([test_prompt_index, answer_logprob, ablated_answer_logprob, answer_logit, ablated_answer_logit, answer_rank, ablated_answer_rank])
ablation_df = pd.DataFrame(data, columns=["Prompt index", "Answer logprob", "Ablated answer logprob", "Answer logit", "Ablated answer logit", "Answer rank", "Ablated answer rank"])
# Calculate differences between original and ablated measures
ablation_df['Logprob Difference'] = ablation_df['Answer logprob'] - ablation_df['Ablated answer logprob']
ablation_df['Logit Difference'] = ablation_df['Answer logit'] - ablation_df['Ablated answer logit']
ablation_df['Rank Difference'] = ablation_df['Answer rank'] - ablation_df['Ablated answer rank']
plot_ablation_losses(ablation_df)

## Direction 2236

In [206]:
direction =8842
#print_top_examples(prompts, second_encoder_max_activations, direction, second_encoder, second_encoder_cfg, n=10)

In [207]:
# Filter test prompts by which prompts the direction activates on
activating_test_prompt_indices = activating_test_prompts_all_dir[:, direction].nonzero().flatten().tolist()
activating_test_prompts = [test_prompts[i] for i in activating_test_prompt_indices]
print(len(activating_test_prompts))

1263


In [208]:
# Activates on many different tokens within quotation marks    
# Higher activations seem to correspond to possible endings of quotation marks
#print_top_examples(prompts, second_encoder_max_activations, direction, second_encoder, second_encoder_cfg, n=5)

In [209]:
def get_common_tinystories_tokens(prompts, model: HookedTransformer, min_occurrences=100):
    occurrences = torch.zeros(model.cfg.d_vocab, dtype=torch.int32).cuda()
    for prompt in prompts: 
        tokens = model.to_tokens(prompt).flatten()
        occurrences = occurrences.index_add(0, tokens, torch.ones_like(tokens, dtype=torch.int32))
    common_tokens = torch.argwhere(occurrences > min_occurrences).flatten()
    rare_tokens = torch.argwhere(occurrences <= min_occurrences).flatten()
    return occurrences, common_tokens, rare_tokens

occurrences, common_tokens, rare_tokens = get_common_tinystories_tokens(prompts, model)
print(len(common_tokens), len(rare_tokens))

2552 47705


In [210]:
quotation_tokens = ["\"", ".", "!", "?", ".\"", "!\"", "?\""]
for quotation_token in quotation_tokens:
    assert model.to_single_token(quotation_token) in common_tokens


print(answer_token, model.to_single_str_token(answer_token))

526 ."


In [211]:
act_tmp = get_acts(activating_test_prompts[0], model, second_encoder, second_encoder_cfg)
print(act_tmp.shape)
print(act_tmp[-10:, direction])

torch.Size([41, 16384])
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7884, 0.0000, 0.6928,
        0.0000], device='cuda:0')


In [212]:
# Prompt based token boosts
def get_direction_logit_and_logprob_boost(
    prompts: list[str],
    encoder: AutoEncoder,
    encoder_neuron,
    model: HookedTransformer,
    all_ignore: Int[Tensor, "tokens"],
    cfg: AutoEncoderConfig,
    pos: -2
):
    zero_direction_hook = [(f"blocks.{cfg.layer}.{cfg.act_name}", get_direction_ablation_hook(
        encoder, encoder_neuron, pos
    ))]

    logprobs_active = []
    logprobs_inactive = []
    for prompt in prompts:
        logits_active = model(prompt, return_type="logits")[0, pos]
        with model.hooks(zero_direction_hook):
            logits_inactive = model(prompt, return_type="logits")[0, pos]

        logprobs_active_current = logits_active.log_softmax(dim=-1)
        logprobs_inactive_current = logits_inactive.log_softmax(dim=-1)
        logprobs_active.append(logprobs_active_current)
        logprobs_inactive.append(logprobs_inactive_current)
    logprobs_active = torch.stack(logprobs_active).mean(0)
    logprobs_inactive = torch.stack(logprobs_inactive).mean(0)
    print(logprobs_active[answer_token], logprobs_inactive[answer_token])

    boosts = (logprobs_active - logprobs_inactive)
    boosts[logprobs_active < -9] = 0
    boosts[all_ignore] = 0
    top_boosts, top_tokens = torch.topk(boosts, 15)
    non_zero_boosts = top_boosts != 0
    top_deboosts, top_deboosted_tokens = torch.topk(boosts, 15, largest=False)
    non_zero_deboosts = top_deboosts != 0
    boosted_tokens = (
        model.to_str_tokens(top_tokens[non_zero_boosts]),
        top_boosts[non_zero_boosts].tolist(),
    )
    deboosted_tokens = (
        model.to_str_tokens(top_deboosted_tokens[non_zero_deboosts]),
        top_deboosts[non_zero_deboosts].tolist(),
    )
    logging.info(f"Top boosted: {boosted_tokens}")
    logging.info(f"Top deboosted: {deboosted_tokens}")

get_direction_logit_and_logprob_boost(activating_test_prompts[:100], second_encoder, direction, model, rare_tokens, second_encoder_cfg, pos=-2)

(INFO) 10:24:23: Top boosted: (['!"', '".', '."'], [0.17473149299621582, 0.10397529602050781, 0.07069569826126099])
(INFO) 10:24:23: Top deboosted: (['!', '.', ' with', ' and', ' in', ' to', ',', ' on', ',"', ' so', ' now', ' when', ' too', ' for', ' if'], [-0.2674741744995117, -0.24495577812194824, -0.19048261642456055, -0.18987417221069336, -0.18756914138793945, -0.18593645095825195, -0.18051671981811523, -0.18021774291992188, -0.12707233428955078, -0.12655258178710938, -0.11950302124023438, -0.11670160293579102, -0.10678291320800781, -0.06883668899536133, -0.03220367431640625])


tensor(-0.6755, device='cuda:0') tensor(-0.7462, device='cuda:0')


In [214]:
# Weight based token boosts

def get_direction_boosted_tokens(direction, encoder: AutoEncoder, model: HookedTransformer, cfg: AutoEncoderConfig, rare_tokens: Tensor):
    token_boosts = encoder.W_dec[direction] @ model.W_out[cfg.layer] @ model.unembed.W_U
    token_boosts[rare_tokens] = 0
    return token_boosts

def print_token_boosts(boosts, tokens):
    str_tokens = model.to_str_tokens(tokens)
    boost_str = ""
    for token, boost in zip(str_tokens, boosts.tolist()):
        boost_str += f"('{token}': {boost:.2f}), "
    print(boost_str[:-2])

boosts = get_direction_boosted_tokens(direction, second_encoder, model, second_encoder_cfg, rare_tokens)
top_boosts, top_tokens = torch.topk(boosts, 25)
print_token_boosts(top_boosts, top_tokens)
top_boosts, top_tokens = torch.topk(boosts, 25, largest=False)
print_token_boosts(top_boosts, top_tokens)

('!"': 0.61), ('".': 0.52), ('!".': 0.51), ('."': 0.50), (' then': 0.38), (' if': 0.34), ('"': 0.33), (' right': 0.32), (' today': 0.31), (' for': 0.30), ('?".': 0.29), (' while': 0.29), (' yourself': 0.29), (' knowing': 0.28), (' quick': 0.28), (' tomorrow': 0.28), (' here': 0.27), ('?"': 0.27), (' no': 0.26), (' too': 0.25), (' when': 0.24), (' home': 0.24), (' now': 0.24), (' ok': 0.24), (' sometimes': 0.23)
(' their': -0.35), (' couldn': -0.35), (' shared': -0.34), (' had': -0.34), (' wasn': -0.31), (' was': -0.30), (' could': -0.30), (' played': -0.28), (' tasted': -0.28), (' would': -0.27), (' grew': -0.27), (' were': -0.26), (' offered': -0.26), (' bought': -0.26), (' Their': -0.25), (''d': -0.25), (' drank': -0.25), (' stayed': -0.25), (' sprayed': -0.25), (' worked': -0.25), (' danced': -0.24), (' took': -0.24), (' stood': -0.24), (' ate': -0.24), (' wore': -0.24)


In [215]:
prompt = test_prompts[0]
print(prompt)
test_prompt('Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day', '.\"', model, prepend_space_to_answer=False)

Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day."
Tokenized prompt: ['<|endoftext|>', 'Spot', '.', ' Spot', ' saw', ' the', ' shiny', ' car', ' and', ' said', ',', ' "', 'Wow', ',', ' Kitty', ',', ' your', ' car', ' is', ' so', ' bright', ' and', ' clean', '!"', ' Kitty', ' smiled', ' and', ' replied', ',', ' "', 'Thank', ' you', ',', ' Spot', '.', ' I', ' polish', ' it', ' every', ' day']
Tokenized answer: ['."']


Top 0th token. Logit: 25.26 Prob: 46.09% Token: | to|
Top 1th token. Logit: 24.84 Prob: 30.33% Token: |."|
Top 2th token. Logit: 22.82 Prob:  4.03% Token: |,|
Top 3th token. Logit: 22.62 Prob:  3.30% Token: | so|
Top 4th token. Logit: 22.54 Prob:  3.05% Token: |.|
Top 5th token. Logit: 22.49 Prob:  2.88% Token: | and|
Top 6th token. Logit: 22.45 Prob:  2.77% Token: | with|
Top 7th token. Logit: 22.26 Prob:  2.31% Token: |!"|
Top 8th token. Logit: 21.80 Prob:  1.45% Token: | because|
Top 9th token. Logit: 21.30 Prob:  0.88% Token: | for|


In [216]:
# Does ablating the feature increase loss when the feature is active?

# On positions which close quotation
loss_increases = []
for prompt in activating_test_prompts[:200]:
    tokens = model.to_tokens(prompt)
    pos = tokens.shape[1]-2
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, second_encoder, model, direction, second_encoder_cfg, pos=pos)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

0.0793583412701264 0.1212017695966228


In [47]:
# In general on max activating prompts, ablating globally
loss_increases = []
max_activating_prompts, max_activating_token_indices = get_top_activating_examples_for_direction(prompts, direction, second_encoder_max_activations, second_encoder_max_activation_token_indices, k=100)
for prompt, index in zip(max_activating_prompts, max_activating_token_indices.tolist()):
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, second_encoder, model, direction, second_encoder_cfg)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

0.006354030966758728 0.00835937654341787


In [48]:
# In general on max activating prompts, ablating one activating position at a time
loss_increases = []
max_activating_prompts, max_activating_token_indices = get_top_activating_examples_for_direction(prompts, direction, second_encoder_max_activations, second_encoder_max_activation_token_indices, k=100)
threshold = second_encoder_max_activations[:, direction].max() * 0.1

for prompt, index in tqdm(zip(max_activating_prompts, max_activating_token_indices.tolist()), total=len(max_activating_prompts)):
    acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[:, direction]
    # No loss for last position, exclude
    active_positions = torch.argwhere(acts[:-1] > threshold).flatten().tolist()
    for position in active_positions:
        original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, second_encoder, model, direction, second_encoder_cfg, pos=position)
        loss_increase = ablated_loss - original_loss
        loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

  0%|          | 0/100 [00:00<?, ?it/s]

0.10963073582908171 0.4840503690021192


## Direction 15777 (second highest quotation DLA)

In [34]:
direction = 15777

In [35]:
print_top_examples(prompts, second_encoder_max_activations, direction, second_encoder, second_encoder_cfg, n=5)

In [36]:
get_direction_logit_and_logprob_boost(test_prompts[3], second_encoder, direction, model, rare_tokens, second_encoder_cfg, pos=-2)

(INFO) 08:04:06: Top boosted: ([], [])
(INFO) 08:04:06: Top deboosted: ([], [])


tensor(-0.2608, device='cuda:0') tensor(-0.2608, device='cuda:0')


## Check for directions with positive DLA


In [22]:
activation_scaled_direction_dla = (all_acts.unsqueeze(1) * second_encoder.W_dec) @ model.W_out[second_encoder_cfg.layer] @ model.unembed.W_U[:, answer_token]
print(activation_scaled_direction_dla.shape)
top_acts, top_dirs = torch.topk(activation_scaled_direction_dla, 10)
print(top_acts, top_dirs)

torch.Size([24576])
tensor([0.0782, 0.0675, 0.0424, 0.0401, 0.0372, 0.0339, 0.0335, 0.0334, 0.0331,
        0.0320], device='cuda:0') tensor([11440, 14734,  7789,  2484, 18814, 18261, 18186,  6987,  2852, 24056],
       device='cuda:0')


In [23]:
top_acts.mean()

tensor(0.0431, device='cuda:0')

In [24]:
direction_dla = second_encoder.W_dec @ model.W_out[second_encoder_cfg.layer] @ model.unembed.W_U[:, answer_token]
print(direction_dla.shape)

torch.Size([24576])


In [25]:
active_directions = (all_acts > 0.05).cpu()
print(active_directions.sum())

fig = go.Figure(data=go.Scatter(x=direction_dla.cpu()[active_directions], y=all_acts.cpu()[active_directions], mode='markers'))
fig.update_layout(
    title="Gradients of decoder features in Layer 0 wrt a Layer 1 feature against cosine similarities of same",
    xaxis_title='Direction \'.\"\' DLA',
    yaxis_title='Direction activation',
    width=900
)
fig.show()

tensor(237)


In [26]:
active_directions = (all_acts > 0.05).cpu()
print(active_directions.sum())

fig = go.Figure(data=go.Scatter(x=activation_scaled_direction_dla.cpu()[active_directions], y=all_acts.cpu()[active_directions], mode='markers'))
fig.update_layout(
    title="DLA and activation of active Layer 1 encoder directions",
    xaxis_title='Direction \'.\"\' DLA',
    yaxis_title='Direction activation',
    width=900
)
fig.show()

tensor(237)
