## Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json
from collections import Counter


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")

english_activations = {}
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
english_activations[LAYER_TO_ABLATE] = haystack_utils.get_mlp_activations(english_data[:100], LAYER_TO_ABLATE, model, mean=False)
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

# Load bigrams
with open("./data/low_indirect_loss_trigrams.json", "r") as f:
    trigrams = json.load(f)

all_ignore, valid_tokens = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data[:200], model, all_ignore, k=100)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [3]:
save_name = "33_dandy_silence"
model_name = 'pythia-70m'
path = Path('pythia-70m')

with open(f"{model_name}/{save_name}.json", "r") as f:
    cfg = json.load(f)

In [4]:
# Load 70m dict
layer = cfg["layer"]
act_name =  cfg["act"] #"hook_mlp_out", "mlp.hook_post"
expansion_factor = cfg["expansion_factor"]
l1_coeff = cfg["l1_coeff"]

if act_name == "hook_mlp_out":
    d_in = model.cfg.d_model #d_mlp
else:
    d_in = model.cfg.d_mlp
encoder_hook_point = f"blocks.{layer}.{act_name}"
autoencoder_dim = d_in * expansion_factor

def pickle_pt(name: str, path: Path):
    autoencoder = torch.load(os.path.join(path, name + '.pt'))
    with open(os.path.join(path, name + '.pkl'), 'wb') as f:
        pickle.dump(autoencoder, f)

pickle_pt(name=save_name, path=path)

autoencoder_70m = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
autoencoder_70m_filename = os.path.join(path, save_name + '.pkl')
with open(autoencoder_70m_filename, 'rb') as f:
    autoencoder_70m_state_dict = pickle.load(f)
autoencoder_70m.load_state_dict(autoencoder_70m_state_dict)
autoencoder_70m.to(device)


AutoEncoder()

## Analysis

