## Setup

In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm import tqdm
from transformer_lens import HookedTransformer
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt
import pickle
from ipywidgets import interact, IntSlider, SelectionSlider
import plotly.graph_objects as go

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import (
    AutoEncoderConfig, 
    get_all_activating_test_prompts, 
    eval_direction_tokens_global, 
    get_encode_activations_hook, 
    get_activations, 
    get_acts, 
    load_encoder, 
    eval_ablation_token_rank, 
    get_direction_ablation_hook, 
    get_top_activating_examples_for_direction, 
    evaluate_direction_ablation_single_prompt,
    eval_encoder_reconstruction_single_position,
    get_top_direction_ablation_df,
    get_mean_component_wise_mlp,
    get_custom_forward_hook
)
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line, multiple_line
%reload_ext autoreload
%autoreload 2

In [2]:
autoencoder_directions = {}

In [3]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

config.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/323M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [4]:
run_names = ["54_serene_plasma", "189_giddy_water"] 
encoders = []
for run_name in run_names:
    encoder, cfg = load_encoder(run_name, model_name, model)
    cfg.run_name = run_name
    print(cfg.run_name, cfg.layer, cfg.l1_coeff)
    encoders.append((encoder, cfg))

{'cfg_file': None, 'data_path': '/workspace/data/tinystories', 'save_path': '/workspace', 'use_wandb': True, 'num_eval_tokens': 800000, 'num_training_tokens': 500000000.0, 'batch_size': 4096, 'buffer_mult': 128, 'seq_len': 128, 'model': 'tiny-stories-2L-33M', 'layer': 0, 'act': 'mlp.hook_post', 'expansion_factor': 4, 'seed': 47, 'lr': 0.0001, 'l1_coeff': 0.0003, 'l1_target': None, 'wd': 0.01, 'beta1': 0.9, 'beta2': 0.99, 'num_eval_prompts': 200, 'save_checkpoint_models': False, 'reg': 'l1', 'finetune_encoder': None, 'dead_direction_frequency': 0.0005, 'model_batch_size': 32, 'buffer_size': 524288, 'buffer_batches': 4096, 'num_eval_batches': 195, 'd_in': 4096, 'wandb_name': 'serene-plasma-54', 'save_name': '54_serene_plasma'}


54_serene_plasma 0 0.0003
{'cfg_file': None, 'data_path': '/workspace/data/tinystories', 'save_path': '/workspace', 'use_wandb': True, 'num_eval_tokens': 800000, 'num_training_tokens': 500000000.0, 'batch_size': 5080, 'buffer_mult': 128, 'seq_len': 127, 'model': 'tiny-stories-2L-33M', 'layer': 1, 'act': 'mlp.hook_post', 'expansion_factor': 4, 'seed': 47, 'lr': 0.0001, 'l1_coeff': [0.0001, 0.00015], 'l1_target': None, 'wd': 0.01, 'beta1': 0.9, 'beta2': 0.99, 'num_eval_prompts': 200, 'save_checkpoint_models': False, 'reg': 'combined_hoyer_sqrt', 'finetune_encoder': None, 'dead_direction_frequency': 1e-05, 'model_batch_size': 40, 'buffer_size': 650240, 'buffer_batches': 5120, 'num_eval_batches': 157, 'd_in': 4096, 'wandb_name': 'giddy-water-189', 'save_name': '189_giddy_water'}
189_giddy_water 1 [0.0001, 0.00015]


In [5]:
prompts = load_tinystories_validation_prompts()

(INFO) 03:32:04: Loaded 21990 TinyStories validation prompts


In [6]:
max_activation_data = {}
for encoder, cfg in encoders:
    run_name = cfg.run_name
    max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name, prompts, model)
    max_activation_data[run_name] = {
        "max_activations": max_activations.cpu(),
        "max_activation_token_indices": max_activation_token_indices.cpu()
    }

In [7]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            print(f"Prompt: {prompt_index}")
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [8]:
first_encoder, first_encoder_cfg = encoders[0]
second_encoder, second_encoder_cfg = encoders[1]

first_encoder_max_activations = max_activation_data[first_encoder_cfg.run_name]["max_activations"]
first_encoder_max_activation_token_indices = max_activation_data[first_encoder_cfg.run_name]["max_activation_token_indices"]
second_encoder_max_activations = max_activation_data[second_encoder_cfg.run_name]["max_activations"]
second_encoder_max_activation_token_indices = max_activation_data[second_encoder_cfg.run_name]["max_activation_token_indices"]

In [9]:
def get_common_tinystories_tokens(prompts, model: HookedTransformer, min_occurrences=100):
    occurrences = torch.zeros(model.cfg.d_vocab, dtype=torch.int32).cuda()
    for prompt in prompts: 
        tokens = model.to_tokens(prompt).flatten()
        occurrences = occurrences.index_add(0, tokens, torch.ones_like(tokens, dtype=torch.int32))
    common_tokens = torch.argwhere(occurrences > min_occurrences).flatten()
    rare_tokens = torch.argwhere(occurrences <= min_occurrences).flatten()
    return occurrences, common_tokens, rare_tokens

occurrences, common_tokens, rare_tokens = get_common_tinystories_tokens(prompts, model)
print(len(common_tokens), len(rare_tokens))

2552 47705


In [10]:
# Filter test prompts following 'said, " [...] ."' pattern
# '."' '?"' and '!"' are single tokens

answer_token = model.to_single_token(".\"")
test_prompts = []
for prompt in tqdm(prompts):
    if "said, \"" in prompt:
        start_index = prompt.index("said, \"") + 7
        end_index = prompt.find(".\"", start_index)
        # Exclude long prompts - model performance degrades
        if (end_index != -1) and (end_index < 1600):
            subprompt = prompt[:end_index+2]
            tokens = model.to_tokens(subprompt)
            last_token = model.to_single_str_token(tokens[0, -1].item())
            if (subprompt[-2:] == ".\"") and (last_token == ".\""):
                test_prompts.append(subprompt)
print(len(test_prompts))

 95%|█████████▍| 20824/21990 [00:01<00:00, 11562.15it/s]

100%|██████████| 21990/21990 [00:01<00:00, 11298.37it/s]

4500





## Model DLA

In [11]:
def DLA(prompts: list[str], model: HookedTransformer, pos=-1) -> tuple[Float[Tensor, "component"], list[str]]:
    logit_attributions = []
    for prompt in tqdm(prompts):
        tokens = model.to_tokens(prompt)
        answers = tokens[:, 1:]
        tokens = tokens[:, :-1]
        answer_residual_directions = model.tokens_to_residual_directions(answers)[0, pos]  # [batch pos d_model]
        _, cache = model.run_with_cache(tokens)
        accumulated_residual, labels = cache.get_full_resid_decomposition(layer=-1, pos_slice=pos, return_labels=True, expand_neurons=False)
        scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer = -1, pos_slice=pos).squeeze(1)
        logit_attribution = einops.einsum(scaled_residual_stack, answer_residual_directions, "component d_model, d_model -> component") # / answers.shape[0]
        logit_attributions.append(logit_attribution)
    
    logit_attributions = torch.stack(logit_attributions).mean(0)
    return logit_attributions, labels

In [12]:
# DLA of '."', find relevant MLPs if they exist, check encoders of those MLPs for relevant directions
# DLA looks pretty different for different prompts
# Relevant boosting heads:
# Relevant boosting MLP: 
# Relevant deboosting heads: 
# Relevant deboosting MLP:
dfas, labels = DLA(test_prompts, model, pos=-1)
print(dfas.shape)

  0%|          | 0/4500 [00:00<?, ?it/s]

100%|██████████| 4500/4500 [00:33<00:00, 135.01it/s]

torch.Size([37])





In [13]:
l1_attention_dla = sum([dla.item() for dla, label in zip(dfas, labels) if "L1H" in label])
print(l1_attention_dla)

6.00066873896867


In [14]:
line(dfas.cpu().numpy(), xticks=labels, width=1200, title="Closing quotation DLA", show_legend=False)

## Ablate L1 Directions

In [15]:
activating_test_prompts_l1 = get_all_activating_test_prompts(test_prompts, second_encoder, model, second_encoder_cfg, active_threshold=0.1)

  0%|          | 22/4500 [00:00<00:42, 104.45it/s]

100%|██████████| 4500/4500 [00:42<00:00, 107.03it/s]


In [16]:
original_loss, ablated_loss, encoded_loss = eval_encoder_reconstruction_single_position(test_prompts[200:500], second_encoder, model, second_encoder_cfg)
print(f"Original loss: {original_loss:.3f}, ablated loss: {ablated_loss:.3f}, encoded loss: {encoded_loss:.3f}")
print(f"Loss recovered: {(ablated_loss - encoded_loss)/(ablated_loss - original_loss):.4f}")

  3%|▎         | 10/300 [00:00<00:06, 46.11it/s]

100%|██████████| 300/300 [00:06<00:00, 45.50it/s]

Original loss: 0.968, ablated loss: 1.832, encoded loss: 1.045
Loss recovered: 0.9109





In [17]:
df_l1, loss_increases_l1 = get_top_direction_ablation_df(activating_test_prompts_l1, test_prompts, model, second_encoder, second_encoder_cfg, second_encoder_max_activations)
df_l1.head(10)

Number of directions with mean activation > 0: 34


100%|██████████| 34/34 [02:21<00:00,  4.17s/it]


Unnamed: 0,Direction,Loss increase,Loss increase (encoded),Mean activation,Percentage activation
0,82,0.07321,-0.012924,5.005048,0.026
1,11865,0.15087,0.167747,3.859326,0.03
2,8093,0.49626,0.477907,3.815932,0.471
3,14643,0.177023,0.144021,3.262734,0.021
4,15796,0.136471,0.166351,2.788814,0.069
5,6011,0.195737,0.392007,2.756625,0.119
6,5000,0.007704,0.106521,2.548186,0.02
7,5914,0.049325,0.032349,2.516097,0.02
8,794,0.312786,0.395805,2.493319,0.411
9,11157,0.028837,-0.086961,2.467405,0.023


