## Setup

In [85]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"

In [86]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [87]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [88]:
# Initial task accuracy
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


## Prompt Generation

In [89]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
    "After Sam and Matt went to the park,{} gave a drink to",
    "After Adam and Miles went to the park,{} gave a drink to",
    "After Oscar and William went to the park,{} gave a drink to",
    "After Sally and Kate went to the park,{} gave a drink to",
    "After Karen and Lisa went to the park,{} gave a drink to",
    "After Emily and Laura went to the park,{} gave a drink to",
    "After Jacob and Cam went to the park,{} gave a drink to",
]
name_pairs = [
    (" John", " Mary"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
    (" Sam", " Matt"),
    (" Adam", " Miles"),
    (" Oscar", " William"),
    (" Sally", " Kate"),
    (" Karen", " Lisa"),
    (" Emily", " Laura"),
    (" Jacob", " Cam"),
]

# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)
prompts = [
    prompt.format(name) 
    for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1] 
]
# Define the answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
# Define the answer tokens (same shape as the answers)
answer_tokens = torch.concat([
    model.to_tokens(names, prepend_bos=False).T for names in answers
])

# print(prompts)
# print(answers)
# print(answer_tokens)

In [90]:
table = Table("Prompt", "Correct", "Incorrect", "Token 1", "Token 2", title="Prompts & Answers:")

for prompt, answer, token in zip(prompts, answers, answer_tokens.tolist()):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]), str(token[0]), str(token[1]))

rprint(table)

## Average Logit Differences

In [91]:
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)

original_logits, cache = model.run_with_cache(tokens)
original_cache = cache

In [158]:
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False,
    print_accuracy = True
):
    '''
    Returns logit difference between the correct and incorrect answer. The first token is assumed to be correct.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits

    accuracy = (answer_logit_diff > 0).float().mean()
    if print_accuracy:
        print("Logit difference accuracy: ", f"{accuracy.item():.3f}")
    
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()


original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", original_per_prompt_diff)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold")
]
table = Table(*cols, title="Logit differences")

for prompt, answer, logit_diff in zip(prompts, answers, original_per_prompt_diff):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]), f"{logit_diff.item():.3f}")

rprint(table)

Logit difference accuracy:  1.000
Per prompt logit difference: tensor([3.2016, 3.3367, 2.7095, 3.7975, 1.7204, 5.2812, 2.6008, 5.7674, 3.0110,
        2.4750, 1.9531, 2.2397, 2.2629, 3.5962, 2.1930, 3.0245, 1.5002, 2.1427,
        2.8699, 2.3140, 3.8759, 3.6036], device='cuda:0',
       grad_fn=<SubBackward0>)
Logit difference accuracy:  1.000
Average logit difference: tensor(2.9762, device='cuda:0', grad_fn=<MeanBackward0>)


## Logit attribution

In [146]:
def get_logit_diff_directions(model, answer_tokens) -> Float[Tensor, "batch d_model"]:
    """
    The ideal logit difference direction in the unembed, formed by substracting the incorrect answer token's direction from the correct answer token's direction.
    """
    
    # Token unembed with layer norm handled
    answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
    correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
    logit_diff_directions = correct_residual_directions - incorrect_residual_directions
    # print("Answer residual directions shape:", answer_residual_directions.shape)
    # print(f"Logit difference directions shape:", logit_diff_directions.shape)
    return logit_diff_directions


logit_diff_directions = get_logit_diff_directions(model, answer_tokens)

In [147]:
# Cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream: Float[Tensor, "batch seq d_model"] = cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions,
    "batch d_model, batch d_model ->"
) / len(prompts)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {original_average_logit_diff}")


Final residual stream shape: torch.Size([22, 15, 768])
Calculated average logit diff: 2.9762184620
Original logit difference:     2.976217269897461


In [148]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"], 
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given 
    stack of components in the residual stream.
    '''
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size

# torch.testing.assert_close(
#     residual_stack_to_logit_diff(final_token_residual_stream, cache),
#     original_average_logit_diff
# )

In [149]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual, cache)

