In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal
from transformer_lens.utils import test_prompt
import pickle
from ipywidgets import interact, IntSlider, SelectionSlider
from transformer_lens.utils import test_prompt
import os
import plotly.graph_objects as go

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import AutoEncoderConfig, custom_forward, eval_direction_tokens_global, get_encode_activations_hook, get_activations, get_acts, load_encoder, eval_ablation_token_rank, get_direction_ablation_hook, get_top_activating_examples_for_direction, evaluate_direction_ablation_single_prompt
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line, multiple_line
%reload_ext autoreload
%autoreload 2

In [2]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)
model.set_use_attn_result(True)

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [3]:
run_name = "185_upbeat_field"#"143_graceful_darkness"# 
encoder, cfg = load_encoder(run_name, model_name, model)
cfg.run_name = run_name
print(cfg.run_name, cfg.layer, cfg.l1_coeff)

185_upbeat_field 1 0.0004


In [4]:
prompts = load_tinystories_validation_prompts()
run_name = cfg.run_name
max_activations, max_activation_token_indices = get_activations(encoder, cfg, run_name, prompts, model, save_activations=True)

(INFO) 08:26:12: Loaded 21990 TinyStories validation prompts
100%|██████████| 21990/21990 [03:53<00:00, 94.17it/s] 


Active directions on validation data: 16384 out of 16384


## Plot results from different runs

In [8]:
# Step 1: Read each CSV file and store in a list
dataframes = []
for filename in os.listdir("./data/low_density_analysis"):
    if filename.endswith(".csv"):
        filepath = os.path.join("./data/low_density_analysis", filename)
        df = pd.read_csv(filepath)
        dataframes.append(df)

# Step 2: Concatenate all DataFrames into one
combined_df = pd.concat(dataframes)

# Categorize columns
l2_loss_columns = [col for col in combined_df.columns if 'l2_loss' in col]
remaining_loss_columns = [col for col in combined_df.columns if 'loss' in col and col not in l2_loss_columns]
other_columns = [col for col in combined_df.columns if 'loss' not in col and col != 'run']

# Melt the DataFrames for plotting
l2_loss_df_melted = combined_df.melt(id_vars='run', value_vars=l2_loss_columns, var_name='variable', value_name='value')
remaining_loss_df_melted = combined_df.melt(id_vars='run', value_vars=remaining_loss_columns, var_name='variable', value_name='value')
other_df_melted = combined_df.melt(id_vars='run', value_vars=other_columns, var_name='variable', value_name='value')

# Plot 1: 'l2_loss' columns
fig_l2_loss = px.bar(l2_loss_df_melted, 
                x='variable', 
                y='value', 
                color='run', 
                barmode='group',
                title='Ablation of low density cluster autoencoder L2 loss')

fig_l2_loss.update_layout(
    xaxis_title="L2 Loss Variable",
    yaxis_title="Value",
    legend_title="Run"
)

fig_l2_loss.show()

# Plot 2: Remaining 'loss' columns
fig_remaining_loss = px.bar(remaining_loss_df_melted, 
                   x='variable', 
                   y='value', 
                   color='run', 
                   barmode='group',
                   title='Reconstruction loss ablation of low density cluster')

fig_remaining_loss.update_layout(
    xaxis_title="Remaining Loss Variable",
    yaxis_title="Value",
    legend_title="Run"
)

fig_remaining_loss.show()

# Plot 3: Other columns
fig_other = px.bar(other_df_melted, 
                   x='variable', 
                   y='value', 
                   color='run', 
                   barmode='group',
                   title='Low density clusters')

fig_other.update_layout(
    xaxis_title="Other Variable",
    yaxis_title="Value",
    legend_title="Run"
)

fig_other.show()


## Identifying low density clusters

### Cosine sim thresholds

In [5]:
normalized_W_enc = F.normalize(encoder.W_enc, dim=0)
cosine_sims = (normalized_W_enc.T @ normalized_W_enc)
mask = torch.tril(torch.ones_like(cosine_sims), diagonal=-1).flatten().bool()
unique_cosine_sims = cosine_sims.flatten()[mask]

In [6]:
unique_cosine_sims_np = unique_cosine_sims.cpu().numpy()

# Compute histogram
bin_counts, bin_edges = np.histogram(unique_cosine_sims_np, bins=50) # You can adjust the number of bins

# Create bin centers from bin edges for plotting
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

