## Setup

In [56]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward

%reload_ext autoreload
%autoreload 2

In [None]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")

english_activations = {}
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
english_activations[LAYER_TO_ABLATE] = haystack_utils.get_mlp_activations(english_data[:100], LAYER_TO_ABLATE, model, mean=False)
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

# Load bigrams
with open("../context_neuron/data/checkpoint/high_indirect_loss_trigrams.json", "r") as f:
    trigrams = json.load(f)

all_ignore, valid_tokens = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data[:200], model, all_ignore, k=100)

In [76]:
save_name = "25_gallant_monkey"
model_name = 'pythia-70m'
path = Path('pythia-70m')

# with open(f"{model_name}/{save_name}.json", "r") as f:
#     cfg = json.load(f)

In [77]:
# Load 70m dict
layer = 5
act_name = "hook_mlp_out" #"mlp.hook_post"
if act_name == "hook_mlp_out":
    d_in = model.cfg.d_model #d_mlp
else:
    d_in = model.cfg.d_mlp
encoder_hook_point = f"blocks.{layer}.{act_name}"
expansion_factor = 8
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.001

def pickle_pt(name: str, path: Path):
    autoencoder = torch.load(os.path.join(path, name + '.pt'))
    with open(os.path.join(path, name + '.pkl'), 'wb') as f:
        pickle.dump(autoencoder, f)

pickle_pt(name=save_name, path=path)

autoencoder_70m = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
autoencoder_70m_filename = os.path.join(path, save_name + '.pkl')
with open(autoencoder_70m_filename, 'rb') as f:
    autoencoder_70m_state_dict = pickle.load(f)
autoencoder_70m.load_state_dict(autoencoder_70m_state_dict)
autoencoder_70m.to(device)


AutoEncoder()

## Analysis

In [78]:
# Loss increase
def evaluate_dict(autoencoder: AutoEncoder, encoded_hook_name: str, german_data: list):
    def encode_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [(encoded_hook_name, encode_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

    print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

evaluate_dict(autoencoder_70m, encoder_hook_point, german_data=german_data)

  0%|          | 0/200 [00:00<?, ?it/s]

Average loss increase after encoding: 0.2400


In [79]:
active_features = torch.zeros(autoencoder_dim).bool().cuda()
for prompt in tqdm(german_data[:200]):
    _, cache = model.run_with_cache(
        prompt, names_filter=f"blocks.{layer}.{act_name}"
        )
    acts = cache[f"blocks.{layer}.{act_name}"].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    active_features = active_features | (mid_acts.sum(dim=0) > 0)
print(active_features.sum())

  0%|          | 0/200 [00:00<?, ?it/s]

tensor(2058, device='cuda:0')


In [80]:
def encoder_dla(tokens: Int[Tensor, "batch pos"], model: HookedTransformer, encoder: AutoEncoder) -> Float[Tensor, "pos n_neurons"]:
    
    _, cache = model.run_with_cache(tokens)
    mlp_activations = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1]
    _, _, mid_acts, _, _ = autoencoder_70m(mlp_activations)

    W_U_token = model.W_U[:, tokens.flatten()]
    W_out_U_token = model.W_out[layer] @ W_U_token # (n_mlp_neurons, n_tokens)
    W_dec_W_out_U_token = encoder.W_dec @ W_out_U_token # (n_encoder_neurons, n_tokens)
    
    #dla = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1] * W_out_U_token[:, 1:].T
    dla = mid_acts * W_dec_W_out_U_token[:, 1:].T
    scale = cache["ln_final.hook_scale"][0, :-1]
    dla = dla / scale
    return dla

# for prompt in tqdm(german_data[:200]):
#     tokens = model.to_tokens(prompt)
#     dla = encoder_dla(tokens, model, autoencoder_70m)
#     #fig = px.line(dla[-5].cpu().numpy())
#     #fig.show()
#     fig = px.histogram(dla.flatten().cpu().numpy(), nbins=100)
#     fig.update_layout(
#         # xaxis limit
#         yaxis=dict(
#             range=[0, 120],
#         ))
#     fig.show()
#     break

In [81]:
def encoder_dla_batched(tokens: Int[Tensor, "batch pos"], model: HookedTransformer, encoder: AutoEncoder) -> Float[Tensor, "pos n_neurons"]:
    batch_dim, seq_len = tokens.shape
    _, cache = model.run_with_cache(tokens)
    mlp_activations = cache[encoder_hook_point][:, :-1]
    _, _, mid_acts, _, _ = autoencoder_70m(mlp_activations)
    W_U_token = einops.rearrange(model.W_U[:, tokens.flatten()], "d_res (batch pos) -> d_res batch pos", batch=batch_dim, pos=seq_len)
    if act_name == "mlp.hook_post":
        W_out_U_token = einops.einsum(model.W_out[layer], W_U_token, "d_mlp d_res, d_res batch pos -> d_mlp batch pos")
        W_dec_W_out_U_token = einops.einsum(encoder.W_dec, W_out_U_token, "d_dec d_mlp, d_mlp batch pos -> d_dec batch pos")
        dla = einops.einsum(mid_acts, W_dec_W_out_U_token[:, :, 1:], "batch pos d_dec, d_dec batch pos -> batch pos d_dec")
    elif act_name == "hook_mlp_out":
        W_dec_U_token = einops.einsum(encoder.W_dec, W_U_token, "d_dec d_res, d_res batch pos -> d_dec batch pos")
        dla = einops.einsum(mid_acts, W_dec_U_token[:, :, 1:], "batch pos d_dec, d_dec batch pos -> batch pos d_dec")
    else:
        raise ValueError("Unknown act_name")
    scale = cache["ln_final.hook_scale"][:, :-1]
    dla = dla / scale
    return dla

In [93]:
trigram = " Vorschlägen" #trigrams[0]
tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)
dla = encoder_dla_batched(tokens, model, autoencoder_70m)[:, -1].mean(0)
px.line(dla.cpu().numpy(), title=f"Average autoencoder DLA for '{trigram}' (100 samples)")