In [18]:
directions = df_l1["Direction"].tolist()
# Preparing DataFrame
df = pd.DataFrame([(direction, value) for direction, losses in zip(directions, loss_increases_l1) for value in losses],
                  columns=['Direction', 'Loss Increase'])

# Creating histogram
fig = px.histogram(df, x="Loss Increase", color="Direction", barmode='overlay')

# Updating labels and title
fig.update_layout(
    xaxis_title='Loss Increase',
    yaxis_title='Count',
    title='Loss Increase Histogram by Direction'
)

fig.show()

In [19]:
df = df_l1.sort_values("Loss increase", ascending=False)
top_directions = df["Direction"].tolist()
directions = top_directions[:3]
print(f"The top 3 directions are {directions}")
df.head(10)

The top 3 directions are [8093, 794, 6011]


Unnamed: 0,Direction,Loss increase,Loss increase (encoded),Mean activation,Percentage activation
2,8093,0.49626,0.477907,3.815932,0.471
8,794,0.312786,0.395805,2.493319,0.411
5,6011,0.195737,0.392007,2.756625,0.119
10,13657,0.18129,0.260132,2.457184,0.02
3,14643,0.177023,0.144021,3.262734,0.021
18,10665,0.163724,0.235448,2.031421,0.027
1,11865,0.15087,0.167747,3.859326,0.03
4,15796,0.136471,0.166351,2.788814,0.069
17,8594,0.103299,0.143853,2.041233,0.037
32,6258,0.100347,0.19109,1.181257,0.04


## Within encoder, within decoder, and between encoder similarities

In [20]:
top_directions = df_l1.head(9)["Direction"].tolist()
top_directions_w_enc = second_encoder.W_enc[:, top_directions].cpu().numpy()
top_directions_w_dec = second_encoder.W_dec[top_directions, :].cpu().numpy()

In [21]:
# Cosine sims between top_directions_w_enc
cosine_sims = np.zeros((len(top_directions), len(top_directions)))
for i in range(len(top_directions)):
    for j in range(len(top_directions)):
        cosine_sims[i, j] = np.dot(top_directions_w_enc[i], top_directions_w_enc[j]) / (np.linalg.norm(top_directions_w_enc[i]) * np.linalg.norm(top_directions_w_enc[j]))
print(np.round(cosine_sims,2))

[[ 1.    0.09 -0.35  0.36  0.29  0.21  0.01  0.01  0.44]
 [ 0.09  1.   -0.2   0.25  0.09 -0.53 -0.38  0.49  0.29]
 [-0.35 -0.2   1.   -0.41  0.22  0.31  0.02 -0.13 -0.23]
 [ 0.36  0.25 -0.41  1.   -0.16 -0.25 -0.69  0.51 -0.01]
 [ 0.29  0.09  0.22 -0.16  1.    0.47  0.36  0.05 -0.15]
 [ 0.21 -0.53  0.31 -0.25  0.47  1.    0.57 -0.35 -0.31]
 [ 0.01 -0.38  0.02 -0.69  0.36  0.57  1.   -0.6  -0.11]
 [ 0.01  0.49 -0.13  0.51  0.05 -0.35 -0.6   1.   -0.13]
 [ 0.44  0.29 -0.23 -0.01 -0.15 -0.31 -0.11 -0.13  1.  ]]


In [22]:
# Cosine sims between top_directions_w_dec
cosine_sims = np.zeros((len(top_directions), len(top_directions)))
for i in range(len(top_directions)):
    for j in range(len(top_directions)):
        cosine_sims[i, j] = np.dot(top_directions_w_dec[i], top_directions_w_dec[j]) / (np.linalg.norm(top_directions_w_dec[i]) * np.linalg.norm(top_directions_w_dec[j]))
print(np.round(cosine_sims,2))

[[1.   0.16 0.17 0.27 0.21 0.16 0.14 0.04 0.21]
 [0.16 1.   0.13 0.02 0.15 0.16 0.01 0.05 0.27]
 [0.17 0.13 1.   0.13 0.45 0.34 0.08 0.18 0.43]
 [0.27 0.02 0.13 1.   0.26 0.12 0.06 0.05 0.16]
 [0.21 0.15 0.45 0.26 1.   0.42 0.04 0.12 0.55]
 [0.16 0.16 0.34 0.12 0.42 1.   0.03 0.08 0.53]
 [0.14 0.01 0.08 0.06 0.04 0.03 1.   0.17 0.1 ]
 [0.04 0.05 0.18 0.05 0.12 0.08 0.17 1.   0.15]
 [0.21 0.27 0.43 0.16 0.55 0.53 0.1  0.15 1.  ]]


In [23]:
autoencoder_directions[second_encoder_cfg.run_name] = {
    "df": df_l1,
    "top_directions": top_directions,
    "top_directions_w_enc": top_directions_w_enc,
    "top_directions_w_dec": top_directions_w_dec,
}

In [None]:
# df_l0, _ = get_top_direction_ablation_df(activating_test_prompts_l0, test_prompts, model, first_encoder, first_encoder_cfg, first_encoder_max_activations)
# autoencoder_directions[first_encoder_cfg.run_name] = {
#     "df": df_l0,
#     "top_directions": ,
#     "top_directions_w_enc": ,
#     "top_directions_w_dec": ,
# }

In [24]:
all_mean_directions = []
for key in autoencoder_directions.keys():
    df = autoencoder_directions[key]["df"]
    num_directions = len(df[df["Loss increase"] > 0.05])
    print(f"{key}: {num_directions} directions")
    directions = autoencoder_directions[key]["top_directions_w_dec"][:num_directions]
    mean_direction = directions.mean(0)
    all_mean_directions.append(mean_direction)
all_mean_directions = np.stack(all_mean_directions, 1)
print(all_mean_directions.shape)


189_giddy_water: 19 directions
(4096, 1)


In [25]:
# # Cosine sim between directions
# cosine_sim = np.zeros((4, 4))
# for i in range(4):
#     for j in range(4):
#         cosine_sim[i, j] = np.dot(all_mean_directions[:, i], all_mean_directions[:, j]) / (np.linalg.norm(all_mean_directions[:, i]) * np.linalg.norm(all_mean_directions[:, j]))

# cosine_sim

IndexError: index 1 is out of bounds for axis 1 with size 1

## MLP remainder term

In [26]:
# Compute MLP remainder
# DLA of MLP compared to DLA of reconstructed MLP
# Mean activations over a bunch of test prompts
# Mean reconstructions over a bunch of test prompts
# Apply LN final layer to both
# Compute DLA of both

dfas = []
hook_name = f"blocks.{second_encoder_cfg.layer}.{second_encoder_cfg.act_name}"
for prompt in test_prompts[:200]:
    _, cache = model.run_with_cache(prompt)
    mlp_acts = cache[hook_name][0, -2]
    
    _, x_reconstruct, _, _, _ = second_encoder(mlp_acts)
    
    mlp_res = mlp_acts @ model.W_out[second_encoder_cfg.layer]
    reconstruct_res = x_reconstruct @ model.W_out[second_encoder_cfg.layer]
    remainder = (mlp_res - reconstruct_res) 

    stacked_residual = torch.stack([mlp_res, reconstruct_res, remainder])
    stacked_residual = cache.apply_ln_to_stack(stacked_residual, pos_slice=-2)

    dfa = stacked_residual @ model.W_U[:, answer_token]

    dfas.append(dfa)

dfa = torch.stack(dfas).mean(0)
print(dfa.shape)


torch.Size([3])


In [27]:
line(dfa.tolist(), xticks=["MLP", "Reconstruct", "Remainder"], show_legend=False, title="Last MLP layer DLA for closing quotation prompts")

## Direction DFA 

In [28]:
# Check neuron basis for directions
# Cosine sims of direction encoder weights
# Check if directions activate together or on separate examples
# Compare set of top 100 contributing neurons for each direction for overlap
# DFA for directions

In [29]:
df["Direction"]

0        82
1     11865
2      8093
3     14643
4     15796
5      6011
6      5000
7      5914
8       794
9     11157
10    13657
11     1289
12     8689
13    10095
14     2863
15     7266
16    11471
17     8594
18    10665
19     4178
20    13082
21     6686
22     3789
23    16171
24     4344
25    10493
26     1266
27     1993
28    10949
29    11854
30     5470
31     2447
32     6258
33     9353
Name: Direction, dtype: int64

In [30]:
directions = df["Direction"][:1].tolist()
directions

[82]

In [31]:
# from utils.circuit_discovery_utils import final_token_indices

In [32]:
def direction_dfa(test_prompts, direction, activating_test_prompts, encoder, encoder_cfg, model, mean_mlp_decomp= None, pos=-2):
    activating_test_prompt_indices = torch.argwhere(activating_test_prompts[:, direction] > 0).flatten().tolist()

    direction_weight = encoder.W_enc[:, direction]
    dfas = []
    for prompt_index in activating_test_prompt_indices:
        prompt = test_prompts[prompt_index]
        _, cache = model.run_with_cache(prompt)

        decomposition, labels = cache.get_full_resid_decomposition(encoder_cfg.layer, mlp_input=True, apply_ln=True, return_labels=True, expand_neurons=False, pos_slice=pos)
        decomposition = decomposition.squeeze(1)

        # Account for GELU in DLA by setting neuron contributions to 0 if they are not activated
        mlp_wise_decomposition = einops.einsum(decomposition, model.W_in[encoder_cfg.layer], "component d_res, d_res d_mlp -> component d_mlp")
        mlp_activations = cache[f"blocks.{encoder_cfg.layer}.mlp.hook_post"][0, pos, :]
        zeroed_neurons = torch.argwhere(mlp_activations <= 0).flatten()
        if mean_mlp_decomp is not None:
            mlp_wise_decomposition[:, zeroed_neurons] = mean_mlp_decomp[:, zeroed_neurons]
        else:
            mlp_wise_decomposition[:, zeroed_neurons] = 0

        dfa = einops.einsum(mlp_wise_decomposition, direction_weight, "component d_mlp, d_mlp -> component")
        dfas.append(dfa)
    dfa = torch.stack(dfas).mean(0).tolist()
    return dfa, labels

