# Mini-Sprint: SAEs for Semantics of Copy-Suppression

In [None]:
import torch
import wandb

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner
from sae_training.utils import LMSparseAutoencoderSessionloader

# run = wandb.init()
# artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v36', type='model')
# artifact_dir = artifact.download()


path ="artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v36/final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)



In [None]:
from sae_training.train_sae_on_language_model import get_recons_loss

score, loss, recons_loss, zero_abl_loss = get_recons_loss(
    sparse_autoencoder, model, activations_loader, num_batches=10,
)

print(score, loss, recons_loss, zero_abl_loss)

In [None]:
from transformer_lens import utils
import torch
import einops
import plotly.express as px

import pandas as pd
from functools import partial
import numpy as np
from jaxtyping import Float
from transformer_lens import ActivationCache
import torch.nn.functional as F


from transformer_lens import utils

example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
device = "mps"
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(answers)

In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)
original_logits, cache = model.run_with_cache(tokens)


def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

per_prompt_original_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True).detach()


def replacement_hook(resid, hook, pos, encoder):
    resid[:,pos] = encoder(resid[:,pos])[0]
    return resid


model.reset_hooks()
hook_fn = partial(replacement_hook, encoder=sparse_autoencoder, pos = -1)
model.add_hook(utils.get_act_name("resid_pre", 10), hook_fn, prepend=True)
hook_fn = partial(replacement_hook, encoder=sparse_autoencoder, pos = 4)
model.add_hook(utils.get_act_name("resid_pre", 10), hook_fn, prepend=True)
tokens = model.to_tokens(prompts, prepend_bos=True)
new_logits, new_cache = model.run_with_cache(tokens)
model.reset_hooks()

per_prompt_new_logit_diff = logits_to_ave_logit_diff(new_logits, answer_tokens, per_prompt=True).detach()

px.bar(
    y = per_prompt_original_logit_diff.cpu().numpy() - per_prompt_new_logit_diff.cpu().numpy(),
    labels={"x": "Original Logit Difference", "y": "New Logit Difference"},
)

In [None]:
# So basically the reconstructi

In [None]:
from circuitsvis.attention import attention_patterns

patterns = cache["blocks.10.attn.hook_pattern"][0]
attention_patterns(tokens=model.to_str_tokens(prompts[0]), attention=patterns)

In [None]:
patterns = new_cache["blocks.10.attn.hook_pattern"][0]
attention_patterns(tokens=model.to_str_tokens(prompts[0]), attention=patterns)

Drill down on this sentence "When John and Mary went to the shops, John gave the bag to".
- The correct answer is Mary. 
- L9H9 contributes directly to this. 
- L10H7 suppresses Mary (via copy-suppression)
  - L10H7 Attends to the previous instance of Mary.
  - And inhibits mary.


L10H7 attends to Mary from -1 to 4 before inhibiting Mary. 



In [None]:

answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)



def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)


from fancy_einsum import einsum

accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)


logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
fig = px.line(
    y=logit_lens_logit_diffs.detach().cpu().numpy(),
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)
fig.update_layout(width=800, height=400)
fig.show()

per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
fig = px.imshow(
    per_head_logit_diffs.detach().cpu().numpy(),
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
)
fig.update_layout(width=800, height=400)
fig.show()


accumulated_residual, labels = new_cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)


logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, new_cache)
fig = px.line(
    y=logit_lens_logit_diffs.detach().cpu().numpy(),
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)
fig.update_layout(width=800, height=400)
fig.show()

per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, new_cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
fig = px.imshow(
    per_head_logit_diffs.detach().cpu().numpy(),
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
)
fig.update_layout(width=800, height=400)
fig.show()

## Which features causes L10H7 to attend to Mary?

Hypothesis 1: Features in resid_pre 10 that unembed as Mary will also be causally responsible for attention to Mary.
Experiment: Rank Features by how much the decoder direction activates the Mary token and compare these to the contribution they make to the attn.