In [5]:
# Loss increase
def evaluate_dict(autoencoder: AutoEncoder, encoded_hook_name: str, german_data: list):
    def encode_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [(encoded_hook_name, encode_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

    print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

evaluate_dict(autoencoder_70m, encoder_hook_point, german_data=german_data)

  0%|          | 0/200 [00:00<?, ?it/s]

Average loss increase after encoding: 0.2470


In [6]:
num_feature_activations = torch.zeros(autoencoder_dim).to(device)
mean_active = []
total_tokens = 0
for prompt in tqdm(german_data[:200]):
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(
        tokens, names_filter=f"blocks.{layer}.{act_name}"
        )
    acts = cache[f"blocks.{layer}.{act_name}"].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    num_feature_activations = num_feature_activations + (mid_acts>0).sum(dim=0)
    active_features = (mid_acts > 0).sum(dim=1).float().mean(dim=0).item()
    mean_active.append(active_features)
    total_tokens += torch.numel(tokens)

active_features = (num_feature_activations > 0).sum().item()
feature_frequencies = num_feature_activations / total_tokens
print(f"Number of active features over {total_tokens} tokens: {active_features}")
print(f"Number of average active features per token: {np.mean(mean_active):.2f}")
fig = px.histogram(feature_frequencies.cpu().numpy(), histnorm='probability', log_y=True, title="Histogram of feature frequencies", nbins=40)
fig.update_layout(xaxis_title="Feature frequency", yaxis_title="Probability", showlegend=False, width=600)

  0%|          | 0/200 [00:00<?, ?it/s]

Number of active features over 105423 tokens: 13043
Number of average active features per token: 54.32


In [7]:
def encoder_dla(tokens: Int[Tensor, "batch pos"], model: HookedTransformer, encoder: AutoEncoder) -> Float[Tensor, "pos n_neurons"]:
    
    _, cache = model.run_with_cache(tokens)
    mlp_activations = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1]
    _, _, mid_acts, _, _ = autoencoder_70m(mlp_activations)

    W_U_token = model.W_U[:, tokens.flatten()]
    W_out_U_token = model.W_out[layer] @ W_U_token # (n_mlp_neurons, n_tokens)
    W_dec_W_out_U_token = encoder.W_dec @ W_out_U_token # (n_encoder_neurons, n_tokens)
    
    #dla = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1] * W_out_U_token[:, 1:].T
    dla = mid_acts * W_dec_W_out_U_token[:, 1:].T
    scale = cache["ln_final.hook_scale"][0, :-1]
    dla = dla / scale
    return dla

# for prompt in tqdm(german_data[:200]):
#     tokens = model.to_tokens(prompt)
#     dla = encoder_dla(tokens, model, autoencoder_70m)
#     #fig = px.line(dla[-5].cpu().numpy())
#     #fig.show()
#     fig = px.histogram(dla.flatten().cpu().numpy(), nbins=100)
#     fig.update_layout(
#         # xaxis limit
#         yaxis=dict(
#             range=[0, 120],
#         ))
#     fig.show()
#     break

In [8]:
def encoder_dla_batched(tokens: Int[Tensor, "batch pos"], model: HookedTransformer, encoder: AutoEncoder) -> Float[Tensor, "pos n_neurons"]:
    batch_dim, seq_len = tokens.shape
    _, cache = model.run_with_cache(tokens)
    mlp_activations = cache[encoder_hook_point][:, :-1]
    _, _, mid_acts, _, _ = autoencoder_70m(mlp_activations)
    
    W_U_token = einops.rearrange(model.W_U[:, tokens.flatten()], "d_res (batch pos) -> d_res batch pos", batch=batch_dim, pos=seq_len)
    if act_name == "mlp.hook_post":
        W_out_U_token = einops.einsum(model.W_out[layer], W_U_token, "d_mlp d_res, d_res batch pos -> d_mlp batch pos")
        W_dec_W_out_U_token = einops.einsum(encoder.W_dec, W_out_U_token, "d_dec d_mlp, d_mlp batch pos -> d_dec batch pos")
        dla = einops.einsum(mid_acts, W_dec_W_out_U_token[:, :, 1:], "batch pos d_dec, d_dec batch pos -> batch pos d_dec")
    elif act_name == "hook_mlp_out":
        W_dec_U_token = einops.einsum(encoder.W_dec, W_U_token, "d_dec d_res, d_res batch pos -> d_dec batch pos")
        dla = einops.einsum(mid_acts, W_dec_U_token[:, :, 1:], "batch pos d_dec, d_dec batch pos -> batch pos d_dec")
    else:
        raise ValueError("Unknown act_name")
    scale = cache["ln_final.hook_scale"][:, :-1]
    dla = dla / scale
    return dla

In [9]:
trigram = trigrams[3]
last_trigram_token = model.to_str_tokens(model.to_tokens(trigram))[-1]
middle_trigram_token = model.to_str_tokens(model.to_tokens(trigram))[-2]
trigram_tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)
dla = encoder_dla_batched(trigram_tokens, model, autoencoder_70m)[:, -1].mean(0)
px.line(dla.cpu().numpy(), title=f"Average autoencoder DLA for '{trigram}' (100 samples)")

In [54]:
def get_directions_from_dla(dla: Float[Tensor, "n_neurons"], cutoff_dla=0.2, max_directions=3):
    top_dla, top_neurons = torch.topk(dla, max_directions, largest=True)
    directions = []
    for i in range(max_directions):
        if top_dla[i] > cutoff_dla:
            directions.append(top_neurons[i].item())
    return directions

directions = get_directions_from_dla(dla, cutoff_dla=0.2, max_directions=3)

tensor([0.8761, 0.5668, 0.2734], device='cuda:0') tensor([3669,  555, 3747], device='cuda:0')
[3669, 555, 3747]


