In [2]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from functools import lru_cache
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')


sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder, generate_with_encoder
from utils.haystack_utils import get_occurring_tokens
from sparse_coding.model_qual_eval import get_max_activations

%reload_ext autoreload
%autoreload 2

In [3]:
# Run overview
model_name = "tiny-stories-33M"
layer_name = "L1"
print_model_name = f"{model_name}-{layer_name}"

# df = pd.read_csv(f"{model_name}/wandb_runs.csv")
# df = df.sort_values(by="l1_coeff", ascending=True)
# df.columns

# save_names = [f.split(".")[0] for f in os.listdir(model_name) if f.endswith('.pt')]
save_names = ['72_confused_silence']

In [4]:
# Wandb data
# fig = px.line(df, x="l1_coeff", y=["l2_loss", "l1_loss", "avg_directions"], markers=True, title=f"{print_model_name}: L1 loss, L2 loss, and average number of active directions")
# fig.update_layout(
#     xaxis_title="L1 coefficient",
#     yaxis_title="",
#     legend_title="",
#     width = 800,
#     xaxis={'tickformat':'.1e'}
# )
# fig.update_xaxes(type='linear')
# fig.show()

In [5]:
prompts = load_tinystories_validation_prompts()

(INFO) 05:24:42: Loaded 21990 TinyStories validation prompts


In [6]:
model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device,
    )

Loaded pretrained model tiny-stories-33M into HookedTransformer


In [7]:
loss_data = []
for save_name in tqdm(save_names):
    encoder, cfg = load_encoder(save_name, model_name, model)
    original_loss, encoder_loss, zero_ablation_loss = evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:200], model)
    loss_data.append([cfg.l1_coeff, original_loss, encoder_loss, zero_ablation_loss])
loss_df = pd.DataFrame(loss_data, columns=["L1 coefficient", "Original Loss", "Reconstruction Loss", "Zero Ablation Loss"])
loss_df = loss_df.sort_values(by="L1 coefficient", ascending=True)
loss_df["L1 coefficient"] = loss_df["L1 coefficient"].astype(str)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:07<00:00, 27.56it/s]
(INFO) 05:24:55: Average loss increase after encoding: 0.0652


In [8]:
loss_df = loss_df.melt(id_vars=["L1 coefficient"], var_name="Loss Type", value_name="Loss", value_vars=["Original Loss", "Reconstruction Loss", "Zero Ablation Loss"])
fig = px.line(loss_df, x="L1 coefficient", y="Loss", color="Loss Type", markers=True,  title=f"{print_model_name}: Encoder reconstruction loss increase")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="Loss increase",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

In [9]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts

# def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
#     activations = []
#     for prompt in tqdm(prompts):
#         acts = get_acts(prompt, model, encoder, cfg)
#         max_prompt_activation = acts.max(0)[0]
#         activations.append(max_prompt_activation)

#     max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc

#     total_activations = max_activation_per_prompt.sum(0)
#     print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
#     return max_activation_per_prompt

def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [10]:
# l1 = 2e-4
# def get_encoder_by_l1(encoders, l1_coeff):
#     for encoder, cfg in encoders:
#         if cfg.l1_coeff == l1_coeff:
#             return encoder, cfg
#     raise ValueError(f"Encoder with L1 coefficient {l1_coeff} not found")
# encoder, cfg = get_encoder_by_l1(encoders, l1)
encoder, cfg = load_encoder(save_names[0], model_name, model)
print(f"Encoder L1 coefficient: {cfg.l1_coeff}")

Encoder L1 coefficient: 0.0001


In [11]:
# total_tokens = 0
# for prompt in prompts: 
#     tokens = model.to_tokens(prompt)
#     total_tokens += torch.numel(tokens)
# print(total_tokens)

# 4_765_918

In [13]:
# feature_frequencies = get_encoder_feature_frequencies(prompts, model, encoder, cfg)
# zero_activating_features = (feature_frequencies == 0).sum(0).item()
# low_density = ((feature_frequencies > 0) & (feature_frequencies < 1e-6)).sum(0).item()
# high_density = (feature_frequencies > 1e-6).sum(0).item()
# print(zero_activating_features, low_density, high_density)
# fig = px.histogram(feature_frequencies.cpu().numpy(), histnorm='probability', title=f"{print_model_name} L1={cfg.l1_coeff}: Histogram of feature frequencies", nbins=40)
# fig.update_yaxes(type='log')
# fig.update_layout(xaxis_title="Feature frequency", yaxis_title="Probability", showlegend=False, width=600)