In [None]:
original_logits, cache = model.run_with_cache(tokens)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(cache["blocks.10.hook_resid_pre"][0])
print(mse_loss)

(feature_acts > 0).float().sum(1).mean()

In [None]:
sae_qk_circuit = sparse_autoencoder.W_dec @ model.blocks[10].attn.QK[7] @ sparse_autoencoder.W_dec.T
sae_qk_circuit = sae_qk_circuit.A @ sae_qk_circuit.B
sae_qk_circuit = sae_qk_circuit.detach().cpu().numpy()
print(sae_qk_circuit.shape)

In [None]:
px.histogram(sae_qk_circuit.flatten()[:10000]) # not very sparse!

In [None]:
sparse_autoencoder.W_enc.norm(dim=1).shape

In [None]:
sparse_autoencoder.W_dec.norm(dim=1).mean()

In [None]:
example = 0
key_resid_pre_feature_acts = sparse_autoencoder(cache["blocks.10.hook_resid_pre"][example,4])[1]
query_resid_pre_feature_acts = sparse_autoencoder(cache["blocks.10.hook_resid_pre"][example,-1])[1]
sae_qk_circuit_instance = (key_resid_pre_feature_acts)[:, None] * sparse_autoencoder.W_dec * sparse_autoencoder.W_dec @ model.blocks[10].attn.QK[7] @ ((query_resid_pre_feature_acts)[:, None] * sparse_autoencoder.W_dec).T
(sae_qk_circuit_instance.AB > 0).sum()

values, indices = torch.topk(sae_qk_circuit_instance.AB.detach().flatten(),25)
d_enc = sparse_autoencoder.cfg.d_sae
start_topk_ind = (indices // d_enc)
end_topk_ind = (indices % d_enc)


In [None]:
df = pd.DataFrame({"start": start_topk_ind.detach().cpu(), "end": end_topk_ind.detach().cpu(), "value": values.detach().cpu()})
df.start = df.start.map(lambda x: f"start_{x}")
df.end = df.end.map(lambda x: f"end_{x}")
px.bar(df, y="value", color="value", hover_data=["start", "end"], color_continuous_midpoint=0, color_continuous_scale="RdBu").show()

# Create the pivoted DataFrame
pivot_df = df.pivot(index="start", columns="end", values="value").fillna(0)

# Use the index and columns of the pivoted DataFrame for tick labels
px.imshow(pivot_df.values,
          labels={"x": "End", "y": "Start"},
          x=pivot_df.columns,
          y=pivot_df.index,
          title="SAE QK Circuit",
          color_continuous_scale="RdBu",
          color_continuous_midpoint=0
         ).show()

In [None]:
def plot_feature_embed_bar(feature_id, sparse_autoencoder, feature_name = ""):

    norm_embed = model.W_E / model.W_E.norm(dim=1)[:, None]
    feature_unembed = sparse_autoencoder.W_dec[feature_id] @ norm_embed.T
    # torch.topk(unembed_4795,10)

    feature_unembed_df = pd.DataFrame(
        feature_unembed.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name],
                color_continuous_midpoint=0,
                 color_continuous_scale="RdBu",)

    fig.update_layout(
        width=500,
        height=600,
    )

    fig.show()
    
def plot_feature_unembed_bar(feature_id, sparse_autoencoder, feature_name = ""):
    
    norm_unembed = model.W_U / model.W_U.norm(dim=0)[None: None]
    feature_unembed = sparse_autoencoder.W_dec[feature_id] @ norm_unembed
    # torch.topk(unembed_4795,10)

    feature_unembed_df = pd.DataFrame(
        feature_unembed.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
                 color_continuous_midpoint=0,
                 color_continuous_scale="RdBu",
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name])

    fig.update_layout(
        width=500,
        height=600,
    )

    fig.show()