In [55]:
def get_trigram_token_dla(model: HookedTransformer, encoder_neuron: int, trigram: str, cfg: dict):
    last_trigram_token = model.to_str_tokens(model.to_tokens(trigram))[-1]
    correct_token = model.to_single_token(last_trigram_token)

    if cfg["act"] == "mlp.hook_post":
        boosts = (autoencoder_70m.W_dec[encoder_neuron] @ model.W_out[layer]) @ model.W_U
    else:
        boosts = autoencoder_70m.W_dec[encoder_neuron] @ model.W_U
    trigram_token_dla = boosts[correct_token].item()
    print(f"'{last_trigram_token}' DLA = {trigram_token_dla:.2f}")
    return trigram_token_dla

get_trigram_token_dla(model, autoencoder_70m, directions[0], trigram, cfg)

'ruck' DLA = 0.93


0.9255110025405884

In [None]:
def get_trigram_dataset_examples(model: HookedTransformer, trigram: str, german_data: list[str], cfg: dict):
    # Boosted tokens with reasonable logprob
    middle_trigram_token = model.to_str_tokens(model.to_tokens(trigram))[-2]
    token = model.to_single_token(middle_trigram_token)
    prompt_length = 50
    token_prompts = []
    occurrences = []
    for prompt in german_data[:200]:
        tokenized_prompt = model.to_tokens(prompt).flatten()
        token_indices = torch.where(tokenized_prompt == token)
        if len(token_indices[0]) > 0:
            for i in token_indices[0]:
                if i > prompt_length:
                    new_prompt = tokenized_prompt[i-prompt_length:i+1]
                    occurrences.append("".join(model.to_str_tokens(tokenized_prompt[i-1:i+2])))
                    token_prompts.append(new_prompt)
    token_prompts = torch.stack(token_prompts)
    print(f"Found {len(token_prompts)} prompts with token '{middle_trigram_token}'")
    print(Counter(occurrences))

## Manual investigation


In [39]:
# Investigate voschlägen neuron
encoder_neuron = 3669
correct_token = model.to_single_token(last_trigram_token)
#incorrect_token = model.to_single_token("ge")

import plotly.graph_objects as go
# Which tokens is it boosting
if act_name == "mlp.hook_post":
    boosts = (autoencoder_70m.W_dec[encoder_neuron] @ model.W_out[layer]) @ model.W_U
else:
    boosts = autoencoder_70m.W_dec[encoder_neuron] @ model.W_U
top_boosts, top_tokens = torch.topk(boosts[valid_tokens.cpu()], 15)
top_deboosts, top_deboosted_tokens = torch.topk(boosts[valid_tokens.cpu()], 15, largest=False)
print("Boosted", model.to_str_tokens(top_tokens))
print("Deboosted", model.to_str_tokens(top_deboosted_tokens))
print(f"'{last_trigram_token}' DLA = {boosts[correct_token].item():.2f}")
fig = px.histogram(boosts[valid_tokens].cpu().numpy(), nbins=100, title=f"Histogram of N{encoder_neuron} token-wise DLA", histnorm="probability")
fig.update_layout(
    showlegend=False,
)
fig.add_shape(
    go.layout.Shape(
        type="line",
        x0=boosts[correct_token].item(),
        x1=boosts[correct_token].item(),
        y0=0,
        y1=1,
        yref="paper",
        line=dict(color="Red")
    )
)

fig.add_annotation(
    x=boosts[correct_token].item()+0.03,
    y=0.95,
    yref="paper",
    text=last_trigram_token,
    showarrow=False,
    arrowhead=7,
    ax=0,
    ay=-40
)

Boosted ['KP', 'claimer', 'emic', ' discrep', 'EV', 'ial', 'Т', ' Can', ' FOR', ' Mass', '\x19', 'lation', 'CC', ' Heck', ' Synthesis']
Deboosted [' weeks', ' 63', ' subgroups', ' �', ' sedan', ' love', '}({{\\', '//', ' iOS', '}}({{\\', 'node', ' lectures', ' cz', 'Hig', 'City']
'ruck' DLA = 0.93


