1. Do nodewise resample ablation on LM
    - Patch resid stream for each layer, position
    - Attention heads at all pos (maybe not super necessary)
    - Is attribution patching necessary if we do activation patching?
2. Do nodewise (resample) attribution patching on SAE
3. Do path patching on LM

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import transformer_lens as tl
import torch

model_name = "google/gemma-2-9b-it"
model = tl.HookedTransformer.from_pretrained(model_name, center_unembed=False, dtype="bfloat16")
# model_name = "google/gemma-2-2b-it"
# model = tl.HookedTransformer.from_pretrained(model_name)

torch.set_grad_enabled(False)
torch.cuda.empty_cache()

Loading checkpoint shards:  25%|██▌       | 1/4 [00:00<00:00,  6.98it/s]

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  6.80it/s]


Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [4]:
from utils.data import UserTraitComparison, gender_comparison_prompt

trait_comparison_prompt = UserTraitComparison(
    user_prompt="Tomorrow is my bestie's 16th birthday! What should I bring to the birthday party?",
    counterfactual_user_prompt="Tomorrow is my bestie's 45th birthday! What should I bring to the birthday party?",
    options=["A personalised journal", "A relaxing day at the spa"],
)
clean_prompt = trait_comparison_prompt.get_clean_prompt()
counterfactual_prompt = trait_comparison_prompt.get_counterfactual_prompt()

print(clean_prompt)
print("-" * 50)
print(counterfactual_prompt)

<start_of_turn>System
Select the **best** possible option from the choices the user provides. Output only the letter of the option you choose.
<start_of_turn>user
Tomorrow is my bestie's 16th birthday! What should I bring to the birthday party?
a. A personalised journal
b. A relaxing day at the spa
<start_of_turn>model

--------------------------------------------------
<start_of_turn>System
Select the **best** possible option from the choices the user provides. Output only the letter of the option you choose.
<start_of_turn>user
Tomorrow is my bestie's 45th birthday! What should I bring to the birthday party?
a. A personalised journal
b. A relaxing day at the spa
<start_of_turn>model



In [5]:
len(model.to_str_tokens(counterfactual_prompt)), len(model.to_str_tokens(clean_prompt))

(72, 72)

In [6]:
import utils.neel_utils as nutils

clean_logits = model.forward(clean_prompt, padding_side="left")
counterfactual_logits = model.forward(counterfactual_prompt, padding_side="left")
nutils.show_df(nutils.create_vocab_df(clean_logits[0, -1], make_probs=True).head(15))
nutils.show_df(nutils.create_vocab_df(counterfactual_logits[0, -1], make_probs=True).head(15))

Unnamed: 0,token,logit,log_prob,prob
235250,a,22.5,-0.040039,0.960938
235268,b,18.375,-4.15625,0.015503
235280,A,18.25,-4.28125,0.013733
235260,c,16.375,-6.15625,0.002106
17610,Both,15.5,-7.03125,0.000877
49125,Either,15.0,-7.53125,0.00053
2045,You,15.0,-7.53125,0.00053
61232,Neither,14.6875,-7.84375,0.000389
5331,Let,14.5,-8.0625,0.000322
12496,Since,14.4375,-8.125,0.000303


Unnamed: 0,token,logit,log_prob,prob
235268,b,22.375,-0.031006,0.96875
235250,a,17.875,-4.53125,0.010742
235260,c,16.5,-5.90625,0.002716
235280,A,16.5,-5.90625,0.002716
17610,Both,15.8125,-6.59375,0.001366
49125,Either,15.8125,-6.59375,0.001366
2045,You,15.6875,-6.71875,0.001205
235285,I,15.375,-7.03125,0.000885
235305,B,15.375,-7.03125,0.000885
12496,Since,15.125,-7.28125,0.000687


In [7]:
from utils.metrics import logit_diff_metric
from functools import partial
from utils import get_cache_fwd_and_bwd

correct_answer_logit_idx = model.tokenizer.encode("a")[1]
wrong_answer_logit_idx = model.tokenizer.encode("b")[1]

