In [53]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder

%reload_ext autoreload
%autoreload 2

In [61]:
def pickle_pt(name: str, path: Path):
    autoencoder = torch.load(os.path.join(path, name + '.pt'))
    with open(os.path.join(path, name + '_2' + '.pkl'), 'wb') as f:
        pickle.dump(autoencoder, f)

pickle_pt(name='hook_mlp_out_l5_8k_0001_2e', path=Path('pythia-70m'))


In [62]:
german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")


model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


In [63]:
english_activations = {}
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
english_activations[LAYER_TO_ABLATE] = haystack_utils.get_mlp_activations(english_data[:100], LAYER_TO_ABLATE, model, mean=False)
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

  0%|          | 0/100 [00:00<?, ?it/s]

In [77]:
# Load 70m dict
layer = 5
act_name = "hook_mlp_out" #"mlp.hook_post"
if act_name == "hook_mlp_out":
    d_in = model.cfg.d_model #d_mlp
else:
    d_in = model.cfg.d_mlp
encoder_hook_point = f"blocks.{layer}.{act_name}"
expansion_factor = 4
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.001

autoencoder_70m = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
autoencoder_70m_filename = "pythia-70m/hook_mlp_out_l5_8k_0001_2e_2.pkl"
with open(autoencoder_70m_filename, 'rb') as f:
    autoencoder_70m_state_dict = pickle.load(f)
autoencoder_70m.load_state_dict(autoencoder_70m_state_dict)
autoencoder_70m.to(device)

AutoEncoder()

In [69]:
# Loss increase
def evaluate_dict(autoencoder: AutoEncoder, encoded_hook_name: str, german_data: list):
    def encode_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [(encoded_hook_name, encode_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

    print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

evaluate_dict(autoencoder_70m, encoder_hook_point, german_data=german_data)

  0%|          | 0/200 [00:00<?, ?it/s]

Average loss increase after encoding: 0.1797


In [71]:
active_features = torch.zeros(autoencoder_dim).bool().cuda()
for prompt in tqdm(german_data[:200]):
    _, cache = model.run_with_cache(
        prompt, names_filter=f"blocks.{layer}.{act_name}"
        )
    acts = cache[f"blocks.{layer}.{act_name}"].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    active_features = active_features | (mid_acts.sum(dim=0) > 0)
print(active_features.sum())

  0%|          | 0/200 [00:00<?, ?it/s]

tensor(1984, device='cuda:0')


In [72]:
def encoder_dla(tokens: Int[Tensor, "batch pos"], model: HookedTransformer, encoder: AutoEncoder) -> Float[Tensor, "pos n_neurons"]:
    
    _, cache = model.run_with_cache(tokens)
    mlp_activations = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1]
    _, _, mid_acts, _, _ = autoencoder_70m(mlp_activations)

    W_U_token = model.W_U[:, tokens.flatten()]
    W_out_U_token = model.W_out[layer] @ W_U_token # (n_mlp_neurons, n_tokens)
    W_dec_W_out_U_token = encoder.W_dec @ W_out_U_token # (n_encoder_neurons, n_tokens)
    
    #dla = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1] * W_out_U_token[:, 1:].T
    dla = mid_acts * W_dec_W_out_U_token[:, 1:].T
    scale = cache["ln_final.hook_scale"][0, :-1]
    dla = dla / scale
    return dla

# for prompt in tqdm(german_data[:200]):
#     tokens = model.to_tokens(prompt)
#     dla = encoder_dla(tokens, model, autoencoder_70m)
#     #fig = px.line(dla[-5].cpu().numpy())
#     #fig.show()
#     fig = px.histogram(dla.flatten().cpu().numpy(), nbins=100)
#     fig.update_layout(
#         # xaxis limit
#         yaxis=dict(
#             range=[0, 120],
#         ))
#     fig.show()
#     break

In [73]:
# Load bigrams
with open("../context_neuron/data/checkpoint/high_indirect_loss_trigrams.json", "r") as f:
    trigrams = json.load(f)

print(len(trigrams))

235


In [74]:
all_ignore, valid_tokens = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data[:200], model, all_ignore, k=100)

  0%|          | 0/200 [00:00<?, ?it/s]