In [40]:
# Boosted tokens with reasonable logprob
token = model.to_single_token(middle_trigram_token)
prompt_length = 50
token_prompts = []
occurrences = []
for prompt in german_data[:200]:
    tokenized_prompt = model.to_tokens(prompt).flatten()
    token_indices = torch.where(tokenized_prompt == token)
    if len(token_indices[0]) > 0:
        for i in token_indices[0]:
            if i > prompt_length:
                new_prompt = tokenized_prompt[i-prompt_length:i+1]
                occurrences.append("".join(model.to_str_tokens(tokenized_prompt[i-1:i+2])))
                token_prompts.append(new_prompt)
token_prompts = torch.stack(token_prompts)
print(f"Found {len(token_prompts)} prompts with token '{middle_trigram_token}'")
print(Counter(occurrences))

Found 56 prompts with token 'd'
Counter({' Ausdruck': 13, ' Generaldire': 3, 'ichend zu': 2, 'ivdire': 2, ' Mrd.': 2, 'ädte': 2, 'bald die': 2, 'üdfr': 2, 'ichend,': 1, 'städten': 1, 'ichend.': 1, 'anzdien': 1, 'ichend von': 1, ' rund um': 1, ' Held des': 1, ' Verdacht': 1, 'bald ih': 1, ' rund 10': 1, ' Zeitdruck': 1, 'üd-': 1, 'obald die': 1, 'bald dieser': 1, 'entsdien': 1, 'atsdien': 1, 'umdung': 1, 'chend vol': 1, ' überdacht': 1, 'üdwest': 1, 'umdungen': 1, ' sudanes': 1, 'ckdien': 1, 'ikdien': 1, ' Ausdru': 1, 'ädten': 1, ' UdSS': 1, ' rund 20': 1})


In [41]:
# Compute dataset example boosts and deboosts of the direction

def zero_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -1], encoder_neuron, 0)
    value[:, -1] = x_reconstruct
    return value

def encode_activations_hook(value, hook):
    global feature_activations
    _, x_reconstruct, acts, _, _ = autoencoder_70m(value[:, -1])
    feature_activations = acts[:, encoder_neuron]
    value[:, -1] = x_reconstruct
    return value

feature_activations = []

with model.hooks([(encoder_hook_point, encode_activations_hook)]): 
    logits_active = model(trigram_tokens[:, :-1], return_type="logits")[:, -1]

with model.hooks([(encoder_hook_point, zero_feature_hook)]):
    logits_inactive = model(trigram_tokens[:, :-1], return_type="logits")[:, -1]

last_trigram_token_tokenized = model.to_single_token(last_trigram_token)
print(f"Logit '{last_trigram_token}' boosts {logits_active[:, last_trigram_token_tokenized].mean(0).item()}, {logits_inactive[:, last_trigram_token_tokenized].mean(0).item():.2f}, {(logits_active[:, last_trigram_token_tokenized] - logits_inactive[:, last_trigram_token_tokenized]).mean(0).item():.2f}")
average_logit_boosts = (logits_active - logits_inactive).mean()
print(f"Average logit boost: {average_logit_boosts:.2f}")

logprobs_active = logits_active.log_softmax(dim=-1)
logprobs_inactive = logits_inactive.log_softmax(dim=-1)
print(f"Logprob '{last_trigram_token}' boosts {logprobs_active[:, last_trigram_token_tokenized].mean(0).item()}, {logprobs_inactive[:, last_trigram_token_tokenized].mean(0).item():.2f}, {(logprobs_active[:, last_trigram_token_tokenized] - logprobs_inactive[:, last_trigram_token_tokenized]).mean(0).item():.2f}")