No low frequency features
Should we have them?
If yes: 
- L1 too low
- Expansion too low

In [14]:
max_activation_per_prompt = get_max_activations(tuple(prompts), model, encoder, cfg)

  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 24576 out of 24576



Persisting input arguments took 3.19s to run.If this happens often in your code, it can cause performance problems (results will be correct in all cases). The reason for this is probably some large input arguments for a wrapped function.



In [15]:
def plot_direction_frequency(data: list[str], direction: int, cfg: AutoEncoderConfig):
    activations = []
    for prompt in tqdm(data):
        tokens = model.to_tokens(prompt)
        _, cache = model.run_with_cache(
            tokens, names_filter=f"blocks.{cfg.layer}.{cfg.act_name}"
            )
        acts = cache[f"blocks.{cfg.layer}.{cfg.act_name}"].squeeze(0)
        _, _, mid_acts, _, _ = encoder(acts)
        activations.append(mid_acts[:, direction])
    activations = torch.cat(activations)
    print(activations.shape)

    fig = px.histogram(activations.tolist(), 
                       title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", 
                       histnorm="probability")
    fig.update_layout(
        xaxis_title="Activation",
        yaxis_title="Probability",
        width = 600,
        showlegend=False
    )
    fig.update_yaxes(type='log')
    fig.show()

direction = 0
plot_direction_frequency(prompts[:50], direction, cfg)

  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


In [16]:
def print_direction_example(direction, n=10):
    print_top_examples(prompts, max_activation_per_prompt, direction, n)

# Max activations
interact(print_direction_example, 
         direction=IntSlider(min=0, max=encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))

# 528: "The end." and preceeding text
# 549 activates on "they hadn't" and relevant tokens before hand
# 556 tracks objects
# 561 pronouns
# Most other directions max activate on "butterfly" and "out" - the ultra low density cluster perhaps?



