In [88]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder
import utils.haystack_utils as haystack_utils

%reload_ext autoreload
%autoreload 2

In [22]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)


Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [33]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts


def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    indices = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)[:-1]
        value, index = acts.max(0)
        activations.append(value)
        indices.append(index)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc
    max_activation_token_index = torch.stack(indices)

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt, max_activation_token_index


def get_token_kurtosis_for_decoder(model: HookedTransformer, layer: int, decoder: torch.Tensor):
    '''Return excess kurtosis over all decoder features' cosine sims with the unembed (higher is better)'''
    W_out = model.W_out[layer]
    resid_dirs = torch.nn.functional.normalize(decoder @ W_out, dim=-1)
    unembed = torch.nn.functional.normalize(model.unembed.W_U, dim=0)
    cosine_sims = einops.einsum(resid_dirs, unembed, 'd_hidden d_model, d_model d_vocab -> d_hidden d_vocab')
    
    mean = einops.repeat(cosine_sims.mean(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    std = einops.repeat(cosine_sims.std(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    kurt = torch.mean((cosine_sims - mean / std) ** 4, dim=-1) - 3
    return kurt



In [32]:
torch.cuda.empty_cache()

In [3]:
# vocab_counts = get_occuring_tokens(model, load_tinystories_validation_prompts())
# percent_ever_occur = {vocab_counts.sum() / len(vocab_counts) * 100}
print(f'25% of tokens ever occur in the validation set')

25% of tokens ever occur in the validation set


In [4]:
# 1. List of clean features
# 2. Sort by indirect ablation increase
# 3. Sort for things in layer 1

In [99]:
l0_encoder, l0_config = load_encoder('18_morning_sun', model_name, model)
l1_encoder, l1_config = load_encoder('2_upbeat_snowball', model_name, model)
prompts = load_tinystories_validation_prompts()[:10000]
# l1_kurtosis = get_token_kurtosis_for_decoder(model, 1, l1_encoder.W_dec)
# px.histogram(pd.DataFrame({"kurtosis": l1_kurtosis.cpu()}))   

Once upon a time there was an old man who wanted to get to the other side of the river. Unfortunately, the river was too wide for him to cross and he had no boat. He was stuck. As he stood there, looking sadly at the river, he spotted a stubborn swan swimming towards him. The old man was surprised and sad, but he decided to ask the swan for help. 

He shouted to the swan, â€œCan you help me to cross the river?â€
The swan then stopped swimming and replied, â€œI can help, but you will have to split up my feathers to make a raft.â€

The old man agreed and started picking up feathers and piling them up. After a while, he had enough feathers to make a raft. The old man then climbed onto the raft and started paddling across the river. He reached the other side and thanked the swan, who swam away. 

The old man was happy and grateful to the stubborn swan for helping him. He went on his way, never forgetting what the swan had done for him.


## Clean direction ablation

In [55]:
def get_direction_ablation_hook(encoder, direction, hook_pos=None):
    def subtract_direction_hook(value, hook):
        x_cent = value[0, :] - encoder.b_dec
        acts = F.relu(x_cent @ encoder.W_enc[:, direction] + encoder.b_enc[direction])
        x_reconstruct = einops.einsum(acts, encoder.W_dec[direction, :], "pos, d_mlp -> pos d_mlp") + encoder.b_dec
        if hook_pos is not None:
            value[:, hook_pos, :] -= x_reconstruct[hook_pos]
        else:
            value[:, :] -= x_reconstruct
        return value
    return subtract_direction_hook

def evaluate_direction_ablation(prompts: list[str], encoder: AutoEncoder, model: HookedTransformer, direction: int, cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    
    original_losses = []
    ablated_losses = []
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    for prompt in prompts:
        if pos is not None:
            original_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
        else:
            original_loss = model(prompt, return_type="loss")
        
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(encoder, direction, pos))]):
            if pos is not None:
                ablated_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
            else:
                ablated_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        ablated_losses.append(ablated_loss.item())
    return np.mean(original_losses), np.mean(ablated_losses)

original_loss, ablated_loss = evaluate_direction_ablation(prompts[:2], l1_encoder, model, 0, l1_config, pos=None)
print(original_loss, ablated_loss)

1.1389805674552917 1.2179880142211914


## Max activating examples

In [34]:
max_activations_l0, max_activation_token_indices_l0 = get_max_activations(prompts[:5000], model, l0_encoder, l0_config)

  0%|          | 0/5000 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


In [35]:
max_activations_l0.shape

torch.Size([5000, 16384])

## Cosine sims

In [9]:
W_out = model.W_out[0]
W_in = model.W_in[1]
# d_hidden d_mlp d_mlp d_resid
# d_resid d_mlp d_mlp d_hidden
cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec @ W_out, dim=-1) @ torch.nn.functional.normalize(W_in @ l1_encoder.W_enc, dim=0)
cosine_sims = torch.tril(cosine_sims)
print(cosine_sims.shape)

all_sims = cosine_sims.flatten().cpu()

torch.Size([16384, 16384])


In [10]:
px.histogram(all_sims[torch.randperm(len(all_sims))][:10_000])

In [11]:
values, indices = torch.topk(all_sims, 100)
print(values[:10])

In [13]:
def i_to_row_col(i: int, n_cols: int = len(cosine_sims)):
    row = i // n_cols
    col = i % n_cols
    return row, col

l0_dir, l1_dir = i_to_row_col(indices[0], len(cosine_sims))
print(l0_dir, l1_dir)

tensor(14478) tensor(11523)


In [102]:
def get_top_activating_examples_for_direction(prompts, direction, max_activations_per_prompt, max_activation_token_indices, k=10):
    activations = max_activations_per_prompt[:, direction]
    _, prompt_indices = activations.topk(10)
    top_prompts = [prompts[i] for i in prompt_indices]
    token_indices = max_activation_token_indices[prompt_indices, direction]
    return top_prompts, token_indices

top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0)