def get_boosts(boosts: Float[Tensor, "d_vocab"], logprobs_active: Float[Tensor, "d_vocab"]):
    boosts[logprobs_active < -7] = 0
    boosts[all_ignore] = 0
    top_boosts, top_tokens = torch.topk(boosts, 15)
    non_zero_boosts = top_boosts != 0
    top_deboosts, top_deboosted_tokens = torch.topk(boosts, 15, largest=False)
    non_zero_deboosts = top_deboosts != 0
    print("Boosted", model.to_str_tokens(top_tokens[non_zero_boosts]), top_boosts[non_zero_boosts].tolist())
    print("Deboosted", model.to_str_tokens(top_deboosted_tokens[non_zero_deboosts]), top_deboosts[non_zero_deboosts].tolist())

mean_feature_activation = feature_activations.mean().item()
print(f"Mean feature activation: {mean_feature_activation:.2f}")
mean_boosts = (logprobs_active- logprobs_inactive).mean(0)
get_boosts(mean_boosts, logprobs_active.mean(0))

Logit 'ruck' boosts 25.69709014892578, 23.93, 1.76
Average logit boost: -0.00
Logprob 'ruck' boosts -2.1953587532043457, -3.04, 0.84
Mean feature activation: 1.44
Boosted ['ruck', 'aw', 'ien', 'abe', 'rogen', 'au', 'auer', 'ah', 'urch', 'rei', 'iente', 'ok', 'ahl', 'ä', 'ire'] [0.8433094024658203, 0.7798275351524353, 0.7554044127464294, 0.7294979691505432, 0.4687522351741791, 0.4328324794769287, 0.4035177230834961, 0.3722052574157715, 0.3679516017436981, 0.3139191269874573, 0.21151041984558105, 0.18118643760681152, 0.15839944779872894, 0.1319742351770401, 0.11004642397165298]
Deboosted ['ö', 'w', 'ies', 'rit', 'acht', 'amp', 'ru', 'az', 'är', 'ahn', 'ank', 'icht', 'ogen', 're', 'ring'] [-0.787299633026123, -0.6536190509796143, -0.6360180377960205, -0.48028892278671265, -0.4679807126522064, -0.41598203778266907, -0.35135534405708313, -0.3473084270954132, -0.27936992049217224, -0.2411508560180664, -0.22906802594661713, -0.22370198369026184, -0.11480723321437836, -0.06346326321363449, -0.

In [42]:
# Which dataset examples active the direction

for prompt in german_data[:10]:
    _, cache = model.run_with_cache(
        prompt, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    neuron_act = mid_acts[:, encoder_neuron]
    if neuron_act.max() > 0:
        str_tokens = model.to_str_tokens(model.to_tokens(prompt))
        haystack_utils.clean_print_strings_as_html(str_tokens, neuron_act.cpu().numpy(), max_value=mean_feature_activation)

In [48]:
# What happens when we ablate the context neuron

context_active_loss = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
_, cache = model.run_with_cache(
    trigram_tokens, names_filter=encoder_hook_point
    )
acts = cache[encoder_hook_point]
_, _, mid_acts, _, _ = autoencoder_70m(acts)
feature_activation_context_active = mid_acts[:, -2, encoder_neuron].mean().item()

with model.hooks(deactivate_neurons_fwd_hooks):
    context_ablated_loss = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    
    _, cache = model.run_with_cache(
        trigram_tokens, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point]
    _, _, mid_acts, _, _ = autoencoder_70m(acts)
    feature_activation_context_inactive = mid_acts[:, -2, encoder_neuron].mean().item()

print(f"Mean loss context active: {context_active_loss.item():.2f}")
print(f"Mean loss context inactive: {context_ablated_loss.item():.2f}")

print(f"Mean feature activation when context neuron active: {feature_activation_context_active:.2f}")
print(f"Mean feature activation with context neuron inactive: {feature_activation_context_inactive:.2f}")

Mean loss context active: 0.95
Mean loss context inactive: 4.04
Mean feature activation when context neuron active: 1.44
Mean feature activation with context neuron inactive: 1.02


In [49]:
downstream_components = ("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")

original_metric, activated_metric, ablated_metric, direct_effect_metric, indirect_effect_metric = haystack_utils.get_context_effect(trigram_tokens, model, deactivate_neurons_fwd_hooks, context_activation_hooks=[], downstream_components=downstream_components, pos=-1)
print(f"Direct indirect loss split from context neuron: {original_metric.mean():.2f}, {ablated_metric.mean():.2f}, {direct_effect_metric.mean():.2f}, {indirect_effect_metric.mean():.2f}")

Direct indirect loss split from context neuron: 0.95, 4.04, 3.57, 1.25


In [51]:
# Set all features to context active or inactive
_, cache = model.run_with_cache(
    trigram_tokens, names_filter=encoder_hook_point
    )
acts_active = cache[encoder_hook_point][:, -2]

with model.hooks(deactivate_neurons_fwd_hooks):
    _, cache = model.run_with_cache(
        trigram_tokens, names_filter=encoder_hook_point
        )
    acts_inactive = cache[encoder_hook_point][:, -2]

def activate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = autoencoder_70m(acts_active)
    value[:, -2] = x_reconstruct
    return value

def deactivate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = autoencoder_70m(acts_inactive)
    value[:, -2] = x_reconstruct
    return value

activate_hooks = [(encoder_hook_point, activate_feature_hook)]
deactivate_hooks = [(encoder_hook_point, deactivate_feature_hook)]

with model.hooks(activate_hooks):
    encoder_context_active_loss = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print(f"Model loss when patching through encoder with context neuron active: {encoder_context_active_loss.item():.2f}")

with model.hooks(deactivate_hooks):
    encoder_context_inactive_loss = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print(f"Model loss when patching through encoder with context neuron inactive: {encoder_context_inactive_loss.item():.2f}")

print(f"Model loss when not patching through encoder: {context_active_loss.item():.2f}")

Model loss when patching through encoder with context neuron active: 2.20
Model loss when patching through encoder with context neuron inactive: 3.80
Model loss when not patching through encoder: 0.95


In [53]:
# Set single feature to context active or inactive
def activate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, feature_activation_context_active)
    value[:, -2] = x_reconstruct
    return value

def deactivate_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, feature_activation_context_inactive)
    value[:, -2] = x_reconstruct
    return value