In [94]:
# Investigate voschlägen neuron
encoder_neuron = 986
correct_token = model.to_single_token("gen")
incorrect_token = model.to_single_token("ge")

import plotly.graph_objects as go
# Which tokens is it boosting
if act_name == "mlp.hook_post":
    boosts = (autoencoder_70m.W_dec[encoder_neuron] @ model.W_out[layer]) @ model.W_U
else:
    boosts = autoencoder_70m.W_dec[encoder_neuron] @ model.W_U
top_boosts, top_tokens = torch.topk(boosts[valid_tokens.cpu()], 15)
top_deboosts, top_deboosted_tokens = torch.topk(boosts[valid_tokens.cpu()], 15, largest=False)
print("Boosted", model.to_str_tokens(top_tokens))
print("Deboosted", model.to_str_tokens(top_deboosted_tokens))
print(boosts[correct_token], boosts[incorrect_token])
fig = px.histogram(boosts[valid_tokens].cpu().numpy(), nbins=100, title="Histogram of N1752 token-wise DLA", histnorm="probability")
fig.update_layout(
    showlegend=False,
)
# fig.add_shape(
#     go.layout.Shape(
#         type="line",
#         x0=boosts[correct_token].item(),
#         x1=boosts[correct_token].item(),
#         y0=0,
#         y1=1,
#         yref="paper",
#         line=dict(color="Red")
#     )
# )

# fig.add_annotation(
#     x=boosts[correct_token].item()+0.03,
#     y=0.95,
#     yref="paper",
#     text="gen",
#     showarrow=False,
#     arrowhead=7,
#     ax=0,
#     ay=-40
# )

Boosted ['ogn', ' tab', ' her', ' nicer', ' online', ' leaving', ' Kentucky', ' Python', 'aign', ' cam', ' Alb', 'efore', 'Let', ' hanging', 'Vs']
Deboosted ['2011', 'lens', ')]{}', ' Ho', ' purpose', ' acknowledged', ' conclusive', ' Stephan', ' []', 'gay', ' purchaser', ' measles', '561', 'Ui', ' nmol']
tensor(0.6078, device='cuda:0') tensor(0.5691, device='cuda:0')