In [33]:
dfas = []
mean_mlp_decomp = get_mean_component_wise_mlp(prompts[:100], model, second_encoder_cfg)
for direction in directions:
    dfa, labels = direction_dfa(test_prompts, direction, activating_test_prompts_l1, second_encoder, second_encoder_cfg, model, mean_mlp_decomp=mean_mlp_decomp)
    dfas.append(dfa)
multiple_line(dfas, directions, xticks=labels, title="DFA for top directions (mean ablation of inactive MLP neurons)", width=1000)

In [34]:
dfas = []
for direction in directions:
    dfa, labels = direction_dfa(test_prompts, direction, activating_test_prompts_l1, second_encoder, second_encoder_cfg, model, mean_mlp_decomp=None)
    dfas.append(dfa)
multiple_line(dfas, directions, xticks=labels, title="DFA for top directions (zero ablation of inactive MLP neurons)", width=1000)

In [35]:
# Iterate through components, ablate, check ablation loss and top direction activation

def zero_ablate_component_hook(value, hook):
    value[:, -2] = 0
    return value

def mean_ablate_component_hook(value, hook):
    value[:, -2] = component_activations[hook.name]
    return value


components = ['original', 'hook_embed', 'hook_pos_embed', 'blocks.0.hook_attn_out', 'blocks.0.hook_mlp_out', 'blocks.1.hook_attn_out', 'blocks.1.hook_mlp_out']
component_activations = {c:[] for c in components}

for prompt in prompts[:1000]:
    tokens = model.to_tokens(prompt, prepend_bos=False)

    _, cache = model.run_with_cache(tokens)

    for component in components[1:]:
        activation = cache[component].mean((0,1))
        component_activations[component].append(activation)

for component in components[1:]:
    component_activations[component] = torch.stack(component_activations[component]).mean(0)

In [36]:
def get_direction_ablation_df(prompts: list[str], directions: list[int]):
    data = []
    for component in tqdm(components):
        losses = []
        direction_acts = []
        for prompt in prompts:
            tokens = model.to_tokens(prompt, prepend_bos=False)
            if component == "original":
                loss = model(tokens, return_type="loss", loss_per_token=True)[0, -1].item()
                acts = get_acts(tokens, model, second_encoder, second_encoder_cfg)[-2][directions]
            else:
                with model.hooks([(component, mean_ablate_component_hook)]):
                    loss = model(tokens, return_type="loss", loss_per_token=True)[0, -1].item()
                    acts = get_acts(tokens, model, second_encoder, second_encoder_cfg)[-2][directions]
            direction_acts.append(acts)
            losses.append(loss)
        direction_acts = torch.stack(direction_acts)
        mean_direction_acts = direction_acts.mean(0).tolist()
        quantile_acts = torch.quantile(direction_acts, 0.90, dim=0).tolist()
        quantile_loss = np.quantile(losses, 0.90)
        data.append([component, np.mean(loss), quantile_loss] + mean_direction_acts + quantile_acts)
    df = pd.DataFrame(data, columns=["Component", "Loss (mean)", "Loss (90th)"] + [f"Direction {i} (mean)" for i in directions] + [f"Direction {i} (90th)" for i in directions])
    return df

In [37]:
def get_active_test_prompts(test_prompts: list[str], activating_test_prompts: Tensor, direction: int):
    assert len(test_prompts) == activating_test_prompts.shape[0]
    activating_prompt_indices = activating_test_prompts[:, direction].nonzero().flatten().tolist()
    activating_prompts = [test_prompts[i] for i in activating_prompt_indices]
    return activating_prompts

In [38]:
active_test_prompts = get_active_test_prompts(test_prompts, activating_test_prompts_l1, direction)
df_tmp = get_direction_ablation_df(active_test_prompts, [direction])
df_tmp

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:15<00:00,  2.27s/it]


Unnamed: 0,Component,Loss (mean),Loss (90th),Direction 82 (mean),Direction 82 (90th)
0,original,0.522195,1.179955,3.971275,6.959687
1,hook_embed,1.976023,3.356292,0.0,0.0
2,hook_pos_embed,0.338117,2.018333,2.350632,3.997147
3,blocks.0.hook_attn_out,3.804909,6.225432,0.0,0.0
4,blocks.0.hook_mlp_out,0.84401,2.193273,0.606841,1.305548
5,blocks.1.hook_attn_out,6.899246,8.442786,0.986717,2.090104
6,blocks.1.hook_mlp_out,0.944188,2.267761,3.971275,6.959687


In [39]:
cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.ho

In [40]:
cache['blocks.1.attn.hook_result'].shape

torch.Size([1, 150, 16, 1024])

In [41]:
# Iterate over attention heads

def get_mean_ablate_attention_head_hook(layer, head):
    def ablate_hook(value, hook):
        value[:, -2, head] = cache[hook.name][:, -2, ]
        return value

In [42]:
data = []
for component in tqdm(components):
    losses = []
    direction_acts = []
    for prompt in test_prompts[:500]:
        tokens = model.to_tokens(prompt, prepend_bos=False)
        if component == "original":
            loss = model(tokens, return_type="loss", loss_per_token=True)[0, -1].item()
            acts = get_acts(tokens, model, second_encoder, second_encoder_cfg)[-2][directions]
        else:
            with model.hooks([(component, mean_ablate_component_hook)]):
                loss = model(tokens, return_type="loss", loss_per_token=True)[0, -1].item()
                acts = get_acts(tokens, model, second_encoder, second_encoder_cfg)[-2][directions]
        direction_acts.append(acts)
        losses.append(loss)
    direction_acts = torch.stack(direction_acts)
    mean_direction_acts = direction_acts.mean(0).tolist()
    quantile_acts = torch.quantile(direction_acts, 0.90, dim=0).tolist()
    quantile_loss = np.quantile(losses, 0.90)
    data.append([component, np.mean(loss), quantile_loss] + mean_direction_acts + quantile_acts)
df = pd.DataFrame(data, columns=["Component", "Loss (mean)", "Loss (90th)"] + [f"Direction {i} (mean)" for i in directions] + [f"Direction {i} (90th)" for i in directions])
df.head(10)

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:54<00:00,  7.85s/it]


Unnamed: 0,Component,Loss (mean),Loss (90th),Direction 82 (mean),Direction 82 (90th)
0,original,0.044562,2.248487,0.107998,0.0
1,hook_embed,0.853199,3.388864,0.0,0.0
2,hook_pos_embed,0.448618,3.497252,0.062231,0.0
3,blocks.0.hook_attn_out,3.684969,7.946192,0.0,0.0
4,blocks.0.hook_mlp_out,0.485114,2.667844,0.014991,0.0
5,blocks.1.hook_attn_out,8.031485,10.159293,0.028858,0.0
6,blocks.1.hook_mlp_out,0.987445,3.834849,0.107998,0.0


In [43]:
px.bar(df, x="Component", y="Loss (mean)", title="Mean ablation loss of components", width=700)

In [44]:
# Zero ablate MLP0, check direction activations?
def zero_ablate_hook(value, hook):
    value[0, -2] = 0
    return value

zero_ablate_mlp0_hooks = [(first_encoder_cfg.encoder_hook_point, zero_ablate_hook)]

top_direction_activations = []
for prompt in test_prompts[550:600]:
    with model.hooks(zero_ablate_mlp0_hooks):
        acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-2][directions]
        top_direction_activations.append(acts)

top_direction_activations = torch.stack(top_direction_activations)
print(top_direction_activations.mean(0))

tensor([0.], device='cuda:0')


In [45]:
losses = []
length = [len(prompt) for prompt in test_prompts]
for prompt in test_prompts:
    loss = model(prompt, return_type="loss", loss_per_token=True)[0, -1].item()
    losses.append(loss)

In [46]:
line(losses)

In [47]:
px.scatter(x=length, y=losses)

In [48]:
# Compare MLP0 patched direction activations to original direction activations
print(df_l1["Mean activation"][:3])

0    5.005048
1    3.859326
2    3.815932
Name: Mean activation, dtype: float64


In [49]:
import circuitsvis as cv
print(len(get_active_test_prompts(test_prompts, activating_test_prompts_l1, direction)))
print(activating_test_prompts_l1.shape)
prompt = get_active_test_prompts(test_prompts, activating_test_prompts_l1, direction)[143] # 504
_, cache = model.run_with_cache(prompt)
pattern = cache["pattern", 1][:, [11, 13]].squeeze(0)
print(pattern.shape)

# display(cv.attention.attention_patterns(
#         attention = pattern.cpu(),
#         tokens = model.to_str_tokens(prompt),
#         attention_head_names = ["L1H11", "L1H13"],
#     ))

144
torch.Size([4500, 16384])
torch.Size([2, 125, 125])


In [50]:
# # Correlation between activating test prompts
# direction_activations = activating_test_prompts_l1[:, directions].to(torch.int32)
# correlation_matrix = torch.corrcoef(direction_activations.T)
# print(direction_activations.shape)
# print(correlation_matrix.triu(1))

torch.Size([4500, 1])


RuntimeError: triu: input tensor must have at least 2 dimensions

## L0 Directions

