In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')


sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies
import utils.haystack_utils as haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
# Run overview
model_name = "tiny-stories-2L-33M"
layer_name = "L0"
print_model_name = f"{model_name}-{layer_name}"
df = pd.read_csv(f"{model_name}/wandb_runs.csv")
df = df.sort_values(by="l1_coeff", ascending=True)
df.columns

Index(['Name', 'State', 'Notes', 'User', 'Tags', 'Created', 'Runtime', 'Sweep',
       'act', 'batch_size', 'beta1', 'beta2', 'buffer_batches', 'buffer_mult',
       'buffer_size', 'd_mlp', 'data_paths', 'expansion_factor', 'l1_coeff',
       'layer', 'lr', 'model', 'model_batch_size', 'num_eval_batches',
       'num_eval_tokens', 'num_training_tokens', 'seed', 'seq_len',
       'use_wandb', 'wd', 'avg_directions', 'batch', 'bias_mean', 'bias_std',
       'dead_directions', 'epoch', 'l1_loss', 'l2_loss',
       'long term dead directions', 'loss'],
      dtype='object')

In [4]:
fig = px.line(df, x="l1_coeff", y=["l2_loss", "l1_loss", "avg_directions"], markers=True, title=f"{print_model_name}: L1 loss, L2 loss, and average number of active directions")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="",
    legend_title="",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

In [3]:
prompts = load_tinystories_validation_prompts()

(INFO) 10:47:46: Loaded 21990 TinyStories validation prompts


In [4]:
model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device,
    )

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [5]:
def strip_final_numbers(filename):
    return re.sub(r'_\d+$', '', filename)
    
STRIP_FINAL_NUMBERS = True

def load_encoder(save_name, model_name):
    json_save_name = save_name
    if STRIP_FINAL_NUMBERS:
        json_save_name = strip_final_numbers(save_name)
    with open(f"{model_name}/{json_save_name}.json", "r") as f:
        cfg = json.load(f)

    cfg = AutoEncoderConfig(
        cfg["layer"], cfg["act"], cfg["expansion_factor"], cfg["l1_coeff"]
    )

    if cfg.act_name == "hook_mlp_out":
        d_in = model.cfg.d_model  # d_mlp
    else:
        d_in = model.cfg.d_mlp
    d_hidden = d_in * cfg.expansion_factor

    encoder = AutoEncoder(d_hidden, cfg.l1_coeff, d_in)
    encoder.load_state_dict(torch.load(os.path.join(model_name, save_name + ".pt")))
    encoder.to(device)
    return encoder, cfg

save_names = [f.split(".")[0] for f in os.listdir(model_name) if f.endswith('.pt')]
# encoders = [load_encoder(save_name, model_name) for save_name in save_names]

# Sweep eval

In [42]:
save_names = ["2_upbeat_snowball"]

In [43]:
loss_data = []
for save_name in tqdm(save_names):
    encoder, cfg = load_encoder(save_name, model_name)
    original_loss, encoder_loss, zero_ablation_loss = evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:200], model)
    loss_data.append([cfg.l1_coeff, original_loss, encoder_loss, zero_ablation_loss])
loss_df = pd.DataFrame(loss_data, columns=["L1 coefficient", "Original Loss", "Reconstruction Loss", "Zero Ablation Loss"])
loss_df = loss_df.sort_values(by="L1 coefficient", ascending=True)
loss_df["L1 coefficient"] = loss_df["L1 coefficient"].astype(str)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:05<00:00, 34.47it/s]
(INFO) 10:37:03: Average loss increase after encoding: 0.1116


In [9]:
loss_df = loss_df.melt(id_vars=["L1 coefficient"], var_name="Loss Type", value_name="Loss", value_vars=["Original Loss", "Reconstruction Loss", "Zero Ablation Loss"])
fig = px.line(loss_df, x="L1 coefficient", y="Loss", color="Loss Type", markers=True,  title=f"{print_model_name}: Encoder reconstruction loss increase")
fig.update_layout(
    xaxis_title="L1 coefficient",
    yaxis_title="Loss increase",
    width = 800,
    xaxis={'tickformat':'.1e'}
)
fig.update_xaxes(type='linear')
fig.show()

# Single encoder eval

