In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
import pickle
import os
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
import json
from collections import Counter
import logging

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.WARNING, datefmt='%I:%M:%S')
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import sys
sys.path.append('../')  # Add the parent directory to the system path
import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
import utils.autoencoder_utils as autils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")

english_activations = {}
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
english_activations[LAYER_TO_ABLATE] = haystack_utils.get_mlp_activations(
    english_data[:100], LAYER_TO_ABLATE, model, mean=False
)
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][
    :, NEURONS_TO_ABLATE
].mean()


def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value


deactivate_neurons_fwd_hooks = [
    (f"blocks.{LAYER_TO_ABLATE}.mlp.hook_post", deactivate_neurons_hook)
]

# Load bigrams
with open("./data/low_indirect_loss_trigrams.json", "r") as f:
    trigrams = json.load(f)

all_ignore, valid_tokens = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(
    german_data[:200], model, all_ignore, k=100
)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [22]:
save_name = "33_dandy_silence"  # "25_gallant_monkey"
model_name = "pythia-70m"
path = Path("pythia-70m")

with open(f"{model_name}/{save_name}.json", "r") as f:
    cfg = json.load(f)

d_in = act_name_to_d_in(cfg['act'])

cfg = autils.AutoEncoderConfig(
    cfg["layer"], cfg["act"], cfg["expansion_factor"], cfg["l1_coeff"], d_in
)
cfg

AutoEncoderConfig(layer=5, act_name='mlp.hook_post', expansion_factor=8, l1_coeff=0.001)

In [23]:
d_hidden = cfg.d_in * cfg.expansion_factor
encoder = AutoEncoder(d_hidden, cfg.l1_coeff, cfg.d_in)
encoder.load_state_dict(torch.load(os.path.join(path, save_name + ".pt")))
encoder.to(device)

AutoEncoder()

In [24]:
autils.evaluate_autoencoder_reconstruction(encoder, cfg.encoder_hook_point, german_data[:200], model)

100%|██████████| 200/200 [00:12<00:00, 15.74it/s]


(3.0272637605667114, 3.274273495078087)

In [5]:
def eval_trigram_direction(trigram, random_prompts, encoder_neuron):
    correct_token_dla = autils.get_trigram_token_dla(
        model, encoder, encoder_neuron, trigram, cfg
    )
    (
        last_token_logit_encoded,
        last_token_logit_zeroed,
        last_token_logprob_encoded,
        last_token_logprob_zeroed,
        boosted_tokens,
        deboosted_tokens,
    ) = autils.get_direction_logit_and_logprob_boost(
        random_prompts,
        encoder,
        encoder_neuron,
        model,
        trigram,
        common_tokens,
        all_ignore,
        cfg,
    )
    # autils.print_direction_activations(german_data[:2], model, encoder, encoder_neuron, cfg)
    (
        context_active_loss,
        context_ablated_loss,
        feature_activation_context_active,
        feature_activation_context_inactive,
    ) = autils.get_context_effect_on_feature_activations(
        model,
        random_prompts,
        encoder,
        encoder_neuron,
        deactivate_neurons_fwd_hooks,
        cfg,
    )
    (
        encoder_context_active_loss,
        encoder_context_inactive_loss,
    ) = autils.get_encoder_token_reconstruction_losses(
        random_prompts, model, encoder, deactivate_neurons_fwd_hooks, cfg
    )
    (
        loss_encoder_direction_active,
        loss_encoder_direction_inactive,
        loss_encoder_direction_zeroed,
    ) = autils.get_encoder_feature_reconstruction_losses(
        random_prompts,
        encoder,
        model,
        encoder_neuron,
        feature_activation_context_active,
        feature_activation_context_inactive,
        cfg,
    )
    return {
        "Trigram": trigram,
        "Direction": encoder_neuron,
        "Correct token dla": correct_token_dla,
        "Last token logit encoded": last_token_logit_encoded,
        "Last token logit zeroed": last_token_logit_zeroed,
        "Last token logprob encoded": last_token_logprob_encoded,
        "Last token logprob zeroed": last_token_logprob_zeroed,
        "Context active loss": context_active_loss,
        "Context ablated loss": context_ablated_loss,
        "Encoder context active loss": encoder_context_active_loss.item(),
        "Encoder context inactive loss": encoder_context_inactive_loss.item(),
        "Encoder direction active loss": loss_encoder_direction_active.item(),
        "Encoder direction inactive loss": loss_encoder_direction_inactive.item(),
        "Encoder direction zeroed loss": loss_encoder_direction_zeroed.item(),
        "Direction activation context active": feature_activation_context_active,
        "Direction activation context inactive": feature_activation_context_inactive,
    }


