In [72]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json
from collections import Counter
from datasets import load_dataset
import requests
import pandas as pd
from ipywidgets import interact, IntSlider

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies
import utils.haystack_utils as haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
# data = load_dataset("MechInterpResearch/tinystories_tokenized", split="train")
# data.save_to_disk(f"data/tinystories/data.hf")
# del data

In [55]:
# Run overview
model_name = "tiny-stories-2L-33M"
layer_name = "L0"
print_model_name = f"{model_name}-{layer_name}"
df = pd.read_csv(f"{model_name}/wandb_runs.csv")
df = df.sort_values(by="l1_coeff", ascending=True)
df.columns

Index(['Name', 'State', 'Notes', 'User', 'Tags', 'Created', 'Runtime', 'Sweep',
       'act', 'batch_size', 'beta1', 'beta2', 'buffer_batches', 'buffer_mult',
       'buffer_size', 'd_mlp', 'data_paths', 'expansion_factor', 'l1_coeff',
       'layer', 'lr', 'model', 'model_batch_size', 'num_eval_batches',
       'num_eval_tokens', 'num_training_tokens', 'seed', 'seq_len',
       'use_wandb', 'wd', 'avg_directions', 'batch', 'bias_mean', 'bias_std',
       'dead_directions', 'epoch', 'l1_loss', 'l2_loss',
       'long term dead directions', 'loss'],
      dtype='object')

In [56]:
fig = px.line(df, x="l1_coeff", y=["l2_loss", "l1_loss", "avg_directions"], markers=True, title=f"{print_model_name}: L1 loss, L2 loss, and average number of active directions")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="",
    legend_title="",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

In [12]:
# val_ds = load_dataset("roneneldan/TinyStories", split="validation")
# prompts = [x["text"] for x in val_ds]
# del val_ds

# Download Parquet file
validation_data_path = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/data/validation-00000-of-00001-869c898b519ad725.parquet"
response = requests.get(validation_data_path)
with open("validation.parquet", "wb") as f:
    f.write(response.content)

# Load Parquet into Pandas DataFrame
df = pd.read_parquet("validation.parquet")
prompts = df["text"].tolist()

In [57]:
model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device,
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/323M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [65]:
def load_encoder(save_name, model_name):
    with open(f"{model_name}/{save_name}.json", "r") as f:
        cfg = json.load(f)

    cfg = AutoEncoderConfig(
        cfg["layer"], cfg["act"], cfg["expansion_factor"], cfg["l1_coeff"]
    )

    if cfg.act_name == "hook_mlp_out":
        d_in = model.cfg.d_model  # d_mlp
    else:
        d_in = model.cfg.d_mlp
    d_hidden = d_in * cfg.expansion_factor

    encoder = AutoEncoder(d_hidden, cfg.l1_coeff, d_in)
    encoder.load_state_dict(torch.load(os.path.join(model_name, save_name + ".pt")))
    encoder.to(device)
    return encoder, cfg

save_names = [f.split(".")[0] for f in os.listdir(model_name) if f.endswith('.pt')]
encoders = [load_encoder(save_name, model_name) for save_name in save_names]

In [66]:
encoders

[(AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=5e-05)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.0001)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.0005)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.0002)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.0003)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.0004)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.001)),
 (AutoEncoder(),
  AutoEncoderConfig(layer=0, act_name='mlp.hook_post', expansion_factor=4, l1_coeff=0.00015))]

In [73]:
loss_data = []
for encoder, cfg in encoders:
    original_loss, encoder_loss, zero_ablation_loss = evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:200], model)
    loss_data.append([cfg.l1_coeff, original_loss, encoder_loss, zero_ablation_loss])
loss_df = pd.DataFrame(loss_data, columns=["L1 coefficient", "Original Loss", "Reconstruction Loss", "Zero Ablation Loss"])
loss_df = loss_df.sort_values(by="L1 coefficient", ascending=True)
loss_df["L1 coefficient"] = loss_df["L1 coefficient"].astype(str)