total_elements = unique_cosine_sims_np.size
bin_percentages = (bin_counts / total_elements) * 100

# Plot histogram
fig = go.Figure(data=go.Bar(x=bin_centers, y=bin_percentages))
fig.update_layout(title=f"{run_name}: Histogram of pairwise encoder cosine sims",
                  xaxis_title="Cosine Similarity",
                  yaxis_title="Percentage (%)",
                  width=900)
fig.show()

In [7]:
## Plot histogram of 90th percentile cosine sim per direction
percentile_cosine_sims = []
for direction in range(encoder.d_hidden):
    direction_cosine_sims = cosine_sims[direction]
    percentile_cosine_sim = np.percentile(direction_cosine_sims.cpu().numpy(), 90)
    percentile_cosine_sims.append(percentile_cosine_sim)

print(len(percentile_cosine_sims))

16384


In [8]:
px.histogram(percentile_cosine_sims, nbins=50, histnorm='percent', title=f"{run_name}: Histogram of 90th percentile cosine sim per direction")

In [9]:
# Number of > 0.8 cosine sims per direction
percent_high_cosine_sims = []
for direction in range(encoder.d_hidden):
    direction_cosine_sims = cosine_sims[direction]
    percent_above_threshold = (direction_cosine_sims > 0.80).sum() / encoder.d_hidden
    percent_high_cosine_sims.append(percent_above_threshold.item())

fig = px.histogram(percent_high_cosine_sims, nbins=50, histnorm='percent', title=f"{run_name}: Histogram of number of > 0.8 cosine sims per direction")
fig.update_layout(xaxis_title="Percentage of directions", yaxis_title="Percent")

In [10]:
# Thresholds
min_cosine_sim = 0.8
min_high_cosine_sims = 0.8

high_cosine_sim_cluster = []

for direction in range(encoder.d_hidden):
    direction_cosine_sims = cosine_sims[direction]
    percent_above_threshold = (direction_cosine_sims > min_cosine_sim).sum() / encoder.d_hidden
    if percent_above_threshold > min_high_cosine_sims:
        high_cosine_sim_cluster.append(direction)

print(f"Number of high cosine sim directions: {len(high_cosine_sim_cluster)} ({(len(high_cosine_sim_cluster) / encoder.d_hidden):.2f}%)")

Number of high cosine sim directions: 0 (0.00%)


### Identifying through OOD validation prompt activations

In [11]:
num_top_prompts_per_direction=10
num_most_common_prompts_considered=5

In [12]:
def get_top_prompt_indices(max_activations, direction, k):
    top_idxs = max_activations[:, direction].argsort(descending=True).cpu().tolist()[:k]
    # Filter by activation > 0 
    top_idxs = [idx for idx in top_idxs if max_activations[idx, direction] > 0]
    return top_idxs

direction_top_indices = []
for direction in range(encoder.d_hidden):
    top_idxs = get_top_prompt_indices(max_activations, direction, k=num_top_prompts_per_direction)
    direction_top_indices.append(top_idxs)

top_indices_counter = Counter([idx for top_idxs in direction_top_indices for idx in top_idxs])
top_5_indices = [idx for idx, _ in top_indices_counter.most_common(num_most_common_prompts_considered)]

print(f"Top prompt occurrences: {top_indices_counter.most_common(num_most_common_prompts_considered)}")

clustered_direction = []
for direction, top_indices in enumerate(direction_top_indices):
    cluster_direction = False
    for top_index in top_indices:
        if top_index in top_5_indices:
            cluster_direction = True
    if cluster_direction:
        clustered_direction.append(direction)

print(f"Number of clustered directions: {len(clustered_direction)} ({(len(clustered_direction) / encoder.d_hidden):.2f})")

Top prompt occurrences: [(20232, 64), (3995, 64), (2171, 63), (18927, 62), (13668, 62)]
Number of clustered directions: 281 (0.02)


### Feature density


In [13]:
# Fraction of tokens for which each features has a nonzero value

feature_activations = torch.zeros(encoder.d_hidden).to(torch.long)
total_num_tokens = 0
for prompt in tqdm(prompts):
    acts = get_acts(prompt, model, encoder, cfg).cpu()
    num_tokens = acts.shape[0]
    acts = (acts>0).sum(0)
    total_num_tokens += num_tokens
    feature_activations += acts

feature_density = feature_activations / total_num_tokens
print(np.mean(feature_density.cpu().numpy()))

  0%|          | 0/21990 [00:00<?, ?it/s]