def eval_trigram(trigram: str):
    # End in all three trigram tokens
    random_trigram_prompts = haystack_utils.generate_random_prompts(
        trigram, model, common_tokens, n=100, length=20
    )
    trigram_dla = autils.encoder_dla_batched(
        random_trigram_prompts, model, encoder, cfg
    )[:, -1].mean(0)
    encoder_neurons = autils.get_directions_from_dla(trigram_dla)
    # Don't contain last trigram token
    # dataset_trigram_prompts = autils.get_trigram_dataset_examples(model, trigram, german_data, max_prompts=100)
    data = []
    for encoder_neuron in encoder_neurons:
        res = eval_trigram_direction(trigram, random_trigram_prompts, encoder_neuron)
        data.append(res)
    return data


def eval_trigrams(trigrams: list[str]):
    data = []
    for trigram in tqdm(trigrams):
        data += eval_trigram(trigram)
    return data


data = eval_trigrams(trigrams)
df = pd.DataFrame(data)
df.to_csv(f"data/trigram_eval_{save_name}.csv")

# df = pd.read_csv(f"data/trigram_eval_{save_name}.csv")

  0%|          | 0/228 [00:00<?, ?it/s]

In [6]:
# Low loss increase from using encoder
df["Encoder loss increase"] = (
    df["Encoder direction active loss"] - df["Context active loss"]
).clip(lower=0)
# High loss increase when removing direction
df["Direction ablation loss increase"] = (
    df["Encoder direction inactive loss"] - np.maximum(df["Encoder direction active loss"], df["Context active loss"])
).clip(lower=0)

# Look for low decoder loss increase but high direction ablation loss increase
df["Interest"] = df["Direction ablation loss increase"] - df["Encoder loss increase"]

# High loss increase when removing direction
df["Direction zero ablation loss increase"] = (
    df["Encoder direction zeroed loss"] - np.maximum(df["Encoder direction active loss"], df["Context active loss"])
).clip(lower=0)

In [7]:
boosted_logporb = df["Last token logprob encoded"] > df["Last token logprob zeroed"]
boosted_logit = df["Last token logit encoded"] > df["Last token logit zeroed"]
pos_dla = df["Correct token dla"] > 0
filtered_df = df[boosted_logporb & boosted_logit & pos_dla].copy()
print(len(df), len(filtered_df), sum(boosted_logporb), sum(boosted_logit), sum(pos_dla))

267 244 248 261 267


In [18]:
average_trigram_loss_increase = (filtered_df["Encoder context active loss"] - filtered_df["Context active loss"]).mean()
average_trigram_loss_increase

2.2359863315693667

In [8]:
filtered_df.sort_values("Interest", ascending=False).head(20)

Unnamed: 0,Trigram,Direction,Correct token dla,Last token logit encoded,Last token logit zeroed,Last token logprob encoded,Last token logprob zeroed,Context active loss,Context ablated loss,Encoder context active loss,Encoder context inactive loss,Encoder direction active loss,Encoder direction inactive loss,Encoder direction zeroed loss,Direction activation context active,Direction activation context inactive,Encoder loss increase,Direction ablation loss increase,Interest,Direction zero ablation loss increase
112,ländliche,13162,0.13606,20.794052,16.936703,-2.764619,-4.716274,2.731569,4.778275,2.764603,3.191539,2.749074,3.249552,4.716255,8.016915,5.547169,0.017505,0.500479,0.482974,1.967182
242,ren besch,688,0.955422,24.012661,19.067404,-1.301255,-2.69925,1.222352,3.273979,1.301255,1.939751,1.298244,1.723461,2.699253,8.068145,4.557231,0.075892,0.425217,0.349325,1.401009
120,oft besch,688,0.955422,23.81509,19.102703,-1.701061,-3.116453,1.335787,3.935906,1.701062,2.755648,1.698619,2.200395,3.116461,7.626274,4.140427,0.362832,0.501777,0.138945,1.417842
108,im Zeichen,688,0.581965,27.42061,27.060787,-1.513794,-1.566893,2.153482,5.494951,1.513776,1.255562,1.512599,1.566598,1.566873,0.641201,0.003033,0.0,0.0,0.0,0.0
223,it gewes,14747,0.505645,23.69965,20.887188,-0.840129,-1.387269,1.675176,3.630307,0.840117,1.27633,0.837669,1.025429,1.387253,4.978337,2.830535,0.0,0.0,0.0,0.0
195,estern abend,14747,0.255666,23.408102,21.290066,-1.161237,-1.648828,1.513013,4.269652,1.161207,1.392663,1.157081,1.389264,1.648797,4.291055,1.91992,0.0,0.0,0.0,0.135784
181,dern wür,8837,0.130447,21.039242,18.953444,-2.089862,-2.560258,2.787728,4.814742,2.089807,1.810706,2.088469,2.176852,2.560201,6.857302,5.079157,0.0,0.0,0.0,0.0
180,dern wür,5594,0.459458,21.039242,18.232399,-2.089862,-2.718118,2.787728,4.814742,2.089807,1.810706,2.081852,2.088267,2.718057,8.805533,8.385005,0.0,0.0,0.0,0.0
224,it gewes,5594,0.359598,23.69965,22.724873,-0.840129,-0.863228,1.675176,3.630307,0.840117,1.27633,0.8338,0.833533,0.863216,2.95224,2.868789,0.0,0.0,0.0,0.0
179,cksichtig,15497,0.252909,21.418335,20.914911,-2.016298,-2.181461,2.730982,5.050739,2.016263,1.799329,2.01322,2.107349,2.18142,4.667688,1.636636,0.0,0.0,0.0,0.0