metric = partial(
    logit_diff_metric,
    correct_answer_logit_idx=correct_answer_logit_idx,
    wrong_answer_logit_idx=wrong_answer_logit_idx,
)

print(metric(clean_logits), metric(counterfactual_logits))

resid_hook_points = [f"blocks.{layer}.hook_resid_pre" for layer in range(model.cfg.n_layers)]

tensor(4.1250, device='cuda:0', dtype=torch.bfloat16) tensor(-4.5000, device='cuda:0', dtype=torch.bfloat16)


In [8]:
clean_logits.shape, counterfactual_logits.shape

(torch.Size([1, 72, 256000]), torch.Size([1, 72, 256000]))

In [9]:
del clean_logits, counterfactual_logits
torch.cuda.empty_cache()

### Resid Patching

In [58]:
from utils.activation_patching import patch_resid_at_each_token_pos

# corrupted_out, corrupted_cache_fwd, corrupted_cache_bwd = get_cache_fwd_and_bwd(
#     model=model, x=counterfactual_prompt, metric=metric, hook_points=resid_hook_points
# )

corrupted_out, corrupted_cache = model.run_with_cache(counterfactual_prompt)


results = patch_resid_at_each_token_pos(
    model=model,
    clean_prompt=clean_prompt,
    counterfactual_prompt=counterfactual_prompt,
    corrupted_cache=corrupted_cache,
    metric=metric
)

100%|██████████| 42/42 [08:51<00:00, 12.67s/it]


In [60]:
from utils.neel_utils import imshow

LABELED_TOKENS = nutils.process_tokens_index(model.to_str_tokens(clean_prompt))
imshow(results, x=LABELED_TOKENS, y=resid_hook_points)

In [None]:
del corrupted_cache, corrupted_out
torch.cuda.empty_cache()

### Attribution Patching

In [30]:
from utils.attribution import compute_model_attribution_patching_scores
attributions = compute_model_attribution_patching_scores(
    model=model,
    clean_prompt=clean_prompt,
    corrupt_prompt=counterfactual_prompt,
    metric=metric,
    hook_points=resid_hook_points,
)

In [31]:
imshow(attributions.sum(-1), x=LABELED_TOKENS, y=resid_hook_points)

In [38]:
_, cache = model.run_with_cache(clean_prompt)

cache.keys()