line(
    logit_lens_logit_diffs.detach().cpu().numpy(), 
    #hovermode="x unified",
    title="Logit Difference From Accumulated Residual Stream",
    labels={"x": "Layer", "y": "Logit Diff"},
    #xaxis_tickvals=labels,
    width=800
)

In [150]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs.detach().cpu().numpy(), 
    #hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    #xaxis_tickvals=labels,
    width=800
)

In [151]:
def imshow(tensor, renderer=None, **kwargs):
    preset_kwargs = {
        "color_continuous_midpoint": 0.0,
        "color_continuous_scale": "RdBu",
        "text_auto":".2f"
    }

    fig = px.imshow(utils.to_numpy(tensor), **{**preset_kwargs, **kwargs})
    fig.show(renderer=renderer)

In [152]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(
    per_head_residual, 
    "(layer head) ... -> layer head ...", 
    layer=model.cfg.n_layers
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)

imshow(
    per_head_logit_diffs, 
    labels={"x":"Head", "y":"Layer"}, 
    title="Logit Difference From Each Head",
    width=600
)

## Try to find backup name movers - mean ablations

In [153]:
def plot_per_head_attribution(cache: ActivationCache, logit_diff_directions=logit_diff_directions, title=""):
    per_head_residual, _ = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
    per_head_residual = einops.rearrange(
        per_head_residual, 
        "(layer head) ... -> layer head ...", 
        layer=model.cfg.n_layers
    )
    per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache, logit_diff_directions)

    imshow(
        per_head_logit_diffs, 
        labels={"x":"Head", "y":"Layer"}, 
        title=title,
        width=600
    )

In [154]:
dataset = load_dataset("NeelNanda/c4-10k", split="train")
dataloader = torch.utils.data.DataLoader(dataset["text"], batch_size=1)

