In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder
import utils.haystack_utils as haystack_utils

# from joblib import Memory
# cachedir = '/workspace/cache'
# os.makedirs(cachedir, exist_ok=True)
# memory = Memory(cachedir, verbose=0, bytes_limit=20e9)

%reload_ext autoreload
%autoreload 2

In [2]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/323M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [3]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts


def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    indices = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)[:-1]
        value, index = acts.max(0)
        activations.append(value)
        indices.append(index)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc
    max_activation_token_index = torch.stack(indices)

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt, max_activation_token_index


def get_token_kurtosis_for_decoder(model: HookedTransformer, layer: int, decoder: torch.Tensor):
    '''Return excess kurtosis over all decoder features' cosine sims with the unembed (higher is better)'''
    W_out = model.W_out[layer]
    resid_dirs = torch.nn.functional.normalize(decoder @ W_out, dim=-1)
    unembed = torch.nn.functional.normalize(model.unembed.W_U, dim=0)
    cosine_sims = einops.einsum(resid_dirs, unembed, 'd_hidden d_model, d_model d_vocab -> d_hidden d_vocab')
    
    mean = einops.repeat(cosine_sims.mean(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    std = einops.repeat(cosine_sims.std(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    kurt = torch.mean(((cosine_sims - mean) / std) ** 4, dim=-1) - 3
    return kurt



In [4]:
# vocab_counts = get_occuring_tokens(model, load_tinystories_validation_prompts())
# percent_ever_occur = {vocab_counts.sum() / len(vocab_counts) * 100}
print(f'25% of tokens ever occur in the validation set')

25% of tokens ever occur in the validation set


In [5]:
# 1. List of clean features
# 2. Sort by indirect ablation increase
# 3. Sort for things in layer 1

In [6]:
l0_encoder, l0_config = load_encoder('18_morning_sun', model_name, model)
l1_encoder, l1_config = load_encoder('2_upbeat_snowball', model_name, model)
prompts = load_tinystories_validation_prompts()
# l1_kurtosis = get_token_kurtosis_for_decoder(model, 1, l1_encoder.W_dec)
# px.histogram(pd.DataFrame({"kurtosis": l1_kurtosis.cpu()}))   

Downloading TinyStories validation prompts


(INFO) 07:54:13: Loaded 21990 TinyStories validation prompts


## Clean direction ablation

In [7]:
def get_direction_ablation_hook(encoder, direction, hook_pos=None):
    def subtract_direction_hook(value, hook):
        
        x_cent = value[0, :] - encoder.b_dec
        acts = F.relu(x_cent @ encoder.W_enc[:, direction] + encoder.b_enc[direction])
        
        direction_impact_on_reconstruction = einops.einsum(acts, encoder.W_dec[direction, :], "pos, d_mlp -> pos d_mlp") # + encoder.b_dec ???
        if hook_pos is not None:
            value[:, hook_pos, :] -= direction_impact_on_reconstruction[hook_pos]
        else:
            value[:, :] -= direction_impact_on_reconstruction
        return value
    return subtract_direction_hook

def evaluate_direction_ablation(prompts: list[str], encoder: AutoEncoder, model: HookedTransformer, direction: int, cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    
    original_losses = []
    ablated_losses = []
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    for prompt in prompts:
        if pos is not None:
            original_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
        else:
            original_loss = model(prompt, return_type="loss")
        
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(encoder, direction, pos))]):
            if pos is not None:
                ablated_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
            else:
                ablated_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        ablated_losses.append(ablated_loss.item())
    return np.mean(original_losses), np.mean(ablated_losses)

original_loss, ablated_loss = evaluate_direction_ablation(prompts[0:2], l1_encoder, model, 0, l1_config, pos=None)
print(original_loss, ablated_loss)

1.138980358839035 1.1388814449310303


## Max activating examples

In [8]:
#@memory.cache
def get_activations(encoder, cfg):
    max_activations, max_activation_token_indices = get_max_activations(prompts, model, encoder, cfg)
    return max_activations, max_activation_token_indices

max_activations_l0, max_activation_token_indices_l0 = get_max_activations(prompts, model, l0_encoder, l0_config)

  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


In [9]:
max_activations_l1, max_activation_token_indices_l1 = get_max_activations(prompts, model, l1_encoder, l1_config)

  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


## Cosine sims

In [10]:
W_out = model.W_out[0]
W_in = model.W_in[1]
# d_hidden d_mlp d_mlp d_resid
# d_resid d_mlp d_mlp d_hidden
cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec @ W_out, dim=-1) @ torch.nn.functional.normalize(W_in @ l1_encoder.W_enc, dim=0)
cosine_sims = torch.tril(cosine_sims)
print(cosine_sims.shape)

all_sims = cosine_sims.flatten().cpu()

torch.Size([16384, 16384])


In [11]:
px.histogram(all_sims[torch.randperm(len(all_sims))][:10000], title="Randomly sampled cosine similarities between layer 0 and layer 1 features", width=1000)

In [12]:
values, indices = torch.topk(all_sims, 10000)
px.histogram(values.cpu().numpy(), title="Top 10k cosine similarities between layer 0 and layer 1 features", width=1000)

In [13]:
def i_to_row_col(i: int, n_cols: int = len(cosine_sims)):
    row = i // n_cols
    col = i % n_cols
    return row, col

def get_top_activating_examples_for_direction(prompts, direction, max_activations_per_prompt: Tensor, max_activation_token_indices, k=10, mode: Literal["lower", "middle", "upper", "top"]="top"):
    
    sorted_activations = max_activations_per_prompt[:, direction].sort().values
    num_non_zero_activations = sorted_activations[sorted_activations > 0].shape[0]

    max_activation = sorted_activations[-1]
    if mode=="upper":
        index = torch.argwhere(sorted_activations > ((max_activation // 3) * 2)).min()
    elif mode == "middle":
        index = torch.argwhere(sorted_activations > ((max_activation // 3))).min()
    else:
        index = torch.argwhere(sorted_activations > ((max_activation // 10))).min()
    negative_index = sorted_activations.shape[0] - index

    activations = max_activations_per_prompt[:, direction]
    _, prompt_indices = activations.topk(num_non_zero_activations+1)
    if mode=="top":
        prompt_indices = prompt_indices[:k]
    else:
        prompt_indices = prompt_indices[negative_index:negative_index+k]
    prompt_indices = prompt_indices[:num_non_zero_activations]

    top_prompts = [prompts[i] for i in prompt_indices]
    token_indices = max_activation_token_indices[prompt_indices, direction]
    return top_prompts, token_indices

#top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0)

In [14]:
# ablation increase on next token
# ablation increase over the whole prompt
# Baselines: ablate random active features 

def evaluate_direction_ablation_single_prompt(prompt: str, encoder: AutoEncoder, model: HookedTransformer, direction: int, cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    if pos is not None:
        original_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
    else:
        original_loss = model(prompt, return_type="loss")
    
    with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(encoder, direction, pos))]):
        if pos is not None:
            ablated_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
        else:
            ablated_loss = model(prompt, return_type="loss")
    return original_loss.item(), ablated_loss.item()

In [15]:
threshold = 0.3
n_high_cosine_sims = all_sims[all_sims > threshold].shape[0]
print(n_high_cosine_sims)

470


In [16]:
n = 470
top_cosine_similarities, top_cosine_sim_indices = torch.topk(all_sims, n)

## Top cosine sim pairs loss increases

In [None]:
# Losses for ablating L0 and L1 directions with high cosine sim (individually)
data = []
for top_cosine_index in tqdm(top_cosine_sim_indices):
    l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=100)
    original_losses = []
    l0_losses = []
    l1_losses = []
    for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
        original_loss, l0_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l0_encoder, model, l0_dir, l0_config, pos=pos)
        _, l1_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l1_encoder, model, l1_dir, l1_config, pos=pos)
        original_losses.append(original_loss)
        l0_losses.append(l0_ablated_loss)
        l1_losses.append(l1_ablated_loss)
    data.append([l0_dir.item(), l1_dir.item(), np.mean(original_losses), np.mean(l0_losses), np.mean(l1_losses)])
    #print(f"Direction {l0_dir} -> {l1_dir}: {np.mean(original_losses):.2f}, {np.mean(l0_losses):.2f}, {np.mean(l1_losses):.2f}")
df = pd.DataFrame(data, columns=["L0 direction", "L1 direction", "Original loss", "L0 direction ablation loss", "L1 direction ablation loss"])
df["Cosine similarity"] = top_cosine_similarities.tolist()
df.head()

In [22]:
df = pd.read_csv("data/cosine_sim_2L_tinystories.csv")

## Check within layer cosine similarity

In [23]:
# Check how similar high cosine sims are to each other (L0 decoder)
#l0_indices = np.unique([i_to_row_col(index)[0] for index in top_cosine_sim_indices])
#l0_cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec[l0_indices], dim=1) @ torch.nn.functional.normalize(l0_encoder.W_dec[l0_indices].T, dim=0)

l0_cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec, dim=1) @ torch.nn.functional.normalize(l0_encoder.W_dec.T, dim=0)

l0_cosine_sims = torch.tril(l0_cosine_sims, diagonal=-1)
top_l0_sims, top_l0_sim_indices = torch.topk(l0_cosine_sims.flatten(), 20000)
num_high_l0_sims = top_l0_sims[top_l0_sims > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 20k cosine similarities between layer 0 decoder features", width=1000)

Number of L0 cosine sims > 0.7: 643


In [24]:
def get_high_similarity_directions(direction: int, encoder: AutoEncoder, use_enc=False, threshold=0.8):
    if use_enc:
        cosine_sims = torch.nn.functional.normalize(encoder.W_enc.T, dim=1) @ torch.nn.functional.normalize(encoder.W_enc, dim=0)
    else:
        cosine_sims = torch.nn.functional.normalize(encoder.W_dec, dim=1) @ torch.nn.functional.normalize(encoder.W_dec.T, dim=0)

    cosine_sims = torch.tril(cosine_sims, diagonal=-1)[direction]
    similar_directions = torch.argwhere(cosine_sims > threshold).flatten()
    return similar_directions.tolist()

get_high_similarity_directions(14471, l0_encoder, use_enc=False)

[5846, 14403]

In [25]:
# Encoder sims
l0_encoder_sims = torch.nn.functional.normalize(l0_encoder.W_enc.T, dim=1) @ torch.nn.functional.normalize(l0_encoder.W_enc, dim=0)
l0_encoder_sims = torch.tril(l0_encoder_sims, diagonal=-1)

rows, cols = l0_encoder_sims.shape
mask = torch.tril(torch.ones(rows, cols, dtype=torch.bool), diagonal=-1)
lower_triangle_elements = l0_encoder_sims[mask]


sorted_elements = torch.sort(lower_triangle_elements).values
n_elements = len(sorted_elements)
median = sorted_elements[n_elements // 2]
upper_quartile = sorted_elements[(n_elements // 4) * 3]
top_quartile = sorted_elements[(n_elements // 10) * 9]
top_top_quartile = sorted_elements[(n_elements // 20) * 19]
print(f"Median: {median:.2f}, 75th percentile: {upper_quartile:.2f}, 90th percentile: {top_quartile:.2f}, 95th percentile: {top_top_quartile:.2f}")

top_l0_sims, top_l0_sim_indices = torch.topk(lower_triangle_elements, k=20000)
num_high_l0_sims = lower_triangle_elements[lower_triangle_elements > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 10k cosine similarities between layer 0 encoder features", width=1000)

Median: 0.10, 75th percentile: 0.13, 90th percentile: 0.18, 95th percentile: 0.24
Number of L0 cosine sims > 0.7: 1857407


In [26]:
# Baseline: check MLP in weights
l0_encoder_sims = torch.nn.functional.normalize(model.W_in[0].T, dim=1) @ torch.nn.functional.normalize(model.W_in[0], dim=0)
l0_encoder_sims = torch.tril(l0_encoder_sims, diagonal=-1)
rows, cols = l0_encoder_sims.shape
mask = torch.tril(torch.ones(rows, cols, dtype=torch.bool), diagonal=-1)
lower_triangle_elements = l0_encoder_sims[mask]
top_l0_sims, top_l0_sim_indices = torch.topk(lower_triangle_elements, 20000)
num_high_l0_sims = top_l0_sims[top_l0_sims > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 10k cosine similarities between layer 0 MLP in features", width=1000)

Number of L0 cosine sims > 0.7: 12


In [27]:
# Baseline: check MLP in weights
l0_encoder_sims = torch.nn.functional.normalize(model.W_out[0], dim=1) @ torch.nn.functional.normalize(model.W_out[0].T, dim=0)
l0_encoder_sims = torch.tril(l0_encoder_sims, diagonal=-1)
rows, cols = l0_encoder_sims.shape
mask = torch.tril(torch.ones(rows, cols, dtype=torch.bool), diagonal=-1)
lower_triangle_elements = l0_encoder_sims[mask]
top_l0_sims, top_l0_sim_indices = torch.topk(lower_triangle_elements, 20000)
num_high_l0_sims = top_l0_sims[top_l0_sims > 0.7].shape[0]
print(f"Number of L0 cosine sims > 0.7: {num_high_l0_sims}")
px.histogram(top_l0_sims.cpu().numpy(), title="Top 10k cosine similarities between layer 0 MLP out features", width=1000)

Number of L0 cosine sims > 0.7: 72


## Baseline loss increase

In [28]:
# Baseline
loss_increases = []
for prompt in tqdm(prompts[:100]):
    acts = get_acts(prompt, model, l0_encoder, l0_config)[:-1]
    pos = 20
    active_directions = torch.argwhere(acts[pos, :] > 0).flatten()
    direction = random.choice(active_directions)
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l0_encoder, model, direction, l0_config, pos=pos)
    loss_increases.append(ablated_loss - original_loss)
print(np.mean(loss_increases))
px.histogram(loss_increases)

  0%|          | 0/100 [00:00<?, ?it/s]

0.0018270882841898129


In [None]:
# Check l1 activation with and without l0 activation
data = []
for top_cosine_index in tqdm(top_cosine_sim_indices[:n]):
    l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=100)
    # Check l1 activation on the position
    # Ablate l0 direction and check l1 activation
    acts = []
    ablated_acts = []
    for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
        act = get_acts(prompt, model, l1_encoder, l1_config)[index, l1_dir].item()
        encoder_hook_point = f"blocks.{l0_config.layer}.{l0_config.act_name}"
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(l0_encoder, l0_dir, index))]):
            ablated_act = get_acts(prompt, model, l1_encoder, l1_config)[index, l1_dir].item()
        acts.append(act)
        ablated_acts.append(ablated_act)
    data.append([l0_dir.item(), l1_dir.item(), np.mean(acts), np.mean(ablated_acts)])
    #print(f"Ablating L0 direction {l0_dir} -> {l1_dir}: {np.mean(acts):.2f} -> {np.mean(ablated_acts):.2f}")
tmp_df = pd.DataFrame(data, columns=["L0 direction", "L1 direction", "L1 activation", "L1 activation after ablation"])
# Merge with df matching L0 and L1 direction
df = df.merge(tmp_df, on=["L0 direction", "L1 direction"])
df.head()

In [30]:
df

Unnamed: 0.1,Unnamed: 0,L0 direction,L1 direction,Original loss,L0 direction ablation loss,L1 direction ablation loss,Cosine similarity,L1 activation,L1 activation after ablation
0,0,14478,11523,1.678893,2.459859,2.047634,0.510517,4.571798,2.057087
1,1,12956,12739,0.617141,3.405645,0.648351,0.482806,5.641631,2.152821
2,2,1793,356,0.000336,3.815857,0.001447,0.470270,7.293748,3.115456
3,3,15326,15231,0.847430,2.815483,1.412397,0.467984,6.849681,2.837492
4,4,1931,1698,3.496340,9.480585,2.395638,0.465342,10.039371,2.133101
...,...,...,...,...,...,...,...,...,...
465,465,13244,61,1.186918,1.776977,1.795121,0.300565,7.050161,5.190381
466,466,15330,14383,1.434038,1.741103,1.345110,0.300543,5.730764,1.749411
467,467,11896,8091,0.954345,0.902033,1.079458,0.300487,1.366038,1.193021
468,468,15740,2890,1.298800,1.501370,1.444484,0.300453,3.866854,2.217294


In [None]:
# Want
# Features that don't look at only current token - dla automation
# Features that increase loss when ablated - test on max activating examples
# Features that causally influence later features - filter for cosine
# Features that are interpretable when looking at activating examples

In [None]:
# Check what other l0 features contribute to l1 activation - maybe we can find set that fully deactivates l1 - feature DLA?
# Look at activating examples manually for interpretability
# Check across positions
# Check if l0 features correspond to single tokens or do something more interesting

In [103]:
df.to_csv("data/cosine_sim_2L_tinystories.csv")

## Activating tokens

In [31]:
l0_dir = 12058
l1_dir = 10477

In [96]:
def eval_direction_tokens(direction, max_activations, max_activation_token_indices, prompts, model, encoder, cfg, percentage_threshold = 0.25):
    max_activation_value = max_activations[:, direction].max().item()
    num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    threshold = max_activation_value * percentage_threshold
    print(f"Threshold: {threshold:.2f}")

    activating_tokens = []
    for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
        str_tokens = model.to_str_tokens(prompt)
        activations = get_acts(prompt, model, encoder, cfg)
        activation = activations[index, direction].item()
        token = str_tokens[index]
        if activation > threshold:
            activating_tokens.append(token)
        else:
            break
    token_counts = Counter(activating_tokens)
    return token_counts

token_counts = eval_direction_tokens(l0_dir, max_activations_l0, max_activation_token_indices_l0, prompts, model, l0_encoder, l0_config)
print("L0 direction", l0_dir, token_counts.most_common(20))

token_counts = eval_direction_tokens(l1_dir, max_activations_l1, max_activation_token_indices_l1, prompts, model, l1_encoder, l1_config)
print("L1 direction", l1_dir, token_counts.most_common(20))

Threshold: 4.19
L0 direction 12058 [('Who', 188), (' Who', 37), ('who', 3)]
Threshold: 2.87
L1 direction 10477 [('Who', 188), (' Who', 41), (' who', 6), ('who', 3)]


In [98]:
# Check if direction always activates on tokens
def count_direction_activations_on_token(direction, count_tokens: Tensor, prompts, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig, threshold = 0.25):
    num_active = 0
    num_total = 0
    for prompt in tqdm(prompts):
        tokens = model.to_tokens(prompt)[0]
        token_positions = torch.argwhere(torch.isin(tokens, count_tokens)).flatten()

        if token_positions.shape[0] > 0:
            activations = get_acts(prompt, model, encoder, cfg)[token_positions, direction]
            num_active += activations[activations > threshold].shape[0]
            num_total += activations.shape[0]
    return num_active, num_total

count_tokens = model.to_tokens(["Who", " Who", "who"], prepend_bos=False).flatten()
count_direction_activations_on_token(l0_dir, count_tokens, prompts, model, l0_encoder, l0_config, threshold=4.19)


  0%|          | 0/21990 [00:00<?, ?it/s]

(249, 257)

## Check direction DLA


In [18]:
# def direction_dla(direction, max_activations, max_activation_token_indices, encoder, encoder_cfg):
#     num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
#     top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    
#     direction_weight = model.W_in[encoder_cfg.layer] @ encoder.W_enc[:, direction]
#     prompt = top_prompts[0]
#     pos = top_prompt_token_indices[0]
#     _, cache = model.run_with_cache(prompt)

#     activation = get_acts(prompt, model, encoder, encoder_cfg)[pos, direction].item()
#     print(activation)
#     decomposition, labels = cache.get_full_resid_decomposition(encoder_cfg.layer+1, apply_ln=True, return_labels=True, expand_neurons=False, pos_slice=pos)
#     decomposition = decomposition.squeeze(1)
#     dla = einops.einsum(decomposition, direction_weight, "component d_res, d_res -> component")
#     last_mlp_label = f"{encoder_cfg.layer}_mlp_out"
#     last_mlp_index = labels.index(last_mlp_label)
#     dla = [dla[i].item() for i in range(len(dla)) if i != last_mlp_index]
#     labels = [label for i, label in enumerate(labels) if i != last_mlp_index]
#     return dla, labels

def direction_dla(direction, max_activations, max_activation_token_indices, encoder, encoder_cfg, n=100):
    num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=num_non_zero_activations, mode="top")
    
    direction_weight = encoder.W_enc[:, direction]
    dlas = []
    for i in range(n):
        prompt = top_prompts[i]
        pos = top_prompt_token_indices[i]
        _, cache = model.run_with_cache(prompt)

        decomposition, labels = cache.get_full_resid_decomposition(encoder_cfg.layer, mlp_input=True, apply_ln=True, return_labels=True, expand_neurons=False, pos_slice=pos)
        decomposition = decomposition.squeeze(1)

        # Account for GELU in DLA by setting neuron contributions to 0 if they are not activated
        mlp_wise_decomposition = einops.einsum(decomposition, model.W_in[encoder_cfg.layer], "component d_res, d_res d_mlp -> component d_mlp")
        mlp_activations = cache[f"blocks.{encoder_cfg.layer}.mlp.hook_post"][0, pos, :]
        zeroed_neurons = torch.argwhere(mlp_activations <= 0).flatten()
        mlp_wise_decomposition[:, zeroed_neurons] = 0

        dla = einops.einsum(mlp_wise_decomposition, direction_weight, "component d_mlp, d_mlp -> component")
        dlas.append(dla)
    dla = torch.stack(dlas).mean(0).tolist()
    return dla, labels

dlas = []
for i in range(5):
    dla, labels = direction_dla(i, max_activations_l0, max_activation_token_indices_l0, l0_encoder, l0_config, n=1)
    embed_index = labels.index("embed")
    embed_dla = dla[embed_index]
    dlas.append(embed_dla)
    if embed_dla < 0:
        print(i, embed_dla)


In [19]:
l0_dir = 144

In [20]:
dla, labels = direction_dla(l0_dir, max_activations_l0, max_activation_token_indices_l0, l0_encoder, l0_config, n=50)

In [21]:
from utils.plotting_utils import line
line(dla, xticks=labels, title=f"DLA for L0 direction {l0_dir}", width=1000)

In [126]:
# Ablation loss
def evaluate_direction_loss_increase(prompts, direction, encoder: AutoEncoder, model: HookedTransformer, encoder_config: AutoEncoderConfig, max_activations, max_activation_token_indices, k=100):
    num_non_zero_activations = max_activations[:, direction].nonzero().shape[0]
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=min(k, num_non_zero_activations), mode="top")
    
    original_losses = []
    ablated_losses = []
    for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
        original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, encoder, model, direction, encoder_config, pos=index)
        original_losses.append(original_loss)
        ablated_losses.append(ablated_loss)
    return np.mean(original_losses), np.mean(ablated_losses)

evaluate_direction_loss_increase(prompts, l0_dir, l0_encoder, model, l0_config, max_activations_l0, max_activation_token_indices_l0)

(2.735752259194851, 2.72342669531703)

In [132]:
evaluate_direction_ablation_single_prompt("The boy was so happy.", l0_encoder, model, l0_dir, l0_config)

(3.8139283657073975, 3.8016443252563477)

## Boosted tokens

In [122]:
def get_common_tinystories_tokens(prompts, model: HookedTransformer, min_occurrences=100):
    occurrences = torch.zeros(model.cfg.d_vocab, dtype=torch.int32).cuda()
    for prompt in prompts: 
        tokens = model.to_tokens(prompt).flatten()
        occurrences = occurrences.index_add(0, tokens, torch.ones_like(tokens, dtype=torch.int32))
    common_tokens = torch.argwhere(occurrences > min_occurrences).flatten()
    rare_tokens = torch.argwhere(occurrences <= min_occurrences).flatten()
    return occurrences, common_tokens, rare_tokens

def get_direction_boosted_tokens(direction, encoder: AutoEncoder, model: HookedTransformer, cfg: AutoEncoderConfig, rare_tokens: Tensor):
    token_boosts = encoder.W_dec[direction] @ model.W_out[cfg.layer] @ model.unembed.W_U
    token_boosts[rare_tokens] = 0
    return token_boosts

def print_token_boosts(boosts, tokens):
    str_tokens = model.to_str_tokens(tokens)
    boost_str = ""
    for token, boost in zip(str_tokens, boosts.tolist()):
        boost_str += f"('{token}': {boost:.2f}), "
    print(boost_str[:-2])

occurrences, common_tokens, rare_tokens = get_common_tinystories_tokens(prompts, model)
print(occurrences.shape, common_tokens.shape, rare_tokens.shape)

torch.Size([50257]) torch.Size([2552]) torch.Size([47705])


In [123]:
l0_boosts = get_direction_boosted_tokens(l0_dir, l0_encoder, model, l0_config, rare_tokens)
top_boosts, top_tokens = torch.topk(l0_boosts, 25)
print(f"L0 direction {l0_dir}")
print_token_boosts(top_boosts, top_tokens)

# l1_boosts = get_direction_boosted_tokens(l1_dir, l1_encoder, model, l1_config, rare_tokens)
# top_boosts, top_tokens = torch.topk(l1_boosts, 25)
# print(f"L1 direction {l1_dir}")
# print_token_boosts(top_boosts, top_tokens)

L0 direction 144
(' liked': 0.28), (' all': 0.28), (' always': 0.27), (' impressed': 0.27), (' land': 0.27), (' machine': 0.26), (' black': 0.26), (' past': 0.26), (' before': 0.26), (' gray': 0.26), (' Everything': 0.26), (' proud': 0.25), (' forward': 0.25), (' white': 0.25), (' likes': 0.25), (' only': 0.25), (' wish': 0.24), (' amazed': 0.24), (' l': 0.24), (' print': 0.24), (' honey': 0.24), (' drove': 0.24), (' green': 0.23), (' tw': 0.23), (' smart': 0.23)


## Top dataset examples

In [124]:
# def print_max_activating_examples(l0_dir, l1_dir, prompts, max_activations, max_activation_token_indices, k=3, mode="top"):
#     print(f"Direction {l0_dir} -> {l1_dir}")
#     top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations, max_activation_token_indices, k=k, mode=mode)
#     max_activation_value = max_activations[:, l0_dir].max().item()
#     for prompt in top_prompts:
#         acts = get_acts(prompt, model, l0_encoder, l0_config)[:, l0_dir].tolist()
#         l1_acts = get_acts(prompt, model, l1_encoder, l1_config)[:, l1_dir].tolist()
#         str_tokens = model.to_str_tokens(prompt)
#         haystack_utils.clean_print_strings_as_html(str_tokens, acts, max_value=max_activation_value, additional_measures=[l1_acts], additional_measure_names = ["L1 act"])

def print_max_activating_examples(direction, encoder, config, prompts, max_activations, max_activation_token_indices, k=3, mode="top"):
    print(f"Direction {direction}")
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, direction, max_activations, max_activation_token_indices, k=k, mode=mode)
    max_activation_value = max_activations[:, direction].max().item()
    for prompt in top_prompts:
        acts = get_acts(prompt, model, encoder, config)[:, direction].tolist()
        str_tokens = model.to_str_tokens(prompt)
        haystack_utils.clean_print_strings_as_html(str_tokens, acts, max_value=max_activation_value)

print("L0 max activating examples")
print_max_activating_examples(l0_dir, l0_encoder, l0_config, prompts, max_activations_l0, max_activation_token_indices_l0, k=20, mode="top")

# print("L1 max activating examples")
# print_max_activating_examples(l1_dir, l1_encoder, l1_config, prompts, max_activations_l1, max_activation_token_indices_l1, k=5, mode="top")

L0 max activating examples
Direction 144