plot_feature_unembed_bar(18032, sparse_autoencoder, "feature_4795_embed")

In [None]:


plot_feature_embed_bar(1, sparse_autoencoder, feature_name = "resid stream change")

In [None]:


# topk by change in resid stream (downward)
plot_feature_unembed_bar(1, sparse_autoencoder, feature_name = "resid stream change")
plot_feature_unembed_bar(2, sparse_autoencoder, feature_name = "resid stream change")

In [None]:
cache.apply_ln_to_stack(x_reconstruct[0],layer=10).mean()

In [None]:

attn_scores_original = cache["blocks.10.attn.hook_attn_scores"][0,7].detach().cpu()
attn_scores_new = new_cache["blocks.10.attn.hook_attn_scores"][0,7].detach().cpu()
px.imshow(
    torch.stack([attn_scores_original, attn_scores_new]),
    facet_col = 0,
    color_continuous_midpoint=0, 
    color_continuous_scale="RdBu"
)

In [None]:
px.scatter(x = attn_scores_original.flatten(), y = attn_scores_new.flatten())

In [None]:
px.imshow(sae_qk_circuit[:1000,:1000], title="SAE QK Circuit", color_continuous_midpoint=0, color_continuous_scale="RdBu")

In [None]:
px.imshow(sae_qk_circuit.cpu().numpy(), title="SAE QK Circuit")

In [None]:
L10H7_W_K = model.W_K[10,7]
L10H7_W_Q = model.W_K[10,7]
print(L10H7_W_K.shape)
print(L10H7_W_Q.shape)
reconstructed_keys = x_reconstruct @ L10H7_W_K
reconstructed_queries = x_reconstruct @ L10H7_W_Q
print(reconstructed_keys.shape)
print(reconstructed_queries.shape)

In [None]:
model

In [None]:
cache.keys()

In [None]:
loss, x_reconstruct, acts, l2_loss, l1_loss = encoder_resid_pre_10(cache["blocks.10.hook_resid_pre"])
_, _, acts_mid, _, _ = encoder_resid_pre_10(cache["blocks.10.hook_resid_mid"])
acts = acts.detach().cpu()
acts_mid = acts_mid.detach().cpu()
fired_in_acts_and_not_in_acts_mid = (acts[0,-1] > 0) & (acts_mid[0,-1] == 0)
fired_in_acts_and_not_in_acts_mid.squeeze()