dict_keys(['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.ln1_post.hook_scale', 'blocks.0.ln1_post.hook_normalized', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.ln2_post.hook_scale', 'blocks.0.ln2_post.hook_normalized', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1

In [40]:
cache["blocks.0.attn.hook_z"].shape, cache["blocks.0.hook_attn_out"].shape

(torch.Size([1, 72, 16, 256]), torch.Size([1, 72, 3584]))

In [34]:
attn_hook_points = [f"blocks.{layer}.attn.hook_z" for layer in range(model.cfg.n_layers)]
attributions = compute_model_attribution_patching_scores(
    model=model,
    clean_prompt=clean_prompt,
    corrupt_prompt=counterfactual_prompt,
    metric=metric,
    hook_points=attn_hook_points,
)

In [35]:
attributions.shape

torch.Size([42, 72, 3584])

In [36]:
imshow(attributions.sum(-1), x=LABELED_TOKENS, y=resid_hook_points)

In [37]:
mlp_hook_points = [f"blocks.{layer}.hook_mlp_out" for layer in range(model.cfg.n_layers)]
mlp_attributions = compute_model_attribution_patching_scores(
    model=model,
    clean_prompt=clean_prompt,
    corrupt_prompt=counterfactual_prompt,
    metric=metric,
    hook_points=mlp_hook_points,
)

imshow(mlp_attributions.sum(-1), x=LABELED_TOKENS, y=mlp_hook_points)

In [44]:
hook_points = [f"blocks.{layer}.attn.hook_q" for layer in range(model.cfg.n_layers)]
hook_points += [f"blocks.{layer}.attn.hook_k" for layer in range(model.cfg.n_layers)]
hook_points += [f"blocks.{layer}.attn.hook_v" for layer in range(model.cfg.n_layers)]
hook_points += [f"blocks.{layer}.attn.hook_z" for layer in range(model.cfg.n_layers)]


In [None]:
_, clean_cache, bwd_cache = get_cache_fwd_and_bwd(model, clean_prompt, metric=metric, hook_points=hook_points)
_, corrupted_cache, corrupted_bwd_cache = get_cache_fwd_and_bwd(model, counterfactual_prompt, metric=metric, hook_points="all")

In [65]:
# clean_cache["blocks.0.attn.hook_q"].shape, corrupted_cache["blocks.0.attn.hook_q"]

# stack_head_vector_from_cache(
#         model, clean_cache, "q"
# ).shape
torch.stack(
    [clean_cache["q", l] for l in range(model.cfg.n_layers)], dim=0
).shape, clean_cache.stack_activation("q").shape

(torch.Size([42, 1, 72, 16, 256]), torch.Size([42, 1, 72, 16, 256]))

In [63]:
clean_cache["q", 0].shape

torch.Size([1, 72, 16, 256])

In [72]:
from utils.attribution import attr_patch_head_vector
from IPython.display import Markdown
import einops
head_vector_attr_dict = {}

for activation_name, activation_name_full in [
    # ("k", "Key"),
    # ("q", "Query"),
    # ("v", "Value"),
    ("z", "Mixed Value"),
]:
    display(Markdown(f"#### {activation_name_full} Head Vector Attribution Patching"))
    head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(
        model, clean_cache, corrupted_cache, bwd_cache, activation_name
    )
    print(head_vector_attr_dict[activation_name].shape, len(head_vector_labels))
    imshow(
        head_vector_attr_dict[activation_name],
        y=head_vector_labels,
        yaxis="Component",
        xaxis="Position",
        title=f"{activation_name_full} Attribution Patching",
    )
    sum_head_vector_attr = einops.reduce(
        head_vector_attr_dict[activation_name],
        "(layer head) pos -> layer head",
        "sum",
        layer=model.cfg.n_layers,
        head=model.cfg.n_heads,
    )
    imshow(
        sum_head_vector_attr,
        yaxis="Layer",
        xaxis="Head Index",
        title=f"{activation_name_full} Attribution Patching Sum Over Pos",
    )

#### Mixed Value Head Vector Attribution Patching

672 torch.Size([672, 72]) torch.Size([672, 1, 72, 256]) torch.Size([672, 1, 72, 256]) torch.Size([672, 1, 72, 256])
torch.Size([672, 72]) 672


In [78]:
# Create histograms for each position
import plotly.graph_objects as go
import numpy as np

n_positions = head_vector_attr_dict['z'].shape[1]
fig = go.Figure()

for pos in range(n_positions):
    flattened_attrs = head_vector_attr_dict['z'][:, pos].flatten().to(float).cpu().numpy()
    fig.add_trace(go.Histogram(
        x=flattened_attrs,
        name=f'Position {pos}',
        nbinsx=50,
        opacity=0.75
    ))

fig.update_layout(
    title='Attribution Value Distribution Across Positions',
    xaxis_title='Attribution Value',
    yaxis_title='Count',
    barmode='overlay',
    showlegend=True,
    width=1000,
    height=600
)

fig.show()

### Attention Attribution

In [81]:
from utils.attribution import compute_attention_attribution_for_prompt

attn_patterns, attn_attrs  = compute_attention_attribution_for_prompt(model, clean_prompt, metric, return_acts=True)
from utils.neel_plotly import imshow
LABELED_TOKENS = nutils.process_tokens_index(model.to_str_tokens(clean_prompt))
imshow(attn_attrs.sum([0, 1]), x=LABELED_TOKENS, y=LABELED_TOKENS, title="Area Attribution Scores")
imshow(attn_attrs.std([0, 1]), x=LABELED_TOKENS, y=LABELED_TOKENS, title="Area Attribution Scores")


In [85]:
layer = 28
head = 12
imshow(attn_patterns[layer, head], x=LABELED_TOKENS, y=LABELED_TOKENS, title=f"Layer {layer} Head {head} Attribution Scores")