Found cached dataset parquet (/root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--c4-10k-dc1f5fce0477f6d0/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


In [155]:
import gc
gc.collect()
torch.cuda.empty_cache()

### Store mean activations of interesting heads for use in mean ablations

In [103]:
def get_mean_activations(layer, head, n=50, sequence_length=15):
    activations = []

    for i, batch in enumerate(dataloader):
        tokens = model.to_tokens(batch[0])[0, :sequence_length]
        _, cache = model.run_with_cache(tokens)
        attention_pattern = cache[f"blocks.{layer}.attn.hook_pattern"]
        head_results = attention_pattern[0, head, :, :]
        activations.append(head_results)
        if i == n:
            break

    all_activations = torch.stack(activations, dim=0)
    #print("Combined:", all_activations.shape)
    return torch.mean(all_activations, dim=(0))


# Heads of interest we identified using direct logit attribution, tuple(layer, head_index). Slightly different to IOI paper.
positive_dla_heads = [(7, 3), (7, 9), (8,10), (9, 6), (9, 9), (10, 0), (10, 6), (10, 10)] # I took another stab at this using ~0.2 as a cutoff and found [(11, 2), (11, 10), (10, 0), (10, 7), (10, 10), (9, 6), (9, 9)]

# Heads of interest identified in IOI paper. 
s_inhibition_heads = [(7, 3), (7, 9), (8, 6), (8, 10)]
positive_name_movers = [(9, 9), (9, 6), (10, 0)]
negative_name_movers = [(10, 7), (11, 10)]
name_movers = positive_name_movers + negative_name_movers
# backup_name_movers = [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (11, 9), (11, 3), (9, 7)]

all_heads = set(s_inhibition_heads + name_movers + positive_dla_heads)

# Store heads' mean activations in a dictionary for use in ablations.
mean_activations_dict = {}
for layer, head in tqdm(all_heads):
    mean_activations = get_mean_activations(layer, head)
    mean_activations_dict[(layer, head)] = mean_activations


  0%|          | 0/11 [00:00<?, ?it/s]

### Utility function to ablate selected heads and show logit difference information

In [156]:
def head_mean_ablation_logit_diff(heads: list[(int, int)], answer_tokens=answer_tokens, prompts=prompts, mean_activations_dict=mean_activations_dict):
        def hook_all_attention_patterns(attn_pattern, hook: HookPoint, head_idx: int, layer_idx: int):
            if hook.layer() == layer_idx:
                attn_pattern[:, head_idx] = mean_activations_dict[(layer_idx, head_idx)]
                return attn_pattern
            
        model.reset_hooks()
        model.set_use_attn_result(True)
        tokens = model.to_tokens(prompts)
        logit_diff_directions = get_logit_diff_directions(model, answer_tokens)

        # Original cache
        original_logits, original_cache = model.run_with_cache(tokens)

        # Ablated cache
        name_filter = lambda name: name.endswith("pattern")
        hooks = []
        for layer, head in heads:
            hooks.append((name_filter, functools.partial(hook_all_attention_patterns, head_idx=head, layer_idx=layer)))

        with model.hooks(fwd_hooks=hooks):
            ablated_logits, ablated_cache = model.run_with_cache(tokens)

        # Average logit differences
        original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
        print("Average logit difference:", original_average_logit_diff)
        ablated_average_logit_diff = logits_to_ave_logit_diff(ablated_logits, answer_tokens)
        print("Average logit difference:", ablated_average_logit_diff)

        # Plot logit differences from each head
        plot_per_head_attribution(ablated_cache, logit_diff_directions, title=f"Ablated Average Logit Difference From Heads <br> {heads}")
        plot_per_head_attribution(original_cache, logit_diff_directions, title="Unablated Average Logit Difference From Each Head", )

### Zoom in on L9H9 - get ablated average logit difference, compare L9H9 activations with L9H6 and other L9H

In [131]:
head_mean_ablation_logit_diff(heads=[(9, 9)], answer_tokens=answer_tokens, prompts=prompts, mean_activations_dict=mean_activations_dict)


Logit difference accuracy:  1.000
Average logit difference: 2.976
Logit difference accuracy:  1.000
Average logit difference: 2.718


In [132]:
# Get output of important head in layer 9 for last position
layer_9 = cache.stack_head_results(layer=9, apply_ln=True)[model.cfg.n_heads*8:]
layer_9_other = layer_9[[0, 1,2,3,4,5,7,8,10,11], :, -1].flatten().detach().cpu().numpy()
layer_9_head_6 = layer_9[6, :, -1].flatten().detach().cpu().numpy()
layer_9_head_9 = layer_9[9, :, -1].flatten().detach().cpu().numpy()
print(layer_9_other.shape) # batch d_model

(168960,)


In [133]:
dfs = [pd.DataFrame({"activation": layer_9_head_9, "name": "Head 9"}), 
        pd.DataFrame({"activation": layer_9_head_6, "name": "Head 6"}),
        pd.DataFrame({"activation": layer_9_other, "name": "Other heads"})]
df = pd.concat(dfs)
df.groupby("name").agg(["mean", "std"])

Unnamed: 0_level_0,activation,activation
Unnamed: 0_level_1,mean,std
name,Unnamed: 1_level_2,Unnamed: 2_level_2
Head 6,-7.269343e-10,0.218672
Head 9,1.362588e-10,0.051043
Other heads,5.291054e-11,0.096656


In [135]:
layer_9_head_6 = layer_9[6, 0, -1]
layer_9_head_9 = layer_9[9, 0, -1]
print(layer_9_head_6.shape, layer_9_head_9.shape)

cos = torch.nn.CosineSimilarity(dim=-1)
# Cosine similarities between head outputs and directions, shape=(d_head, pos)
similarities = cos(layer_9_head_9, layer_9_head_6).item()
print(similarities)

other_similarities = []
for i in range(12):
    if i not in [6, 9]:
        other_head = layer_9[i, 0, -1]
        other_similarities.append(cos(layer_9_head_9, other_head).item())
print(other_similarities)

torch.Size([768]) torch.Size([768])
0.13419874012470245
[0.1889655739068985, -0.0439043864607811, -0.0717644989490509, 0.07560774683952332, -0.0708644837141037, 0.10396718978881836, -0.058510538190603256, -0.20690611004829407, -0.2218116819858551, 0.03350242227315903]


### Ablate positive_dla_heads

In [125]:
def hook_all_attention_patterns(attn_pattern, hook: HookPoint, head_idx: int, layer_idx: int):
    if hook.layer() == layer_idx:
        attn_pattern[:, head_idx] = mean_activations_dict[(layer_idx, head_idx)]
        return attn_pattern

name_filter = lambda name: name.endswith("pattern")
hooks = []
for layer, head in positive_dla_heads:
    hooks.append((name_filter, functools.partial(hook_all_attention_patterns, head_idx=head, layer_idx=layer)))

with model.hooks(fwd_hooks=hooks):
    ablated_logits, ablated_cache = model.run_with_cache(tokens)

In [126]:
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

ablated_average_logit_diff = logits_to_ave_logit_diff(ablated_logits, answer_tokens)
print("Average logit difference:", ablated_average_logit_diff)

Logit difference accuracy:  1.000
Average logit difference: 2.976
Logit difference accuracy:  0.955
Average logit difference: 1.745


In [127]:
plot_per_head_attribution(ablated_cache, title="Ablated Logit Difference From Each Head")

In [137]:
# mean ablation of all heads we identified with DLA
with model.hooks(fwd_hooks=hooks):
    ablated_logits, ablated_cache = model.run_with_cache(tokens[1])
attention_pattern = ablated_cache["blocks.10.attn.hook_pattern"].squeeze()
str_tokens = model.to_str_tokens(prompts[1])
circuitsvis.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [139]:
# attention_pattern = original_cache["blocks.10.attn.hook_pattern"][0].squeeze()
# print(attention_pattern.shape)
# str_tokens = model.to_str_tokens(prompts[0])
# print(len(str_tokens))
# circuitsvis.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [113]:
#sns.histplot(df[df["activation"]], x="activation", hue="name", stat="probability")

- Layer 10 head 7 has negative logit difference on original prompts
- Ablating ALL name movers causes L10H7 to have positive logit difference
- On both original and ablated prompt, L10H7 has roughly the same attention pattern
- On both prompts, it mostly attends to the Mary position
- Since the logit differences are very much not the same, the value in the residual stream for "mary" must be different in both runs
- The previous name mover heads must write something important to the mary position

## Try to investigate negative name mover heads

- Introduce factors where it is ambiguous which head is correct 
- Add 3rd person to the prompt
- See if the negative name mover activates more for the subject or one of the two possible objects


In [115]:
# ablate the name mover heads
# calculate logit difference using backup name mover head
# ablate the S-Inhibition heads used in the original circuit
# calculate logit difference again and see if the backup name mover head can recover the correct answer (+ve logit diff)
    

In [169]:
from transformer_lens import patching


def ioi_metric(
            logits: Float[Tensor, "batch seq d_vocab"], 
            answer_tokens: Float[Tensor, "batch 2"],
            corrupted_logit_diff: float,
            clean_logit_diff: float,
        ) -> Float[Tensor, ""]:
            '''
            Linear function of logit diff, calibrated so that it equals 0 when performance is 
            same as on corrupted input, and 1 when performance is same as on clean input.
            '''
            patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, print_accuracy=False)
            return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff  - corrupted_logit_diff)