In [51]:
activating_test_prompts_l0 = get_all_activating_test_prompts(test_prompts, first_encoder, model, first_encoder_cfg, active_threshold=0.05)

  1%|          | 34/4500 [00:00<00:41, 108.07it/s]

100%|██████████| 4500/4500 [00:41<00:00, 107.51it/s]


In [52]:
original_loss, ablated_loss, encoded_loss = eval_encoder_reconstruction_single_position(test_prompts[:200], first_encoder, model, first_encoder_cfg)
print(f"Original loss: {original_loss:.3f}, ablated loss: {ablated_loss:.3f}, encoded loss: {encoded_loss:.3f}")
print(f"Loss recovered: {(ablated_loss - encoded_loss)/(ablated_loss - original_loss):.4f}")

  5%|▌         | 10/200 [00:00<00:04, 44.29it/s]

100%|██████████| 200/200 [00:04<00:00, 44.76it/s]

Original loss: 0.827, ablated loss: 2.192, encoded loss: 0.815
Loss recovered: 1.0089





In [53]:
df_l0, loss_increases = get_top_direction_ablation_df(activating_test_prompts_l0, test_prompts, model, first_encoder, first_encoder_cfg, first_encoder_max_activations)
df_l0.head(10)

Number of directions with mean activation > 0: 49


100%|██████████| 49/49 [03:30<00:00,  4.29s/it]


Unnamed: 0,Direction,Loss increase,Loss increase (encoded),Mean activation,Percentage activation
0,15185,-0.014816,0.004133,3.867517,0.026
1,12897,-0.025269,-0.025048,1.850117,0.03
2,5159,0.001342,-0.007032,1.719988,0.029
3,3176,-0.011921,-0.001843,1.719035,0.057
4,9738,-0.017886,0.00659,1.60651,0.035
5,4243,0.004172,-0.032979,1.544437,0.02
6,5469,0.013695,0.056473,1.460816,0.021
7,8812,-0.01526,-0.048391,1.238541,0.091
8,3347,-0.009763,-0.076835,1.211181,0.021
9,12362,-0.001899,0.013309,1.204899,0.02


In [54]:
l0_directions = df_l0["Direction"][:4].tolist()

In [55]:
from utils.autoencoder_utils import get_custom_forward_hook
hook = get_custom_forward_hook(first_encoder, l0_directions, 0, first_encoder_cfg, pos=-2)

loss_increases_encoded_ablation = []
for prompt in test_prompts[:200]:
    tokens = model.to_tokens(prompt, prepend_bos=False)
    pos = tokens.shape[1]-2
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, encoder, model, l0_directions, cfg, pos=pos)
    with model.hooks(hook):
        encoded_ablated_loss = model(tokens, return_type="loss", loss_per_token=True)[0, -1].item()
    loss_increases_encoded_ablation.append(encoded_ablated_loss - original_loss)
print(np.mean(loss_increases_encoded_ablation))

-0.006296741990372539


In [56]:
# Ablate multiple l0 directions
loss_increases = []
for prompt in test_prompts[:200]:
    tokens = model.to_tokens(prompt)
    pos = tokens.shape[1]-2
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, first_encoder, model, l0_directions, first_encoder_cfg, pos=pos)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

0.0011443432793021202 0.023230887545525575


In [57]:
px.histogram(loss_increases, title="Loss increase from ablation of multiple L0 directions")

## Ablate set of important L1 directions

In [58]:
#directions = [3373, 15330, 2594, 8842, 7447]
directions = df_l1["Direction"][:5].tolist()
print(len(directions), directions)

5 [82, 11865, 8093, 14643, 15796]


In [59]:
# On positions which close quotation
loss_increases = []
for prompt in test_prompts[:200]:
    tokens = model.to_tokens(prompt)
    pos = tokens.shape[1]-2
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, second_encoder, model, directions, second_encoder_cfg, pos=pos)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

0.29853714456781744 0.5053143324271339


In [60]:
answer_token = ".\""
n = 200
data = []
for test_prompt_index, prompt in tqdm(enumerate(test_prompts[:n]), total=n):
    answer_logprob, ablated_answer_logprob, answer_logit, ablated_answer_logit, answer_rank, ablated_answer_rank = eval_ablation_token_rank(prompt, second_encoder, model, directions, second_encoder_cfg, answer_token, pos=-2)
    data.append([test_prompt_index, answer_logprob, ablated_answer_logprob, answer_logit, ablated_answer_logit, answer_rank, ablated_answer_rank])
ablation_df = pd.DataFrame(data, columns=["Prompt index", "Answer logprob", "Ablated answer logprob", "Answer logit", "Ablated answer logit", "Answer rank", "Ablated answer rank"])
# Calculate differences between original and ablated measures
ablation_df['Logprob Difference'] = ablation_df['Answer logprob'] - ablation_df['Ablated answer logprob']
ablation_df['Logit Difference'] = ablation_df['Answer logit'] - ablation_df['Ablated answer logit']
ablation_df['Rank Difference'] = ablation_df['Answer rank'] - ablation_df['Ablated answer rank']

  7%|▋         | 14/200 [00:00<00:02, 67.97it/s]

100%|██████████| 200/200 [00:02<00:00, 68.11it/s]


In [61]:
def plot_ablation_losses(ablation_df):
    # Calculate means, standard deviations, and standard errors for differences
    means_diff = ablation_df[['Logprob Difference', 'Logit Difference', 'Rank Difference']].mean()
    stds_diff = ablation_df[['Logprob Difference', 'Logit Difference', 'Rank Difference']].std()

    # Standard Errors
    stderrs_diff = stds_diff / np.sqrt(n)

    # 95% Confidence Intervals for Differences
    confidence_interval_diff = 1.96 * stderrs_diff

    # Lower and upper bounds for differences
    lower_bounds_diff = means_diff - confidence_interval_diff
    upper_bounds_diff = means_diff + confidence_interval_diff

    # Create a bar plot for differences
    fig = go.Figure()
    metrics = ['Logprob Difference', 'Logit Difference', 'Rank Difference']

    for metric in metrics:
        fig.add_trace(go.Bar(
            x=[metric],
            y=[means_diff[metric]],
            error_y=dict(
                type='data',  # Represents actual data points
                array=[upper_bounds_diff[metric] - means_diff[metric]],
                arrayminus=[means_diff[metric] - lower_bounds_diff[metric]]
            ),
            name=metric
        ))

    fig.update_layout(
        title="Original - Ablated Measures",
        xaxis_title="Metric",
        yaxis_title="Difference",
        barmode='group',
        width=700
    )

    fig.show()


In [62]:
plot_ablation_losses(ablation_df)

In [63]:
# Baseline check: mean ablate MLP1
hook_name = f"blocks.{second_encoder_cfg.layer}.{second_encoder_cfg.act_name}"
mean_activation = []
for prompt in prompts[:10]:
    _, cache = model.run_with_cache(prompt)
    activations = cache[hook_name][0].mean(0)
    mean_activation.append(activations)
mean_activation = torch.stack(mean_activation).mean(0)

In [64]:
def mean_ablation_hook(value, hook):
    value[0, -2] = mean_activation 
    return value

In [65]:
loss_increases = []
answer_token_index = model.to_single_token(answer_token)
pos = -2
data = []
for prompt in test_prompts[:100]:
    indirect_logits = model(prompt, return_type="logits")[0, pos]
    answer_rank = (indirect_logits > indirect_logits[answer_token_index]).sum().item()
    answer_logprob = indirect_logits.log_softmax(dim=-1)[answer_token_index].item()
    answer_logit = indirect_logits[answer_token_index].item()
    with model.hooks(fwd_hooks=[(hook_name, mean_ablation_hook)]):
        ablated_logits = model(prompt, return_type="logits")[0, pos]
        ablated_answer_rank = (ablated_logits > ablated_logits[answer_token_index]).sum().item()
        ablated_answer_logprob = ablated_logits.log_softmax(dim=-1)[answer_token_index].item()
        ablated_answer_logit = ablated_logits[answer_token_index].item()
    data.append([test_prompt_index, answer_logprob, ablated_answer_logprob, answer_logit, ablated_answer_logit, answer_rank, ablated_answer_rank])
ablation_df = pd.DataFrame(data, columns=["Prompt index", "Answer logprob", "Ablated answer logprob", "Answer logit", "Ablated answer logit", "Answer rank", "Ablated answer rank"])
# Calculate differences between original and ablated measures
ablation_df['Logprob Difference'] = ablation_df['Answer logprob'] - ablation_df['Ablated answer logprob']
ablation_df['Logit Difference'] = ablation_df['Answer logit'] - ablation_df['Ablated answer logit']
ablation_df['Rank Difference'] = ablation_df['Answer rank'] - ablation_df['Ablated answer rank']
plot_ablation_losses(ablation_df)

## Ablate complement of important directions subspace

In [66]:
def direction_subspace_projection(directions: Float[Tensor, "n_directions d_in"]) -> Float[Tensor, "d_in d_in"]:
    # Take transpose to make each vector a column then orthonormalize
    Q, _ = np.linalg.qr(directions.T)
    # Return the projection matrix
    return torch.tensor(Q @ Q.T)


def get_zero_ablate_complement_hooks(projection_matrix: Float[Tensor, "d_in d_in"], cfg: AutoEncoderConfig, pos=-2):
    def feature_hook(value, hook):
        value[:, pos] = value[:, pos] @ projection_matrix.cuda()
        return value
    return [(f'{cfg.encoder_hook_point}', feature_hook)]

In [67]:
# Get projection matrix of the quotation mark subspace
direction_indices = df_l1["Direction"][:5]
direction_vectors = encoder.W_dec[direction_indices].detach().cpu() # [d_subspace, d_vector_space (i.e. d_hidden)]
projection_matrix = direction_subspace_projection(direction_vectors)