0.0023473988


In [20]:
# Calculate histogram with numpy

hist, bin_edges = np.histogram(feature_density, bins=np.logspace(np.log10(1e-7), np.log10(1), 50), density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
hist_percent = (hist / hist.sum()) * 100

fig = go.Figure()
for i in range(len(hist_percent)):
    fig.add_trace(go.Bar(
        x=[bin_centers[i]], 
        y=[hist_percent[i]],
        width=[bin_edges[i+1] - bin_edges[i]], # explicit width
        showlegend=False,
        marker_color='blue'
    ))

fig.update_layout(
    xaxis_title="Feature density",
    yaxis_title="Percent",
    title=f"{run_name}: Histogram of feature density",
    xaxis=dict(
        type="log",
        tickvals=[1e-7, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1],
        ticktext=['1e-7', '1e-6', '1e-5', '1e-4', '1e-3', '1e-2', '1e-1', '1'],
        range=[np.log10(1e-7), np.log10(1)]
    ),
    yaxis=dict(
        range=[np.log10(0.001), np.log10(100)], # Adjusted for percent
        type="log"
    ),
    barmode='overlay'
)
fig.show()

In [21]:
# Cutoff at 1e-5

low_density_features = []
for direction in range(encoder.d_hidden):
    if feature_density[direction] < 1e-5:
        low_density_features.append(direction)

print(f"Number of low density features: {len(low_density_features)} ({(len(low_density_features) / encoder.d_hidden):.2f})")

Number of low density features: 7 (0.00)


In [22]:
print(feature_density[:10])

tensor([0.0015, 0.0013, 0.0026, 0.0014, 0.0011, 0.0020, 0.0003, 0.0008, 0.0025,
        0.0019])


In [23]:
percent_ood_directions = len(clustered_direction) / encoder.d_hidden
percent_high_cosine_sims = len(high_cosine_sim_cluster) / encoder.d_hidden
percent_low_density_features = len(low_density_features) / encoder.d_hidden

print(percent_high_cosine_sims, percent_ood_directions, percent_low_density_features)

0.0 0.01715087890625 0.00042724609375


In [24]:
all_ood_directions = list(set(clustered_direction + high_cosine_sim_cluster + low_density_features))
percent_good_directions = 1 - (len(all_ood_directions) / encoder.d_hidden)
print(f"Percent good directions: {percent_good_directions}")

Percent good directions: 0.982421875


In [25]:
# Check overlap of sets
print(len(set(clustered_direction).intersection(set(high_cosine_sim_cluster))) / min(len(clustered_direction), len(high_cosine_sim_cluster)))
print(len(set(clustered_direction).intersection(set(low_density_features))) / min(len(clustered_direction), len(low_density_features)))
print(len(set(high_cosine_sim_cluster).intersection(set(low_density_features))) / min(len(high_cosine_sim_cluster), len(low_density_features)))

ZeroDivisionError: division by zero

In [26]:
# Barplot of percents
fig = go.Figure()
fig.add_trace(go.Bar(
    x=['High cosine sim directions', 'Low density directions', 'High OOD activation directions', 'Good directions'],
    y=[percent_high_cosine_sims, percent_low_density_features, percent_ood_directions, percent_good_directions],
    marker_color='blue'
))
fig.update_layout(
    yaxis_title="Percent",
    title=f"{run_name}: Percent of low density directions",
    width=600
)
fig.show()

## Activation frequency


In [27]:
low_density_directions = torch.LongTensor(sorted(list(set(clustered_direction + high_cosine_sim_cluster + low_density_features))))
low_density_directions = low_density_directions.to(device)
print(low_density_directions.shape, low_density_directions[:10])

torch.Size([288]) tensor([ 29,  49,  96, 123, 209, 212, 273, 310, 311, 325], device='cuda:0')


In [28]:
low_density_acts = torch.zeros_like(low_density_directions).cuda()
all_acts = torch.zeros(encoder.d_hidden).cuda()
total_tokens = 0
for prompt in tqdm(prompts[:2000]):
    acts = get_acts(prompt, model, encoder, cfg)
    num_tokens = acts.shape[0]
    acts = (acts>0).sum(0)
    total_tokens += num_tokens
    all_acts += acts
    low_density_acts += acts[low_density_directions]

print(low_density_acts.sum() / total_tokens)
print(all_acts.sum() / total_tokens)

  0%|          | 0/2000 [00:00<?, ?it/s]

tensor(0.2797, device='cuda:0')
tensor(37.8890, device='cuda:0')


In [245]:
print(low_density_acts.sum())
print(all_acts.sum())

tensor(2038189, device='cuda:0')
tensor(14733732., device='cuda:0')


In [246]:
token_active_dirs = []
for prompt in tqdm(prompts[:2000]):
    acts = get_acts(prompt, model, encoder, cfg)[1:]
    num_tokens = acts.shape[0]
    acts = (acts>0).sum(1).tolist()
    token_active_dirs.extend(acts)

print(len(token_active_dirs))
print(np.mean(token_active_dirs))

  0%|          | 0/2000 [00:00<?, ?it/s]

387121
38.04942640673071


In [247]:
px.histogram(token_active_dirs, nbins=50, title=f"{run_name}: Histogram of number of active directions per token")

## Low density ablation analysis

In [248]:
# Mean activation over 2000 prompts
low_density_acts = torch.zeros(len(low_density_directions)).to(device)
total_tokens = 0
for prompt in tqdm(prompts[:2000]):
    acts = get_acts(prompt, model, encoder, cfg)
    low_density_act = acts[:, low_density_directions].sum(0)
    low_density_acts += low_density_act
    total_tokens += acts.shape[0]
low_density_acts /= total_tokens
print(low_density_acts.shape)

  0%|          | 0/2000 [00:00<?, ?it/s]

torch.Size([9508])


In [249]:
px.histogram((low_density_acts/total_tokens).cpu().numpy(), nbins=50, title=f"{run_name}: Histogram of mean activation of low density directions", width=800)

In [250]:
def custom_forward(
    enc: AutoEncoder, x: Float[Tensor, "batch d_in"], neuron: int | Tensor, activation: float | Tensor
):
    x_cent = x - enc.b_dec
    acts = F.relu(x_cent @ enc.W_enc + enc.b_enc)
    acts[:, neuron] = activation
    x_reconstruct = acts @ enc.W_dec + enc.b_dec
    l2_loss = (x_reconstruct - x).pow(2).sum(-1).mean(0)
    return x_reconstruct, l2_loss

In [251]:
# Zero ablate low density

original_l2_losses = []
zero_ablated_l2_losses = []
mean_ablated_l2_losses = []

original_reconstruct_losses = []
zero_reconstruct_losses = []
mean_reconstruct_losses = []
original_model_losses = []

for prompt in tqdm(prompts[:1000]):
    _, cache = model.run_with_cache(prompt)
    mlp_activation = cache[cfg.encoder_hook_point][0]
    _, x_reconstruct_original, _, l2_loss, _ = encoder(mlp_activation)
    original_l2_losses.append(l2_loss.item())
    x_reconstruct_zero, l2_loss = custom_forward(encoder, mlp_activation, low_density_directions, activation=0)
    zero_ablated_l2_losses.append(l2_loss.item())
    x_reconstruct_mean, l2_loss = custom_forward(encoder, mlp_activation, low_density_directions, activation=low_density_acts/total_tokens)
    mean_ablated_l2_losses.append(l2_loss.item())

    model_loss = model(prompt, return_type="loss").item()

    def reconstruct_hook(value, hook):
        value = x_reconstruct_original.unsqueeze(0)
        return value

    def zero_ablate_hook(value, hook):
        value = x_reconstruct_zero.unsqueeze(0)
        return value
    
    def mean_ablate_hook(value, hook):
        value = x_reconstruct_mean.unsqueeze(0)
        return value
        
    with model.hooks([(cfg.encoder_hook_point, zero_ablate_hook)]):
        zero_reconstruct_loss = model(prompt, return_type="loss").item()
    
    with model.hooks([(cfg.encoder_hook_point, mean_ablate_hook)]):
        mean_reconstruct_loss = model(prompt, return_type="loss").item()
    
    with model.hooks([(cfg.encoder_hook_point, reconstruct_hook)]):
        original_reconstruct_loss = model(prompt, return_type="loss").item()

    original_reconstruct_losses.append(original_reconstruct_loss)
    zero_reconstruct_losses.append(zero_reconstruct_loss)
    mean_reconstruct_losses.append(mean_reconstruct_loss)
    original_model_losses.append(model_loss)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [252]:
fig = go.Figure()
fig.add_trace(go.Bar(
    x=["Original L2 loss", "Zero ablated L2 loss", "Mean ablated L2 loss"],
    y=[np.mean(original_l2_losses), np.mean(zero_ablated_l2_losses), np.mean(mean_ablated_l2_losses)],
    marker_color='blue'
))
fig.update_layout(
    yaxis_title="Percent",
    title=f"{run_name}: Autoencoder L2 loss",
    width=600
)
fig.show()

In [253]:
fig = go.Figure()
fig.add_trace(go.Bar(
    y=[np.mean(original_model_losses), np.mean(original_reconstruct_losses), np.mean(zero_reconstruct_losses), np.mean(mean_reconstruct_losses)],
    x=["Original model loss", "Original reconstruction loss", "Zero ablated reconstruction loss", "Mean ablated reconstruction loss"],
    marker_color='blue'
))
fig.update_layout(
    yaxis_title="Percent",
    title=f"{run_name}: Autoencoder reconstruction loss from ablating low density directions",
    width=600
)
fig.show()

## Summary DF

In [254]:
res = {
    "run": run_name,
    "original_l2_loss": np.mean(original_l2_losses),
    "zero_ablated_l2_loss": np.mean(zero_ablated_l2_losses),
    "mean_ablated_l2_loss": np.mean(mean_ablated_l2_losses),
    "original_model_loss": np.mean(original_model_losses),
    "original_reconstruct_loss": np.mean(original_reconstruct_losses),
    "zero_reconstruct_loss": np.mean(zero_reconstruct_losses),
    "mean_reconstruct_loss": np.mean(mean_reconstruct_losses),
    "percent_good_directions": percent_good_directions,
    "percent_high_cosine_sims": percent_high_cosine_sims,
    "percent_ood_directions": percent_ood_directions,
    "percent_low_density_features": percent_low_density_features,
}

In [255]:
df = pd.DataFrame(res, index=[0])
df.to_csv(f"./data/low_density_analysis/{run_name}_low_density_eval.csv", index=False)

## Dumb comparison - MLP density

In [None]:
mlp_activations = torch.zeros(model.cfg.d_mlp).to(torch.long)
total_num_tokens = 0

for prompt in tqdm(prompts):
    _, cache = model.run_with_cache(prompt)
    mlp_activations = cache[cfg.encoder_hook_point][0]
    active_neurons = (mlp_activations > 0).sum(0)
    total_tokens = mlp_activations.shape[0]
    total_num_tokens += total_tokens
    mlp_activations += active_neurons

print(mlp_activations.mean())

  0%|          | 0/21990 [00:00<?, ?it/s]

tensor(45.6379, device='cuda:0')


In [None]:
# Calculate histogram with numpy

mlp_density = mlp_activations / total_num_tokens
hist, bin_edges = np.histogram(mlp_density.cpu().numpy(), bins=np.logspace(np.log10(1e-8), np.log10(1), 50), density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
hist_percent = (hist / hist.sum()) * 100

fig = go.Figure()
for i in range(len(hist_percent)):
    fig.add_trace(go.Bar(
        x=[bin_centers[i]], 
        y=[hist_percent[i]],
        width=[bin_edges[i+1] - bin_edges[i]], # explicit width
        showlegend=False,
        marker_color='blue'
    ))

fig.update_layout(
    xaxis_title="Feature density",
    yaxis_title="Percent",
    title=f"{run_name}: Histogram of feature density",
    xaxis=dict(
        type="log",
        tickvals=[1e-8, 1e-7, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1],
        ticktext=['1e-8', '1e-7', '1e-6', '1e-5', '1e-4', '1e-3', '1e-2', '1e-1', '1'],
        range=[np.log10(1e-8), np.log10(1)]
    ),
    yaxis=dict(
        range=[np.log10(0.001), np.log10(100)], # Adjusted for percent
        type="log"
    ),
    barmode='overlay'
)
fig.show()

In [None]:
# Average MLP neurons active per token
active_neurons = []
for prompt in tqdm(prompts[:1000]):
    _, cache = model.run_with_cache(prompt)
    mlp_activations = cache[cfg.encoder_hook_point][0]
    active_neurons_per_token = (mlp_activations > 0).sum(1).float().mean(0)
    active_neurons.append(active_neurons_per_token.item())    
print(np.mean(active_neurons))

  0%|          | 0/1000 [00:00<?, ?it/s]

705.6655244750976