In [9]:
px.histogram(filtered_df, x="Correct token dla", width=800, title="Feature directions trigram token DLA", histnorm="probability")

In [10]:
px.histogram(filtered_df, x="Direction ablation loss increase", width=800, title="Loss increase when ablating feature directions to activation when ctxt neuron inactive", histnorm="probability")

In [11]:
px.histogram(filtered_df, x="Direction zero ablation loss increase", width=800, title="Loss increase when zero ablating feature directions", histnorm="probability")

In [12]:
px.histogram(filtered_df, x="Encoder loss increase", width=800, title="Loss increase when running model through encoder", histnorm="probability")

In [13]:
px.histogram(filtered_df, x=["Direction activation context active", "Direction activation context inactive"], barmode="group", width=1200, title="Trigram feature activation with and without ctxt neuron active", histnorm="probability")

In [14]:
px.histogram(filtered_df, x=["Context active loss", "Encoder context active loss"], barmode="group", width=1200, title="Trigram losses with and without using the encoder on MLP5", histnorm="probability")

In [16]:
df["Direction activation percent decrease"] = (df["Direction activation context active"] - df["Direction activation context inactive"]) / df["Direction activation context active"]

In [17]:
print(sum(df["Encoder loss increase"] < 0.02) / len(df))
print(sum(df["Direction activation percent decrease"] > 0.25) / len(df))
print(sum(df["Direction ablation loss increase"]>0.05) / len(df))

0.04119850187265917
0.7228464419475655
0.651685393258427


In [107]:
tmp_df = df[(df["Direction activation percent decrease"] > 0.25) & (df["Encoder loss increase"] < 0.02) & (df["Direction ablation loss increase"]>0.02)]
tmp_df

Unnamed: 0,Trigram,Direction,Correct token dla,Last token logit encoded,Last token logit zeroed,Last token logprob encoded,Last token logprob zeroed,Context active loss,Context ablated loss,Encoder context active loss,...,Encoder direction active loss,Encoder direction inactive loss,Encoder direction zeroed loss,Direction activation context active,Direction activation context inactive,Encoder loss increase,Direction ablation loss increase,Interest,Direction zero ablation loss increase,Direction activation percent decrease
147,erwartet,519,0.904442,27.498091,26.287764,-0.949611,-1.277169,1.001729,3.322693,0.949634,...,0.950618,1.023155,1.277128,1.351653,0.99957,0.0,0.021426,0.021426,0.2754,0.260484
188,innergeme,3197,0.832652,21.623207,20.882915,-2.951739,-3.448537,3.080269,5.2395,2.951724,...,2.924867,3.393865,3.448493,0.550125,0.054774,0.0,0.313596,0.313596,0.368224,0.900434
291,berschrift,1926,0.61724,28.175371,27.935961,-2.009027,-2.11517,2.024949,5.178033,2.009142,...,1.993715,2.114429,2.115256,0.509732,0.00349,0.0,0.08948,0.08948,0.090307,0.993154
442,wei Abge,1926,0.671799,23.793924,23.582016,-2.537454,-2.560196,2.528239,5.055126,2.537446,...,2.537069,2.56004,2.560199,0.568847,0.00316,0.008831,0.022971,0.01414,0.02313,0.994445


In [108]:
melt_df = tmp_df.melt(id_vars=["Trigram", "Direction"], value_vars=["Direction activation context active", "Direction activation context inactive", "Context active loss", "Encoder context active loss", "Encoder direction inactive loss", "Context ablated loss", ], var_name="Type", value_name="Value")
px.bar(melt_df, x="Type", y="Value", color="Trigram", barmode="group", width=1200, height=800, title="Trigram feature activation with and without ctxt neuron active")