In [None]:
px.imshow((acts[:,1:]>0).sum(-1).float().detach().cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu").show()
px.imshow((acts_mid[:,1:]>0).sum(-1).float().detach().cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu").show()

In [None]:
pred_john_features = (acts[0,-1] > 0).nonzero()
pred_mary_features = (acts[1,-1] > 0).nonzero()
common_features = set(pred_john_features.flatten().tolist()).intersection(set(pred_mary_features.flatten().tolist()))
print(len(common_features))

In [None]:
girls_names = [" Mary", " Magdalene", " Maria", " Marie", " Marian", " Marianne", " Mara", " Maura", " Marla", " Marta", " Maire", " Maree"]
boys_names = [" John", " Joseph", " Jon", " Jonathan", " Johan", " Johannes", " Sean", " Shaun", " Shane", " Ivan", " Jan"]


boys_name_tokens = []
for name in boys_names:
    # get the token without bost
    john_tokens = model.to_tokens(name, prepend_bos=False)[0]
    print(name, john_tokens, john_tokens.shape)
    boys_name_tokens.append(john_tokens[0])
boys_name_tokens = torch.stack(boys_name_tokens, dim=0).to(device)
print(boys_name_tokens)

girls_names_tokens = []
for name in girls_names:
    # get the token without bost
    mary_tokens = model.to_tokens(name, prepend_bos=False)[0]
    print(name, mary_tokens, mary_tokens.shape)
    girls_names_tokens.append(mary_tokens[0])
girls_names_tokens = torch.stack(girls_names_tokens, dim=0).to(device)
print(girls_names_tokens)
print(model.tokenizer.decode(boys_name_tokens))
print(model.tokenizer.decode(girls_names_tokens))

In [None]:
feature_proj = encoder_resid_pre_10.W_dec @ model.W_U[:, boys_name_tokens]
feature_proj = feature_proj.detach().cpu().numpy()

feature_proj_df = pd.DataFrame(feature_proj, columns = boys_name_tokens.detach().cpu().numpy())
feature_proj_df_long = feature_proj_df.reset_index().melt(id_vars='index')
feature_proj_df_long.columns = ['feature', 'name', 'value']
# make feature categorical
feature_proj_df_long.feature = feature_proj_df_long.feature.astype(str)
feature_proj_df_long.name = feature_proj_df_long.name.apply(lambda x: model.tokenizer.decode(x))
# px.strip(feature_proj_df_long, x= "name", y = "value", color = "name", hover_data=["feature"]).show()

# most negative change resid pre to resid mid 4795, 3501, 5337, 920  
# most positive change resid pre to resid mid 5680, 4470, 1057, 369
# px.bar(feature_proj_df_long[feature_proj_df_long.feature.map(int).isin([4795, 3501, 5337, 920])], y = "value", facet_col= "feature", x = "name", hover_data=["feature"]).show()
# px.bar(feature_proj_df_long[feature_proj_df_long.feature.map(int).isin([5680, 4470, 1057, 369])], y = "value", facet_col= "feature", x = "name", hover_data=["feature"]).show()

resid_stream_change = pd.Series(acts[0,-1] - acts_mid[0,-1])
resid_stream_change.index = resid_stream_change.index.astype(str)
feature_proj_df_long = feature_proj_df_long.merge(resid_stream_change.rename('change'), left_on='feature', right_index=True)


fig = px.scatter(
    feature_proj_df_long,
    x = "value",
    y  = "change",
    facet_col = "name",
    labels = {'value': 'feature projection onto {name}',
              'change': 'Residual stream change'},
    hover_data=["feature"])
fig.show()

In [None]:
feature_proj = encoder_resid_pre_10.W_dec @ model.W_U[:, girls_names_tokens]
feature_proj = feature_proj.detach().cpu().numpy()

feature_proj_df = pd.DataFrame(feature_proj, columns = girls_names_tokens.detach().cpu().numpy())
feature_proj_df_long = feature_proj_df.reset_index().melt(id_vars='index')
feature_proj_df_long.columns = ['feature', 'name', 'value']
# make feature categorical
feature_proj_df_long.feature = feature_proj_df_long.feature.astype(str)
feature_proj_df_long.name = feature_proj_df_long.name.apply(lambda x: model.tokenizer.decode(x))
# px.strip(feature_proj_df_long, x= "name", y = "value", color = "name", hover_data=["feature"]).show()

# most negative change resid pre to resid mid 4795, 3501, 5337, 920  
# most positive change resid pre to resid mid 5680, 4470, 1057, 369
# px.bar(feature_proj_df_long[feature_proj_df_long.feature.map(int).isin([4795, 3501, 5337, 920])], y = "value", facet_col= "feature", x = "name", hover_data=["feature"]).show()
# px.bar(feature_proj_df_long[feature_proj_df_long.feature.map(int).isin([5680, 4470, 1057, 369])], y = "value", facet_col= "feature", x = "name", hover_data=["feature"]).show()

resid_stream_change = pd.Series(acts[0,-1] - acts_mid[0,-1])
resid_stream_change.index = resid_stream_change.index.astype(str)
feature_proj_df_long = feature_proj_df_long.merge(resid_stream_change.rename('change'), left_on='feature', right_index=True)


fig = px.scatter(
    feature_proj_df_long,
    x = "value",
    y  = "change",
    facet_col = "name",
    labels = {'value': 'feature projection onto {name}',
              'change': 'Residual stream change'},
    hover_data=["feature"])
fig.show()

In [None]:
values = torch.topk(acts[0,-1] - acts_mid[0,-1], k = 10, largest=False)
indices = values.indices

px.bar(
    x=values.values.detach().cpu().numpy() * -1,
    text=indices.detach().cpu().numpy(),
    orientation="h",
)

In [None]:
def plot_feature_unembed_bar(feature_id, feature_name):

    feature_unembed = encoder_resid_pre_10.W_dec[feature_id] @ model.W_U
    # torch.topk(unembed_4795,10)

    feature_unembed_df = pd.DataFrame(
        feature_unembed.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name])

    fig.update_layout(
        width=500,
        height=600,
    )

    fig.show()


# topk by change in resid stream (downward)
plot_feature_unembed_bar(4795, 'Bible Feature')
# plot_feature_unembed_bar(3501, '??')
# plot_feature_unembed_bar(920, "'their' family feature")
# plot_feature_unembed_bar(5337, "??")
# plot_feature_unembed_bar(4793, 'American Last names?')
# plot_feature_unembed_bar(3213, 'second token in multitoken word')

# plot_feature_unembed_bar(2434, 'Boy Names Feature')
# plot_feature_unembed_bar(4793, 'American Last names?')
# # plot_feature_unembed_bar(3213, 'second token in multitoken word')
# plot_feature_unembed_bar(4774, 'second token in multitoken word')
# plot_feature_unembed_bar(3732, 'Israel/Palestine Feature')
# plot_feature_unembed_bar(358, 'Them Feature')

# plot_feature_unembed_bar(2440, 'Names, most often women, capitalized.')

In [None]:
cache["blocks.10.attn.hook_z"][0, -1].shape

In [None]:
model.W_O[10].shape

In [None]:

L10_resid_stream_contribution = einsum("head d_head,  head d_head d_model -> head d_model", cache["blocks.10.attn.hook_z"][0, -1], model.W_O[10]).detach()
L10_resid_stream_contribution.shape

In [None]:
L10H7_resid_stream_contribution.shape

In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)