In [16]:
# ablation increase on next token
# ablation increase over the whole prompt
# Baselines: ablate random active features 

def evaluate_direction_ablation_single_prompt(prompt: str, encoder: AutoEncoder, model: HookedTransformer, direction: int, cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    if pos is not None:
        original_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
    else:
        original_loss = model(prompt, return_type="loss")
    
    with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(encoder, direction, pos))]):
        if pos is not None:
            ablated_loss = model(prompt, return_type="loss", loss_per_token=True)[0, pos]
        else:
            ablated_loss = model(prompt, return_type="loss")
    return original_loss.item(), ablated_loss.item()

In [109]:
# Losses for ablating L0 and L1 directions with high cosine sim (individually)
n = 50
top_cosine_similarities, top_cosine_sim_indices = torch.topk(all_sims, n)
data = []
for top_cosine_index in tqdm(top_cosine_sim_indices):
    l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=25)
    original_losses = []
    l0_losses = []
    l1_losses = []
    for prompt, pos in zip(top_prompts, top_prompt_token_indices.tolist()):
        original_loss, l0_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l0_encoder, model, l0_dir, l0_config, pos=pos)
        _, l1_ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l1_encoder, model, l1_dir, l1_config, pos=pos)
        original_losses.append(original_loss)
        l0_losses.append(l0_ablated_loss)
        l1_losses.append(l1_ablated_loss)
    data.append([l0_dir.item(), l1_dir.item(), np.mean(original_losses), np.mean(l0_losses), np.mean(l1_losses)])
    #print(f"Direction {l0_dir} -> {l1_dir}: {np.mean(original_losses):.2f}, {np.mean(l0_losses):.2f}, {np.mean(l1_losses):.2f}")
df = pd.DataFrame(data, columns=["L0 direction", "L1 direction", "Original loss", "L0 direction ablation loss", "L1 direction ablation loss"])
df["Cosine similarity"] = top_cosine_similarities.tolist()
df

  0%|          | 0/50 [00:00<?, ?it/s]