# Version of head_mean_ablation_logit_diff that patches in activations from a corrupted prompt rather than the mean activations of many random prompts
def head_activation_patching_logit_diff(heads: list[(int, int)], answer_tokens, clean_prompts, corrupted_prompts):
        model.reset_hooks()
        model.set_use_attn_result(True)

        clean_tokens = model.to_tokens(clean_prompts)
        clean_logits, clean_cache = model.run_with_cache(clean_tokens)
        clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
        
        corrupted_tokens = model.to_tokens(corrupted_prompts)
        corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
        corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)

        
        ioi_metric_for_run = functools.partial(ioi_metric, answer_tokens=first_name_correct_answer_tokens, corrupted_logit_diff=corrupted_logit_diff, clean_logit_diff=clean_logit_diff)

        # switched this to test out other patches
        act_patch_resid_pre = patching.get_act_patch_attn_head_all_pos_every(
            model = model,
            corrupted_tokens = corrupted_tokens,
            clean_cache = clean_cache,
            metric = ioi_metric_for_run
        )

        labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

        imshow(
            act_patch_resid_pre, 
            labels={"x": "Position", "y": "Layer"},
            x=labels,
            title="resid_pre Activation Patching",
            width=600
        )


# Uses prompts with the indirect object and subject swapped around as the corrupted prompt but we could also use 
# a prompt with a previously unused name as the subject, making the indirect object ambiguous
first_name_correct_prompts = prompts[0::2]
second_name_correct_prompts = prompts[1::2]
first_name_correct_answer_tokens = answer_tokens[0::2]
# second_name_correct_answer_tokens = answer_tokens[1::2]