In [95]:
# Boosted tokens with reasonable logprob
token = model.to_single_token(" ver")
prompt_length = 50
token_prompts = []
for prompt in german_data[:100]:
    tokenized_prompt = model.to_tokens(prompt).flatten()
    token_indices = torch.where(tokenized_prompt == token)
    if len(token_indices[0]) > 0:
        for i in token_indices[0]:
            if i > prompt_length:
                new_prompt = tokenized_prompt[i-prompt_length:i+1]
                token_prompts.append(new_prompt)
token_prompts = torch.stack(token_prompts)

In [86]:
def zero_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -1], encoder_neuron, 0)
    value[:, -1] = x_reconstruct
    return value

def encode_activations_hook(value, hook):
    global feature_activations
    _, x_reconstruct, acts, _, _ = autoencoder_70m(value[:, -1])
    feature_activations = acts[:, encoder_neuron]
    value[:, -1] = x_reconstruct
    return value

feature_activations = []

with model.hooks([(encoder_hook_point, encode_activations_hook)]): 
    logprobs = model(token_prompts, return_type="logits").log_softmax(-1)[:, -1]

with model.hooks([(encoder_hook_point, zero_feature_hook)]):
    logprobs_with_zeroed = model(token_prompts, return_type="logits").log_softmax(-1)[:, -1]

def get_boosts(logprobs_active: Float[Tensor, "d_vocab"], logprobs_inactive: Float[Tensor, "d_vocab"]):
    boosts = logprobs_active - logprobs_inactive
    boosts[logprobs_active < -5] = 0
    boosts[all_ignore] = 0
    top_boosts, top_tokens = torch.topk(boosts, 15)
    top_deboosts, top_deboosted_tokens = torch.topk(boosts, 15, largest=False)
    print("Boosted", model.to_str_tokens(top_tokens))
    print("Deboosted", model.to_str_tokens(top_deboosted_tokens))

get_boosts(logprobs.mean(0), logprobs_with_zeroed.mean(0))


Boosted ['ur', 'urs', 'mitt', 'fol', 'let', 'abs', 'bot', 'lass', 'pf', 'gle', 'gang', 'ste', 'stand', 'mut', 'me']
Deboosted ['&', '%', '#', '$', '<|padding|>', '<|endoftext|>', '!', '"', '*', ')', "'", '(', ',', '+', '-']


In [87]:
feature_activations