loss, x_reconstruct, acts, l2_loss, l1_loss = encoder_resid_pre_10(cache["blocks.10.hook_resid_pre"])
_, _, acts_mid, _, _ = encoder_resid_pre_10(cache["blocks.10.hook_resid_mid"])
acts = acts.detach().cpu()
acts_mid = acts_mid.detach().cpu()
fired_in_acts_and_not_in_acts_mid = (acts[0,-1] > 0) & (acts_mid[0,-1] == 0)
fired_in_acts_and_not_in_acts_mid.squeeze()
difference_in_acts_and_acts_mid = acts[0,-1] - acts_mid[0,-1]


L10H7_resid_stream_contribution = cache["blocks.10.attn.hook_z"][0,-1,7] @ model.W_O[10,7] @ encoder_resid_pre_10.W_dec.T
feature_contributions_L10H7 = einsum("head d_head,  head d_head d_model -> head d_model", cache["blocks.10.attn.hook_z"][0, -1], model.W_O[10]).detach().detach().cpu()
mary_unembed_scores = encoder_resid_pre_10.W_dec @ model.W_U[:, 5335]

df = pd.DataFrame(
    {
        'mary_unembed_scores': mary_unembed_scores.detach().cpu(),
        'fired_in_acts_and_not_in_acts_mid': fired_in_acts_and_not_in_acts_mid.detach().cpu(),
        'difference_in_acts_and_acts_mid': difference_in_acts_and_acts_mid.detach().cpu(),
        '_L10H7_resid_stream_contribution': L10H7_resid_stream_contribution.detach().cpu()
    }

)
df.reset_index(inplace=True)
df['FeatureId']=df['index']

