In [2]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies
import utils.haystack_utils as haystack_utils

%reload_ext autoreload
%autoreload 2

In [40]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)

def load_encoder(save_name, model_name):
    with open(f"{model_name}/{save_name}.json", "r") as f:
        cfg = json.load(f)

    cfg = AutoEncoderConfig(
        cfg["layer"], cfg["act"], cfg["expansion_factor"], cfg["l1_coeff"]
    )

    if cfg.act_name == "hook_mlp_out":
        d_in = model.cfg.d_model
    else:
        d_in = model.cfg.d_mlp
    d_hidden = d_in * cfg.expansion_factor

    encoder = AutoEncoder(d_hidden, cfg.l1_coeff, d_in)
    encoder.load_state_dict(torch.load(os.path.join(model_name, save_name + ".pt")))
    encoder.to(device)
    return encoder, cfg


@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts


def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    indices = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)
        value, index = acts.max(0)
        activations.append(value)
        indices.append(index)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc
    max_activation_token_index = torch.stack(indices)

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt, max_activation_token_index


def get_token_kurtosis_for_decoder(model: HookedTransformer, layer: int, decoder: torch.Tensor):
    '''Return excess kurtosis over all decoder features' cosine sims with the unembed (higher is better)'''
    W_out = model.W_out[layer]
    resid_dirs = torch.nn.functional.normalize(decoder @ W_out, dim=-1)
    unembed = torch.nn.functional.normalize(model.unembed.W_U, dim=0)
    cosine_sims = einops.einsum(resid_dirs, unembed, 'd_hidden d_model, d_model d_vocab -> d_hidden d_vocab')
    
    mean = einops.repeat(cosine_sims.mean(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    std = einops.repeat(cosine_sims.std(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    kurt = torch.mean((cosine_sims - mean / std) ** 4, dim=-1) - 3
    return kurt

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [4]:
# 1. List of clean features
# 2. Sort by indirect ablation increase
# 3. Sort for things in layer 1

In [6]:
l0_encoder, l0_config = load_encoder('18_morning_sun', model_name)
l1_encoder, l1_config = load_encoder('2_upbeat_snowball', model_name)

# l1_kurtosis = get_token_kurtosis_for_decoder(model, 1, l1_encoder.W_dec)
# px.histogram(pd.DataFrame({"kurtosis": l1_kurtosis.cpu()}))   

In [41]:
prompts = load_tinystories_validation_prompts()[:5000]
max_activations, max_activation_token_indies = get_max_activations(prompts, model, l0_encoder, l0_config)

(INFO) 11:21:25: Loaded 21990 TinyStories validation prompts


  0%|          | 0/5000 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


In [42]:
max_activation_token_indices = max_activation_token_indies
max_activations_l0 = max_activations
max_activation_token_indices_l0 = max_activation_token_indices

In [25]:
cosine_sims = torch.nn.functional.normalize(l0_encoder.W_dec, dim=-1) @ torch.nn.functional.normalize(l1_encoder.W_enc, dim=0)
cosine_sims = torch.tril(cosine_sims)
print(cosine_sims.shape)

all_sims = cosine_sims.flatten().cpu()

torch.Size([16384, 16384])


In [34]:
# px.histogram(all_sims[torch.randperm(len(all_sims))][:10_000])

In [27]:
values, indices = torch.topk(all_sims, 10)

tensor([0.9031, 0.9005, 0.8999, 0.8994, 0.8970, 0.8961, 0.8960, 0.8953, 0.8949,
        0.8942])


In [33]:
def i_to_row_col(i: int, n_cols: int):
    row = i // n_cols
    col = i % n_cols
    return row, col

l0_dir, l1_dir = i_to_row_col(indices[0], len(cosine_sims))
print(l0_dir, l1_dir)

tensor(0.9031)
tensor(15671) tensor(1008)


In [39]:
print(max_activations_l0.shape)

torch.Size([5000, 16384])


In [None]:
# one prompt
# save direction activations
# get index of max direction activation per prompt from pre-existing data
# return loss per token, original and ablated
# index into loss with positions where directions active, calculate loss increase


def custom_forward(
    enc: AutoEncoder, x: Float[Tensor, "batch d_in"], neuron: int, activation: float
):
    x_cent = x - enc.b_dec
    acts = F.relu(x_cent @ enc.W_enc + enc.b_enc)
    acts[:, neuron] = activation
    x_reconstruct = acts @ enc.W_dec + enc.b_dec
    l2_loss = (x_reconstruct - x).pow(2).sum(-1).mean(0)
    l1_loss = enc.l1_coeff * (acts.abs().sum())
    loss = l2_loss + l1_loss
    return loss, x_reconstruct, acts, l2_loss, l1_loss


@torch.no_grad()
def evaluate_autoencoder_reconstruction_per_token(autoencoder: AutoEncoder, encoded_hook_name: str, data: list[str], model: HookedTransformer, reconstruction_loss_only: bool = False, show_tqdm=True):
    def encode_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = autoencoder(value)
        return x_reconstruct.unsqueeze(0)
    reconstruct_hooks = [(encoded_hook_name, encode_activations_hook)]

    def zero_ablate_hook(value, hook):
        value[:] = 0
        return value
    zero_ablate_hooks = [(encoded_hook_name, zero_ablate_hook)]
    
    original_losses = []
    reconstruct_losses = []
    zero_ablation_losses = []
    for prompt in tqdm(data, disable=(not show_tqdm)):
        with model.hooks(reconstruct_hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        reconstruct_losses.append(reconstruct_loss.item())
        if not reconstruction_loss_only:
            original_loss = model(prompt, return_type="loss")
            with model.hooks(zero_ablate_hooks):
                zero_ablate_loss = model(prompt, return_type="loss")
            original_losses.append(original_loss.item())
            zero_ablation_losses.append(zero_ablate_loss.item())

    if reconstruction_loss_only:
        return np.mean(reconstruct_losses)
    logging.info(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")
    return np.mean(original_losses), np.mean(reconstruct_losses), np.mean(zero_ablation_losses)


dirs = [i_to_row_col(i) for i in indices]
for l0_dir, l1_dir in dirs:
    
    # max_activations_l0[:, l0_dir]
    values, prompt_indices = torch.topk(max_activations_l0[:, l0_dir], k=100)
    
    for i in prompt_indices:
        prompt = prompts[i]
        pos_index = max_activation_token_indices_l0[i, l0_dir]
        with model.hooks():
            loss_per_token = model(prompt, loss_per_token=True)


    # measure ablation loss increase over these prompts for l0 - next token after high activation, return loss per token
    