In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder, generate_with_encoder
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line

# from joblib import Memory
# cachedir = '/workspace/cache'
# os.makedirs(cachedir, exist_ok=True)
# memory = Memory(cachedir, verbose=0, bytes_limit=20e9)

%reload_ext autoreload
%autoreload 2

In [2]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [3]:
@torch.no_grad()
def get_acts(prompt: str | Tensor, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts


def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    indices = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)[:-1]
        value, index = acts.max(0)
        activations.append(value)
        indices.append(index)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc
    max_activation_token_index = torch.stack(indices)

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt, max_activation_token_index


def get_token_kurtosis_for_decoder(model: HookedTransformer, layer: int, decoder: torch.Tensor):
    '''Return excess kurtosis over all decoder features' cosine sims with the unembed (higher is better)'''
    W_out = model.W_out[layer]
    resid_dirs = torch.nn.functional.normalize(decoder @ W_out, dim=-1)
    unembed = torch.nn.functional.normalize(model.unembed.W_U, dim=0)
    cosine_sims = einops.einsum(resid_dirs, unembed, 'd_hidden d_model, d_model d_vocab -> d_hidden d_vocab')
    
    mean = einops.repeat(cosine_sims.mean(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    std = einops.repeat(cosine_sims.std(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    kurt = torch.mean(((cosine_sims - mean) / std) ** 4, dim=-1) - 3
    return kurt

def get_top_activating_examples_for_direction(prompts, direction, max_activations_per_prompt: Tensor, max_activation_token_indices, k=10, mode: Literal["lower", "middle", "upper", "top"]="top"):
    
    sorted_activations = max_activations_per_prompt[:, direction].sort().values
    num_non_zero_activations = sorted_activations[sorted_activations > 0].shape[0]

    max_activation = sorted_activations[-1]
    if mode=="upper":
        index = torch.argwhere(sorted_activations > ((max_activation // 3) * 2)).min()
    elif mode == "middle":
        index = torch.argwhere(sorted_activations > ((max_activation // 3))).min()
    else:
        index = torch.argwhere(sorted_activations > ((max_activation // 10))).min()
    negative_index = sorted_activations.shape[0] - index

    activations = max_activations_per_prompt[:, direction]
    _, prompt_indices = activations.topk(num_non_zero_activations+1)
    if mode=="top":
        prompt_indices = prompt_indices[:k]
    else:
        prompt_indices = prompt_indices[negative_index:negative_index+k]
    prompt_indices = prompt_indices[:num_non_zero_activations]

    top_prompts = [prompts[i] for i in prompt_indices]
    token_indices = max_activation_token_indices[prompt_indices, direction]
    return top_prompts, token_indices

def get_direction_ablation_hook(encoder, direction, hook_pos=None):
    def subtract_direction_hook(value, hook):
        
        x_cent = value[0, :] - encoder.b_dec
        acts = F.relu(x_cent @ encoder.W_enc[:, direction] + encoder.b_enc[direction])
        direction_impact_on_reconstruction = einops.einsum(acts, encoder.W_dec[direction, :], "pos, d_mlp -> pos d_mlp") # + encoder.b_dec ???
        if hook_pos is not None:
            value[:, hook_pos, :] -= direction_impact_on_reconstruction[hook_pos]
        else:
            value[:, :] -= direction_impact_on_reconstruction
        return value
    return subtract_direction_hook

def evaluate_direction_ablation_single_prompt(prompt: str, encoder: AutoEncoder, model: HookedTransformer, direction: int, cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    if pos is not None:
        original_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
    else:
        original_loss = model(prompt, return_type="loss")
    
    with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(encoder, direction, pos))]):
        if pos is not None:
            ablated_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
        else:
            ablated_loss = model(prompt, return_type="loss")
    return original_loss.item(), ablated_loss.item()

In [4]:
# vocab_counts = get_occuring_tokens(model, load_tinystories_validation_prompts())
# percent_ever_occur = {vocab_counts.sum() / len(vocab_counts) * 100}
print(f'25% of tokens ever occur in the validation set')

25% of tokens ever occur in the validation set


In [5]:
# 1. List of clean features
# 2. Sort by indirect ablation increase
# 3. Sort for things in layer 1

In [6]:
l0_encoder, l0_config = load_encoder('18_morning_sun', model_name, model)
l1_encoder, l1_config = load_encoder('8_deep_brook', model_name, model)
prompts = load_tinystories_validation_prompts()
# l1_kurtosis = get_token_kurtosis_for_decoder(model, 1, l1_encoder.W_dec)
# px.histogram(pd.DataFrame({"kurtosis": l1_kurtosis.cpu()}))   

(INFO) 09:56:56: Loaded 21990 TinyStories validation prompts


In [7]:
def get_max_activation_df(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    dfs = []
    direction_index = np.arange(encoder.d_hidden).astype('int32')
    for i, prompt in tqdm(enumerate(prompts), total=len(prompts)):
        tokens = model.to_tokens(prompt)[0, :-1]
        acts = get_acts(tokens, model, encoder, cfg)
        value, index = acts.max(0)
        max_activating_tokens = tokens[index]
        example_df = pd.DataFrame({"prompt": i, "direction": direction_index, "max_activating_token_index": index.cpu().numpy(), "max_activating_token": max_activating_tokens.cpu().numpy(), "max_activation": value.cpu().numpy()})
        example_df['prompt'] = example_df['prompt'].astype('int32')
        example_df['max_activating_token_index'] = example_df['max_activating_token_index'].astype('int32')
        example_df['max_activating_token'] = example_df['max_activating_token'].astype('int32')
        dfs.append(example_df)
    return pd.concat(dfs)

# ~10GB
#df = get_max_activation_df(prompts, model, l0_encoder, l0_config)

## Clean direction ablation

In [8]:
def evaluate_direction_ablation(prompts: list[str], encoder: AutoEncoder, model: HookedTransformer, direction: int, cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    
    original_losses = []
    ablated_losses = []
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    for prompt in prompts:
        if pos is not None:
            original_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
        else:
            original_loss = model(prompt, return_type="loss")
        
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(encoder, direction, pos))]):
            if pos is not None:
                ablated_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
            else:
                ablated_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        ablated_losses.append(ablated_loss.item())
    return np.mean(original_losses), np.mean(ablated_losses)

original_loss, ablated_loss = evaluate_direction_ablation(prompts[0:2], l1_encoder, model, 0, l1_config, pos=None)
print(original_loss, ablated_loss)

1.138980358839035 1.1388558745384216


## Max activating examples

In [9]:
#@memory.cache
def get_activations(encoder, cfg):
    max_activations, max_activation_token_indices = get_max_activations(prompts, model, encoder, cfg)
    return max_activations, max_activation_token_indices

#max_activations_l0, max_activation_token_indices_l0 = get_max_activations(prompts, model, l0_encoder, l0_config)

In [10]:
#max_activations_l1, max_activation_token_indices_l1 = get_max_activations(prompts, model, l1_encoder, l1_config)

In [11]:
# max_activations_2L_data = {
#     "max_activations_l0": max_activations_l0.cpu(),
#     "max_activation_token_indices_l0": max_activation_token_indices_l0.cpu(),
#     "max_activations_l1": max_activations_l1.cpu(),
#     "max_activation_token_indices_l1": max_activation_token_indices_l1.cpu(),
# }

import pickle
# with open("/workspace/max_activations_2L_data.pkl", "wb") as f:
#     pickle.dump(max_activations_2L_data, f)

with open("/workspace/max_activations_2L_data.pkl", "rb") as f:
    max_activations_2L_data = pickle.load(f)
    max_activations_l0 = max_activations_2L_data["max_activations_l0"]
    max_activation_token_indices_l0 = max_activations_2L_data["max_activation_token_indices_l0"]
    max_activations_l1 = max_activations_2L_data["max_activations_l1"]
    max_activation_token_indices_l1 = max_activations_2L_data["max_activation_token_indices_l1"]

## Cosine sims

In [12]:
W_out = model.W_out[0]
W_in = model.W_in[1]
# d_hidden d_mlp d_mlp d_resid
# d_resid d_mlp d_mlp d_hidden
cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec @ W_out, dim=-1) @ torch.nn.functional.normalize(W_in @ l1_encoder.W_enc, dim=0)
cosine_sims = torch.tril(cosine_sims)
print(cosine_sims.shape)

all_sims = cosine_sims.flatten().cpu()

torch.Size([16384, 16384])


In [13]:
px.histogram(all_sims[torch.randperm(len(all_sims))][:10000], title="Randomly sampled cosine similarities between layer 0 and layer 1 features", width=1000)

In [14]:
values, indices = torch.topk(all_sims, 10000)
px.histogram(values.cpu().numpy(), title="Top 10k cosine similarities between layer 0 and layer 1 features", width=1000)

In [15]:
def i_to_row_col(i: int, n_cols: int = len(cosine_sims)):
    row = i // n_cols
    col = i % n_cols
    return row, col

#top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0)

In [16]:
# ablation increase on next token
# ablation increase over the whole prompt
# Baselines: ablate random active features 



In [17]:
threshold = 0.3
n_high_cosine_sims = all_sims[all_sims > threshold].shape[0]
print(n_high_cosine_sims)

591


In [18]:
n = 50#n_high_cosine_sims
top_cosine_similarities, top_cosine_sim_indices = torch.topk(all_sims, n)

## Top cosine sim pairs loss increases

In [19]:
# # Losses for ablating L0 and L1 directions with high cosine sim (individually)
# data = []
# for top_cosine_index in tqdm(top_cosine_sim_indices):
#     l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
#     top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=100)
#     original_losses = []
#     l0_losses = []
#     l1_losses = []
#     for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
#         original_loss, l0_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l0_encoder, model, l0_dir, l0_config, pos=pos)
#         _, l1_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l1_encoder, model, l1_dir, l1_config, pos=pos)
#         original_losses.append(original_loss)
#         l0_losses.append(l0_ablated_loss)
#         l1_losses.append(l1_ablated_loss)
#     data.append([l0_dir.item(), l1_dir.item(), np.mean(original_losses), np.mean(l0_losses), np.mean(l1_losses)])
#     #print(f"Direction {l0_dir} -> {l1_dir}: {np.mean(original_losses):.2f}, {np.mean(l0_losses):.2f}, {np.mean(l1_losses):.2f}")
# df = pd.DataFrame(data, columns=["L0 direction", "L1 direction", "Original loss", "L0 direction ablation loss", "L1 direction ablation loss"])
# df["Cosine similarity"] = top_cosine_similarities.tolist()

In [20]:
df = pd.read_csv("data/cosine_sim_2L_tinystories.csv")

## Check within layer cosine similarity

In [21]:
# Check how similar high cosine sims are to each other (L0 decoder)
#l0_indices = np.unique([i_to_row_col(index)[0] for index in top_cosine_sim_indices])
#l0_cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec[l0_indices], dim=1) @ torch.nn.functional.normalize(l0_encoder.W_dec[l0_indices].T, dim=0)

l0_cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec, dim=1) @ torch.nn.functional.normalize(l0_encoder.W_dec.T, dim=0)

l0_cosine_sims = torch.tril(l0_cosine_sims, diagonal=-1)
top_l0_sims, top_l0_sim_indices = torch.topk(l0_cosine_sims.flatten(), 20000)
num_high_l0_sims = top_l0_sims[top_l0_sims > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 20k cosine similarities between layer 0 decoder features", width=1000)

Number of L0 cosine sims > 0.7: 643


In [22]:
def get_high_similarity_directions(direction: int, encoder: AutoEncoder, use_enc=False, threshold=0.8):
    if use_enc:
        cosine_sims = torch.nn.functional.normalize(encoder.W_enc.T, dim=1) @ torch.nn.functional.normalize(encoder.W_enc, dim=0)
    else:
        cosine_sims = torch.nn.functional.normalize(encoder.W_dec, dim=1) @ torch.nn.functional.normalize(encoder.W_dec.T, dim=0)

    cosine_sims = torch.tril(cosine_sims, diagonal=-1)[direction]
    similar_directions = torch.argwhere(cosine_sims > threshold).flatten()
    return similar_directions.tolist()

get_high_similarity_directions(14471, l0_encoder, use_enc=False)

[5846, 14403]

In [23]:
# Encoder sims
l0_encoder_sims = torch.nn.functional.normalize(l0_encoder.W_enc.T, dim=1) @ torch.nn.functional.normalize(l0_encoder.W_enc, dim=0)
l0_encoder_sims = torch.tril(l0_encoder_sims, diagonal=-1)

rows, cols = l0_encoder_sims.shape
mask = torch.tril(torch.ones(rows, cols, dtype=torch.bool), diagonal=-1)
lower_triangle_elements = l0_encoder_sims[mask]


sorted_elements = torch.sort(lower_triangle_elements).values
n_elements = len(sorted_elements)
median = sorted_elements[n_elements // 2]
upper_quartile = sorted_elements[(n_elements // 4) * 3]
top_quartile = sorted_elements[(n_elements // 10) * 9]
top_top_quartile = sorted_elements[(n_elements // 20) * 19]
print(f"Median: {median:.2f}, 75th percentile: {upper_quartile:.2f}, 90th percentile: {top_quartile:.2f}, 95th percentile: {top_top_quartile:.2f}")

top_l0_sims, top_l0_sim_indices = torch.topk(lower_triangle_elements, k=20000)
num_high_l0_sims = lower_triangle_elements[lower_triangle_elements > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 10k cosine similarities between layer 0 encoder features", width=1000)

Median: 0.10, 75th percentile: 0.13, 90th percentile: 0.18, 95th percentile: 0.24
Number of L0 cosine sims > 0.7: 1857407


In [24]:
# Baseline: check MLP in weights
l0_encoder_sims = torch.nn.functional.normalize(model.W_in[0].T, dim=1) @ torch.nn.functional.normalize(model.W_in[0], dim=0)
l0_encoder_sims = torch.tril(l0_encoder_sims, diagonal=-1)
rows, cols = l0_encoder_sims.shape
mask = torch.tril(torch.ones(rows, cols, dtype=torch.bool), diagonal=-1)
lower_triangle_elements = l0_encoder_sims[mask]
top_l0_sims, top_l0_sim_indices = torch.topk(lower_triangle_elements, 20000)
num_high_l0_sims = top_l0_sims[top_l0_sims > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 10k cosine similarities between layer 0 MLP in features", width=1000)

Number of L0 cosine sims > 0.7: 12


In [25]:
# Baseline: check MLP in weights
l0_encoder_sims = torch.nn.functional.normalize(model.W_out[0], dim=1) @ torch.nn.functional.normalize(model.W_out[0].T, dim=0)
l0_encoder_sims = torch.tril(l0_encoder_sims, diagonal=-1)
rows, cols = l0_encoder_sims.shape
mask = torch.tril(torch.ones(rows, cols, dtype=torch.bool), diagonal=-1)
lower_triangle_elements = l0_encoder_sims[mask]
top_l0_sims, top_l0_sim_indices = torch.topk(lower_triangle_elements, 20000)
num_high_l0_sims = top_l0_sims[top_l0_sims > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 10k cosine similarities between layer 0 MLP out features", width=1000)

Number of L0 cosine sims > 0.7: 72


## Baseline loss increase

In [26]:
# Baseline
loss_increases = []
for prompt in tqdm(prompts[:100]):
    acts = get_acts(prompt, model, l0_encoder, l0_config)[:-1]
    pos = 20
    active_directions = torch.argwhere(acts[pos, :] > 0).flatten()
    direction = random.choice(active_directions)
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l0_encoder, model, direction, l0_config, pos=pos)
    loss_increases.append(ablated_loss - original_loss)
print(np.mean(loss_increases))
px.histogram(loss_increases)

  0%|          | 0/100 [00:00<?, ?it/s]

0.008247477585973684


## Causal link L0 -> L1

In [27]:
# Check l1 activation with and without l0 activation
# data = []
# for top_cosine_index in tqdm(top_cosine_sim_indices[:n]):
#     l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
#     top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=100)
#     # Check l1 activation on the position
#     # Ablate l0 direction and check l1 activation
#     acts = []
#     ablated_acts = []
#     for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
#         act = get_acts(prompt, model, l1_encoder, l1_config)[index, l1_dir].item()
#         encoder_hook_point = f"blocks.{l0_config.layer}.{l0_config.act_name}"
#         with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(l0_encoder, l0_dir, index))]):
#             ablated_act = get_acts(prompt, model, l1_encoder, l1_config)[index, l1_dir].item()
#         acts.append(act)
#         ablated_acts.append(ablated_act)
#     data.append([l0_dir.item(), l1_dir.item(), np.mean(acts), np.mean(ablated_acts)])
#     #print(f"Ablating L0 direction {l0_dir} -> {l1_dir}: {np.mean(acts):.2f} -> {np.mean(ablated_acts):.2f}")
# tmp_df = pd.DataFrame(data, columns=["L0 direction", "L1 direction", "L1 activation", "L1 activation after ablation"])
# # Merge with df matching L0 and L1 direction
# df = df.merge(tmp_df, on=["L0 direction", "L1 direction"])
# df.head()

In [28]:
# Want
# Features that don't look at only current token - dla automation
# Features that increase loss when ablated - test on max activating examples
# Features that causally influence later features - filter for cosine
# Features that are interpretable when looking at activating examples

In [29]:
# Check what other l0 features contribute to l1 activation - maybe we can find set that fully deactivates l1 - feature DLA?
# Look at activating examples manually for interpretability
# Check across positions
# Check if l0 features correspond to single tokens or do something more interesting

In [30]:
df.to_csv("data/cosine_sim_2L_tinystories.csv")

In [31]:
df["L1 ablation loss increase"] = df["L1 direction ablation loss"] - df["Original loss"]
df = df.sort_values("L1 ablation loss increase", ascending=False)
df.head()

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,L0 direction,L1 direction,Original loss,L0 direction ablation loss,L1 direction ablation loss,Cosine similarity,L1 activation,L1 activation after ablation,L1 ablation loss increase
30,30,30,30,7105,880,0.019676,1.894193,1.715879,0.430657,8.816866,1.289062,1.696203
39,39,39,39,10805,8651,0.667334,1.204928,1.995131,0.419653,9.9573,2.479752,1.327797
13,13,13,13,13798,8061,0.030373,0.487021,0.995018,0.448903,7.848553,3.781541,0.964645
12,12,12,12,12704,10081,0.343059,3.065146,1.111156,0.449377,7.530861,1.675186,0.768097
31,31,31,31,3041,576,0.650594,1.707483,1.329974,0.430053,5.63222,1.772597,0.67938


## About to circuit

In [32]:
def print_max_activating_examples(direction, encoder, config, prompts, max_activations, max_activation_token_indices, k=3, mode="top"):
    print(f"Direction {direction}")
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=k, mode=mode)
    max_activation_value = max_activations[:, direction].max().item()
    for prompt in top_prompts:
        acts = get_acts(prompt, model, encoder, config)[:, direction].tolist()
        str_tokens = model.to_str_tokens(prompt)
        haystack_utils.clean_print_strings_as_html(str_tokens, acts, max_value=max_activation_value)

dirs = df["L0 direction"].unique().tolist()
for l0_dir in dirs[:1]:
    print("L0 max activating examples", l0_dir)
    print_max_activating_examples(l0_dir, l0_encoder, l0_config, prompts, max_activations_l0, max_activation_token_indices_l0, k=5, mode="top")

L0 max activating examples 7105
Direction 7105


In [33]:
top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=10, mode="top")
test_prompts = []
for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
    tokens = model.to_str_tokens(prompt)
    assert tokens[index+1] == " to"
    tokens = tokens[:index+1]
    prompt = "".join(tokens)
    test_prompts.append(prompt)

In [34]:
# Test prompt appends answer, "about" pos = -2
ablation_hook = get_direction_ablation_hook(l0_encoder, l0_dir, -2)

for prompt in test_prompts[:1]:
    test_prompt(prompt, " to", model, prepend_space_to_answer=False, prepend_bos=False)
    
    with model.hooks([(f"blocks.{l0_config.layer}.{l0_config.act_name}", ablation_hook)]):
        test_prompt(prompt, " to", model, prepend_space_to_answer=False, prepend_bos=False)


Tokenized prompt: ['<|endoftext|>', 'The', 'odore', ' was', ' building', ' a', ' unique', ' cell', ' in', ' his', ' backyard', '.', ' He', ' wanted', ' it', ' to', ' look', ' special', '.', ' He', ' carefully', ' picked', ' out', ' different', ' pieces', ' and', ' put', ' them', ' together', ' like', ' a', ' puzzle', '.', ' ', '\n', '\n', 'Suddenly', ',', ' he', ' heard', ' a', ' muff', 'led', ' voice', ' coming', ' from', ' inside', ' the', ' cell', '.', ' "', 'Help', ' me', '!', ' Can', ' someone', ' help', ' me', '?"', ' Theodore', ' was', ' shocked', ' and', ' frightened', '.', ' ', '\n', '\n', 'The', 'odore', ' got', ' closer', ' to', ' the', ' cell', ' and', ' said', ',', ' "', 'Who', ' are', ' you', '?"', '\n', '\n', 'The', ' voice', ' replied', ',', ' "', 'I', "'m", ' Daisy', ',', ' help', ' me', '!', ' I', "'m", ' stuck', ' inside', ' here', '."', ' ', '\n', '\n', 'The', 'odore', ' knew', ' he', ' had', ' to', ' help', ' Daisy', '.', ' He', ' used', ' his', ' hand', ' to', ' t

Top 0th token. Logit: 25.08 Prob: 99.94% Token: | to|
Top 1th token. Logit: 16.20 Prob:  0.01% Token: | the|
Top 2th token. Logit: 15.66 Prob:  0.01% Token: |,|
Top 3th token. Logit: 15.00 Prob:  0.00% Token: | done|
Top 4th token. Logit: 14.60 Prob:  0.00% Token: | it|
Top 5th token. Logit: 14.50 Prob:  0.00% Token: | he|
Top 6th token. Logit: 14.29 Prob:  0.00% Token: | how|
Top 7th token. Logit: 14.25 Prob:  0.00% Token: | halfway|
Top 8th token. Logit: 13.92 Prob:  0.00% Token: | his|
Top 9th token. Logit: 13.90 Prob:  0.00% Token: | there|


Tokenized prompt: ['<|endoftext|>', 'The', 'odore', ' was', ' building', ' a', ' unique', ' cell', ' in', ' his', ' backyard', '.', ' He', ' wanted', ' it', ' to', ' look', ' special', '.', ' He', ' carefully', ' picked', ' out', ' different', ' pieces', ' and', ' put', ' them', ' together', ' like', ' a', ' puzzle', '.', ' ', '\n', '\n', 'Suddenly', ',', ' he', ' heard', ' a', ' muff', 'led', ' voice', ' coming', ' from', ' inside', ' the', ' cell', '.', ' "', 'Help', ' me', '!', ' Can', ' someone', ' help', ' me', '?"', ' Theodore', ' was', ' shocked', ' and', ' frightened', '.', ' ', '\n', '\n', 'The', 'odore', ' got', ' closer', ' to', ' the', ' cell', ' and', ' said', ',', ' "', 'Who', ' are', ' you', '?"', '\n', '\n', 'The', ' voice', ' replied', ',', ' "', 'I', "'m", ' Daisy', ',', ' help', ' me', '!', ' I', "'m", ' stuck', ' inside', ' here', '."', ' ', '\n', '\n', 'The', 'odore', ' knew', ' he', ' had', ' to', ' help', ' Daisy', '.', ' He', ' used', ' his', ' hand', ' to', ' t

Top 0th token. Logit: 19.79 Prob: 35.10% Token: | the|
Top 1th token. Logit: 19.50 Prob: 26.17% Token: |,|
Top 2th token. Logit: 19.29 Prob: 21.27% Token: | it|
Top 3th token. Logit: 17.61 Prob:  3.97% Token: | to|
Top 4th token. Logit: 17.42 Prob:  3.29% Token: | he|
Top 5th token. Logit: 17.37 Prob:  3.13% Token: | how|
Top 6th token. Logit: 16.45 Prob:  1.24% Token: | about|
Top 7th token. Logit: 16.41 Prob:  1.19% Token: | his|
Top 8th token. Logit: 15.27 Prob:  0.38% Token: | what|
Top 9th token. Logit: 15.19 Prob:  0.35% Token: | that|


## Activating tokens

In [35]:
def eval_direction_tokens(direction, max_activations, max_activation_token_indices, prompts, model, encoder, cfg, percentage_threshold = 0.25):
    max_activation_value = max_activations[:, direction].max().item()
    num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    threshold = max_activation_value * percentage_threshold
    print(f"Threshold: {threshold:.2f}")

    activating_tokens = []
    for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
        str_tokens = model.to_str_tokens(prompt)
        activations = get_acts(prompt, model, encoder, cfg)
        activation = activations[index, direction].item()
        token = str_tokens[index]
        if activation > threshold:
            activating_tokens.append(token)
        else:
            break
    token_counts = Counter(activating_tokens)
    return token_counts, threshold

# Check if direction always activates on tokens
def count_direction_activations_on_token(direction, count_tokens: Tensor, prompts, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig, threshold = 0.25):
    num_active = 0
    num_total = 0
    for prompt in tqdm(prompts):
        tokens = model.to_tokens(prompt)[0]
        token_positions = torch.argwhere(torch.isin(tokens, count_tokens)).flatten()

        if token_positions.shape[0] > 0:
            activations = get_acts(prompt, model, encoder, cfg)[token_positions, direction]
            num_active += activations[activations > threshold].shape[0]
            num_total += activations.shape[0]

    return num_active, num_total



In [36]:
# token_counts, threshold = eval_direction_tokens(l0_dir, max_activations_l0, max_activation_token_indices_l0, prompts, model, l0_encoder, l0_config)
# print("L0 direction", l0_dir, token_counts.most_common(20))
# top_token = token_counts.most_common(1)[0][0]

# token_counts = eval_direction_tokens(l1_dir, max_activations_l1, max_activation_token_indices_l1, prompts, model, l1_encoder, l1_config)
# print("L1 direction", l1_dir, token_counts.most_common(20))

# count_tokens = model.to_tokens([top_token], prepend_bos=False).flatten()
# count_direction_activations_on_token(l0_dir, count_tokens, prompts[:5000], model, l0_encoder, l0_config, threshold=threshold)


## Check direction DLA


In [37]:
# def direction_dla(direction, max_activations, max_activation_token_indices, encoder, encoder_cfg):
#     num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
#     top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    
#     direction_weight = model.W_in[encoder_cfg.layer] @ encoder.W_enc[:, direction]
#     prompt = top_prompts[0]
#     pos = top_prompt_token_indices[0]
#     _, cache = model.run_with_cache(prompt)

#     activation = get_acts(prompt, model, encoder, encoder_cfg)[pos, direction].item()
#     print(activation)
#     decomposition, labels = cache.get_full_resid_decomposition(encoder_cfg.layer+1, apply_ln=True, return_labels=True, expand_neurons=False, pos_slice=pos)
#     decomposition = decomposition.squeeze(1)
#     dla = einops.einsum(decomposition, direction_weight, "component d_res, d_res -> component")
#     last_mlp_label = f"{encoder_cfg.layer}_mlp_out"
#     last_mlp_index = labels.index(last_mlp_label)
#     dla = [dla[i].item() for i in range(len(dla)) if i != last_mlp_index]
#     labels = [label for i, label in enumerate(labels) if i != last_mlp_index]
#     return dla, labels

def direction_dla(direction, max_activations, max_activation_token_indices, encoder, encoder_cfg, n=100):
    num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    
    direction_weight = encoder.W_enc[:, direction]
    dlas = []
    for i in range(n):
        prompt = top_prompts[i]
        pos = top_prompt_token_indices[i]
        _, cache = model.run_with_cache(prompt)

        decomposition, labels = cache.get_full_resid_decomposition(encoder_cfg.layer, mlp_input=True, apply_ln=True, return_labels=True, expand_neurons=False, pos_slice=pos)
        decomposition = decomposition.squeeze(1)

        # Account for GELU in DLA by setting neuron contributions to 0 if they are not activated
        mlp_wise_decomposition = einops.einsum(decomposition, model.W_in[encoder_cfg.layer], "component d_res, d_res d_mlp -> component d_mlp")
        mlp_activations = cache[f"blocks.{encoder_cfg.layer}.mlp.hook_post"][0, pos, :]
        zeroed_neurons = torch.argwhere(mlp_activations <= 0).flatten()
        mlp_wise_decomposition[:, zeroed_neurons] = 0

        dla = einops.einsum(mlp_wise_decomposition, direction_weight, "component d_mlp, d_mlp -> component")
        dlas.append(dla)
    dla = torch.stack(dlas).mean(0).tolist()
    return dla, labels

In [38]:
dla, labels = direction_dla(l0_dir, max_activations_l0, max_activation_token_indices_l0, l0_encoder, l0_config, n=50)
line(dla, xticks=labels, title=f"DLA for L0 direction {l0_dir}", width=1000)

## Boosted tokens

In [39]:
def get_common_tinystories_tokens(prompts, model: HookedTransformer, min_occurrences=100):
    occurrences = torch.zeros(model.cfg.d_vocab, dtype=torch.int32).cuda()
    for prompt in prompts: 
        tokens = model.to_tokens(prompt).flatten()
        occurrences = occurrences.index_add(0, tokens, torch.ones_like(tokens, dtype=torch.int32))
    common_tokens = torch.argwhere(occurrences > min_occurrences).flatten()
    rare_tokens = torch.argwhere(occurrences <= min_occurrences).flatten()
    return occurrences, common_tokens, rare_tokens

def get_direction_boosted_tokens(direction, encoder: AutoEncoder, model: HookedTransformer, cfg: AutoEncoderConfig, rare_tokens: Tensor):
    token_boosts = encoder.W_dec[direction] @ model.W_out[cfg.layer] @ model.unembed.W_U
    token_boosts[rare_tokens] = 0
    return token_boosts

def print_token_boosts(boosts, tokens):
    str_tokens = model.to_str_tokens(tokens)
    boost_str = ""
    for token, boost in zip(str_tokens, boosts.tolist()):
        boost_str += f"('{token}': {boost:.2f}), "
    print(boost_str[:-2])

occurrences, common_tokens, rare_tokens = get_common_tinystories_tokens(prompts, model)
#print(occurrences.shape, common_tokens.shape, rare_tokens.shape)

In [40]:
l0_boosts = get_direction_boosted_tokens(l0_dir, l0_encoder, model, l0_config, rare_tokens)
top_boosts, top_tokens = torch.topk(l0_boosts, 25)
print(f"L0 direction {l0_dir}")
print_token_boosts(top_boosts, top_tokens)

# l1_boosts = get_direction_boosted_tokens(l1_dir, l1_encoder, model, l1_config, rare_tokens)
# top_boosts, top_tokens = torch.topk(l1_boosts, 25)
# print(f"L1 direction {l1_dir}")
# print_token_boosts(top_boosts, top_tokens)

L0 direction 7105
(' bugs': 0.20), (' dogs': 0.20), (' frogs': 0.20), (' birds': 0.19), (' bears': 0.19), (' half': 0.18), (' finding': 0.18), (' three': 0.17), (' butterflies': 0.17), (' hungry': 0.17), (' ducks': 0.17), ('erry': 0.16), (' bees': 0.16), ('es': 0.16), ('op': 0.16), (' 3': 0.16), (' wearing': 0.16), (' four': 0.15), (' boys': 0.15), ('e': 0.14), (' belonged': 0.14), ('ite': 0.14), (' stumbled': 0.13), ('ate': 0.13), ('eer': 0.13)


## Top dataset examples

In [41]:
# def print_max_activating_examples(l0_dir, l1_dir, prompts, max_activations, max_activation_token_indices, k=3, mode="top"):
#     print(f"Direction {l0_dir} -> {l1_dir}")
#     top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations, max_activation_token_indices, k=k, mode=mode)
#     max_activation_value = max_activations[:, l0_dir].max().item()
#     for prompt in top_prompts:
#         acts = get_acts(prompt, model, l0_encoder, l0_config)[:, l0_dir].tolist()
#         l1_acts = get_acts(prompt, model, l1_encoder, l1_config)[:, l1_dir].tolist()
#         str_tokens = model.to_str_tokens(prompt)
#         haystack_utils.clean_print_strings_as_html(str_tokens, acts, max_value=max_activation_value, additional_measures=[l1_acts], additional_measure_names = ["L1 act"])

def print_max_activating_examples(direction, encoder, config, prompts, max_activations, max_activation_token_indices, k=3, mode="top"):
    print(f"Direction {direction}")
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=k, mode=mode)
    max_activation_value = max_activations[:, direction].max().item()
    for prompt in top_prompts:
        acts = get_acts(prompt, model, encoder, config)[:, direction].tolist()
        str_tokens = model.to_str_tokens(prompt)
        haystack_utils.clean_print_strings_as_html(str_tokens, acts, max_value=max_activation_value)

# print("L0 max activating examples")
# print_max_activating_examples(l0_dir, l0_encoder, l0_config, prompts, max_activations_l0, max_activation_token_indices_l0, k=20, mode="top")

# print("L1 max activating examples")
# print_max_activating_examples(l1_dir, l1_encoder, l1_config, prompts, max_activations_l1, max_activation_token_indices_l1, k=5, mode="top")

## Inspect random features

In [42]:
from ipywidgets import interact, IntSlider

def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

def print_direction_example(direction, n=10):
    print_top_examples(prompts, max_activations_l0, direction, l0_encoder, l0_config, n)

# Max activations
interact(print_direction_example, 
         direction=IntSlider(min=0, max=l0_encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))

interactive(children=(IntSlider(value=0, description='direction', max=16383), IntSlider(value=5, description='…

<function __main__.print_direction_example(direction, n=10)>

## About to circuit

In [43]:
l0_dir = 7105
l1_dir = 880

In [44]:
print("L0 max activating examples", l0_dir)
print_max_activating_examples(l0_dir, l0_encoder, l0_config, prompts, max_activations_l0, max_activation_token_indices_l0, k=5, mode="top")

print("L1 max activating examples", l1_dir)
print_max_activating_examples(l1_dir, l1_encoder, l1_config, prompts, max_activations_l1, max_activation_token_indices_l1, k=5, mode="top")

L0 max activating examples 7105
Direction 7105


L1 max activating examples 880
Direction 880


In [45]:
token_counts, threshold = eval_direction_tokens(l0_dir, max_activations_l0, max_activation_token_indices_l0, prompts, model, l0_encoder, l0_config)
print("L0 direction", l0_dir, token_counts.most_common(20))

token_counts, threshold = eval_direction_tokens(l1_dir, max_activations_l1, max_activation_token_indices_l1, prompts, model, l1_encoder, l1_config)
print("L1 direction", l1_dir, token_counts.most_common(20))

Threshold: 3.59
L0 direction 7105 [(' about', 328)]
Threshold: 2.50
L1 direction 880 [(' about', 319)]


In [46]:
answer_token = model.to_single_token(" to")

l0_boosts = get_direction_boosted_tokens(l0_dir, l0_encoder, model, l0_config, rare_tokens)
top_boosts, top_tokens = torch.topk(l0_boosts, 25)
print(f"L0 direction {l0_dir}")
print(f"L0 answer token boost: {l0_boosts[answer_token].item():.2f}")
print_token_boosts(top_boosts, top_tokens)

l1_boosts = get_direction_boosted_tokens(l1_dir, l1_encoder, model, l1_config, rare_tokens)
top_boosts, top_tokens = torch.topk(l1_boosts, 25)
print(f"L1 direction {l1_dir}")
print(f"L1 answer token boost: {l1_boosts[answer_token].item():.2f}")
print_token_boosts(top_boosts, top_tokens)

L0 direction 7105
L0 answer token boost: 0.06
(' bugs': 0.20), (' dogs': 0.20), (' frogs': 0.20), (' birds': 0.19), (' bears': 0.19), (' half': 0.18), (' finding': 0.18), (' three': 0.17), (' butterflies': 0.17), (' hungry': 0.17), (' ducks': 0.17), ('erry': 0.16), (' bees': 0.16), ('es': 0.16), ('op': 0.16), (' 3': 0.16), (' wearing': 0.16), (' four': 0.15), (' boys': 0.15), ('e': 0.14), (' belonged': 0.14), ('ite': 0.14), (' stumbled': 0.13), ('ate': 0.13), ('eer': 0.13)
L1 direction 880
L1 answer token boost: 0.90
(' to': 0.90), (' chance': 0.34), (' turn': 0.31), (' 3': 0.30), (' giving': 0.29), (' back': 0.27), (' finish': 0.27), (' not': 0.27), (' outside': 0.27), (' showing': 0.26), (' stairs': 0.26), (' four': 0.26), (' somewhere': 0.25), (' close': 0.25), (' fill': 0.25), (' solve': 0.25), (' give': 0.25), (' three': 0.24), (' pay': 0.24), (' getting': 0.24), (' for': 0.24), (' across': 0.23), (' reach': 0.23), (' add': 0.23), (' get': 0.23)


In [47]:
# Mean activation on top 100 examples
def get_mean_activation(direction: int, max_activations: Float[Tensor, "n_examples d_enc"], n=100):
    top_activations, _ = torch.topk(max_activations[:, direction], n)
    return top_activations.mean().item()

mean_l0_activation = get_mean_activation(l0_dir, max_activations_l0)
mean_l1_activation = get_mean_activation(l1_dir, max_activations_l1)
print(f"L0 direction {l0_dir}: {mean_l0_activation:.2f}, L1 direction {l1_dir}: {mean_l1_activation:.2f}")

L0 direction 7105: 12.97, L1 direction 880: 9.02


In [48]:
l0_direction_mlp_impact = l0_encoder.W_dec[l0_dir] * mean_l0_activation
print(l0_direction_mlp_impact.shape)

torch.Size([4096])


In [49]:
# MLP difference
def get_zero_ablate_encoder_direction_hook(
    encoder: AutoEncoder, encoder_neuron, pos, cfg: AutoEncoderConfig
):
    def zero_feature_hook(value, hook):
        _, x_reconstruct, _, _, _ = custom_forward(
            encoder, value[:, pos], encoder_neuron, 0
        )
        value[:, pos] = x_reconstruct
        return value
    return [(cfg.encoder_hook_point, zero_feature_hook)]

def get_encode_direction_hook(
    encoder: AutoEncoder, pos, cfg: AutoEncoderConfig
):
    def encoder_hook(value, hook):
        _, x_reconstruct, _, _, _ = encoder(value[:, pos])
        value[:, pos] = x_reconstruct
        return value
    return [(cfg.encoder_hook_point, encoder_hook)]

def eval_direction_mlp_impact(direction, max_activations, max_activation_indices, model: HookedTransformer, encoder: AutoEncoder, encoder_cfg: AutoEncoderConfig, n=100):
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_indices, k=n, mode="top")
    hook_name = f"blocks.{encoder_cfg.layer}.{encoder_cfg.act_name}"
    mlp_acts = []
    zero_ablated_mlp_acts = []
    encoded_mlp_acts = []
    direction_mlp_impacts = []
    for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
        # MLP activation
        _, cache = model.run_with_cache(prompt)
        mlp_activation = cache[hook_name][0, index]
        mlp_acts.append(mlp_activation)

        # Direction impact when running model through autoencoder
        ablate_hook = get_zero_ablate_encoder_direction_hook(encoder, direction, index, encoder_cfg)
        with model.hooks(ablate_hook):
            _, cache = model.run_with_cache(prompt)
            zero_ablated_mlp_activation = cache[hook_name][0, index]
            zero_ablated_mlp_acts.append(zero_ablated_mlp_activation)

        # Overall impact from running model through autoencoder
        encode_hook = get_encode_direction_hook(encoder, index, encoder_cfg)
        with model.hooks(encode_hook):
            _, cache = model.run_with_cache(prompt)
            encoded_mlp_activation = cache[hook_name][0, index]
            encoded_mlp_acts.append(encoded_mlp_activation)

        # Manual direction impact on activation
        x_cent = mlp_activation - encoder.b_dec
        direction_act = F.relu(x_cent @ encoder.W_enc[:, direction] + encoder.b_enc[direction])
        direction_impact_on_reconstruction = direction_act * encoder.W_dec[direction, :]
        direction_mlp_impacts.append(direction_impact_on_reconstruction)
    zero_ablated_mlp_acts = torch.stack(zero_ablated_mlp_acts)
    mlp_acts = torch.stack(mlp_acts)
    encoded_mlp_acts = torch.stack(encoded_mlp_acts)
    direction_mlp_impacts = torch.stack(direction_mlp_impacts)
    return mlp_acts, zero_ablated_mlp_acts, encoded_mlp_acts, direction_mlp_impacts

mlp_acts, zero_ablated_mlp_acts, encoded_mlp_acts, direction_mlp_impacts = eval_direction_mlp_impact(l0_dir, max_activations_l0, max_activation_token_indices_l0, model, l0_encoder, l0_config)

In [50]:
px.histogram(l0_encoder.b_dec.cpu().numpy())

In [51]:
mean_mlp_acts = mlp_acts.mean(0).cpu().numpy()
mean_direction_mlp_impact = direction_mlp_impacts.mean(0).cpu().numpy()
mean_acts_after_ablation = (mlp_acts - direction_mlp_impacts).mean(0).cpu().numpy()
mean_zero_ablated_mlp_acts = zero_ablated_mlp_acts.mean(0).cpu().numpy()
mean_encoded_mlp_acts = encoded_mlp_acts.mean(0).cpu().numpy()
centered_mlp_acts = mean_mlp_acts - l0_encoder.b_dec.cpu().numpy()
direction_input_weight = ((mean_mlp_acts - l0_encoder.b_dec.cpu().numpy()) * l0_encoder.W_enc[:, l0_dir].cpu().numpy())

activation_df = pd.DataFrame({
    'mlp_neuron': np.arange(mean_mlp_acts.shape[0]),
    'mean_mlp_acts': mean_mlp_acts,
    'centered_mlp_acts': centered_mlp_acts,
    "neuron_wise_direction_contribution": direction_input_weight,
    'mean_direction_mlp_impact': mean_direction_mlp_impact,
    'mean_acts_after_ablation': mean_acts_after_ablation,
    'mean_zero_ablated_mlp_acts': mean_zero_ablated_mlp_acts,
    'mean_encoded_mlp_acts': mean_encoded_mlp_acts,
    'mean_mlp_activation_difference': mean_mlp_acts - mean_zero_ablated_mlp_acts,
})

activation_df = activation_df.sort_values("neuron_wise_direction_contribution", ascending=False)
activation_df.head(10)

Unnamed: 0,mlp_neuron,mean_mlp_acts,centered_mlp_acts,neuron_wise_direction_contribution,mean_direction_mlp_impact,mean_acts_after_ablation,mean_zero_ablated_mlp_acts,mean_encoded_mlp_acts,mean_mlp_activation_difference
2084,2084,4.313724,4.360909,0.720992,4.011521,0.302203,0.258023,4.269543,4.055701
888,888,2.571172,2.605393,0.434503,2.496989,0.074183,0.127724,2.624712,2.443448
1341,1341,2.303901,2.341377,0.3439,2.116715,0.187186,0.105707,2.222422,2.198194
234,234,2.805055,2.873334,0.336763,2.280952,0.524103,0.44799,2.728942,2.357065
2801,2801,2.087464,2.134264,0.335977,1.957796,0.129669,0.104301,2.062097,1.983164
3920,3920,2.629687,2.689751,0.296707,2.541306,0.08838,0.111754,2.653061,2.517932
3018,3018,2.948689,3.013937,0.271539,2.776868,0.171821,0.101635,2.878503,2.847055
2992,2992,2.199173,2.237825,0.263453,1.871986,0.327187,0.197159,2.069145,2.002013
3944,3944,1.545581,1.616104,0.262671,1.26818,0.277401,0.155293,1.423473,1.390289
1641,1641,1.876351,1.9527,0.260707,1.444853,0.431498,0.330259,1.775112,1.546091


In [52]:
tmp = activation_df[(activation_df["neuron_wise_direction_contribution"] <= 0) & ((activation_df["centered_mlp_acts"] <= 0))]
print(len(tmp), tmp["neuron_wise_direction_contribution"].sum())

946 -1.2077988


In [53]:
tmp = activation_df[(activation_df["neuron_wise_direction_contribution"] >= 0) & ((activation_df["centered_mlp_acts"] <= 0))]
print(len(tmp), tmp["neuron_wise_direction_contribution"].sum())
l0_encoder.b_enc[l0_dir].item()

1529 2.5981746


-0.16937312483787537

In [54]:
# Not centered
# Maybe some neurons that are not active and have negative weight make the direction slightly more active
# Mainly they would be responsible for not making the direction fire on other prompts when they are active

In [82]:
# Compute baseline for direction activations on random prompts
baseline_mlp_acts = []
for prompt in prompts[:50]:
    if " was about to" not in prompt:
        hook_name = f"blocks.{l0_config.layer}.{l0_config.act_name}"
        _, cache = model.run_with_cache(prompt)
        mlp_activation = cache[hook_name][0]
        baseline_mlp_acts.append(mlp_activation)
baseline_mlp_acts = torch.cat(baseline_mlp_acts, dim=0)
print(baseline_mlp_acts.shape)
direction_input_weight = ((baseline_mlp_acts.mean(0) - l0_encoder.b_dec) * l0_encoder.W_enc[:, l0_dir]).cpu().numpy()
top_negative_indices = np.argsort(direction_input_weight)[:10]
print(top_negative_indices, direction_input_weight[top_negative_indices])
px.histogram(direction_input_weight)

torch.Size([8312, 4096])
[3000  207 3641 1201 3691 2237 1327 2930 3130  162] [-0.04910274 -0.04493179 -0.04400193 -0.04153743 -0.03823366 -0.02972312
 -0.02775414 -0.02123882 -0.01731402 -0.01712427]


In [56]:
px.histogram(activation_df["neuron_wise_direction_contribution"])

In [80]:
# Check subset of neurons fit 
# Calculate direction activation based on subset of neurons
# Check if direction fires on "was about"
# Check if directio doesn't fire otherwise

def get_direction_activation(direction, prompt, encoder: AutoEncoder, model: HookedTransformer, pos: int = -1):
    _, cache = model.run_with_cache(prompt)
    if pos is not None:
        x_cent = cache[f"blocks.{l0_config.layer}.{l0_config.act_name}"][0, pos] - encoder.b_dec
    else:
        x_cent = cache[f"blocks.{l0_config.layer}.{l0_config.act_name}"][0, :] - encoder.b_dec
    direction_act = F.relu(x_cent @ encoder.W_enc[:, direction] + encoder.b_enc[direction])
    return direction_act

def get_direction_activation_from_neuron_subset(neurons, direction, prompt, encoder: AutoEncoder, model: HookedTransformer, pos: int = -1):
    _, cache = model.run_with_cache(prompt)
    if pos is not None:
        x_cent = cache[f"blocks.{l0_config.layer}.{l0_config.act_name}"][0, pos] - encoder.b_dec
        direction_act = F.relu(x_cent[neurons] @ encoder.W_enc[neurons, direction] + encoder.b_enc[direction])
    else:
        x_cent = cache[f"blocks.{l0_config.layer}.{l0_config.act_name}"][0, :] - encoder.b_dec
        direction_act = F.relu(x_cent[:, neurons] @ encoder.W_enc[neurons, direction] + encoder.b_enc[direction])
    return direction_act

top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=50, mode="top")

full_mlp_direction_acts = []
neuron_subset_direction_acts = []
for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
    neuron_subset_direction_act = get_direction_activation_from_neuron_subset(relevant_l0_neurons, l0_dir, prompt, l0_encoder, model, pos=pos)
    full_mlp_direction_act = get_direction_activation(l0_dir, prompt, l0_encoder, model, pos=pos)
    neuron_subset_direction_acts.append(neuron_subset_direction_act.item())
    full_mlp_direction_acts.append(full_mlp_direction_act.item())

print(np.min(neuron_subset_direction_acts), np.min(full_mlp_direction_acts))

7.968398571014404 12.88914966583252


In [107]:
# Check on "about" without "was" what deactivates direction
mlp_acts = []
hook_name = f"blocks.{l0_config.layer}.{l0_config.act_name}"
for prompt in prompts[:1000]:
    if ("was about" not in prompt) and ("about" in prompt):
        about_pos = model.to_str_tokens(prompt).index(" about")
        _, cache = model.run_with_cache(prompt)
        mlp_activation = cache[hook_name][0, about_pos]
        mlp_acts.append(mlp_activation)

mean_mlp_acts = torch.stack(mlp_acts).mean(0).cpu()
direction_input_weight = ((mean_mlp_acts - l0_encoder.b_dec.cpu().numpy()) * l0_encoder.W_enc[:, l0_dir].cpu().numpy())

non_activation_df = pd.DataFrame({
    'mlp_neuron': np.arange(mean_mlp_acts.shape[0]),
    'mean_mlp_acts': mean_mlp_acts,
    "neuron_wise_direction_contribution": direction_input_weight,
})

non_activation_df = non_activation_df.sort_values("neuron_wise_direction_contribution", ascending=True)
print(non_activation_df.head(10))
top_negative_neurons = non_activation_df["mlp_neuron"].tolist()[:40]
print(top_negative_neurons)

      mlp_neuron  mean_mlp_acts  neuron_wise_direction_contribution
153          153       1.109615                           -0.137479
607          607       0.734240                           -0.060292
2997        2997       0.682363                           -0.058325
1848        1848       1.257419                           -0.051779
3553        3553       0.674238                           -0.049763
3000        3000      -0.004596                           -0.048372
2378        2378       0.583403                           -0.046815
3691        3691      -0.058508                           -0.043182
3641        3641      -0.001971                           -0.043008
2913        2913       0.796823                           -0.042306
[153, 607, 2997, 1848, 3553, 3000, 2378, 3691, 3641, 2913, 2405, 2702, 1906, 1201, 207, 1159, 1327, 711, 2237, 609, 2825, 3509, 1457, 965, 276, 1573, 3346, 3130, 162, 263, 1676, 2930, 1905, 1653, 1725, 3227, 663, 2529, 3906, 806]


In [170]:
relevant_positive_l0_neurons = activation_df[(activation_df["neuron_wise_direction_contribution"] > 0.2)]["mlp_neuron"].tolist()
relevant_l0_neurons = top_negative_neurons + relevant_positive_l0_neurons
print(len(relevant_l0_neurons), len(relevant_positive_l0_neurons), len(top_negative_neurons))

39 14 25


In [172]:
# Evaluate which neurons cause the direction to not activate when it incorrectly fires using the subset
full_mlp_direction_acts = []
neuron_subset_direction_acts = []
mlp_acts = []
os.environ["CUDA_LAUNCH_BLOCKING"]="1"
for i, prompt in enumerate(prompts[:1000]):
    #if "was about" not in prompt:
    if ("was about" not in prompt) and ("about" in prompt):
        # Check direction activation using subset of neurons
        neuron_subset_direction_act = get_direction_activation_from_neuron_subset(relevant_positive_l0_neurons, l0_dir, prompt, l0_encoder, model, pos=None)
        # Compare to direction activation using full set of neurons
        full_mlp_direction_act = get_direction_activation(l0_dir, prompt, l0_encoder, model, pos=None)
        
        incorrect_activating_positions = torch.argwhere((neuron_subset_direction_act > 0.5) & (full_mlp_direction_act < 0.2)).flatten()
        _, cache = model.run_with_cache(prompt)
        for position in incorrect_activating_positions.tolist():
            str_tokens = model.to_str_tokens(prompt)
            print(f"Prompt {i}:{position}, activation {neuron_subset_direction_act[position]:.2f}", str_tokens[position-5:position+1])
            mlp_activation = cache[hook_name][0, position]
            mlp_acts.append(mlp_activation)

        neuron_subset_direction_acts.append(neuron_subset_direction_act)
        full_mlp_direction_acts.append(full_mlp_direction_act)

print(torch.cat(neuron_subset_direction_acts).max(), torch.cat(full_mlp_direction_acts).max())
mlp_acts = torch.stack(mlp_acts)
print(mlp_acts.shape)

Prompt 1:148, activation 1.29 ['\n', 'R', 'oxy', ' told', ' Billy', ' about']
Prompt 10:65, activation 0.52 ['\n', 'Tim', ' and', ' Sue', ' thought', ' about']
Prompt 14:142, activation 0.76 [' with', ' you', '."', ' Max', ' thought', ' about']
Prompt 16:56, activation 0.67 ['.', ' She', ' was', ' very', ' curious', ' about']
Prompt 37:105, activation 0.50 ['\n', '\n', 'T', 'oot', ' thought', ' about']
Prompt 45:116, activation 1.09 [' and', ' told', ' all', ' their', ' friends', ' about']
Prompt 45:126, activation 0.63 [' met', '.', ' They', ' also', ' talked', ' about']
Prompt 48:142, activation 0.92 [' hard', ' that', ' he', ' forgot', ' all', ' about']
Prompt 49:35, activation 1.13 [',', ' she', ' told', ' her', ' mom', ' about']
Prompt 49:63, activation 0.50 ['"', 'What', ' did', ' you', ' dream', ' about']
Prompt 49:188, activation 0.77 [' with', ' your', ' toys', ' and', ' forget', ' about']
Prompt 56:41, activation 0.53 [' was', ' very', ' tall', ' and', ' made', ' of']
Prompt 

In [173]:
mean_mlp_acts = mlp_acts.mean(0).cpu()
direction_input_weight = ((mean_mlp_acts - l0_encoder.b_dec.cpu().numpy()) * l0_encoder.W_enc[:, l0_dir].cpu().numpy())

non_activation_df = pd.DataFrame({
    'mlp_neuron': np.arange(mean_mlp_acts.shape[0]),
    'mean_mlp_acts': mean_mlp_acts,
    "neuron_wise_direction_contribution": direction_input_weight,
})

non_activation_df = non_activation_df.sort_values("neuron_wise_direction_contribution", ascending=True)
print(non_activation_df.head(10))
num_negative_neurons = (non_activation_df["neuron_wise_direction_contribution"]<-0.02).sum()
print(len(non_activation_df), num_negative_neurons)
top_negative_neurons = non_activation_df["mlp_neuron"].tolist()[:num_negative_neurons]
print(top_negative_neurons)

      mlp_neuron  mean_mlp_acts  neuron_wise_direction_contribution
153          153       0.845790                           -0.105665
3000        3000      -0.004830                           -0.048411
3641        3641      -0.002681                           -0.043124
3691        3691      -0.054497                           -0.042555
2997        2997       0.479857                           -0.042491
607          607       0.495997                           -0.041947
207          207      -0.028422                           -0.039112
1201        1201      -0.083387                           -0.038899
1848        1848       0.906287                           -0.037239
2378        2378       0.421047                           -0.034917
4096 25
[153, 3000, 3641, 3691, 2997, 607, 207, 1201, 1848, 2378, 3553, 1327, 2702, 2237, 2913, 3509, 609, 2405, 1159, 1906, 711, 2825, 1457, 3130, 2930]


In [174]:
relevant_positive_l0_neurons = activation_df[(activation_df["neuron_wise_direction_contribution"] > 0.2)]["mlp_neuron"].tolist()
relevant_l0_neurons = top_negative_neurons + relevant_positive_l0_neurons
print(len(relevant_l0_neurons), len(relevant_positive_l0_neurons), len(top_negative_neurons))

39 14 25


In [175]:
full_mlp_direction_acts = []
neuron_subset_direction_acts = []
mlp_acts = []
os.environ["CUDA_LAUNCH_BLOCKING"]="1"
for i, prompt in enumerate(prompts[:1000]):
    #if "was about" not in prompt:
    if ("was about" not in prompt) and ("about" in prompt):
        neuron_subset_direction_act = get_direction_activation_from_neuron_subset(relevant_l0_neurons, l0_dir, prompt, l0_encoder, model, pos=None)
        full_mlp_direction_act = get_direction_activation(l0_dir, prompt, l0_encoder, model, pos=None)
        neuron_subset_direction_acts.append(neuron_subset_direction_act)
        full_mlp_direction_acts.append(full_mlp_direction_act)

print(torch.cat(neuron_subset_direction_acts).mean(), torch.cat(full_mlp_direction_acts).mean())


tensor(0.0016, device='cuda:0') tensor(0.0012, device='cuda:0')


In [177]:
subset_activation = torch.cat(neuron_subset_direction_acts).cpu().numpy()
full_activation = torch.cat(full_mlp_direction_acts).cpu().numpy()
print(subset_activation.shape, full_activation.shape)
# Creating a DataFrame
df = pd.DataFrame({
    'Subset Activation': subset_activation,
    'Full Activation': full_activation
})

df_melted = df.melt(var_name='Activation Type', value_name='Values')

fig = px.histogram(df_melted, x='Values', color='Activation Type', barmode='group', nbins=100)
fig.update_layout({
    # set yaxis range to max = 10
    'yaxis': {'range': [0, 100]},
    "width": 1200,
    "title": "Distribution of direction activations using negative examples"
})
fig.show()

(27755,) (27755,)


In [179]:
# Check top activating examples with new subset
top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=100, mode="top")

full_mlp_direction_acts = []
neuron_subset_direction_acts = []
for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
    neuron_subset_direction_act = get_direction_activation_from_neuron_subset(relevant_l0_neurons, l0_dir, prompt, l0_encoder, model, pos=pos)
    full_mlp_direction_act = get_direction_activation(l0_dir, prompt, l0_encoder, model, pos=pos)
    neuron_subset_direction_acts.append(neuron_subset_direction_act.item())
    full_mlp_direction_acts.append(full_mlp_direction_act.item())

print(np.min(neuron_subset_direction_acts), np.min(full_mlp_direction_acts))

subset_activation = np.array(neuron_subset_direction_acts)
full_activation = np.array(full_mlp_direction_acts)
print(subset_activation.shape, full_activation.shape)
# Creating a DataFrame
df = pd.DataFrame({
    'Subset Activation': subset_activation,
    'Full Activation': full_activation
})

df_melted = df.melt(var_name='Activation Type', value_name='Values')

fig = px.histogram(df_melted, x='Values', color='Activation Type', barmode='group', nbins=100)
fig.update_layout({
    # set yaxis range to max = 10
    'yaxis': {'range': [0, 100]},
    "width": 1200,
    "title": "Distribution of direction activation on 'was about to' examples"
})
fig.show()

3.5799736976623535 12.371442794799805
(100,) (100,)


In [125]:
# Get max activating tokens on subset

In [None]:
# Results
# Neurons activating direction are in superposition


# Nice to have
# pin down 2 sets of neurons - was neurons and about neurons

In [70]:
px.histogram(l0_encoder.W_enc[:, l0_dir].cpu().numpy() + l0_encoder.b_enc[l0_dir].item(), title="L0 direction input weights", width=1000)


In [67]:
df_melted = activation_df.melt(var_name='Variable', value_name='Values', value_vars=['mean_mlp_acts', 'mean_direction_mlp_impact', 'mean_acts_after_ablation', 'mean_zero_ablated_mlp_acts', 'mean_encoded_mlp_acts'])
fig = px.histogram(df_melted, x='Values', color='Variable', barmode='overlay')
fig.show()