px.scatter(
    df, x = "_L10H7_resid_stream_contribution", y = "mary_unembed_scores", color = "difference_in_acts_and_acts_mid", hover_data=["FeatureId"],
    color_continuous_scale="RdBu", color_continuous_midpoint=0)



In [None]:
feature_contributions_to_mary = L10H7_resid_stream_contribution.cpu()* mary_unembed_scores.cpu()
torch.topk(feature_contributions_to_mary,10, largest=False)

In [None]:

all_head_proj = feature_contributions_L10H7 @ encoder_resid_pre_10.W_dec.T.cpu()
df2 = pd.DataFrame(all_head_proj.detach().T, columns = [f'L10H{h}' for h in range(12)])
df = df.merge(df2, left_index=True, right_index=True)
long_df = df.melt(id_vars=['FeatureId','mary_unembed_scores', 'fired_in_acts_and_not_in_acts_mid', 'difference_in_acts_and_acts_mid'],
                  value_vars=[col for col in df.columns if col.startswith('L10H')],
                  var_name='Head',
                  value_name='Score')
long_df = long_df[long_df.difference_in_acts_and_acts_mid != 0]
long_df.head()
# df_long = pd.wide_to_long(df,stubnames = 'L10', i=['FeatureId', 'mary_unembed_scores', 'fired_in_acts_and_not_in_acts_mid', 'difference_in_acts_and_acts_mid'],  j='value')
# df_long.head()#.head()


fig = px.scatter( long_df.sort_values(["Head","difference_in_acts_and_acts_mid"]), x = 'Score', y = 'mary_unembed_scores', color ='difference_in_acts_and_acts_mid',
                 animation_frame="Head",
                #  facet_col="Head",
                #  facet_col_wrap=4,
                 color_continuous_scale='RdBu', color_continuous_midpoint=0,
                 hover_data=['FeatureId'],
                #  template='plotly_dark',
)
fig.update_layout(
    width=1000,
    height=800,
)
fig.show()
# px.histogram(feature_contributions_L10H7.detach().cpu()).show()
# torch.topk(feature_contributions_L10H7, 20, largest=False)
# head_contributions = pd.Series(feature_contributions_L10H7.detach().cpu())
# head_contributions.index = head_contributions.index.astype(str)
# # feature_proj_df_long = feature_proj_df_long.merge(resid_stream_change.rename('change'), left_on='feature', right_index=True)
# head_contributions.head()

In [None]:
decomp, labels = cache.get_full_resid_decomposition(expand_neurons=False, return_labels=True) 
mary_proj = decomp[:,0,-1] @ model.W_U[:, 5335]
pd.DataFrame(mary_proj.detach().cpu().numpy(), index = labels).sort_values(0, ascending=True).head(20) # this is where we let the neuron fire, lets replace it. 
pd.DataFrame(mary_proj.detach().cpu().numpy(), index = labels).sort_values(0, ascending=True).tail(20) 

In [None]:
acts.shape

In [None]:
 encoder_resid_pre_10.W_dec[5335].shape

In [None]:
feature = 5335
tokens = model.to_tokens(prompts, prepend_bos=True)
replacement_for_feature = torch.zeros_like(tokens)
replacement_for_feature[0, -1] = acts[:, 0, -1, feature]
out_diff = replacement_for_feature[:, :, None] * encoder_resid_pre_10.W_dec[5335]

In [None]:

def remove_bible_feature(mlp_out, hook):
    mlp_out[:, :] -= mlp_out0_diff
    return mlp_out
    
model.reset_hooks()
model.blocks[0].hook_mlp_out.add_hook(remove_bible_feature)
_, new_cache = model.run_with_cache(prompts, stop_at_layer=2, names_filter=lambda x: "mlp_out" in x)

with torch.no_grad():
    loss, x_reconstruct, acts, l2_loss, l1_loss = encoder_resid_pre_10(cache["blocks.10.hook_resid_pre"])

model.reset_hooks()