In [68]:
def get_mean_ablate_complement_hooks(projection_matrix: Float[Tensor, "d_in d_in"], mean_subspace: Float[Tensor, "d_in"], cfg: AutoEncoderConfig, pos=-2):
    def feature_hook(value, hook):
        value[:, pos] = value[:, pos] @ projection_matrix.cuda() + mean_subspace
        return value
    return [(f'{cfg.encoder_hook_point}', feature_hook)]

def get_mean_direct_sum_complement(cfg: AutoEncoderConfig, prompts, projection):
    projection_complement = torch.eye(projection.shape[0]) - projection
    complement_vals = []
    for prompt in prompts[:200]:
        logits, cache = model.run_with_cache(prompt)
        complement_vals.append(cache[cfg.encoder_hook_point][0] @ projection_complement.cuda())
    return torch.cat(complement_vals, dim=0).mean(0)

mean_subspace = get_mean_direct_sum_complement(cfg, prompts, projection_matrix)


In [69]:
loss_diffs = []
for prompt in test_prompts[:200]:
    loss = model(prompt, return_type='loss', loss_per_token=True)

    with model.hooks(get_mean_ablate_complement_hooks(projection_matrix, mean_subspace, cfg)):
        ablated_loss = model(prompt, return_type='loss', loss_per_token=True)
    loss_diffs.append(loss[0, -1].item() - ablated_loss[0, -1].item())
print(np.mean(loss_diffs), np.std(loss_diffs)) # loss improvements

0.27033707085531206 0.6464786449984219


In [70]:
loss_diffs = []
for prompt in test_prompts[:200]:
    loss = model(prompt, return_type='loss', loss_per_token=True)

    with model.hooks(get_zero_ablate_complement_hooks(projection_matrix, cfg)):
        ablated_loss = model(prompt, return_type='loss', loss_per_token=True)
    loss_diffs.append(loss[0, -1].item() - ablated_loss[0, -1].item())
print(np.mean(loss_diffs), np.std(loss_diffs)) # loss improvements

0.25340040573850275 0.6567619273447303


## Loss recovered analysis

In [71]:
# Baseline check: mean ablate MLP1
hook_name = f"blocks.{second_encoder_cfg.layer}.{second_encoder_cfg.act_name}"

mean_activation = []
for prompt in prompts[:20]:
    _, cache = model.run_with_cache(prompt)
    activations = cache[hook_name][0].mean(0)
    mean_activation.append(activations)
mean_activation = torch.stack(mean_activation).mean(0)

def mean_ablation_hook(value, hook):
    value[0, -2] = mean_activation 
    return value

def zero_ablation_hook(value, hook):
    value[0, -2] = 0
    return value

def encode_activations_hook(value, hook):
    _, x_reconstruct, acts, _, _ = second_encoder(value[:, -2])
    value[:, -2] = x_reconstruct
    return value

ablate_top_directions_hook = [(hook_name, get_direction_ablation_hook(second_encoder, directions, -2))]
encode_mlp_hook = [(hook_name, encode_activations_hook)]

In [72]:
# Zero ablate MLP
# ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
# Mean ablate MLP
# Enable only top 2 directions
# Enable full autoencoder

losses = []
mean_abl_losses = []
zero_abl_losses = []
top_abl_losses = []
recons_losses = []

for prompt in tqdm(test_prompts):
    loss = model(prompt, return_type="loss", loss_per_token=True)[0, -1].item()
    with model.hooks(fwd_hooks=[(hook_name, mean_ablation_hook)]):
        mean_abl_loss = model(prompt, return_type="loss", loss_per_token=True)[0, -1].item()
    with model.hooks(fwd_hooks=[(hook_name, zero_ablation_hook)]):
        zero_abl_loss = model(prompt, return_type="loss", loss_per_token=True)[0, -1].item()
    with model.hooks(fwd_hooks=ablate_top_directions_hook):
        top_abl_loss = model(prompt, return_type="loss", loss_per_token=True)[0, -1].item()
    with model.hooks(fwd_hooks=encode_mlp_hook):
        recons_loss = model(prompt, return_type="loss", loss_per_token=True)[0, -1].item()
    losses.append(loss)
    mean_abl_losses.append(mean_abl_loss)
    zero_abl_losses.append(zero_abl_loss)
    top_abl_losses.append(top_abl_loss)
    recons_losses.append(recons_loss)


  0%|          | 6/4500 [00:00<02:42, 27.62it/s]

100%|██████████| 4500/4500 [02:49<00:00, 26.56it/s]


In [73]:
import plotly.graph_objects as go
import numpy as np

# Calculate means and standard deviations
means = [
    np.mean(losses),
    np.mean(mean_abl_losses),
    np.mean(zero_abl_losses),
    np.mean(top_abl_losses),
    np.mean(recons_losses)
]

def standard_error(data):
    return np.std(data) / np.sqrt(len(data))

std_errors = [
    standard_error(losses),
    standard_error(mean_abl_losses),
    standard_error(zero_abl_losses),
    standard_error(top_abl_losses),
    standard_error(recons_losses)
]

ci_95 = [se * 1.96 for se in std_errors]

labels = ['Original loss', 'MLP mean ablation loss', 'MLP zero ablation loss', f'Top {len(directions)} SAE directions ablation loss', 'SAE reconstruction loss']
fig = go.Figure(data=[
    go.Bar(
        x=labels,
        y=means,
        error_y=dict(type='data', array=ci_95)
    )
])

fig.update_layout(
    title="Ablation loss comparison for closing quotation prompts '.\"'",
    xaxis_title="Ablation",
    yaxis_title="Loss",
    showlegend=False,
    width=600
)

fig.show()


In [74]:
loss_recovered = (np.mean(zero_abl_losses) - np.mean(recons_losses)) / (np.mean(zero_abl_losses) - np.mean(losses))
print(loss_recovered)

0.8789925760801569


## Individual Directions

In [75]:
direction = 8093		
print(direction)

8093


In [76]:
from ipywidgets import interactive, IntSlider, SelectionSlider

def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            print("Prompt:", prompt_index)
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

# def print_direction_example(direction, num_examples=3):
#     print_top_examples(prompts, second_encoder_max_activations, direction, second_encoder, second_encoder_cfg, num_examples)
# widget = interactive(print_direction_example, 
#              direction=SelectionSlider(options=directions, value=directions[0], description='Direction'),
#              num_examples=IntSlider(min=1, max=20, step=1, value=5))
# display(widget)

# 15xxx: activates within quotation when ending the sentence is good
# 8093: Activates always within quotation

#print_top_examples(prompts, second_encoder_max_activations, direction, second_encoder, second_encoder_cfg, 5)

In [77]:
# Filter test prompts by which prompts the direction activates on
activating_test_prompt_indices = activating_test_prompts_l1[:, direction].nonzero().flatten().tolist()
activating_test_prompts = [test_prompts[i] for i in activating_test_prompt_indices]
print(len(activating_test_prompts), "out of", activating_test_prompts_l1.shape[0])

2620 out of 4500


In [78]:
answer_token = model.to_single_token(".\"")
quotation_tokens = ["\"", ".", "!", "?", ".\"", "!\"", "?\""]
for quotation_token in quotation_tokens:
    assert model.to_single_token(quotation_token) in common_tokens.tolist()

print(answer_token, model.to_single_str_token(answer_token))

526 ."


In [79]:
# Logprob boosts (prompt based)
def get_direction_logit_and_logprob_boost(
    prompts: list[str],
    encoder: AutoEncoder,
    encoder_neuron,
    model: HookedTransformer,
    all_ignore: Int[Tensor, "tokens"],
    cfg: AutoEncoderConfig,
    pos: -2
):
    zero_direction_hook = [(f"blocks.{cfg.layer}.{cfg.act_name}", get_direction_ablation_hook(
        encoder, encoder_neuron, pos
    ))]

    logprobs_active = []
    logprobs_inactive = []
    for prompt in prompts:
        logits_active = model(prompt, return_type="logits")[0, pos]
        with model.hooks(zero_direction_hook):
            logits_inactive = model(prompt, return_type="logits")[0, pos]

        logprobs_active_current = logits_active.log_softmax(dim=-1)
        logprobs_inactive_current = logits_inactive.log_softmax(dim=-1)
        logprobs_active.append(logprobs_active_current)
        logprobs_inactive.append(logprobs_inactive_current)
    logprobs_active = torch.stack(logprobs_active).mean(0)
    logprobs_inactive = torch.stack(logprobs_inactive).mean(0)
    print(logprobs_active[answer_token], logprobs_inactive[answer_token])

    boosts = (logprobs_active - logprobs_inactive)
    boosts[logprobs_active < -9] = 0
    boosts[all_ignore] = 0
    top_boosts, top_tokens = torch.topk(boosts, 15)
    non_zero_boosts = top_boosts != 0
    top_deboosts, top_deboosted_tokens = torch.topk(boosts, 15, largest=False)
    non_zero_deboosts = top_deboosts != 0
    boosted_tokens = (
        model.to_str_tokens(top_tokens[non_zero_boosts]),
        top_boosts[non_zero_boosts].tolist(),
    )
    deboosted_tokens = (
        model.to_str_tokens(top_deboosted_tokens[non_zero_deboosts]),
        top_deboosts[non_zero_deboosts].tolist(),
    )
    logging.info(f"Top boosted: {boosted_tokens}")
    logging.info(f"Top deboosted: {deboosted_tokens}")

get_direction_logit_and_logprob_boost(activating_test_prompts[:100], second_encoder, direction, model, rare_tokens, second_encoder_cfg, pos=-2)