In [6]:
@torch.no_grad()
def get_acts(prompt: str, model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    _, cache = model.run_with_cache(prompt, names_filter=cfg.encoder_hook_point)
    acts = cache[cfg.encoder_hook_point].squeeze(0)
    _, _, mid_acts, _, _ = encoder(acts)
    return mid_acts

def get_max_activations(prompts: list[str], model: HookedTransformer, encoder: AutoEncoder, cfg: AutoEncoderConfig):
    activations = []
    for prompt in tqdm(prompts):
        acts = get_acts(prompt, model, encoder, cfg)
        max_prompt_activation = acts.max(0)[0]
        activations.append(max_prompt_activation)

    max_activation_per_prompt = torch.stack(activations)  # n_prompt x d_enc

    total_activations = max_activation_per_prompt.sum(0)
    print(f"Active directions on validation data: {total_activations.nonzero().shape[0]} out of {total_activations.shape[0]}")
    return max_activation_per_prompt

def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [7]:
# l1 = 2e-4
# def get_encoder_by_l1(encoders, l1_coeff):
#     for encoder, cfg in encoders:
#         if cfg.l1_coeff == l1_coeff:
#             return encoder, cfg
#     raise ValueError(f"Encoder with L1 coefficient {l1_coeff} not found")
# encoder, cfg = get_encoder_by_l1(encoders, l1)
encoder, cfg = load_encoder("18_morning_sun", model_name)
print(f"Encoder L1 coefficient: {cfg.l1_coeff}")

Encoder L1 coefficient: 0.0001


In [12]:
encoder_2, cfg_2 = load_encoder("2_upbeat_snowball", model_name)
evaluate_autoencoder_reconstruction(encoder_2, cfg_2.encoder_hook_point, prompts[:200], model)

100%|██████████| 200/200 [00:05<00:00, 33.82it/s]
(INFO) 10:55:59: Average loss increase after encoding: 0.1116


(1.0763592836260796, 1.1879558461904525, 2.1065332067012785)

In [14]:
total_tokens = 0
for prompt in prompts: 
    tokens = model.to_tokens(prompt)
    total_tokens += torch.numel(tokens)
print(total_tokens)

4765918


In [9]:
feature_frequencies = get_encoder_feature_frequencies(prompts[:5000], model, encoder, cfg)
zero_activating_features = (feature_frequencies == 0).sum(0).item()
low_density = ((feature_frequencies > 0) & (feature_frequencies < 1e-6)).sum(0).item()
high_density = (feature_frequencies > 1e-6).sum(0).item()
print(zero_activating_features, low_density, high_density)
fig = px.histogram(feature_frequencies.cpu().numpy(), histnorm='probability', title=f"{print_model_name} L1={cfg.l1_coeff}: Histogram of feature frequencies", nbins=40)
fig.update_yaxes(type='log')
fig.update_layout(xaxis_title="Feature frequency", yaxis_title="Probability", showlegend=False, width=600)

100%|██████████| 5000/5000 [01:10<00:00, 70.60it/s]


Number of active features over 1000238 tokens: 16384
Number of average active features per token: 98.69
0 0 16384


In [10]:
evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:200], model)

100%|██████████| 200/200 [00:05<00:00, 33.66it/s]
(INFO) 10:50:08: Average loss increase after encoding: 0.0794


(1.0763592836260796, 1.1557915967702865, 4.7769329905509945)

In [17]:
autoencoder = encoder
encoded_hook_name = cfg.encoder_hook_point

def encode_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = autoencoder(value)
        return x_reconstruct.unsqueeze(0)
reconstruct_hooks = [(encoded_hook_name, encode_activations_hook)]

def zero_ablate_hook(value, hook):
    value[:] = 0
    return value
zero_ablate_hooks = [(encoded_hook_name, zero_ablate_hook)]

original_losses = []
reconstruct_losses = []
zero_ablation_losses = []
for prompt in tqdm(prompts[:20]):
    with model.hooks(reconstruct_hooks):
        reconstruct_loss = model(prompt, return_type="loss", loss_per_token=True).flatten().tolist()
    reconstruct_losses.extend(reconstruct_loss)
    original_loss = model(prompt, return_type="loss", loss_per_token=True).flatten().tolist()
    with model.hooks(zero_ablate_hooks):
        zero_ablate_loss = model(prompt, return_type="loss", loss_per_token=True).flatten().tolist()
    original_losses.extend(original_loss)
    zero_ablation_losses.extend(zero_ablate_loss)