In [82]:
def encoder_dla_batched(tokens: Int[Tensor, "batch pos"], model: HookedTransformer, encoder: AutoEncoder) -> Float[Tensor, "pos n_neurons"]:
    batch_dim, seq_len = tokens.shape
    _, cache = model.run_with_cache(tokens)
    mlp_activations = cache[encoder_hook_point][:, :-1]
    _, _, mid_acts, _, _ = autoencoder_70m(mlp_activations)
    W_U_token = einops.rearrange(model.W_U[:, tokens.flatten()], "d_res (batch pos) -> d_res batch pos", batch=batch_dim, pos=seq_len)
    if act_name == "mlp.hook_post":
        W_out_U_token = einops.einsum(model.W_out[layer], W_U_token, "d_mlp d_res, d_res batch pos -> d_mlp batch pos")
        W_dec_W_out_U_token = einops.einsum(encoder.W_dec, W_out_U_token, "d_dec d_mlp, d_mlp batch pos -> d_dec batch pos")
        dla = einops.einsum(mid_acts, W_dec_W_out_U_token[:, :, 1:], "batch pos d_dec, d_dec batch pos -> batch pos d_dec")
    elif act_name == "hook_mlp_out":
        W_dec_U_token = einops.einsum(encoder.W_dec, W_U_token, "d_dec d_res, d_res batch pos -> d_dec batch pos")
        dla = einops.einsum(mid_acts, W_dec_U_token[:, :, 1:], "batch pos d_dec, d_dec batch pos -> batch pos d_dec")
    else:
        raise ValueError("Unknown act_name")
    scale = cache["ln_final.hook_scale"][:, :-1]
    dla = dla / scale
    return dla

In [83]:
trigram = " Vorschlägen"#trigrams[4]
tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)
dla = encoder_dla_batched(tokens, model, autoencoder_70m)[:, -1].mean(0)
px.line(dla.cpu().numpy(), title=f"Average autoencoder DLA for '{trigram}' (100 samples)")

In [86]:
# Investigate voschlägen neuron
encoder_neuron = 558 # 1752
correct_token = model.to_single_token("gen")
incorrect_token = model.to_single_token("ge")

import plotly.graph_objects as go
# Which tokens is it boosting
if act_name == "mlp.hook_post":
    boosts = (autoencoder_70m.W_dec[encoder_neuron] @ model.W_out[layer]) @ model.W_U
else:
    boosts = autoencoder_70m.W_dec[encoder_neuron] @ model.W_U
top_boosts, top_tokens = torch.topk(boosts[valid_tokens.cpu()], 25)
print(model.to_str_tokens(top_tokens))
print(boosts[correct_token], boosts[incorrect_token])
fig = px.histogram(boosts[valid_tokens].cpu().numpy(), nbins=100, title="Histogram of N1752 token-wise DLA", histnorm="probability")
fig.update_layout(
    showlegend=False,
)
fig.add_shape(
    go.layout.Shape(
        type="line",
        x0=boosts[correct_token].item(),
        x1=boosts[correct_token].item(),
        y0=0,
        y1=1,
        yref="paper",
        line=dict(color="Red")
    )
)

fig.add_annotation(
    x=boosts[correct_token].item()+0.03,
    y=0.95,
    yref="paper",
    text="gen",
    showarrow=False,
    arrowhead=7,
    ax=0,
    ay=-40
)

[' nicer', ' her', ' leaving', ' tables', ' tab', 'oud', ' online', 'aign', '=""', ' par', ' Python', ' injured', 'ogn', ' injury', ' France', ' woods', ' Impact', ' hem', ' Kentucky', 'TAIN', 'OO', ' ID', ' eth', ' Divine', ' ophthal']
tensor(0.5374, device='cuda:0') tensor(0.4440, device='cuda:0')


In [87]:
# Which dataset examples active the direction

for prompt in german_data[:15]:
    _, cache = model.run_with_cache(
        prompt, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    neuron_act = mid_acts[:, encoder_neuron]
    if neuron_act.max() > 0:
        str_tokens = model.to_str_tokens(model.to_tokens(prompt))
        haystack_utils.clean_print_strings_as_html(str_tokens, neuron_act.cpu().numpy(), max_value=20)

In [91]:
# What happens when we ablate the context neuron

trigram = " Vorschlägen"#trigrams[4]
tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)

loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
print("Mean loss", loss.item())
_, cache = model.run_with_cache(
    tokens, names_filter=encoder_hook_point
    )
acts = cache[encoder_hook_point]
loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)

mean_active = mid_acts[:, -2, encoder_neuron].mean().item()
print("Mean feature activation inactive", mean_active)

with model.hooks(deactivate_neurons_fwd_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss", loss.item())
    _, cache = model.run_with_cache(
        tokens, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point]
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)

    

    mean_inactive = mid_acts[:, -2, encoder_neuron].mean().item()
    print("Mean feature activation inactive", mean_inactive)


Mean loss 1.225911021232605
Mean feature activation inactive 6.047533988952637
Mean loss 3.191844940185547
Mean feature activation inactive 5.4116621017456055