(INFO) 04:02:46: Top boosted: (['!".', '!"', '".', '."'], [1.1426973342895508, 0.9855108261108398, 0.8336892127990723, 0.3870839476585388])
(INFO) 04:02:46: Top deboosted: (['.', '!', ' on', ' with', ' for', ',"', ',', ' at', ' that', ' in', ' because', ' and', ' so', ' to', ' when'], [-0.7691488265991211, -0.664771556854248, -0.6166729927062988, -0.6153430938720703, -0.6126704216003418, -0.5856924057006836, -0.5640039443969727, -0.5352458953857422, -0.5021038055419922, -0.4979672431945801, -0.4947929382324219, -0.49103641510009766, -0.46613025665283203, -0.39168882369995117, -0.3858466148376465])


tensor(-0.8140, device='cuda:0') tensor(-1.2011, device='cuda:0')


In [80]:
# Logit boosts (weight based DLA)
def get_direction_boosted_tokens(direction, encoder: AutoEncoder, model: HookedTransformer, cfg: AutoEncoderConfig, rare_tokens: Tensor):
    token_boosts = encoder.W_dec[direction] @ model.W_out[cfg.layer] @ model.unembed.W_U
    token_boosts[rare_tokens] = 0
    return token_boosts

def print_token_boosts(boosts, tokens):
    str_tokens = model.to_str_tokens(tokens)
    boost_str = ""
    for token, boost in zip(str_tokens, boosts.tolist()):
        boost_str += f"('{token}': {boost:.2f}), "
    print(boost_str[:-2])

boosts = get_direction_boosted_tokens(direction, second_encoder, model, second_encoder_cfg, rare_tokens)
top_boosts, top_tokens = torch.topk(boosts, 25)
print_token_boosts(top_boosts, top_tokens)
top_boosts, top_tokens = torch.topk(boosts, 25, largest=False)
print_token_boosts(top_boosts, top_tokens)

('!".': 0.91), ('!"': 0.88), ('".': 0.81), ('?".': 0.65), ('."': 0.64), ('"': 0.58), (' I': 0.53), ('?"': 0.51), (' ok': 0.44), (' my': 0.43), (' you': 0.42), (' we': 0.40), (' -': 0.38), (' your': 0.38), (' here': 0.38), ('�': 0.37), (' mum': 0.36), ('heart': 0.35), (' yours': 0.35), (' please': 0.35), (' then': 0.34), (' today': 0.34), (' looks': 0.33), (' our': 0.33), (' yourself': 0.33)
(' She': -0.19), (' herself': -0.17), (' her': -0.14), (' train': -0.13), (' yacht': -0.13), ('ule': -0.13), (' peanut': -0.13), (' bubble': -0.13), (' muff': -0.13), ('ooter': -0.12), (' rode': -0.12), (' volcano': -0.12), ('erry': -0.11), (' Bun': -0.11), (' elevator': -0.11), ('affle': -0.11), (' whole': -0.11), ('agged': -0.11), (' statue': -0.11), (' porch': -0.11), (' ch': -0.10), (' jet': -0.10), (' drank': -0.10), (' He': -0.10), (' sprayed': -0.10)


## Additional ablations

In [81]:
# Does ablating the feature increase loss when the feature is active?

# On positions which close quotation
loss_increases = []
for prompt in tqdm(activating_test_prompts):
    tokens = model.to_tokens(prompt)
    pos = tokens.shape[1]-2
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(tokens, second_encoder, model, direction, second_encoder_cfg, pos=pos)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))
px.histogram(loss_increases, width=700, title=f"Loss increase for removing direction {direction} on closing quotation prompts")

  1%|          | 15/2620 [00:00<00:37, 69.56it/s]

100%|██████████| 2620/2620 [00:37<00:00, 69.48it/s]

0.6916987289375988 0.6615271032660808





In [82]:
test_prompt_activations = []
for prompt in tqdm(activating_test_prompts):
    act = get_acts(prompt, model, second_encoder, second_encoder_cfg)[-2, direction]
    test_prompt_activations.append(act.item())

  1%|          | 22/2620 [00:00<00:24, 107.89it/s]

100%|██████████| 2620/2620 [00:23<00:00, 111.69it/s]


In [83]:
# Histogram of activation x loss increases to check if high loss increase on a prompt just comes from the direction being very active
fig = px.scatter(x=test_prompt_activations, y=loss_increases, width=700, title="Loss increase vs activation for closing quotation prompts")
fig.update_layout({
    "xaxis_title": "Activation",
    "yaxis_title": "Loss increase"
})

In [84]:
## Look at high loss increase prompts
# 15796: always after question in prior sentence, usually end in object (e.g. toy, mistake, door, hive)

top_loss_increases, top_loss_increase_prompt_indices = torch.topk(torch.tensor(loss_increases), 10)

def print_examples(model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig, prompts: list[str], direction: int):
    for prompt in prompts:
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

print_examples(model, encoder, cfg, [activating_test_prompts[i] for i in top_loss_increase_prompt_indices], direction)

In [85]:
# In general on max activating prompts, ablating globally
loss_increases = []
max_activating_prompts, max_activating_token_indices = get_top_activating_examples_for_direction(prompts, direction, second_encoder_max_activations, second_encoder_max_activation_token_indices, k=100)
for prompt, index in zip(max_activating_prompts, max_activating_token_indices.tolist()):
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, second_encoder, model, direction, second_encoder_cfg)
    loss_increase = ablated_loss - original_loss
    loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

0.011457895040512084 0.012963047737776952


In [86]:
# In general on max activating prompts, ablating one activating position at a time
loss_increases = []
max_activating_prompts, max_activating_token_indices = get_top_activating_examples_for_direction(prompts, direction, second_encoder_max_activations, second_encoder_max_activation_token_indices, k=100)
threshold = second_encoder_max_activations[:, direction].max() * 0.1

for prompt, index in tqdm(zip(max_activating_prompts, max_activating_token_indices.tolist()), total=len(max_activating_prompts)):
    acts = get_acts(prompt, model, second_encoder, second_encoder_cfg)[:, direction]
    # No loss for last position, exclude
    active_positions = torch.argwhere(acts[:-1] > threshold).flatten().tolist()
    for position in active_positions:
        original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, second_encoder, model, direction, second_encoder_cfg, pos=position)
        loss_increase = ablated_loss - original_loss
        loss_increases.append(loss_increase)
print(np.mean(loss_increases), np.std(loss_increases))

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:35<00:00,  2.82it/s]

0.11860892018802953 0.6998473607009476





## Check for directions with positive DLA


In [95]:
## Look at high loss increase prompts
# 15796: always after question in prior sentence, usually end in object (e.g. toy, mistake, door, hive)
def get_all_acts(model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig, prompts: list[str], batch_size=16):
    all_acts = []
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i + batch_size]
        all_acts.append(get_acts(batch, model, encoder, cfg).flatten(0, 1))
    return torch.cat(all_acts, dim=0)

all_acts = get_all_acts(model, encoder, cfg, [activating_test_prompts[i] for i in top_loss_increase_prompt_indices], direction)

torch.Size([2100, 16384])


In [96]:
activation_scaled_direction_dla = (all_acts.mean(0).unsqueeze(1) * second_encoder.W_dec) @ model.W_out[second_encoder_cfg.layer] @ model.unembed.W_U[:, answer_token]
top_acts, top_dirs = torch.topk(activation_scaled_direction_dla, 5)
print(top_acts, top_dirs)

tensor([0.3352, 0.1124, 0.1077, 0.0771, 0.0363], device='cuda:0') tensor([  75, 8288, 6402, 8093, 7266], device='cuda:0')


In [None]:
direction_dfa = second_encoder.W_dec @ model.W_out[second_encoder_cfg.layer] @ model.unembed.W_U[:, answer_token]
active_directions = (all_acts.mean(0) > 0.05).cpu()
print(active_directions.sum())

fig = go.Figure(data=go.Scatter(x=direction_dfa.cpu()[active_directions], y=all_acts.mean(0).cpu()[active_directions], mode='markers'))
fig.update_layout(
    title="Direction DLA vs activation",
    xaxis_title='Direction \'.\"\' DLA',
    yaxis_title='Direction activation',
    width=900
)
fig.show()

tensor(154)


In [None]:
active_directions = (all_acts.mean(0) > 0.05).cpu()
print(active_directions.sum())

fig = go.Figure(data=go.Scatter(x=activation_scaled_direction_dla.cpu()[active_directions], y=all_acts.mean(0).cpu()[active_directions], mode='markers'))
fig.update_layout(
    title="DLA and activation of active Layer 1 encoder directions",
    xaxis_title='Direction \'.\"\' DLA',
    yaxis_title='Direction activation',
    width=900
)
fig.show()

tensor(154)


## Patching Paired Sentences for Logit Diff 

In [None]:
_, labels = cache.get_full_resid_decomposition(return_labels=True, expand_neurons=False)
labels


['L0H0',
 'L0H1',
 'L0H2',
 'L0H3',
 'L0H4',
 'L0H5',
 'L0H6',
 'L0H7',
 'L0H8',
 'L0H9',
 'L0H10',
 'L0H11',
 'L0H12',
 'L0H13',
 'L0H14',
 'L0H15',
 'L1H0',
 'L1H1',
 'L1H2',
 'L1H3',
 'L1H4',
 'L1H5',
 'L1H6',
 'L1H7',
 'L1H8',
 'L1H9',
 'L1H10',
 'L1H11',
 'L1H12',
 'L1H13',
 'L1H14',
 'L1H15',
 '0_mlp_out',
 '1_mlp_out',
 'embed',
 'pos_embed',
 'bias']

In [None]:
# cache
# def label_to_hook(label: str) -> str:
#     if label == "pos_embed":
#         return "hook_pos_embed"
#     elif label == "embed":
#         return "hook_embed"
#     elif label[0].isdigit():
#         return f'blocks.{label[0]}.hook_mlp_out'
#     elif label[0] == "L":


SyntaxError: invalid syntax (1360667297.py, line 9)