print(len(reconstruct_losses), len(original_losses), len(zero_ablation_losses))

  0%|          | 0/20 [00:00<?, ?it/s]

3466 3466 3466


In [21]:
np.mean(loss_increase)

0.0657664147715836

In [20]:
loss_increase = np.array(reconstruct_losses) - np.array(original_losses)
fig = px.histogram(loss_increase, title="Distribution of reconstruction loss - original loss")
fig.update_layout(xaxis_title="Loss increase", yaxis_title="Count", showlegend=False, width=600)

No low frequency features
Should we have them?
If yes: 
- L1 too low
- Expansion too low

In [11]:
max_activation_per_prompt = get_max_activations(prompts, model, encoder, cfg)

  0%|          | 0/21990 [00:00<?, ?it/s]

Active directions on validation data: 16384 out of 16384


In [20]:
def plot_direction_frequency(data: list[str], direction: int, cfg: AutoEncoderConfig):
    activations = []
    for prompt in tqdm(data):
        tokens = model.to_tokens(prompt)
        _, cache = model.run_with_cache(
            tokens, names_filter=f"blocks.{cfg.layer}.{cfg.act_name}"
            )
        acts = cache[f"blocks.{cfg.layer}.{cfg.act_name}"].squeeze(0)
        _, _, mid_acts, _, _ = encoder(acts)
        activations.append(mid_acts[:, direction])
    activations = torch.cat(activations)
    print(activations.shape)

    fig = px.histogram(activations.tolist(), 
                       title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", 
                       histnorm="probability")
    fig.update_layout(
        xaxis_title="Activation",
        yaxis_title="Probability",
        width = 600,
        showlegend=False
    )
    fig.update_yaxes(type='log')
    fig.show()

direction = 0
plot_direction_frequency(prompts[:50], direction, cfg)

  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


In [None]:
def print_direction_example(direction, n=5):
    print_top_examples(prompts, max_activation_per_prompt, direction, n)

# Max activations
interact(print_direction_example, 
         direction=IntSlider(min=0, max=encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))