def zero_feature_hook(value, hook):
    _, x_reconstruct, _, _, _ = custom_forward(autoencoder_70m, value[:, -2], encoder_neuron, 0)
    value[:, -2] = x_reconstruct
    return value

activate_hooks = [(encoder_hook_point, activate_feature_hook)]
deactivate_hooks = [(encoder_hook_point, deactivate_feature_hook)]
zero_ablate_hooks = [(encoder_hook_point, zero_feature_hook)]

with model.hooks(activate_hooks):
    loss_encoder_direction_active = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print(f"Mean loss when patching second trigram token through encoder: {loss_encoder_direction_active.item():.2f}")

with model.hooks(deactivate_hooks):
    loss_encoder_direction_inactive = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print(f"Mean loss when patching second trigram token through encoder and setting N{encoder_neuron} to activation with context neuron inactive: {loss_encoder_direction_inactive.item():.2f}")

with model.hooks(zero_ablate_hooks):
    loss_encoder_direction_zeroed = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print(f"Mean loss when patching second trigram token through encoder and setting N{encoder_neuron} to zero: {loss_encoder_direction_zeroed.item():.2f}")

Mean loss when patching second trigram token through encoder: 2.18
Mean loss when patching second trigram token through encoder and setting N3669 to activation with context neuron inactive: 2.40
Mean loss when patching second trigram token through encoder and setting N3669 to zero: 3.04


In [75]:
# 0 ablate MLP5

def zero_layer_hook(value, hook):
    value = 0
    return value

zero_layer_hooks = [("blocks.5.hook_mlp_out", zero_layer_hook)]