Direction 14478 -> 11523: 2.61, 4.06, 3.16
Direction 12956 -> 12739: 0.82, 3.70, 0.77
Direction 1793 -> 356: 0.00, 5.54, 0.01
Direction 15326 -> 15231: 0.84, 3.22, 1.56
Direction 1931 -> 1698: 3.77, 9.56, 3.08
Direction 16319 -> 13895: 0.66, 2.20, 0.55
Direction 9990 -> 95: 1.16, 2.19, 1.10
Direction 13998 -> 13826: 0.15, 0.71, 0.20
Direction 12704 -> 10081: 0.29, 3.84, 0.85
Direction 14358 -> 1009: 0.01, 2.78, 0.06
Direction 3619 -> 3117: 1.38, 1.55, 1.27
Direction 7020 -> 5658: 1.09, 1.77, 1.00
Direction 13798 -> 8061: 0.01, 0.95, 0.05
Direction 11663 -> 7023: 1.62, 3.17, 2.11
Direction 8593 -> 7932: 0.23, 1.59, 0.37
Direction 12900 -> 8663: 0.57, 2.30, 0.71
Direction 12148 -> 9988: 1.75, 2.79, 1.88
Direction 12058 -> 10477: 1.80, 4.05, 2.05
Direction 10929 -> 7023: 1.89, 1.81, 1.97
Direction 15355 -> 7066: 0.08, 0.55, 0.12
Direction 12472 -> 9712: 1.25, 1.95, 1.38
Direction 14562 -> 9508: 1.14, 2.13, 1.69
Direction 13074 -> 2587: 1.20, 1.75, 1.36
Direction 10805 -> 8651: 0.37, 1.66,

Unnamed: 0,L0 direction,L1 direction,Original loss,L0 direction ablation loss,L1 direction ablation loss,Cosine similarity
0,14478,11523,2.607063,4.055653,3.164762,0.510517
1,12956,12739,0.820649,3.703449,0.770419,0.482807
2,1793,356,0.000417,5.535827,0.005061,0.47027
3,15326,15231,0.837235,3.221702,1.561258,0.467984
4,1931,1698,3.774686,9.558991,3.083883,0.465342
5,16319,13895,0.657223,2.198068,0.553014,0.458912
6,9990,95,1.162095,2.19175,1.096082,0.454281
7,13998,13826,0.14863,0.712602,0.197033,0.450984
8,12704,10081,0.289116,3.836256,0.848559,0.45006
9,14358,1009,0.007969,2.778449,0.056052,0.448014


In [46]:
# Baseline
loss_increases = []
for prompt in tqdm(prompts[:100]):
    acts = get_acts(prompt, model, l0_encoder, l0_config)[:-1]
    pos = 20
    active_directions = torch.argwhere(acts[pos, :] > 0).flatten()
    direction = random.choice(active_directions)
    original_loss, ablated_loss = evaluate_direction_ablation_single_prompt(prompt, l0_encoder, model, direction, l0_config, pos=pos)
    loss_increases.append(ablated_loss - original_loss)
print(np.mean(loss_increases))
px.histogram(loss_increases)

  0%|          | 0/100 [00:00<?, ?it/s]

0.1483508816806716


In [110]:
# Check l1 activation with and without l0 activation
data = []
for top_cosine_index in tqdm(top_cosine_sim_indices[:n]):
    l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=25)
    # Check l1 activation on the position
    # Ablate l0 direction and check l1 activation
    acts = []
    ablated_acts = []
    for prompt, index in zip(top_prompts, top_prompt_token_indices.tolist()):
        act = get_acts(prompt, model, l1_encoder, l1_config)[index, l1_dir].item()
        encoder_hook_point = f"blocks.{l0_config.layer}.{l0_config.act_name}"
        with model.hooks(fwd_hooks=[(encoder_hook_point, get_direction_ablation_hook(l0_encoder, l0_dir, index))]):
            ablated_act = get_acts(prompt, model, l1_encoder, l1_config)[index, l1_dir].item()
        acts.append(act)
        ablated_acts.append(ablated_act)
    data.append([l0_dir.item(), l1_dir.item(), np.mean(acts), np.mean(ablated_acts)])
    #print(f"Ablating L0 direction {l0_dir} -> {l1_dir}: {np.mean(acts):.2f} -> {np.mean(ablated_acts):.2f}")
tmp_df = pd.DataFrame(data, columns=["L0 direction", "L1 direction", "L1 activation", "L1 activation after ablation"])
# Merge with df matching L0 and L1 direction
df = df.merge(tmp_df, on=["L0 direction", "L1 direction"])
df.head()

  0%|          | 0/50 [00:00<?, ?it/s]