interactive(children=(IntSlider(value=0, description='direction', max=16383), IntSlider(value=5, description='…

<function __main__.print_direction_example(direction, n=5)>

In [22]:
# Activations of different directions on the same token
prompt = "One day, a little girl named Lily went for a walk in the park"
acts = get_acts(prompt, model, encoder, cfg)[-1] # d_enc
print(f"Active directions on last token: {acts.nonzero().shape[0]} out of {acts.shape[0]}")
active_directions = acts.nonzero().squeeze(1)
highly_active_directions = torch.argwhere((acts > 0.5)).squeeze(1)
low_active_directions = torch.argwhere((acts < 0.5) & (acts > 0.1)).squeeze(1)
px.histogram(acts.cpu().numpy(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for prompt", histnorm="probability", nbins=40)

Active directions on last token: 152 out of 16384


In [23]:
for active_direction in low_active_directions[:5]:#active_directions[:10]:
    print(f"Direction {active_direction}")
    print_top_examples(prompts, max_activation_per_prompt, active_direction, 2)

Direction 225


Direction 234


Direction 431


Direction 548


Direction 603


In [None]:
max_activation_per_prompt = get_max_activations(prompts, model, encoder, cfg)

In [35]:
layer = cfg.layer
haystack_utils.clean_cache()

def get_token_kurtosis_for_feature(model: HookedTransformer, decoder_feature: torch.Tensor):
    W_out = model.W_out[layer]
    resid_dirs = torch.nn.functional.normalize(decoder_feature @ W_out, dim=-1)
    unembed = torch.nn.functional.normalize(model.unembed.W_U, dim=0)
    sims = einops.einsum(resid_dirs, unembed, 'd_model, d_model d_vocab -> d_vocab')

    mean = torch.mean(sims)
    variance = torch.mean((sims - mean) ** 2)
    std = torch.sqrt(variance)
    excess_kurt = torch.mean(((sims - mean) / std) ** 4) - 3
    return excess_kurt

def get_token_kurtosis_for_decoder(model: HookedTransformer, decoder: torch.Tensor):
    '''Return excess kurtosis over all decoder features' cosine sims with the unembed (higher is better)'''
    W_out = model.W_out[layer]
    resid_dirs = torch.nn.functional.normalize(decoder @ W_out, dim=-1)
    unembed = torch.nn.functional.normalize(model.unembed.W_U, dim=0)
    cosine_sims = einops.einsum(resid_dirs, unembed, 'd_hidden d_model, d_model d_vocab -> d_hidden d_vocab')
    
    mean = einops.repeat(cosine_sims.mean(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    std = einops.repeat(cosine_sims.std(dim=-1), f'd_hidden -> d_hidden {cosine_sims.shape[1]}')
    kurt = torch.mean((cosine_sims - mean / std) ** 4, dim=-1) - 3
    return kurt

def top_boosted_tokens(model: HookedTransformer, decoder_feature: torch.Tensor, k=10, plot=False):
    W_out = model.W_out[layer]
    resid_dirs = decoder_feature @ W_out
    tokens = resid_dirs @ model.unembed.W_U

    all_ignore, _ = haystack_utils.get_weird_tokens(model, plot_norms=False)
    values, tokens = haystack_utils.top_k_with_exclude(tokens, k, exclude=all_ignore)
    boosted_labels = model.to_str_tokens(tokens)

    if plot:
        assert k < 300, "Too many tokens to plot"
        fig = haystack_utils.line(x=values.cpu().numpy(), xticks=boosted_labels, title=f"Boosted tokens", width=1200)
        fig.show()

    return boosted_labels

scores = get_token_kurtosis_for_decoder(model, encoder.W_dec)


torch.Size([16384])


In [38]:
px.histogram(pd.DataFrame({"kurtosis": scores.cpu()}))

In [40]:
top_values, top_indices = torch.topk(scores, k=100)
for value, i, in zip(top_values, top_indices):
    print(value)
    print(top_boosted_tokens(model, encoder.W_dec[i]))
    print_top_examples(prompts, max_activation_per_prompt, i, 2)

tensor(-2.8921, device='cuda:0')
[' anthem', 'IE', ' ANY', 'lict', 'umbing', ' SHOW', 'ence', ' NIGHT', 'IFF', ' Japanese']


tensor(-2.8932, device='cuda:0')
['IE', ' ANY', ' NIGHT', ' anthem', 'communication', 'actory', ' SHOW', 'violent', '°', ' "/']


tensor(-2.8944, device='cuda:0')
[',"', '."', '.', '?"', '?".', ',', '..."', ' all', ' one', '".']


tensor(-2.8946, device='cuda:0')
['?".', ' I', 'you', '?!"', ' possessions', '?"', ' emotion', ' you', ' me', ' yourself']


tensor(-2.8947, device='cuda:0')
[' NIGHT', ' anthem', 'ols', ' Sax', 'cham', ' SHOW', 'chief', ' watts', 'ahu', ' ANY']


tensor(-2.8948, device='cuda:0')
[' ANY', 'IE', ' anthem', ' SHOW', 'obos', '987', 'ahu', ' deval', ' condemnation', ' portraying']


tensor(-2.8948, device='cuda:0')
[' ANY', ' SHOW', 'iform', 'IE', ' obj', 'chief', 'mor', 'oxicity', ' telecom', ' ris']


tensor(-2.8949, device='cuda:0')
['ence', 'posing', 'azing', 'lp', 'ants', 'ecided', 'using', 'reath', 'ales', 'expected']


tensor(-2.8949, device='cuda:0')
['IE', ' anthem', '�', ' trib', '°', ' ANY', ' Japanese', ' "/', ' negate', ' Stead']


tensor(-2.8950, device='cuda:0')
['oning', 'ination', 'setting', 'isk', 'osures', 'uting', 'ourage', 'pected', 'ettes', 'opping']


tensor(-2.8952, device='cuda:0')
['hement', 'untarily', ' trib', 'ecause', 'olicy', 'azon', ' Berk', 'obbies', 'SPA', 'apesh']


tensor(-2.8954, device='cuda:0')
[' Dod', ' Mitt', ' Midnight', ' Bark', ' Wil', ' Ros', ' Bug', ' Oct', ' Ch', ' Fro']


tensor(-2.8954, device='cuda:0')
[' apologised', '!",', ' ok', ' colour', ' realise', '--', ' realised', ' leapt', ' -', ' —']


tensor(-2.8955, device='cuda:0')
['ettes', 'allion', 'osures', 'reath', 'ERO', 'ror', 'Parent', 'Available', 'isable', ' Advisor']


tensor(-2.8958, device='cuda:0')
[' ok', '!",', "'d", "!'", '!".', '�', ' pup', '!!"', ' wanna', ' alright']


tensor(-2.8958, device='cuda:0')
['IE', '\n', '�', 'Op', ' NIGHT', ' ANY', ' However', ' Number', ' Stead', ' Cards']


tensor(-2.8960, device='cuda:0')
[' obj', 'apesh', 'MER', 'misc', ' Bulg', 'brand', ' Plymouth', '�', 'casts', ' LF']


tensor(-2.8962, device='cuda:0')
[' nearer', ' compassion', ' motions', ' brave', ' wink', '\n', ' Daisy', ' excited', ' bob', ' Zoe']


tensor(-2.8962, device='cuda:0')
['edience', 'ua', ' MY', ' ment', '�', ' Ped', ' possessed', ' selects', '<', 'Our']


tensor(-2.8968, device='cuda:0')
['\n', 'IE', '�', '====', 'Op', 'lict', ' (', ' NIGHT', ' Number', ' chall']


tensor(-2.8970, device='cuda:0')
['�', 'achable', 'End', '�', ' 30', 'mission', 'ASE', 'isco', '�', '�']


tensor(-2.8971, device='cuda:0')
['Parents', 'Yo', '<', 'Parent', 'His', 'Marsh', 'Allow', ' perpetually', 'pend', 'Sullivan']


tensor(-2.8975, device='cuda:0')
[' Conflict', ' young', ' hungry', ' Fore', ' Moral', ' Bad', ' competitive', ' hunter', ' classroom', ' delicious']


tensor(-2.8977, device='cuda:0')
['€', 'xc', 'ocity', 'uls', 'cient', '16', ' Artist', 'former', 'reed', 'cessive']


tensor(-2.8977, device='cuda:0')
[' Tim', 'emo', ' Remy', ' Bub', ' Benny', ' Nem', ' Lily', ' Wh', ' Bun', ' Pok']


tensor(-2.8977, device='cuda:0')
['ounded', 'osures', 'posing', 'elf', ' tel', 'ptions', 'gravity', 'asant', ' fundra', 'airs']


tensor(-2.8978, device='cuda:0')
['tic', 'using', 'sav', 'Brain', 'enged', 'Mem', 'playing', 'wer', 'actory', ' THERE']


tensor(-2.8978, device='cuda:0')
[' Marie', ' Jos', ' Linda', ' Judy', ' Maria', ' Anne', ' Sus', ' Jenny', ' Jan', ' Bill']


tensor(-2.8978, device='cuda:0')
[' realised', ' vowed', ' realized', ' learnt', ' knew', ' no', ' Nem', ' Bark', ' sailed', ' never']


tensor(-2.8979, device='cuda:0')
['Iron', 'sted', 'renched', 'Steam', 'Angel', 'Cand', 'Dust', 'amber', 'Bul', 'Sand']


tensor(-2.8980, device='cuda:0')
[' and', ' but', ' until', ' too', ' he', ' again', ' she', '!', ' with', ';']


tensor(-2.8983, device='cuda:0')
[' while', ' prematurely', ' so', ' fiercely', ' independently', ' thinly', ' because', ' except', ' and', 'ilit']


tensor(-2.8984, device='cuda:0')
['!",', '?",', '!".', '?".', ' ok', ' learnt', ' Daniel', ' Rebecca', ' Chloe', ' James']


tensor(-2.8984, device='cuda:0')
[' Nem', ' Twe', ' Benny', ' Fin', ' Woo', ' Mitt', ' Tim', ' unexpected', ' Bun', ' Once']


tensor(-2.8990, device='cuda:0')
['Lab', 'Tell', 'Late', ' invasive', '�', '>', '�', 'Re', 'Honest', 'Van']


tensor(-2.8991, device='cuda:0')
[' after', ' towards', ' further', ' inside', ' back', ' as', ' over', ' closer', ' away', ' to']


tensor(-2.8991, device='cuda:0')
['forces', '�', ' possessed', 'Go', 'pired', 'wow', 'ok', ' companies', ' NOT', 'Around']


tensor(-2.8992, device='cuda:0')
[' that', ' then', ' this', ' there', ' now', ' sailing', ' tonight', ' tha', ' playing', ' her']


tensor(-2.8992, device='cuda:0')
[' not', ' his', ' he', ' their', ' she', ' her', ' my', ' something', ' that', ' they']


tensor(-2.8993, device='cuda:0')
[' the', ' helps', ' drowned', ' agreed', ' explained', ' replied', 'o', ' suggests', ' smiled', ' introduces']


tensor(-2.8993, device='cuda:0')
[' Twist', 'Words', 'ossible', ' fou', ' Tim', ' Tweet', ' Bun', ' wi', 'Features', 'osures']


tensor(-2.8994, device='cuda:0')
['Snow', 'Wh', ' Tours', 'Marsh', 'Rep', 'Old', 'Anne', 'Summer', 'Bruce', 'Grand']


tensor(-2.8995, device='cuda:0')
['uld', 'reed', 'nd', 've', 'oked', 'ved', 'mes', ' Pok', ' substitute', 'af']


tensor(-2.8995, device='cuda:0')
[' petertodd', ' loophole', '164', ' representations', 'ا', 'MER', 'misc', ' IU', ' enactment', ' Kurdish']


tensor(-2.8996, device='cuda:0')
[' Oxy', ' sou', ' burgers', ' raft', ' sailing', 'add', ' floats', ' planets', ' dolphins', ' sails']


tensor(-2.8997, device='cuda:0')
[' voyage', ' sym', ' alas', ' plummeted', ' imposing', ' inserted', ' centuries', ' chill', ' jew', ' blazing']


tensor(-2.8997, device='cuda:0')
[' underwater', ' fishing', ' skating', ' underground', ' closer', ' hunting', ' faster', ' biking', ' surfing', ' slower']


tensor(-2.8997, device='cuda:0')
['initely', 'uld', 'cy', 'mit', 'fixed', 'ractive', 'ceed', 'ds', ' 375', 'cor']


tensor(-2.8998, device='cuda:0')
[' three', ' only', ' 3', ' infants', ' celebrating', ' poorly', ' adventurous', ' incredibly', ' lonely', ' inventive']


tensor(-2.8998, device='cuda:0')
[',', ' there', ' a', ' two', ' upon', '--', ';', ' was', ' ,', '…']


tensor(-2.8999, device='cuda:0')
[' commanded', 'Stone', 'adas', ' borrower', 'uda', 'ination', 'ants', 'utions', 'Arm', 'eatures']


tensor(-2.8999, device='cuda:0')


tensor(-2.8999, device='cuda:0')
['inate', 'I', 'reed', 'g', 'Our', 'know', 'Only', 'once', 'Your', 'Behind']


tensor(-2.8999, device='cuda:0')
[',', ' there', ' a', ' upon', ' hundreds', ' two', ' ,', ',,', ' he', ' in']


tensor(-2.8999, device='cuda:0')
['iously', 'ging', 'fing', 'Having', 'ving', 'ized', 'I', ' Seas', 'ised', '>']


tensor(-2.9000, device='cuda:0')
[' NOT', '>', 'him', 'blast', ' ab', '€', ':', ' --', ' Ves', 'per']


tensor(-2.9001, device='cuda:0')
['arrow', ' travelers', 'keys', 'xc', 'running', 'abs', 'abbit', 'coord', 'ebra', 'ree']


tensor(-2.9001, device='cuda:0')
[' her', ' his', ' Randy', ' my', ' Suz', ' Ginny', ' Lia', ' himself', ' Leah', ' Julia']


tensor(-2.9001, device='cuda:0')
[' fleet', ' skeletons', ' sailors', ' dumping', ' loft', ' dumps', ' aboard', ' unh', ' crate', ' padd']


tensor(-2.9001, device='cuda:0')
[' terrified', ' devastated', ' furious', ' curious', ' sincere', ' determined', ' horrified', ' heart', ' relieved', ' thrilled']


tensor(-2.9002, device='cuda:0')
['pering', 'ically', ' �', ' Walk', ' fou', ' ram', 'ico', 'ING', ' Jungle', ' ginger']


tensor(-2.9002, device='cuda:0')
[' pup', ' ok', ' colour', ' wanna', ' colours', ' somebody', ' neighbourhood', ' neighbours', ' Everyday', ' neighbour']


tensor(-2.9002, device='cuda:0')
['nd', 'said', ' role', 'inate', 'iter', '>', 'ocate', 'uted', 'special', 'rust']


tensor(-2.9002, device='cuda:0')
['raction', "''.", 'aks', ',"', '"...', 'Allow', 'lock', ' dive', '�', ",''"]


tensor(-2.9002, device='cuda:0')
['ase', ' possessed', 'port', 'aker', 'uckles', ' perpetually', '>', 'rot', 'ption', ' knows']


tensor(-2.9003, device='cuda:0')
[' agreed', ' discovers', ' practices', ' meets', ' doesn', ' didn', ' explores', ' duties', ' returns', ' learned']


tensor(-2.9003, device='cuda:0')
[' I', ' we', ' you', ' kittens', ' wolves', ' cats', ' dogs', ' diapers', ' anyone', ' anybody']


tensor(-2.9004, device='cuda:0')
[' ever', '..."', ' "', ' mummy', ' Mom', " '", '!",', '".', ' �', '.']


tensor(-2.9005, device='cuda:0')
[' Bees', ' rats', ' coupons', ' dolphins', ' Ar', ' aliens', ' Mir', 'ters', ' Pic', ' wolves']


tensor(-2.9006, device='cuda:0')
['itates', 'atum', 'had', 'Only', 'VID', ' vi', 'ancy', 'Our', 'bringing', 'edience']


tensor(-2.9007, device='cuda:0')
[' mix', ' sail', ' add', ' bury', ' roll', ' look', ' swing', ' dive', ' save', ' unite']


tensor(-2.9008, device='cuda:0')
[' commercial', ' mov', ' descend', 'named', ' bending', 'ihu', 'chain', 'ids', ' strengthen', ' obj']


tensor(-2.9009, device='cuda:0')
['aki', 'ited', 'othes', 'holding', 'acked', 'oted', 'ude', 'atters', 'ause', 'elt']


tensor(-2.9010, device='cuda:0')
[' ok', ' shifting', ' tv', '--', ' voyage', ' supper', ' invent', ' honour', ' neighbour', 'Des']


tensor(-2.9010, device='cuda:0')
['SH', '>', 'Mos', 'bor', 'phone', 'had', 'Squ', 'ains', 'Lady', 'Turkey']


tensor(-2.9011, device='cuda:0')
["'s", ' me', '�', ' us', ' yourselves', ' my', ' yourself', ' her', '`', "'t"]


tensor(-2.9011, device='cuda:0')
[' directly', ' even', ' sadly', ' angrily', ' too', ' more', ' but', ' terribly', ' less', ' really']


tensor(-2.9012, device='cuda:0')
['ried', ' be', 'ant', 'ak', ' bi', 'plates', 'ame', ' come', 'unny', 'forts']


tensor(-2.9013, device='cuda:0')
[' back', ' together', ' further', ' away', ' farther', ' reck', ' freely', ' downstream', ' deeper', ' closer']


tensor(-2.9013, device='cuda:0')
[' steal', ' chew', ' throw', ' kick', ' shoot', ' judge', ' dive', ' eat', ' swim', ' tease']


tensor(-2.9013, device='cuda:0')
['>', 'Get', ' I', 'utions', 'ged', ' ya', 'fing', 'Secret', ' WARN', 'Ger']


tensor(-2.9014, device='cuda:0')
['testing', 'jam', 'planes', 'gravity', 'walking', 'chip', 'mac', 'enemy', 'Mount', 'especially']


tensor(-2.9014, device='cuda:0')
[' girl', ' fish', ' boy', ' bear', ' oct', ' alien', ' spirit', ' ham', ' owl', ' seal']


tensor(-2.9014, device='cuda:0')
[' launch', ' woke', ' awoke', ' sailed', ' unw', ' arrived', ' pedd', ' padd', ' blasted', ' serves']


tensor(-2.9015, device='cuda:0')
['",', ',"', '"', '!",', '�', ' next', ' Unic', '\',"', '".', ' Onion']


tensor(-2.9015, device='cuda:0')
['pit', 'my', 'ophone', 'uses', 'ere', ' abuse', 'cious', 'iously', 'rus', 'main']


tensor(-2.9016, device='cuda:0')
[' girl', ' boy', ' mole', ' boat', ' oct', ' dog', ' goat', ' ot', ' hedge', ' crab']


tensor(-2.9016, device='cuda:0')
[' owl', ' wolf', ' oct', ' shark', ' wife', ' dolphin', ' hunter', ' witch', ' mechanic', ' pilot']


tensor(-2.9017, device='cuda:0')
[' marry', ' soar', ' mix', ' polish', ' splash', ' paint', ' feed', ' sail', ' bounce', ' order']


tensor(-2.9017, device='cuda:0')
['cious', 'ase', 'nder', 'oning', '�', 'ption', 'esty', 'ologist', 'iter', 'redibly']


tensor(-2.9017, device='cuda:0')
['cue', 'OP', 'mediately', 'Get', 'uld', 'ilit', 'arty', ' supervision', ' Story', 'razy']


tensor(-2.9018, device='cuda:0')
[' a', ' WANT', ' Antarctica', 'obiles', 'ophone', ' tur', 'printed', 'icult', ' lemon', ' Australia']


tensor(-2.9018, device='cuda:0')
[' called', '.', ' who', ' named', '--', ';', ' whose', ' —', 'ji', ',']


tensor(-2.9019, device='cuda:0')
['uld', '�', ' �', 'tell', 'ven', 'iverse', 'udding', 'Oct', ' thou', ' annual']


tensor(-2.9019, device='cuda:0')
['per', '>', ' NOT', 'uses', 'xc', 'at', 'cycles', ' bec', ' PLAY', 'astic']


tensor(-2.9019, device='cuda:0')
[' hurt', 'rew', 'gets', ' fou', ' tha', 'ends', 'ave', 'ination', ' li', ' forg']


tensor(-2.9019, device='cuda:0')
[' pick', ' eat', ' play', ' paint', ' look', ' visit', ' wander', ' create', ' explore', ' skip']


tensor(-2.9019, device='cuda:0')
['â', 'without', 'Collect', ' advancing', ' malf', 'gear', ' varieties', 'Land', 'Air', 'perhaps']


tensor(-2.9020, device='cuda:0')
[' there', ',', ' two', ' a', ' hundreds', ' Dave', ' Joe', ' Frank', ' Dale', ' Jim']


tensor(-2.9020, device='cuda:0')
['ail', '>', 'former', 'riend', 'bringing', 'hon', ' NOT', 'cious', 'ter', '�']


In [None]:
for active_direction in highly_active_directions:#active_directions[:10]:
    print(f"Direction {active_direction}")
    print_top_examples(prompts, max_activation_per_prompt, active_direction, 4)
    print(top_boosted_tokens(model, encoder.W_dec[active_direction]))

NameError: name 'highly_active_directions' is not defined

In [None]:
token = model.to_single_token(" park")
token_prompts = []
for prompt in prompts[:1000]:
    tokens = model.to_tokens(prompt)
    if token in tokens:
        token_prompts.append(prompt)
print(len(token_prompts))


163


In [None]:
for direction in active_directions[:10]:
    plot_direction_frequency(prompts[:50], direction, cfg)

  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([8493])


In [None]:
for prompt in token_prompts[:2]:
    for direction in active_directions[:10]:
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

In [None]:
# direction = 9000
# fig = px.histogram(max_activation_per_prompt[:, direction].tolist(), title=f"{print_model_name} L1={cfg.l1_coeff}: Activations for direction {direction}", histnorm="probability")
# fig.update_layout(
#     xaxis_title="Activation",
#     yaxis_title="Probability",
#     width = 800,
#     showlegend=False
# )
# fig.update_yaxes(type='log')
# fig.show()
# print_top_examples(prompts, max_activation_per_prompt, direction)

In [None]:
# Look for active features on specific tokens in prompt
# Baseline: look at neurons
# Train with bigger  L1s
# At some point, it should become non monosemantic as it can just copy MLP
# Train without L1, see what happens