with model.hooks(zero_layer_hooks):
    loss = model(trigram_tokens, return_type="loss", loss_per_token=True)[:, -1].mean()
    print("Mean loss 0 ablated", loss.item())

Mean loss 0 ablated 11.549064636230469


In [21]:
# Cosine sim context neuron and feature
feature_W_enc = autoencoder_70m.W_enc[:, encoder_neuron]
if act_name == "mlp.hook_post":
    feature_W_enc = model.W_in[layer] @ feature_W_enc
else:
    feature_W_enc = model.W_in[layer] @ (model.W_out[layer] @ feature_W_enc)
context_W_out = model.W_out[3, 669]
sim = torch.cosine_similarity(feature_W_enc, context_W_out, dim=0)
print(sim)

tensor(0.0322, device='cuda:0')


In [97]:

answer_residual_directions = einops.repeat(feature_W_enc, "d_res -> repeat d_res", repeat=100)  # [batch pos d_model]
print(answer_residual_directions.shape, trigram_tokens.shape)
_, cache = model.run_with_cache(trigram_tokens)
accumulated_residual, labels = cache.decompose_resid(layer=5, pos_slice=-2, return_labels=True)
scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=5, pos_slice=-2)
logit_attribution = einops.einsum(scaled_residual_stack, answer_residual_directions, "component batch d_model, batch d_model -> batch component")

logit_attribution = logit_attribution.mean(0).cpu().numpy()
index = list(range(len(labels)))

fig = px.line(x=index, y=logit_attribution, title='Vorschlägen autoencoder direction DLA', width=1000)
fig.update_xaxes(title='Index', tickmode='array', tickvals=list(range(len(labels))), ticktext=labels)
fig.update_yaxes(title='Logit Attribution')
fig.show()

torch.Size([100, 512]) torch.Size([100, 24])


In [None]:
# Train directions for earlier MLPs, check if specific directions activate it

In [104]:
# Decoder - context cosine sims
decoder_W_enc = autoencoder_70m.W_enc
context_neuron_W_out = model.W_out[3, 669].unsqueeze(1)
print(decoder_W_enc.shape, context_neuron_W_out.shape)

sims = torch.cosine_similarity(decoder_W_enc, context_neuron_W_out, dim=0)

px.histogram(sims.flatten().cpu().numpy())

torch.Size([512, 4096]) torch.Size([512, 1])


In [105]:
top_sims, top_directions = torch.topk(sims.flatten(), 10)
print(top_sims, top_directions)

tensor([0.1962, 0.1779, 0.1678, 0.1653, 0.1565, 0.1457, 0.1455, 0.1412, 0.1374,
        0.1368], device='cuda:0') tensor([3571, 3433, 1700, 2572, 1724,  801, 3932, 4076, 1915, 1484],
       device='cuda:0')


In [None]:
neuron = 3571
for prompt in english_data[:50] + german_data[:50]:
    _, cache = model.run_with_cache(
        prompt, names_filter=encoder_hook_point
        )
    acts = cache[encoder_hook_point].squeeze(0)
    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = autoencoder_70m(acts)
    neuron_act = mid_acts[:, neuron]
    if neuron_act.max() > 1:
        str_tokens = model.to_str_tokens(model.to_tokens(prompt))
        haystack_utils.clean_print_strings_as_html(str_tokens, neuron_act.cpu().numpy(), max_value=5)

In [138]:
# Filter trigrams
data = []
deactivated_components=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.5.hook_mlp_out")
activated_components=("blocks.4.hook_mlp_out",)
for trigram in tqdm(trigrams + [" orschlägen"]):
    tokens = haystack_utils.generate_random_prompts(trigram, model, common_tokens, n=100, length=20)
    original_metric, ablated_metric, context_and_activated_metric, only_activated_metric = haystack_utils.get_direct_effect(tokens, model, deactivate_neurons_fwd_hooks, context_activation_hooks=[], pos=-1, deactivated_components=deactivated_components, activated_components=activated_components)
    data.append([trigram, original_metric.mean().item(), ablated_metric.mean().item(), context_and_activated_metric.mean().item(), only_activated_metric.mean().item()])