Ablating L0 direction 14478 -> 11523: 4.13 -> 1.93
Ablating L0 direction 12956 -> 12739: 5.90 -> 2.50
Ablating L0 direction 1793 -> 356: 7.08 -> 3.25
Ablating L0 direction 15326 -> 15231: 5.91 -> 2.17
Ablating L0 direction 1931 -> 1698: 9.57 -> 1.85
Ablating L0 direction 16319 -> 13895: 4.79 -> 2.37
Ablating L0 direction 9990 -> 95: 5.25 -> 2.46
Ablating L0 direction 13998 -> 13826: 3.82 -> 2.20
Ablating L0 direction 12704 -> 10081: 6.00 -> 2.36
Ablating L0 direction 14358 -> 1009: 6.16 -> 2.08
Ablating L0 direction 3619 -> 3117: 0.00 -> 0.27
Ablating L0 direction 7020 -> 5658: 0.18 -> 0.15
Ablating L0 direction 13798 -> 8061: 5.72 -> 4.00
Ablating L0 direction 11663 -> 7023: 5.21 -> 2.47
Ablating L0 direction 8593 -> 7932: 6.93 -> 3.47
Ablating L0 direction 12900 -> 8663: 6.38 -> 3.13
Ablating L0 direction 12148 -> 9988: 3.74 -> 1.95
Ablating L0 direction 12058 -> 10477: 10.04 -> 2.25
Ablating L0 direction 10929 -> 7023: 4.60 -> 1.97
Ablating L0 direction 15355 -> 7066: 5.69 -> 3.67
A

Unnamed: 0,L0 direction,L1 direction,Original loss,L0 direction ablation loss,L1 direction ablation loss,Cosine similarity,L1 activation,L1 activation after ablation
0,14478,11523,2.607063,4.055653,3.164762,0.510517,4.133888,1.934463
1,12956,12739,0.820649,3.703449,0.770419,0.482807,5.902056,2.501862
2,1793,356,0.000417,5.535827,0.005061,0.47027,7.080084,3.247411
3,15326,15231,0.837235,3.221702,1.561258,0.467984,5.9066,2.174021
4,1931,1698,3.774686,9.558991,3.083883,0.465342,9.568707,1.848159


In [None]:
# Check what other l0 features contribute to l1 activation - maybe we can find set that fully deactivates l1 - feature DLA?
# Look at activating examples manually for interpretability
# Check across positions
# Check if l0 features correspond to single tokens or do something more interesting

In [112]:
df.to_csv("data/cosine_sim_2L_tinystories.csv")

In [104]:
def print_max_activating_examples(l0_dir, l1_dir, prompts, max_activations, max_activation_token_indices, k=3):
    print(f"Direction {l0_dir} -> {l1_dir}")
    top_prompts, top_prompt_token_indices = get_top_activating_examples_for_direction(prompts, l0_dir, max_activations_l0, max_activation_token_indices_l0, k=k)
    for prompt in top_prompts:
        acts = get_acts(prompt, model, l0_encoder, l0_config)[:, l0_dir].tolist()
        l1_acts = get_acts(prompt, model, l1_encoder, l1_config)[:, l1_dir].tolist()
        str_tokens = model.to_str_tokens(prompt)
        tokens = model.to_tokens(prompt)
        haystack_utils. clean_print_strings_as_html(str_tokens, acts, max_value=10, additional_measures=[l1_acts], additional_measure_names = ["L1 act"])

for top_cosine_index in tqdm(top_cosine_sim_indices[:10]):
    l0_dir, l1_dir = i_to_row_col(top_cosine_index, len(cosine_sims))
    print_max_activating_examples(l0_dir, l1_dir, prompts, max_activations, max_activation_token_indices)

  0%|          | 0/10 [00:00<?, ?it/s]

Direction 14478 -> 11523


Direction 12956 -> 12739


Direction 1793 -> 356


Direction 15326 -> 15231


Direction 1931 -> 1698


Direction 16319 -> 13895


Direction 9990 -> 95


Direction 13998 -> 13826


Direction 12704 -> 10081


Direction 14358 -> 1009
