## Setup

In [18]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt
import pickle
from ipywidgets import interact, IntSlider, SelectionSlider
from transformer_lens.utils import test_prompt

import plotly.graph_objects as go

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import AutoEncoderConfig, train_autoencoder_evaluate_autoencoder_reconstruction, evaluate_autoencoder_reconstruction, eval_direction_tokens_global, get_activations, get_acts, load_encoder, eval_ablation_token_rank, get_direction_ablation_hook, get_top_activating_examples_for_direction, evaluate_direction_ablation_single_prompt
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line, multiple_line, plot_square_heatmap
%reload_ext autoreload
%autoreload 2

In [2]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)
prompts = load_tinystories_validation_prompts()

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


(INFO) 10:01:29: Loaded 21990 TinyStories validation prompts


In [3]:
run_name = "98_hardy_firefly"#"8_deep_brook"#
encoder, cfg = load_encoder(run_name, model_name, model)
cfg.run_name = run_name
print(cfg.run_name, cfg.layer, cfg.l1_coeff)

98_hardy_firefly 1 [0.0001, 0.00015]


In [4]:
run_name = cfg.run_name
max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name, prompts, model, save_activations=True)

## Loss recovered

In [15]:
# Baseline check: mean ablate MLP1
hook_name = f"blocks.{cfg.layer}.{cfg.act_name}"

def zero_ablation_hook(value, hook):
    value[:, :] = 0
    return value

def encode_activations_hook(value, hook):
    _, x_reconstruct, _, _, _ = encoder(value.squeeze(0))
    return x_reconstruct.unsqueeze(0)

zero_ablate_hook = [(hook_name, zero_ablation_hook)]
encode_mlp_hook = [(hook_name, encode_activations_hook)]

In [16]:
# Zero ablate MLP
# ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
# Mean ablate MLP
# Enable only top 2 directions
# Enable full autoencoder

losses = []
zero_abl_losses = []
recons_losses = []

for prompt in tqdm(prompts[:1000]):
    loss = model(prompt, return_type="loss", loss_per_token=False).item()
    with model.hooks(fwd_hooks=zero_ablate_hook):
        zero_abl_loss = model(prompt, return_type="loss", loss_per_token=False).item()
    with model.hooks(fwd_hooks=encode_mlp_hook):
        recons_loss = model(prompt, return_type="loss", loss_per_token=False).item()
    losses.append(loss)
    zero_abl_losses.append(zero_abl_loss)
    recons_losses.append(recons_loss)


  0%|          | 0/1000 [00:00<?, ?it/s]

In [17]:
import plotly.graph_objects as go
import numpy as np

def plot_loss_comparison(loss_groups, group_names):
    """
    Plots a bar chart comparing different loss groups.

    Parameters:
    - loss_groups: List of lists, each containing loss values for a group.
    - group_names: List of names for each loss group.
    """

    # Function to calculate standard error
    def standard_error(data):
        return np.std(data) / np.sqrt(len(data))

    # Calculate means and standard errors for each group
    means = [np.mean(group) for group in loss_groups]
    std_errors = [standard_error(group) for group in loss_groups]
    ci_95 = [se * 1.96 for se in std_errors]

    # Create the bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=group_names,
            y=means,
            error_y=dict(type='data', array=ci_95)
        )
    ])

    # Update layout
    fig.update_layout(
        title="Ablation loss comparison for closing quotation prompts '.\"'",
        xaxis_title="Ablation",
        yaxis_title="Loss",
        showlegend=False,
        width=600
    )

    fig.show()


loss_groups = [losses, zero_abl_losses, recons_losses]
group_names = ['Original loss', 'MLP zero ablation loss', 'SAE reconstruction loss']

plot_loss_comparison(loss_groups, group_names)


In [10]:
loss_recovered = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
print(loss_recovered)

0.602249678414853


## Print examples

In [24]:
def print_top_examples(prompts: list[str], activations: Float[Tensor, "n_prompts d_enc"], direction: int, encoder: AutoEncoder, cfg: AutoEncoderConfig, n=5):
    top_idxs = activations[:, direction].argsort(descending=True)[:n].cpu().tolist()
    for prompt_index in top_idxs:
        prompt = prompts[prompt_index]
        prompt_tokens = model.to_str_tokens(model.to_tokens(prompt))
        acts = get_acts(prompt, model, encoder, cfg)
        direction_act = acts[:, direction].cpu().tolist()
        max_direction_act = max(direction_act)
        if max_direction_act > 0:
            haystack_utils.clean_print_strings_as_html(prompt_tokens, direction_act, max_value=max_direction_act)