100%|██████████| 200/200 [00:05<00:00, 38.60it/s]
100%|██████████| 200/200 [00:05<00:00, 39.10it/s]
100%|██████████| 200/200 [00:05<00:00, 38.92it/s]
100%|██████████| 200/200 [00:05<00:00, 38.96it/s]
100%|██████████| 200/200 [00:05<00:00, 38.86it/s]
100%|██████████| 200/200 [00:05<00:00, 39.10it/s]
100%|██████████| 200/200 [00:05<00:00, 38.83it/s]
100%|██████████| 200/200 [00:05<00:00, 38.71it/s]


In [74]:
loss_df = loss_df.melt(id_vars=["L1 coefficient"], var_name="Loss Type", value_name="Loss", value_vars=["Original Loss", "Reconstruction Loss", "Zero Ablation Loss"])
fig = px.line(loss_df, x="L1 coefficient", y="Loss", color="Loss Type", markers=True,  title=f"{print_model_name}: Encoder reconstruction loss increase")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="Loss increase",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

In [76]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts

def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)
        max_prompt_activation = acts.max(0)[0]
        activations.append(max_prompt_activation)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt

def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [88]:
l1 = 2e-4
def get_encoder_by_l1(encoders, l1_coeff):
    for encoder, cfg in encoders:
        if cfg.l1_coeff == l1_coeff:
            return encoder, cfg
    raise ValueError(f"Encoder with L1 coefficient {l1_coeff} not found")
encoder, cfg = get_encoder_by_l1(encoders, l1)
print(f"Encoder L1 coefficient: {cfg.l1_coeff}")

Encoder L1 coefficient: 0.0002


In [89]:
# feature_frequencies = get_encoder_feature_frequencies(prompts, model, encoder, cfg)
# zero_activating_features = (feature_frequencies == 0).sum(0).item()
# low_density = ((feature_frequencies > 0) & (feature_frequencies < 1e-6)).sum(0).item()
# high_density = (feature_frequencies > 1e-6).sum(0).item()
# print(zero_activating_features, low_density, high_density)
# fig = px.histogram(feature_frequencies.cpu().numpy(), histnorm='probability', title=f"{print_model_name} L1={cfg.l1_coeff}: Histogram of feature frequencies", nbins=40)
# fig.update_yaxes(type='log')
# fig.update_layout(xaxis_title="Feature frequency", yaxis_title="Probability", showlegend=False, width=600)

In [90]:
max_activation_per_prompt = get_max_activations(prompts, model, encoder, cfg)

  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


In [92]:
def plot_direction(direction, n=5):
    fig = px.histogram(max_activation_per_prompt[:, direction].tolist(), 
                       title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", 
                       histnorm="probability")
    fig.update_layout(
        xaxis_title="Activation",
        yaxis_title="Probability",
        width = 800,
        showlegend=False
    )
    fig.update_yaxes(type='log')
    fig.show()
    print_top_examples(prompts, max_activation_per_prompt, direction, n)

interact(plot_direction, 
         direction=IntSlider(min=0, max=encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))


interactive(children=(IntSlider(value=0, description='direction', max=16383), IntSlider(value=5, description='…

<function __main__.plot_direction(direction, n=5)>

In [None]:
print("test")

In [None]:
prompt = "One day, a little girl named Lily went for a walk in the park"
acts = get_acts(prompt, model, encoder, cfg)[-1]
px.histogram(acts.cpu().numpy(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for prompt", histnorm="probability", nbins=40)

In [38]:
# direction = 9000
# fig = px.histogram(max_activation_per_prompt[:, direction].tolist(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", histnorm="probability")
# fig.update_layout(
#     xaxis_title="Activation",
#     yaxis_title="Probability",
#     width = 800,
#     showlegend=False
# )
# fig.update_yaxes(type='log')
# fig.show()
# print_top_examples(prompts, max_activation_per_prompt, direction)

In [None]:
# Look for active features on specific tokens in prompt
# Baseline: look at neurons
# Train with smaller L1s
# At some point, it should become non monosemantic as it can just copy MLP
# Train without L1, see what happens