interactive(children=(IntSlider(value=0, description='direction', max=24575), IntSlider(value=5, description='…

<function __main__.print_direction_example(direction, n=10)>

In [17]:
# cosine sims of butterfly dirs for 18_kind_sound
# butterfly_dirs = [597, 598, 599, 600, 601, 602]
# for direction in butterfly_dirs:
#     print(F.normalize(encoder.W_dec[597], dim=0) @ F.normalize(encoder.W_dec[direction], dim=0))



In [18]:
# Activations of different directions on the same token
prompt = "One day, a little girl named Lily went for a walk in the park"
acts = get_acts(prompt, model, encoder, cfg)[-1] # d_enc
print(f"Active directions on last token: {acts.nonzero().shape[0]} out of {acts.shape[0]}")
active_directions = acts.nonzero().squeeze(1)
highly_active_directions = torch.argwhere((acts > 0.5)).squeeze(1)
low_active_directions = torch.argwhere((acts < 0.5) & (acts > 0.1)).squeeze(1)
px.histogram(acts.cpu().numpy(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for prompt", histnorm="probability", nbins=40)

Active directions on last token: 204 out of 24576


In [19]:
# for active_direction in low_active_directions[:5]:#active_directions[:10]:
#     print(f"Direction {active_direction}")
#     print_top_examples(prompts, max_activation_per_prompt, active_direction, 2)

In [20]:
# Refresh on normalize dimension
# F.normalize(torch.rand(3, 5), p=1, dim=0)

In [21]:
def get_token_kurtosis_for_feature(model: HookedTransformer, decoder_feature: torch.Tensor, layer: int):
    resid_dirs = F.normalize(decoder_feature @ model.W_out[layer], dim=-1)
    unembed = F.normalize(model.unembed.W_U, dim=0)
    sims = einops.einsum(resid_dirs, unembed, 'd_model, d_model d_vocab -> d_vocab')

    # filter out tokens that don't occur in the validation set (ideally should be over the training set)
    vocab_occurs = get_occurring_tokens(model, tuple(load_tinystories_validation_prompts()))
    sims = sims[vocab_occurs == 1]

    mean = torch.mean(sims)
    variance = torch.mean((sims - mean) ** 2)
    std = torch.sqrt(variance)
    excess_kurt = torch.mean(((sims - mean) / std) ** 4) - 3
    return excess_kurt


def top_boosted_tokens(model: HookedTransformer, decoder_feature: torch.Tensor, layer: int, k=10, plot=False):
    token_boosts = F.normalize(decoder_feature @ model.W_out[layer], dim=0) @ model.unembed.W_U

    all_ignore, _ = haystack_utils.get_weird_tokens(model, plot_norms=False)
    values, tokens = haystack_utils.top_k_with_exclude(token_boosts, k, exclude=all_ignore)
    boosted_labels = model.to_str_tokens(tokens)

    if plot:
        assert k < 300, "Too many tokens to plot"
        fig = haystack_utils.line(x=values.cpu().numpy(), xticks=boosted_labels, title=f"Boosted tokens", width=1200)
        fig.show()

    return boosted_labels


def get_token_kurtosis_for_decoder(
        model: HookedTransformer, 
        decoder: Float[Tensor, 'd_hidden'], 
        layer: int, 
        vocab_mask=None):
    '''Excess kurtosis over cosine sims of decoder features and unembed items (higher is better)'''
    resid_dirs: Float[Tensor, 'd_hidden d_model'] = F.normalize(decoder @ model.W_out[layer], dim=-1)
    unembed: Float[Tensor, 'd_model d_vocab'] = F.normalize(model.unembed.W_U, dim=0)
    cosine_sims = einops.einsum(resid_dirs, unembed, 'd_hidden d_model, d_model d_vocab -> d_hidden d_vocab')
    if vocab_mask is not None:
        cosine_sims = cosine_sims[:, vocab_mask == 1]
    
    mean = cosine_sims.mean(dim=-1).unsqueeze(1)
    std = cosine_sims.std(dim=-1).unsqueeze(1) + 1e-9
    kurt = torch.mean(((cosine_sims - mean) / std) ** 4, dim=-1) - 3
    
    
    return kurt

vocab_occurs = get_occurring_tokens(model, tuple(load_tinystories_validation_prompts()))
scores = get_token_kurtosis_for_decoder(model, encoder.W_dec, cfg.layer, vocab_mask=None)

In [22]:
# Viz feature-token kurtosis
# px.histogram(scores.cpu())

In [23]:
# Top kurtosis feature-token clusters
values, indices = torch.topk(scores, k=100)
for value, i in zip(values, indices):
    print(f'{value.item():2f}')
    print(top_boosted_tokens(model, encoder.W_dec[i], cfg.layer))

13.847544
[' Lily', ' Tim', ' Emily', ' Sarah', ' Lucy', ' Billy', ' Sara', ' Max', ' Emma', ' Tommy']
12.035836
[' Bella', ' Lisa', ' Anna', ' Kate', ' Mia', ' Daisy', ' Emma', ' Annie', ' Beth', ' Sally']
11.910196
['tight', ' boasting', ' unintended', ' smack', ' thrust', ' settles', 'toe', ' deepening', ' liquids', 'eating']
11.213412
[' thrust', 'life', ' richness', ' successes', 'feel', ' liquids', ' upl', 'collect', ' enthusi', ' powered']
10.907064
[' Alex', ' Jake', ' Joe', ' Tony', ' Max', ' Josh', ' Jacob', ' Harry', ' James', ' Bobby']
10.675691
['rupt', ' intact', ' Building', ' arrivals', ' unintended', 'eat', ' practices', 'oke', 'sing', 'usted']
10.449274
['onsense', ' Direct', ' retire', ' robes', 'Words', ' overpowered', ' wills', ' thumbs', 'Enc', ' enc']
10.227289
[' hardest', 'mares', ' stroke', ' Rewards', ' races', ' 17', 'ted', 'Cast', ' county', 'pping']
10.010903
['Rh', 'Ti', ' tribe', ' eruption', 'Graham', ' eve', 'keleton', ' Jets', ' Ka', 'rises']
9.633762

In [24]:
print(f'25% of tokens ever occur in the validation set')
# vocab_counts = get_occuring_tokens(model, load_tinystories_validation_prompts())
# percent_ever_occur = {vocab_counts.sum() / len(vocab_counts) * 100}

25% of tokens ever occur in the validation set


In [25]:
# Viz max activating examples for most active directions in sample prompt
# for active_direction in highly_active_directions:#active_directions[:10]:
#     print(f"Direction {active_direction}")
#     print_top_examples(prompts, max_activation_per_prompt, active_direction, 4)

In [26]:
# Get prompts ending in sample token
token = model.to_single_token(" park")
token_prompts = []
for prompt in prompts[:1000]:
    tokens = model.to_tokens(prompt)
    if token in tokens:
        token_prompts.append(prompt)
print(len(token_prompts))


163


In [27]:
# # Direction frequency histograms
# for direction in active_directions[:10]:
#     plot_direction_frequency(prompts[:50], direction, cfg)

In [28]:
# # Max activating examples for most active direction in each sample prompt
# for prompt in token_prompts[:2]:
#     for direction in active_directions[:10]:
#         prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
#         acts = get_acts(prompt, model, encoder, cfg)
#         direction_act = acts[:, direction].cpu().tolist()
#         max_direction_act = max(direction_act)
#         if max_direction_act > 0:
#             haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [29]:
# direction = 9000
# fig = px.histogram(max_activation_per_prompt[:, direction].tolist(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", histnorm="probability")
# fig.update_layout(
#     xaxis_title="Activation",
#     yaxis_title="Probability",
#     width = 800,
#     showlegend=False
# )
# fig.update_yaxes(type='log')
# fig.show()
# print_top_examples(prompts, max_activation_per_prompt, direction)

In [30]:
# Look for active features on specific tokens in prompt
# Baseline: look at neurons
# Train with bigger  L1s
# At some point, it should become non monosemantic as it can just copy MLP
# Train without L1, see what happens

In [31]:
# Generate text with and without encoder
for prompt in ["Harry went to the park", "Once upon a time", "In the beginning"]:
    print("Original:", model.generate(prompt, 40, verbose=False, temperature=0, use_past_kv_cache=False))
    print("With autoencoder:", generate_with_encoder(model, encoder, cfg, prompt, 40))

Original: Harry went to the park with his mom. He saw a big slide and wanted to go down it. His mom said, "Be careful, Harry. Don't go too fast." Harry said, "I won't,
With autoencoder: Harry went to the park with his mom. He saw a big slide and wanted to go on it. His mom said, "Let's go on the slide together." Harry was happy and said, "Yay!"

Original: Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, scary dog. The dog barked and growled at her. Lily
With autoencoder: Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red ball in the grass. She picked it up and started to play
Original: In the beginning of a long journey, the sky was dark and the clouds were thick. Suddenly, a tornado appeared! It was spinning and swirling around the town. Everyone was scared and ran away.

The
With autoencoder: In the beginning of the day, Timmy and his mom w

In [32]:
# Mean loss per token
def encode_activations_hook(value, hook):
    value = value.squeeze(0)
    _, x_reconstruct, _, _, _ = encoder(value)
    return x_reconstruct.unsqueeze(0)
encoded_hook_name = f'blocks.{cfg.layer}.{cfg.act_name}'
reconstruct_hooks = [(encoded_hook_name, encode_activations_hook)]

running_loss_diff_sum = torch.zeros(model.cfg.d_vocab, device=device)
running_encoder_loss_sum = torch.zeros(model.cfg.d_vocab, device=device)
n = torch.zeros(model.cfg.d_vocab, device=device)

for prompt in prompts[:200]:
    tokens = model.to_tokens(prompt)
    original_loss = model(tokens, return_type='loss', loss_per_token=True)
    with model.hooks(reconstruct_hooks):
        encoder_loss = model(tokens, return_type='loss', loss_per_token=True)
    # BOS does not have loss
    running_loss_diff_sum[tokens[:, 1:]] += encoder_loss - original_loss
    running_encoder_loss_sum[tokens[:, 1:]] += encoder_loss
    n[tokens] += 1

absolute_encoded_token_losses = running_encoder_loss_sum / n
mean_token_loss_increases = running_loss_diff_sum / n

# Filter out non-occuring tokens
occurring_tokens = haystack_utils.get_occurring_tokens(model, tuple(prompts))

In [33]:
px.histogram(mean_token_loss_increases.cpu()[occurring_tokens == 1], title=f"Distribution of mean token loss increases with autoencoder - {model_name} L{cfg.layer}, {save_names[0]}")

In [34]:
px.histogram(absolute_encoded_token_losses.cpu()[occurring_tokens == 1], title=f"Distribution of mean token losses with autoencoder - {model_name} L{cfg.layer}, {save_names[0]}")

In [35]:
haystack_utils.clean_cache()
# Mean activation difference per neuron
def encode_activations_hook(value, hook):
    value = value.squeeze(0)
    _, x_reconstruct, _, _, _ = encoder(value)
    hook.ctx['running_sum'] = value - x_reconstruct
    hook.ctx['active_neurons_mask'] = value > 0
reconstruct_hook_label = f'blocks.{cfg.layer}.{cfg.act_name}'
reconstruct_hooks = [(reconstruct_hook_label, encode_activations_hook)]

all_neurons_running_sum = torch.zeros(model.cfg.d_mlp, device=device)
active_neurons_running_sum = torch.zeros(model.cfg.d_mlp, device=device)

all_neurons_n = 0
active_neurons_n = torch.zeros(model.cfg.d_mlp, device=device)

for prompt in prompts[:200]:
    with model.hooks(reconstruct_hooks):
        model(prompt)
        all_neuron_act_diffs = model.hook_dict[reconstruct_hook_label].ctx["running_sum"]
        all_neurons_running_sum += all_neuron_act_diffs.sum(0)
        all_neurons_n += all_neuron_act_diffs.shape[0]

        active_neurons = model.hook_dict[reconstruct_hook_label].ctx["active_neurons_mask"]
        all_neuron_act_diffs[all_neuron_act_diffs < 0] = 0
        active_neurons_running_sum += all_neuron_act_diffs.sum(0)
        active_neurons_n += active_neurons.sum(0)

mean_neuron_diffs = all_neurons_running_sum / all_neurons_n
mean_active_neuron_diffs = active_neurons_running_sum / active_neurons_n

In [36]:
print(active_neurons_n.shape)
print(active_neurons_n[:100])
print(active_neurons_running_sum[:100])

torch.Size([3072])
tensor([ 6650.,  8802.,  9500.,  5881.,  6457.,  7479.,  7147., 10307.,  6744.,
        13093.,  8661.,  8358.,  8762.,  7420., 13040.,  8050., 10493.,  7946.,
         6577., 16642.,  4777.,  5368., 17354.,  9916., 10992.,  8660.,  6263.,
        11949.,  4730.,  7819.,  3572., 14489., 11450.,  8636.,  8199.,  8389.,
         8290., 11133.,  4937., 10473.,  7372.,  7997.,  6629.,  7172.,  6025.,
        10641., 10086.,  4697.,  9519.,  7657.,  8411., 10669.,  6155., 10091.,
        11407.,  6927., 10725.,  6013., 13503.,  8861.,  5945.,  6982.,  9506.,
         7400.,  9809., 12965.,  8544.,  4179.,  7130., 10234., 14786.,  5878.,
         6824.,  9876.,  6445., 12402.,  7216.,  7483.,  7870.,  6647., 12984.,
         4625.,  8983.,  9323.,  5130.,  5645.,  9488.,  2892.,  5840.,  8063.,
         7426.,  8536.,  8396.,  8606.,  9611.,  6969.,  8188.,  7036.,  9914.,
         9967.], device='cuda:0')
tensor([1331.1434, 1540.0059, 1761.1873, 1313.7500, 1267.5094, 1498

In [42]:
px.histogram(mean_neuron_diffs.cpu(), title=f"Distribution of mean neuron activation differences from autoencoder - {model_name} L{cfg.layer}, {save_names[0]}")

In [43]:
px.histogram(mean_active_neuron_diffs.cpu(), title=f"Distribution of mean neuron activation differences from reconstruction - active neurons only - {model_name} L{cfg.layer}, {save_names[0]}")