In [32]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json
from collections import Counter
from datasets import load_dataset


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies
import utils.haystack_utils as haystack_utils

%reload_ext autoreload
%autoreload 2

[autoreload of utils.autoencoder_utils failed: Traceback (most recent call last):
  File "c:\Users\heind\miniconda3\envs\mats\Lib\site-packages\IPython\extensions\autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "c:\Users\heind\miniconda3\envs\mats\Lib\site-packages\IPython\extensions\autoreload.py", line 500, in superreload
    update_generic(old_obj, new_obj)
  File "c:\Users\heind\miniconda3\envs\mats\Lib\site-packages\IPython\extensions\autoreload.py", line 397, in update_generic
    update(a, b)
  File "c:\Users\heind\miniconda3\envs\mats\Lib\site-packages\IPython\extensions\autoreload.py", line 309, in update_function
    setattr(old, name, getattr(new, name))
ValueError: evaluate_autoencoder_reconstruction() requires a code object with 0 free vars, not 1361504632834
]


In [2]:
# data = load_dataset("MechInterpResearch/tinystories_tokenized", split="train")
# data.save_to_disk(f"data/tinystories/data.hf")
# del data

In [23]:
# Run overview
model_name = "tiny-stories-1L-21M"
layer_name = "L0"
print_model_name = f"{model_name}-{layer_name}"
df = pd.read_csv(f"{model_name}/wandb_runs.csv")
df = df.sort_values(by="l1_coeff", ascending=True)
df.columns

Index(['Name', 'State', 'Notes', 'User', 'Tags', 'Created', 'Runtime', 'Sweep',
       'act', 'batch_size', 'beta1', 'beta2', 'buffer_batches', 'buffer_mult',
       'buffer_size', 'd_mlp', 'data_paths', 'expansion_factor', 'l1_coeff',
       'layer', 'lr', 'model', 'model_batch_size', 'num_eval_batches',
       'num_eval_tokens', 'num_training_tokens', 'seed', 'seq_len',
       'use_wandb', 'wd', 'avg_directions', 'batch', 'bias_mean', 'bias_std',
       'dead_directions', 'epoch', 'l1_loss', 'l2_loss',
       'long term dead directions', 'loss'],
      dtype='object')

In [15]:
fig = px.line(df, x="l1_coeff", y=["l2_loss", "l1_loss", "avg_directions"], markers=True, title=f"{print_model_name}: L1 loss, L2 loss, and average number of active directions")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="",
    legend_title="",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

In [16]:
val_ds = load_dataset("roneneldan/TinyStories", split="validation")
prompts = [x["text"] for x in val_ds]
del val_ds

Repo card metadata block was not found. Setting CardData to empty.


In [24]:
model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device,
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development



Downloading pytorch_model.bin:   0%|          | 0.00/269M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


In [25]:
def load_encoder(save_name, model_name):
    with open(f"{model_name}/{save_name}.json", "r") as f:
        cfg = json.load(f)

    cfg = AutoEncoderConfig(
        cfg["layer"], cfg["act"], cfg["expansion_factor"], cfg["l1_coeff"]
    )

    if cfg.act_name == "hook_mlp_out":
        d_in = model.cfg.d_model  # d_mlp
    else:
        d_in = model.cfg.d_mlp
    d_hidden = d_in * cfg.expansion_factor

    encoder = AutoEncoder(d_hidden, cfg.l1_coeff, d_in)
    encoder.load_state_dict(torch.load(os.path.join(model_name, save_name + ".pt")))
    encoder.to(device)
    return encoder, cfg

save_names = [f.split(".")[0] for f in os.listdir(model_name) if f.endswith('.pt')]
encoders = [load_encoder(save_name, model_name) for save_name in save_names]

In [26]:
loss_data = []
for encoder, cfg in encoders:
    original_loss, encoder_loss = evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:200], model)
    loss_data.append([cfg.l1_coeff, encoder_loss - original_loss])
loss_df = pd.DataFrame(loss_data, columns=["L1 coefficient", "Loss Increase"])
loss_df = loss_df.sort_values(by="L1 coefficient", ascending=True)
# set L1 to str
loss_df["L1 coefficient"] = loss_df["L1 coefficient"].astype(str)
#original_loss_df = pd.DataFrame([["orig", original_loss]], columns=["L1 coefficient", "Loss"])
#loss_df = pd.concat([original_loss_df, loss_df])

100%|██████████| 200/200 [00:08<00:00, 23.50it/s]
100%|██████████| 200/200 [00:07<00:00, 28.18it/s]
100%|██████████| 200/200 [00:52<00:00,  3.82it/s]
100%|██████████| 200/200 [01:37<00:00,  2.06it/s]
100%|██████████| 200/200 [01:38<00:00,  2.02it/s]
100%|██████████| 200/200 [01:39<00:00,  2.01it/s]


In [30]:
print(original_loss)

1.2571979832649232


In [28]:
fig = px.line(loss_df, x="L1 coefficient", y="Loss Increase", markers=True,  title=f"{print_model_name}: Encoder reconstruction loss")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="Loss increase",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

In [31]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts

def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)
        max_prompt_activation = acts.max(0)[0]
        activations.append(max_prompt_activation)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt

def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [33]:
encoder, cfg = encoders[3]
print(f"Encoder L1 coefficient: {cfg.l1_coeff}")

Encoder L1 coefficient: 0.0003


In [34]:
feature_frequencies = get_encoder_feature_frequencies(prompts[:1000], model, encoder, cfg)
fig = px.histogram(feature_frequencies.cpu().numpy(), histnorm='probability', log_y=True, title="Histogram of feature frequencies", nbins=40)
fig.update_layout(xaxis_title="Feature frequency", yaxis_title="Probability", showlegend=False, width=600)

100%|██████████| 1000/1000 [07:52<00:00,  2.12it/s]

Number of active features over 194559 tokens: 13876
Number of average active features per token: 22.16





In [45]:
max_activation_per_prompt = get_max_activations(prompts, model, encoder, cfg)

Encoder L1 coefficient: 0.0005


  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


In [48]:
direction = 9000
fig = px.histogram(max_activation_per_prompt[:, direction].tolist(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", histnorm="probability")
fig.update_layout(
    xaxis_title="Activation",
    yaxis_title="Probability",
    width = 800
)
fig.show()
print_top_examples(prompts, max_activation_per_prompt, direction)