In [92]:
trigram = " Vorschlägen"#trigrams[4]
tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)

_, cache = model.run_with_cache(
    tokens, names_filter=encoder_hook_point
    )
acts_active = cache[encoder_hook_point][:, -2]

with model.hooks(deactivate_neurons_fwd_hooks):
    _, cache = model.run_with_cache(
        tokens, names_filter=encoder_hook_point
        )
    acts_inactive = cache[encoder_hook_point][:, -2]

def activate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = autoencoder_70m(acts_active)
    value[:, -2] = x_reconstruct
    return value

def deactivate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = autoencoder_70m(acts_inactive)
    value[:, -2] = x_reconstruct
    return value

activate_hooks = [(encoder_hook_point, activate_feature_hook)]
deactivate_hooks = [(encoder_hook_point, deactivate_feature_hook)]

with model.hooks(activate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss", loss.item())

with model.hooks(deactivate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss", loss.item())

Mean loss 2.3075780868530273
Mean loss 4.91679048538208


In [93]:
def custom_forward(enc: AutoEncoder, x: Float[Tensor, "batch d_in"], neuron: int, activation: float):
    x_cent = x - enc.b_dec
    acts = F.relu(x_cent @ enc.W_enc + enc.b_enc)
    acts[:, neuron] = activation
    x_reconstruct = acts @ enc.W_dec + enc.b_dec
    l2_loss = (x_reconstruct - x).pow(2).sum(-1).mean(0)
    l1_loss = enc.l1_coeff * (acts.abs().sum())
    loss = l2_loss + l1_loss
    return loss, x_reconstruct, acts, l2_loss, l1_loss


def activate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, mean_active)
    value[:, -2] = x_reconstruct
    return value

def deactivate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, mean_inactive)
    value[:, -2] = x_reconstruct
    return value

activate_hooks = [(encoder_hook_point, activate_feature_hook)]
deactivate_hooks = [(encoder_hook_point, deactivate_feature_hook)]

with model.hooks(activate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss active", loss.item())

with model.hooks(deactivate_hooks):
    loss = model(tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss inactive", loss.item())

Mean loss active 2.301732063293457
Mean loss inactive 2.37518572807312


In [96]:
# Check what feature is looking at
feature_W_enc = autoencoder_70m.W_enc[:, encoder_neuron]
if act_name == "mlp.hook_post":
    feature_W_enc = model.W_in[layer] @ feature_W_enc
else:
    feature_W_enc = model.W_in[layer] @ (model.W_out[layer] @ feature_W_enc)
context_W_out = model.W_out[3, 669]
sim = torch.cosine_similarity(feature_W_enc, context_W_out, dim=0)
print(sim)

tensor(0.0322, device='cuda:0')


In [97]:

answer_residual_directions = einops.repeat(feature_W_enc, "d_res -> repeat d_res", repeat=100)  # [batch pos d_model]
print(answer_residual_directions.shape, tokens.shape)
_, cache = model.run_with_cache(tokens)
accumulated_residual, labels = cache.decompose_resid(layer=5, pos_slice=-2, return_labels=True)
scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=5, pos_slice=-2)
logit_attribution = einops.einsum(scaled_residual_stack, answer_residual_directions, "component batch d_model, batch d_model -> batch component")

logit_attribution = logit_attribution.mean(0).cpu().numpy()
index = list(range(len(labels)))

fig = px.line(x=index, y=logit_attribution, title='Vorschlägen autoencoder direction DLA', width=1000)
fig.update_xaxes(title='Index', tickmode='array', tickvals=list(range(len(labels))), ticktext=labels)
fig.update_yaxes(title='Logit Attribution')
fig.show()

torch.Size([100, 512]) torch.Size([100, 24])


In [None]:
# Train directions for earlier MLPs, check if specific directions activate it

## Other models

In [None]:
d_in = model.cfg.d_model
expansion_factor = 4
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.01

our_autoencoder = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
our_autoencoder_filename = "pythia-160m/hook_mlp_out_l8.pt"
our_autoencoder.load_state_dict(torch.load(our_autoencoder_filename))
our_autoencoder.to(device)

with open("pythia-160m/hook_mlp_out_l8.pkl", "wb") as f:
    pickle.dump(our_autoencoder, f)

In [None]:
# Evaluate our dict
with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = our_autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

In [None]:
# Evaluate Logan's one
autoencoder2.to_device(device)
model.to(device)

with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        acts = autoencoder2.encode(value)
        out = autoencoder2.decode(acts)
        return out.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")