def print_direction_example(direction, n=10):
    print_top_examples(prompts, max_activations, direction, encoder, cfg, n)

In [25]:
_ = interact(print_direction_example, 
         direction=IntSlider(min=0, max=encoder.d_hidden-1, step=1, value=0),
         n=IntSlider(min=1, max=20, step=1, value=5))

interactive(children=(IntSlider(value=0, description='direction', max=16383), IntSlider(value=5, description='…

## Loss recovered

In [11]:
# Max per direction
max_val, _ = max_activations.max(0)
threshold_per_direction = (max_val * 0.17).cuda()

In [14]:
threshold_per_direction.shape

torch.Size([16384])

In [18]:
original_loss, reconstruct_loss, zero_ablation_loss = evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, prompts[:500], model)

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:14<00:00, 34.01it/s]
(INFO) 10:18:21: Average loss increase after encoding: 0.2420


In [19]:
def encode_activations_hook(value, hook):
        value = value.squeeze(0)
        x_cent = value - encoder.b_dec
        acts = F.relu(x_cent @ encoder.W_enc + encoder.b_enc)
        acts[acts < threshold_per_direction.unsqueeze(0)] = 0
        x_reconstruct = acts @ encoder.W_dec + encoder.b_dec
        return x_reconstruct.unsqueeze(0)
reconstruct_hooks = [(cfg.encoder_hook_point, encode_activations_hook)]

reconstruct_without_low_density_loss = []
for prompt in tqdm(prompts[:500]):
    with model.hooks(reconstruct_hooks):
        loss = model(prompt, return_type="loss")
        reconstruct_without_low_density_loss.append(loss.item())
print(np.mean(reconstruct_without_low_density_loss))
print(original_loss, reconstruct_loss, zero_ablation_loss)

  0%|          | 0/500 [00:00<?, ?it/s]

1.4775838609933853
1.1528477879762649 1.394811465859413 2.2178650829792024


In [23]:
acts = get_acts(prompts[0], model, encoder, cfg)
px.histogram(acts[acts>0].flatten().cpu().numpy())

In [20]:
np.mean(reconstruct_without_low_density_loss) - original_loss

0.3247360730171205

In [21]:
loss_recovered_reconstruct = (reconstruct_loss - original_loss) / (zero_ablation_loss - original_loss)
loss_recovered_reconstruct_without_low_density = (np.mean(reconstruct_without_low_density_loss) - original_loss) / (zero_ablation_loss - original_loss)
print(loss_recovered_reconstruct, loss_recovered_reconstruct_without_low_density)

0.22719225219951084 0.30491154889294525


## Find low density directions

In [52]:
def get_top_prompt_indices(max_activations, direction, k=10):
    top_idxs = max_activations[:, direction].argsort(descending=True).cpu().tolist()[:k]
    # Filter by activation > 0 
    top_idxs = [idx for idx in top_idxs if max_activations[idx, direction] > 0]
    return top_idxs

direction_top_indices = []
for direction in range(max_activations.shape[1]):
    top_idxs = get_top_prompt_indices(max_activations, direction, k=10)
    direction_top_indices.append(top_idxs)

In [53]:
from collections import Counter
top_indices_counter = Counter([idx for top_idxs in direction_top_indices for idx in top_idxs])
top_indices_counter.most_common(10)

[(14764, 1531),
 (3391, 1167),
 (699, 699),
 (12909, 648),
 (4036, 592),
 (8337, 484),
 (20202, 406),
 (13179, 320),
 (21221, 318),
 (18438, 294)]

In [54]:
top_5_indices = [idx for idx, _ in top_indices_counter.most_common(5)]
print(top_5_indices)

[14764, 3391, 699, 12909, 4036]


In [55]:
clustered_direction = []
for direction, top_indices in enumerate(direction_top_indices):
    cluster_direction = False
    for top_index in top_indices:
        if top_index in top_5_indices:
            cluster_direction = True
    if cluster_direction:
        clustered_direction.append(direction)
print(len(clustered_direction))

4013


In [56]:
directions = torch.LongTensor(clustered_direction).cuda().unique()
print(directions.shape)

torch.Size([4013])


In [57]:
# How tied are encoder and decoder
cosine_sim = torch.nn.CosineSimilarity(dim=0)
sims_global = cosine_sim(encoder.W_enc, encoder.W_dec.T)
sims_Low_density = cosine_sim(encoder.W_enc[:, directions], encoder.W_dec[directions].T)
print(sims_global.mean(0), sims_Low_density.mean(0))

tensor(0.0780, device='cuda:0') tensor(-0.0066, device='cuda:0')


In [58]:
# Sims encoder
normalized_W_enc = F.normalize(encoder.W_enc, dim=0)
cosine_sims = (normalized_W_enc.T @ normalized_W_enc)
mask = torch.tril(torch.ones_like(cosine_sims), diagonal=-1).flatten().bool()
cosine_sims = cosine_sims.flatten()[mask].mean()

cosine_sims_low_density = (normalized_W_enc[:, directions].T @ normalized_W_enc[:, directions])
mask = torch.tril(torch.ones_like(cosine_sims_low_density), diagonal=-1).flatten().bool()
cosine_sims_low_density = cosine_sims_low_density.flatten()[mask].mean()

print(cosine_sims, cosine_sims_low_density)

tensor(0.7005, device='cuda:0') tensor(0.9315, device='cuda:0')


In [59]:
# Sims decoder
#normalized_W_dec = F.normalize(encoder.W_dec, dim=1)
normalized_W_dec = encoder.W_dec
cosine_sims = (normalized_W_dec @ normalized_W_dec.T)
mask = torch.tril(torch.ones_like(cosine_sims), diagonal=-1).flatten().bool()
cosine_sims = cosine_sims.flatten()[mask]

cosine_sims_low_density = (normalized_W_dec[directions, :] @ normalized_W_dec[directions, :].T)
mask = torch.tril(torch.ones_like(cosine_sims_low_density), diagonal=-1).flatten().bool()
cosine_sims_low_density = cosine_sims_low_density.flatten()[mask]

print(cosine_sims.mean(), cosine_sims_low_density.mean())
print(cosine_sims.shape, cosine_sims_low_density.shape)

tensor(0.0005, device='cuda:0') tensor(0.0002, device='cuda:0')
torch.Size([134209536]) torch.Size([8050078])


In [60]:
fig = px.histogram(cosine_sims[:50000].cpu().numpy(), nbins=100, title="Decoder weights cosine similarity (sample)")
fig.update_layout(showlegend=False, width=900)

In [62]:
fig = px.histogram(cosine_sims_low_density[:50000].cpu().numpy(), nbins=100, title="Decoder weights cosine similarity (low density)")
fig.update_layout(showlegend=False, width=900)

## Mean ablate cluster

In [None]:
# Mean activation over entire cluster
acts = []
for prompt in prompts[:200]:
    act = get_acts(prompt, model, encoder, cfg)
    acts.append(act[:, directions].cpu())
acts = torch.cat(acts)
print(acts.shape)
mean_act = acts.mean()
print(mean_act.shape)

def get_cluster_ablation_hook(encoder, directions, hook_pos=None):
    def cluster_ablation_hook(value, hook):
        x_cent = value[0, :] - encoder.b_dec
        acts = F.relu(x_cent @ encoder.W_enc[:, directions] + encoder.b_enc[directions])
        mean_acts = torch.full_like(acts, mean_act)
        directions_impact_on_reconstruction = einops.einsum(mean_acts, encoder.W_dec[directions, :], "pos directions, directions d_mlp -> pos d_mlp") # + encoder.b_dec ???
        if hook_pos is not None:
            value[:, hook_pos, :] -= directions_impact_on_reconstruction[hook_pos]
        else:
            value[:, :] -= directions_impact_on_reconstruction
        return value
    return cluster_ablation_hook


def evaluate_cluster_ablation_single_prompt(prompt: str, encoder: AutoEncoder, model: HookedTransformer, direction: int | list[int], cfg: AutoEncoderConfig, pos: None | int = None) -> float:
    """ Pos needs to be the absolute position of the token to ablate, negative indexing does not work """
    encoder_hook_point = f"blocks.{cfg.layer}.{cfg.act_name}"
    
    original_loss = model(prompt, return_type="loss").item()
    with model.hooks(fwd_hooks=[(encoder_hook_point, get_cluster_ablation_hook(encoder, direction, pos))]):
        ablated_loss = model(prompt, return_type="loss").item()
    return original_loss, ablated_loss



losses = []
for prompt in prompts:
    original, ablated = evaluate_cluster_ablation_single_prompt(prompt, encoder, model, directions)
    losses.append(ablated - original)
print(np.mean(losses))

## Residual stream low density direction

In [None]:
# Calculate mean residual direction of low density features
# Take random residual directions of normal features
# Calculate norm of residual directions over a bunch of examples
# Compare if low density direction is outlier