head_activation_patching_logit_diff([head for head in all_heads if head != (10, 7)], answer_tokens=first_name_correct_answer_tokens, clean_prompts=first_name_correct_prompts, corrupted_prompts=second_name_correct_prompts)

Logit difference accuracy:  1.000
Logit difference accuracy:  0.000


  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [117]:
print("Positive name mover heads")
head_mean_ablation_logit_diff(positive_name_movers)
print("S-inhibition heads")
head_mean_ablation_logit_diff(s_inhibition_heads)

Positive name mover heads
Logit difference accuracy:  1.000
Average logit difference: 2.976
Logit difference accuracy:  1.000
Average logit difference: 3.234


S-inhibition heads
Logit difference accuracy:  1.000
Average logit difference: 2.976
Logit difference accuracy:  0.818
Average logit difference: 0.588


In [118]:
attention_pattern = original_cache["blocks.9.attn.hook_pattern"][4].squeeze()
print(attention_pattern.shape)
str_tokens = model.to_str_tokens(prompts[4])
print("Length of prompt 4:", len(str_tokens))
circuitsvis.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

model.reset_hooks()
def hook_all_attention_patterns(attn_pattern, hook: HookPoint, head_idx: int, layer_idx: int):
    if hook.layer() == layer_idx:
        attn_pattern[:, head_idx] = mean_activations_dict[(layer_idx, head_idx)]
        return attn_pattern

name_filter = lambda name: name.endswith("pattern")
hooks = []
for layer, head in heads:
    hooks.append((name_filter, functools.partial(hook_all_attention_patterns, head_idx=head, layer_idx=layer)))

with model.hooks(fwd_hooks=hooks):
    ablated_logits, ablated_cache = model.run_with_cache(tokens)


original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", original_per_prompt_diff)

ablated_per_prompt_diff = logits_to_ave_logit_diff(ablated_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", ablated_per_prompt_diff)

torch.Size([12, 15, 15])
Length of prompt 4: 15
Logit difference accuracy:  1.000
Per prompt logit difference: tensor([3.2016, 3.3367, 2.7095, 3.7975, 1.7204, 5.2812, 2.6008, 5.7674, 3.0110,
        2.4750, 1.9531, 2.2397, 2.2629, 3.5962, 2.1930, 3.0245, 1.5002, 2.1427,
        2.8699, 2.3140, 3.8759, 3.6036], device='cuda:0',
       grad_fn=<SubBackward0>)
Logit difference accuracy:  1.000
Per prompt logit difference: tensor([2.3279, 1.9115, 1.1121, 1.6898, 0.1710, 3.4720, 1.7161, 1.6984, 2.3755,
        0.6135, 0.6171, 2.5768, 1.7879, 2.0059, 1.1177, 1.6975, 1.3885, 1.4684,
        2.9865, 0.6232, 2.2222, 2.2420], device='cuda:0',
       grad_fn=<SubBackward0>)


In [119]:
first_name_correct_prompts = prompts[0::2]
second_name_correct_prompts = prompts[1::2]

first_name_correct_answer_tokens = answer_tokens[0::2]
second_name_correct_answer_tokens = answer_tokens[1::2]

head_mean_ablation_logit_diff([head for head in all_heads if head != (10, 7)], answer_tokens=first_name_correct_answer_tokens, prompts=first_name_correct_prompts, mean_activations_dict=mean_activations_dict)
head_mean_ablation_logit_diff([head for head in all_heads if head != (10, 7)], answer_tokens=second_name_correct_answer_tokens, prompts=second_name_correct_prompts, mean_activations_dict=mean_activations_dict)

Logit difference accuracy:  1.000
Average logit difference: 2.536
Logit difference accuracy:  0.818
Average logit difference: 0.593


Logit difference accuracy:  1.000
Average logit difference: 3.416
Logit difference accuracy:  0.545
Average logit difference: 0.732


## Non-Dropout Model Replication