In [None]:
def patch_plots(pos: str, neg: str, ans_position=-2, patch_position=-2, title=""):

    assert model.to_tokens(pos).shape[-1] == model.to_tokens(neg).shape[-1]
    
    # The final word helps a lot with prediction
    neg_logits, neg_cache = model.run_with_cache(neg, prepend_bos=False)
    pos_logits, pos_cache = model.run_with_cache(pos, prepend_bos=False)
    answer = model.to_single_token(".\"")
    alt_answer = model.to_single_token("!\"")
    print(f"Pos logits: '.\"'={pos_logits[0, ans_position, answer].item():.2f}, '!\"'={pos_logits[0, ans_position, alt_answer].item():.2f}")
    print(f"Neg logits: '.\"'={neg_logits[0, ans_position, answer].item():.2f}, '!\"'= {neg_logits[0, ans_position, alt_answer].item():.2f}")
    # Patch non-quotey word acts into various components and see what works less well, and what encoder features turn off

    head_hook_labels = [f"blocks.{layer}.attn.hook_result" for layer in [0, 1]]
    head_labels = [f"L{layer}H{head}" for layer in [0, 1] for head in range(model.cfg.n_heads)]
    other_hook_labels = ["hook_embed"] + [f"blocks.{layer}.hook_mlp_out" for layer in [0, 1]]
        
    def get_patch(head_index: int | None=None):
        def ablation_hook(value, hook):
            if not head_index:
                value[:, patch_position] = neg_cache[hook.name][:, patch_position]
            else:
                value[:, patch_position, head_index] = neg_cache[hook.name][:, patch_position, head_index]
            return value
        return ablation_hook

    def get_pos(head_index: int | None=None):
        def ablation_hook(value, hook):
            if not head_index:
                value[:, patch_position] = pos_cache[hook.name][:, patch_position]
            else:
                value[:, patch_position, head_index] = pos_cache[hook.name][:, patch_position, head_index]
            return value
        return ablation_hook

    indirect_logits = []
    direct_logits = []
    ablation_logits = []
    for label in other_hook_labels:
        if label == f"blocks.{0}.hook_mlp_out":
            downstream_components = [f"blocks.{1}.hook_mlp_out", f"blocks.{1}.attn.hook_result"]
        elif label == "hook_embed" or label == "hook_pos_embed":
            downstream_components = [f"blocks.{0}.hook_mlp_out", f"blocks.{0}.attn.hook_result", f"blocks.{1}.hook_mlp_out", f"blocks.{1}.attn.hook_result"]
        else:
            downstream_components = []
        original_logits, activated_logits, ablated_logits, direct_effect_logits, indirect_effect_logits = haystack_utils.get_context_effect(
            pos, model, context_ablation_hooks=[(label, get_patch())], context_activation_hooks=[(label, get_pos())],
            downstream_components=downstream_components, pos=ans_position, return_type="logits", prepend_bos=False)
        direct_logits.append((direct_effect_logits[0, answer].item(), direct_effect_logits[0, alt_answer].item()))
        indirect_logits.append((indirect_effect_logits[0, answer].item(), indirect_effect_logits[0, alt_answer].item()))
        ablation_logits.append((ablated_logits[0, answer].item(), ablated_logits[0, alt_answer].item()))
    for label in head_hook_labels:
        if label == f"blocks.{0}.attn.hook_result":
            downstream_components = [f"blocks.{0}.hook_mlp_out", f"blocks.{1}.hook_mlp_out", f"blocks.{1}.attn.hook_result"]
        elif label == f"blocks.{1}.attn.hook_result":
            downstream_components = [f"blocks.{1}.hook_mlp_out"]

        for head in range(model.cfg.n_heads):
            original_logits, activated_logits, ablated_logits, direct_effect_logits, indirect_effect_logits = haystack_utils.get_context_effect(
            pos, model, context_ablation_hooks=[(label, get_patch(head))], context_activation_hooks=[(label, get_pos(head))],
            downstream_components=downstream_components, pos=ans_position, return_type="logits", prepend_bos=False)
            direct_logits.append((direct_effect_logits[0, answer].item(), direct_effect_logits[0, alt_answer].item()))
            indirect_logits.append((indirect_effect_logits[0, answer].item(), indirect_effect_logits[0, alt_answer].item()))
            ablation_logits.append((ablated_logits[0, answer].item(), ablated_logits[0, alt_answer].item()))


    diffs_ablation = [np.mean((pos_logits[0, ans_position, answer].item() - patched, pos_logits[0, ans_position, alt_answer].item() - patched_alt)) for patched, patched_alt in ablation_logits]
    diffs_direct = [np.mean((pos_logits[0, ans_position, answer].item() - patched, pos_logits[0, ans_position, alt_answer].item() - patched_alt)) for patched, patched_alt in indirect_logits]
    diffs_indirect = [np.mean((pos_logits[0, ans_position, answer].item() - patched, pos_logits[0, ans_position, alt_answer].item() - patched_alt)) for patched, patched_alt in direct_logits]

    df = pd.DataFrame({
        'Label': other_hook_labels + head_labels,
        'Total Effect': diffs_ablation,
        'Indirect Abaltion Effect': diffs_indirect,
        'Direct Ablation Effect': diffs_direct
    })

    fig = px.line(df, x="Label", y=["Total Effect", "Indirect Abaltion Effect", "Direct Ablation Effect"], title=title)
    fig.update_layout({"yaxis_title": "Original logit - Patched logit"})
    fig.show()

In [None]:
prepend = "Mary was a girl who lived in a village with a large forest. Today, she and her mother were going for a hike in the forest. After a while it was getting dark. "
pos = "Mary said \"I don’t want to go home.\""
neg = "Mary said \"I don’t want to go to.\""
title = "Patching: 'I don't want to go to.' -> 'I don’t want to go home.'" 
patch_plots(prepend + pos, prepend + neg, ans_position=-2, patch_position=-2, title=title)

Pos logits: '."'=18.42, '!"'=18.17
Neg logits: '."'=10.29, '!"'= 8.78


In [None]:
to_tokens = lambda x: print(f"{model.to_str_tokens(x, prepend_bos=False)} ({len(model.to_str_tokens(x, prepend_bos=False))})")
to_tokens("Mary said \"I am hungry.\"")
to_tokens("Mary said that she is hungry.\"")

['Mary', ' said', ' "', 'I', ' am', ' hungry', '."'] (7)
['Mary', ' said', ' that', ' she', ' is', ' hungry', '."'] (7)


In [None]:
prepend = "Tim was a boy who lived with his mum and his dog in a large house. Sometimes, they lose the dog because the house is so big. \"Do you know where the dog is?\", asks his mum. "
pos = "Luckily, Tim has seen the dog playing outside and answers \"The dog is in the garden.\""
neg = "Luckily, Tim has seen the dog playing outside and answers that the dog is in the garden.\""
title = "Patching: '...and answers: \"The dog is in the garden.\"' -> '...and answers that the dog is in the garden.\"'"
patch_plots(prepend + pos, prepend + neg, -2, np.s_[:], title=title)

Pos logits: '."'=18.55, '!"'=19.05
Neg logits: '."'=14.66, '!"'= 13.60


In [None]:
# Predict open quotes by final word
# pos = "Sally stamped her foot and said \""
# neg = "Sally stamped her foot and ran \""

pos = "Mary said \"I don't want to go home.\""
neg = "Mary said that she doesn\'t want to go home.\""
title = "Patching: 'Mary said \"I don't want to go home.\"' -> 'Mary said that she doesn't want to go home.\"'"
patch_plots(prepend + pos, prepend + neg, -2, np.s_[:], title=title)

Pos logits: '."'=18.05, '!"'=16.83
Neg logits: '."'=12.94, '!"'= 10.17


In [None]:
prepend = "Mary was a girl who loved the beach. She went to the beach every day. Today, she was going to the beach with her mother. When it was getting dark, her mother told her it was time to go home. "
pos = "Mary said: \"I don\'t want to go home. I want to stay.\""
neg = "Mary said: \"I don\'t want to go home.\" I want to stay.\""
title= "Patching: '\"I don\'t want to go home.\" I want to stay.\"' -> '\"I don\'t want to go home. I want to stay.\"'"
patch_plots(prepend+pos, prepend+neg, -2, np.s_[:], title=title)

Pos logits: '."'=20.75, '!"'=20.27
Neg logits: '."'=19.01, '!"'= 18.67


In [None]:
from transformer_lens import utils
pos = "Sally said: \"I don\'t want to go home. I want to stay"
neg = "Sally said: \"I don\'t want to go home.\" I want to stay"
print(utils.test_prompt(neg, ".\"", model, prepend_space_to_answer=False, prepend_bos=False))
print(utils.test_prompt(pos, ".\"", model, prepend_space_to_answer=False, prepend_bos=False))

Tokenized prompt: ['S', 'ally', ' said', ':', ' "', 'I', ' don', "'t", ' want', ' to', ' go', ' home', '."', ' I', ' want', ' to', ' stay']
Tokenized answer: ['."']


Top 0th token. Logit: 22.62 Prob: 62.55% Token: | and|
Top 1th token. Logit: 20.23 Prob:  5.70% Token: | with|
Top 2th token. Logit: 19.99 Prob:  4.51% Token: |,|
Top 3th token. Logit: 19.75 Prob:  3.54% Token: |.|
Top 4th token. Logit: 19.70 Prob:  3.37% Token: | in|
Top 5th token. Logit: 19.56 Prob:  2.93% Token: | outside|
Top 6th token. Logit: 19.48 Prob:  2.70% Token: | at|
Top 7th token. Logit: 19.23 Prob:  2.11% Token: | a|
Top 8th token. Logit: 18.94 Prob:  1.57% Token: | here|
Top 9th token. Logit: 18.93 Prob:  1.56% Token: |."|