df = pd.DataFrame(data, columns=["trigram", "original", "ablated", "context_and_activated", "only_activated"])

print(df.head())

  0%|          | 0/236 [00:00<?, ?it/s]

      trigram  original   ablated  context_and_activated  only_activated
0  en verwend  5.282837  7.305348               6.604833        8.580179
1  ptabel und  3.210173  4.516782               2.676022        3.690480
2   re Beacht  4.353924  6.772032               4.668176        6.773994
3   empfunden  3.853431  4.803860               2.983654        4.742427
4    Amtszeit  1.497416  3.458482               1.766111        3.050396


In [139]:
df["mlp_decrease"] = df["ablated"] - df["only_activated"]
df["mlp_increase"] = df["only_activated"] - df["original"]

px.histogram(df["mlp_decrease"].values, title="Histogram of MLP increase for trigrams")

In [140]:
df.sort_values("mlp_increase", ascending=True).head(10)

Unnamed: 0,trigram,original,ablated,context_and_activated,only_activated,mlp_decrease,mlp_increase
188,cheid zu,8.696993,8.001113,6.523531,8.190983,-0.18987,-0.50601
175,ine freunds,4.701824,3.820682,3.006701,4.294682,-0.474,-0.407142
123,em expandieren,3.54892,4.613147,1.950871,3.30122,1.311928,-0.2477
94,Sie bestät,5.739153,6.686899,3.11712,5.500078,1.186821,-0.239075
226,ig wird und,3.931733,4.232167,2.579332,3.820715,0.411452,-0.111018
206,Vertrags,0.750227,1.434404,0.458119,0.718266,0.716138,-0.031961
8,vorsät,5.445683,5.942938,3.9992,5.466028,0.476911,0.020344
227,immt werden,4.973125,5.730371,3.88962,5.140984,0.589387,0.167858
190,untersucht,2.458399,4.28398,1.453115,2.637133,1.646847,0.178734
116,finden werden,5.961165,6.343742,4.586891,6.148615,0.195127,0.18745


## Other models

In [None]:
d_in = model.cfg.d_model
expansion_factor = 4
autoencoder_dim = d_in * expansion_factor
l1_coeff = 0.01

our_autoencoder = AutoEncoder(autoencoder_dim, l1_coeff, d_in)
our_autoencoder_filename = "pythia-160m/hook_mlp_out_l8.pt"
our_autoencoder.load_state_dict(torch.load(our_autoencoder_filename))
our_autoencoder.to(device)

with open("pythia-160m/hook_mlp_out_l8.pkl", "wb") as f:
    pickle.dump(our_autoencoder, f)

In [None]:
# Evaluate our dict
with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        _, x_reconstruct, _, _, _ = our_autoencoder(value)
        return x_reconstruct.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")

In [None]:
# Evaluate Logan's one
autoencoder2.to_device(device)
model.to(device)

with torch.no_grad():
    def encode_mlp_activations_hook(value, hook):
        value = value.squeeze(0)
        acts = autoencoder2.encode(value)
        out = autoencoder2.decode(acts)
        return out.unsqueeze(0)

    hooks = [("blocks.8.hook_mlp_out", encode_mlp_activations_hook)]

    original_losses = []
    reconstruct_losses = []
    for prompt in tqdm(german_data[:200]):
        original_loss = model(prompt, return_type="loss")
        with model.hooks(hooks):
            reconstruct_loss = model(prompt, return_type="loss")
        original_losses.append(original_loss.item())
        reconstruct_losses.append(reconstruct_loss.item())

print(f"Average loss increase after encoding: {(np.mean(reconstruct_losses) - np.mean(original_losses)):.4f}")