tensor([14.5645, 14.5904, 13.9783, 10.8091, 12.9292, 12.8410, 13.7169, 14.5916,
        14.1706, 15.1051, 14.8157,  9.9929, 14.2408, 13.4646, 13.7074, 10.2156,
        13.8812, 14.7810, 14.9022, 13.1998, 13.0426, 14.7562, 13.4728, 14.2456,
        13.7223, 14.8273, 14.4944, 14.3098, 14.4047, 13.6450, 13.4145, 13.8052,
        13.6785, 12.9670, 11.6629, 13.5181, 14.1498, 14.7722, 14.9693, 14.7393,
        13.8400, 13.6212, 14.3838, 14.3533, 14.3533,  8.7597, 15.0152, 14.5938,
        13.9294, 14.9663, 15.1596, 14.7936, 14.7925, 14.4883, 14.1538, 14.7440,
        15.2069, 13.6349, 15.3691, 14.5944, 12.8339, 14.7341, 14.5824, 14.5824,
        13.5379, 14.8867, 14.6598, 15.1083, 12.8749, 13.9419, 13.4518, 11.9226,
        14.0936, 14.5040, 13.5822, 14.5665, 14.4684, 13.5500, 15.2841, 13.9638,
        10.2931,  9.4098, 14.9571, 13.4446, 13.3654, 14.4480, 14.7982, 14.8426,
        15.2994, 14.2140, 15.2116, 14.3005, 15.2117, 13.5310, 11.2197, 14.6147,
        14.7895,  9.7046, 10.9305, 14.11

In [96]:
# Which dataset examples active the direction

for prompt in german_data[:5]:
    _, cache = model.run_with_cache(
        prompt, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    neuron_act = mid_acts[:, encoder_neuron]
    if neuron_act.max() > 0:
        str_tokens = model.to_str_tokens(model.to_tokens(prompt))
        haystack_utils.clean_print_strings_as_html(str_tokens, neuron_act.cpu().numpy(), max_value=20)

In [97]:
# What happens when we ablate the context neuron

loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
print("Mean loss", loss.item())
_, cache = model.run_with_cache(
    tokens, names_filter=encoder_hook_point
    )
acts = cache[encoder_hook_point]
loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)

mean_active = mid_acts[:, -2, encoder_neuron].mean().item()
print("Mean feature activation active", mean_active)

with model.hooks(deactivate_neurons_fwd_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss", loss.item())
    _, cache = model.run_with_cache(
        tokens, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point]
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    mean_inactive = mid_acts[:, -2, encoder_neuron].mean().item()
    print("Mean feature activation inactive", mean_inactive)

Mean loss 1.1678730249404907
Mean feature activation active 8.189324378967285
Mean loss 3.551626443862915
Mean feature activation inactive 7.340190887451172


In [98]:
downstream_components = ("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")

original_metric, activated_metric, ablated_metric, direct_effect_metric, indirect_effect_metric = haystack_utils.get_context_effect(tokens, model, deactivate_neurons_fwd_hooks, context_activation_hooks=[], downstream_components=downstream_components, pos=-1)
print(original_metric.mean(), ablated_metric.mean(), direct_effect_metric.mean(), indirect_effect_metric.mean())

tensor(1.1679, device='cuda:0') tensor(3.5516, device='cuda:0') tensor(4.6354, device='cuda:0') tensor(1.1486, device='cuda:0')


In [90]:
# Set all features to context active or inactive
_, cache = model.run_with_cache(
    tokens, names_filter=encoder_hook_point
    )
acts_active = cache[encoder_hook_point][:, -2]

with model.hooks(deactivate_neurons_fwd_hooks):
    _, cache = model.run_with_cache(
        tokens, names_filter=encoder_hook_point
        )
    acts_inactive = cache[encoder_hook_point][:, -2]

def activate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = autoencoder_70m(acts_active)
    value[:, -2] = x_reconstruct
    return value

def deactivate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = autoencoder_70m(acts_inactive)
    value[:, -2] = x_reconstruct
    return value

activate_hooks = [(encoder_hook_point, activate_feature_hook)]
deactivate_hooks = [(encoder_hook_point, deactivate_feature_hook)]

with model.hooks(activate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss active", loss.item())

with model.hooks(deactivate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss inactive", loss.item())

Mean loss active 5.337618827819824
Mean loss inactive 5.60790491104126


In [72]:
# Set single feature to context active or inactive
def activate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, mean_active)
    value[:, -2] = x_reconstruct
    return value

def deactivate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, mean_inactive)
    value[:, -2] = x_reconstruct
    return value

def zero_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, 0)
    value[:, -2] = x_reconstruct
    return value

activate_hooks = [(encoder_hook_point, activate_feature_hook)]
deactivate_hooks = [(encoder_hook_point, deactivate_feature_hook)]
zero_ablate_hooks = [(encoder_hook_point, zero_feature_hook)]

with model.hooks(activate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss active", loss.item())

with model.hooks(deactivate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss inactive", loss.item())

with model.hooks(zero_ablate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss 0 ablated", loss.item())

Mean loss active 5.366043567657471
Mean loss inactive 6.055084228515625
Mean loss 0 ablated 9.260305404663086


In [75]:
# 0 ablate MLP5

def zero_layer_hook(value, hook):
    value = 0
    return value

zero_layer_hooks = [("blocks.5.hook_mlp_out", zero_layer_hook)]

with model.hooks(zero_layer_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss 0 ablated", loss.item())

Mean loss 0 ablated 11.549064636230469


In [21]:
# Cosine sim context neuron and feature
feature_W_enc = autoencoder_70m.W_enc[:, encoder_neuron]
if act_name == "mlp.hook_post":
    feature_W_enc = model.W_in[layer] @ feature_W_enc
else:
    feature_W_enc = model.W_in[layer] @ (model.W_out[layer] @ feature_W_enc)
context_W_out = model.W_out[3, 669]
sim = torch.cosine_similarity(feature_W_enc, context_W_out, dim=0)
print(sim)

tensor(0.0322, device='cuda:0')


In [97]:

answer_residual_directions = einops.repeat(feature_W_enc, "d_res -> repeat d_res", repeat=100)  # [batch pos d_model]
print(answer_residual_directions.shape, tokens.shape)
_, cache = model.run_with_cache(tokens)
accumulated_residual, labels = cache.decompose_resid(layer=5, pos_slice=-2, return_labels=True)
scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=5, pos_slice=-2)
logit_attribution = einops.einsum(scaled_residual_stack, answer_residual_directions, "component batch d_model, batch d_model -> batch component")

logit_attribution = logit_attribution.mean(0).cpu().numpy()
index = list(range(len(labels)))

fig = px.line(x=index, y=logit_attribution, title='Vorschlägen autoencoder direction DLA', width=1000)
fig.update_xaxes(title='Index', tickmode='array', tickvals=list(range(len(labels))), ticktext=labels)
fig.update_yaxes(title='Logit Attribution')
fig.show()

torch.Size([100, 512]) torch.Size([100, 24])


In [None]:
# Train directions for earlier MLPs, check if specific directions activate it

In [104]:
# Decoder - context cosine sims
decoder_W_enc = autoencoder_70m.W_enc
context_neuron_W_out = model.W_out[3, 669].unsqueeze(1)
print(decoder_W_enc.shape, context_neuron_W_out.shape)

sims = torch.cosine_similarity(decoder_W_enc, context_neuron_W_out, dim=0)

px.histogram(sims.flatten().cpu().numpy())

torch.Size([512, 4096]) torch.Size([512, 1])


In [105]:
top_sims, top_directions = torch.topk(sims.flatten(), 10)
print(top_sims, top_directions)

tensor([0.1962, 0.1779, 0.1678, 0.1653, 0.1565, 0.1457, 0.1455, 0.1412, 0.1374,
        0.1368], device='cuda:0') tensor([3571, 3433, 1700, 2572, 1724,  801, 3932, 4076, 1915, 1484],
       device='cuda:0')


In [None]:
neuron = 3571
for prompt in english_data[:50] + german_data[:50]:
    _, cache = model.run_with_cache(
        prompt, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    neuron_act = mid_acts[:, neuron]
    if neuron_act.max() > 1:
        str_tokens = model.to_str_tokens(model.to_tokens(prompt))
        haystack_utils.clean_print_strings_as_html(str_tokens, neuron_act.cpu().numpy(), max_value=5)

In [117]:
# Filter trigrams
data = []
for trigram in tqdm(trigrams + [" orschlägen"]):
    tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)
    original_metric, ablated_metric, context_and_activated_metric, only_activated_metric = haystack_utils.get_direct_effect(tokens, model, deactivate_neurons_fwd_hooks, context_activation_hooks=[], pos=-1)
    data.append([trigram, original_metric.mean(), ablated_metric.mean(), context_and_activated_metric.mean(), only_activated_metric.mean()])

df = pd.DataFrame(data, columns=["trigram", "original", "ablated", "context_and_activated", "only_activated"])

print(df.head())

  0%|          | 0/236 [00:00<?, ?it/s]

In [None]:
df["mlp5_increase"] = df["only_activated"] - df["original"]

px.histogram(df["mlp5_increase"].values, title="Histogram of MLP5 increase for trigrams")

## Other models

In [None]:
d_in = model.cfg.d_model
expansion_factor = 4
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.01

our_autoencoder = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
our_autoencoder_filename = "pythia-160m/hook_mlp_out_l8.pt"
our_autoencoder.load_state_dict(torch.load(our_autoencoder_filename))
our_autoencoder.to(device)

with open("pythia-160m/hook_mlp_out_l8.pkl", "wb") as f:
    pickle.dump(our_autoencoder, f)

In [None]:
# Evaluate our dict
with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = our_autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

In [None]:
# Evaluate Logan's one
autoencoder2.to_device(device)
model.to(device)

with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        acts = autoencoder2.encode(value)
        out = autoencoder2.decode(acts)
        return out.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")