None
Tokenized prompt: ['S', 'ally', ' said', ':', ' "', 'I', ' don', "'t", ' want', ' to', ' go', ' home', '.', ' I', ' want', ' to', ' stay']
Tokenized answer: ['."']


Top 0th token. Logit: 21.55 Prob: 42.79% Token: | and|
Top 1th token. Logit: 20.12 Prob: 10.28% Token: | here|
Top 2th token. Logit: 19.99 Prob:  9.01% Token: |."|
Top 3th token. Logit: 19.85 Prob:  7.83% Token: |!"|
Top 4th token. Logit: 19.72 Prob:  6.90% Token: | with|
Top 5th token. Logit: 19.22 Prob:  4.20% Token: | in|
Top 6th token. Logit: 19.20 Prob:  4.11% Token: | at|
Top 7th token. Logit: 18.46 Prob:  1.95% Token: | outside|
Top 8th token. Logit: 18.25 Prob:  1.58% Token: |.|
Top 9th token. Logit: 18.23 Prob:  1.56% Token: |,"|


None


## Attention head encoder

In [None]:
prepend = "Mary was a girl who lived in a village with a large forest. Today, she and her mother were going for a hike in the forest. After a while it was getting dark. "
pos = "Mary said: \"I don\'t want to go home. I want to stay here forever.\""
neg = "Mary said: \"I don\'t want to go home. I want to stay here and.\""
prompt = prepend + pos
test_prompt(prompt[:-2], ".\"", model, prepend_space_to_answer=False, prepend_bos=False)
test_prompt(prepend + neg[:-2], ".\"", model, prepend_space_to_answer=False, prepend_bos=False)
patch_plots(prepend + pos, prepend + neg, ans_position=-2, patch_position=np.s_[:], title="Patching logit diff")

Tokenized prompt: ['Mary', ' was', ' a', ' girl', ' who', ' lived', ' in', ' a', ' village', ' with', ' a', ' large', ' forest', '.', ' Today', ',', ' she', ' and', ' her', ' mother', ' were', ' going', ' for', ' a', ' hike', ' in', ' the', ' forest', '.', ' After', ' a', ' while', ' it', ' was', ' getting', ' dark', '.', ' Mary', ' said', ':', ' "', 'I', ' don', "'t", ' want', ' to', ' go', ' home', '.', ' I', ' want', ' to', ' stay', ' here', ' forever']
Tokenized answer: ['."']


Top 0th token. Logit: 23.77 Prob: 58.06% Token: |!"|
Top 1th token. Logit: 23.26 Prob: 34.91% Token: |."|
Top 2th token. Logit: 21.00 Prob:  3.62% Token: |".|
Top 3th token. Logit: 19.55 Prob:  0.85% Token: |!".|
Top 4th token. Logit: 19.44 Prob:  0.76% Token: | and|
Top 5th token. Logit: 18.64 Prob:  0.34% Token: |.|
Top 6th token. Logit: 18.31 Prob:  0.25% Token: | in|
Top 7th token. Logit: 18.22 Prob:  0.23% Token: |,|
Top 8th token. Logit: 17.66 Prob:  0.13% Token: |!|
Top 9th token. Logit: 17.49 Prob:  0.11% Token: |,"|


Tokenized prompt: ['Mary', ' was', ' a', ' girl', ' who', ' lived', ' in', ' a', ' village', ' with', ' a', ' large', ' forest', '.', ' Today', ',', ' she', ' and', ' her', ' mother', ' were', ' going', ' for', ' a', ' hike', ' in', ' the', ' forest', '.', ' After', ' a', ' while', ' it', ' was', ' getting', ' dark', '.', ' Mary', ' said', ':', ' "', 'I', ' don', "'t", ' want', ' to', ' go', ' home', '.', ' I', ' want', ' to', ' stay', ' here', ' and']
Tokenized answer: ['."']


Top 0th token. Logit: 21.50 Prob: 56.61% Token: | play|
Top 1th token. Logit: 19.44 Prob:  7.22% Token: | explore|
Top 2th token. Logit: 18.80 Prob:  3.82% Token: | be|
Top 3th token. Logit: 18.76 Prob:  3.67% Token: | look|
Top 4th token. Logit: 18.51 Prob:  2.84% Token: | watch|
Top 5th token. Logit: 18.09 Prob:  1.88% Token: | see|
Top 6th token. Logit: 18.06 Prob:  1.82% Token: | enjoy|
Top 7th token. Logit: 18.01 Prob:  1.74% Token: | keep|
Top 8th token. Logit: 17.73 Prob:  1.31% Token: | have|
Top 9th token. Logit: 17.70 Prob:  1.27% Token: | help|


NameError: name 'patch_plots' is not defined

In [None]:
# Headds L0H0 L1H0 L1H13import circuitsvis as cv
continuation = " But her mother told her they had to go home."
_, cache = model.run_with_cache(prompt + continuation)
pattern = torch.cat([cache["pattern", 0][:, [0,8]].squeeze(0), cache["pattern", 1][:, [0, 13]].squeeze(0)], dim=0)
print(pattern.shape)

import circuitsvis as cv
display(cv.attention.attention_patterns(
        attention = pattern.cpu(),
        tokens = model.to_str_tokens(prompt + continuation),
        attention_head_names = ["L0H0", "L0H8", "L1H11", "L1H13"],
    ))

torch.Size([4, 68, 68])


In [None]:
encoder, cfg = load_encoder("253_peachy_bee", model_name, model)
cfg.head_idx = 0

{'cfg_file': 'config/tiny-stories-attn.json', 'data_path': '/workspace/data/tinystories', 'save_path': '/workspace', 'use_wandb': True, 'num_eval_tokens': 800000, 'num_training_tokens': 500000000.0, 'batch_size': 5080, 'buffer_mult': 256, 'seq_len': 127, 'model': 'tiny-stories-2L-33M', 'layer': 1, 'act': 'attn.hook_result', 'expansion_factor': 4, 'seed': 47, 'lr': 0.0001, 'l1_coeff': [9e-06, 1e-05], 'l1_target': None, 'wd': 0.01, 'beta1': 0.9, 'beta2': 0.99, 'num_eval_prompts': 150, 'save_checkpoint_models': False, 'reg': 'combined_hoyer_sqrt', 'finetune_encoder': None, 'dead_direction_frequency': 1e-05, 'head_idx': 0, 'tried': [[0.0001, 0.00015], [5e-05, 7.5e-05], [1.25e-05, 1.875e-05], [9e-06, 1e-05]], 'model_batch_size': 40, 'buffer_size': 1300480, 'buffer_batches': 10240, 'num_eval_batches': 157, 'd_in': 1024, 'wandb_name': 'peachy-bee-253', 'save_name': '253_peachy_bee'}


In [None]:
encoder, cfg = load_encoder("252_ruby_donkey", model_name, model)
cfg.head_idx=13

{'cfg_file': 'config/tiny-stories-attn.json', 'data_path': '/workspace/data/tinystories', 'save_path': '/workspace', 'use_wandb': True, 'num_eval_tokens': 800000, 'num_training_tokens': 500000000.0, 'batch_size': 5080, 'buffer_mult': 256, 'seq_len': 127, 'model': 'tiny-stories-2L-33M', 'layer': 1, 'act': 'attn.hook_result', 'expansion_factor': 4, 'seed': 47, 'lr': 0.0001, 'l1_coeff': [9e-06, 1e-05], 'l1_target': None, 'wd': 0.01, 'beta1': 0.9, 'beta2': 0.99, 'num_eval_prompts': 150, 'save_checkpoint_models': False, 'reg': 'combined_hoyer_sqrt', 'finetune_encoder': None, 'dead_direction_frequency': 1e-05, 'head_idx': 13, 'tried': [[0.0001, 0.00015], [5e-05, 7.5e-05], [1.25e-05, 1.875e-05], [9e-06, 1e-05]], 'model_batch_size': 40, 'buffer_size': 1300480, 'buffer_batches': 10240, 'num_eval_batches': 157, 'd_in': 1024, 'wandb_name': 'ruby-donkey-252', 'save_name': '252_ruby_donkey'}


In [92]:
prompt = prompts[0]
prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
acts = get_acts(prompt, model, encoder, cfg)

In [None]:
run_name = cfg.run_name
max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name, prompts, model)

100%|██████████| 21990/21990 [03:36<00:00, 101.42it/s]


Active directions on validation data: 2763 out of 4096


In [None]:
direction_counter = torch.zeros(encoder.d_hidden).cuda()
total_tokens = 0
for prompt in tqdm(prompts):
    acts = get_acts(prompt, model, encoder, cfg)
    direction_counter += (acts>0).sum(0)
    total_tokens += acts.shape[0]
direction_counter /= total_tokens

dead_directions = (direction_counter < 1e-5).nonzero().flatten().tolist()
active_directions = (direction_counter > 1e-5).nonzero().flatten().tolist()
print(len(dead_directions), len(active_directions))

100%|██████████| 21990/21990 [03:01<00:00, 121.43it/s]

3488





In [None]:
active_directions = (direction_counter > 1e-5).nonzero().flatten().tolist()
print(len(active_directions))

608


In [None]:
print_top_examples(prompts, max_activations, active_directions[0], encoder, cfg, 5)

Prompt: 11208


Prompt: 16420


Prompt: 12177


Prompt: 20335


Prompt: 5457


In [None]:
prompt = prompts[11208]
# Headds L0H0 L1H0 L1H13import circuitsvis as cv
_, cache = model.run_with_cache(prompt)
pattern = cache["pattern", 1][:, [0, 13]].squeeze(0)
print(pattern.shape)

import circuitsvis as cv
display(cv.attention.attention_patterns(
        attention = pattern.cpu(),
        tokens = model.to_str_tokens(prompt + continuation),
        attention_head_names = ["L1H11", "L1H13"],
    ))

torch.Size([2, 293, 293])
