<a href="https://colab.research.google.com/github/kgwiazdak/Ablation-Experiment/blob/main/Ablation%20Experiment" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Experiment presents differences between ablations and shows when specified ablation type can break. I considered several heads types and 3 ablation types: zero, mean and resample. Than I compare zero and mean ablation with resample ablation in terms of several head types. As we know there is a backup mechanism and this way I'll find out whether all activations are different by the same percent. Recently optimal ablation came over, "Optimal ablation for interpretability", arXiv:2409.09951. It hasn't been considered within this collab as I met some obstackles with running their code but their work brings bright update to an ablation field.

## Setup

### Imports

In [None]:
%pip install transformer_lens
%pip install circuitsvis
%pip install ploty
%pip install numpy

In [None]:
from functools import partial
from typing import List, Optional, Union
from transformer_lens.head_detector import get_supported_heads, detect_head
import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
# from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer


from torch import Tensor
from transformer_lens.hook_points import (
    HookPoint,
)
from jaxtyping import Int, Float
import torch.nn.functional as F
from tqdm import tqdm
import functools

In [None]:
device: torch.device = utils.get_device()
print(device)

### Plot utils


In [4]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()

### Setup model and prompt to all problems

The model that we will operate on is gpt2-small. The prompt that will be use in all next experiments is random generated tokens repeated 2 times.

In [None]:
torch.manual_seed(0)

device: torch.device = utils.get_device()
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

## Logit difference

In [None]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(answers)

In [7]:
tokens = model.to_tokens(prompts, prepend_bos=True)
original_logits, cache = model.run_with_cache(tokens)

In [8]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)

## Activation Patching

In [None]:
corrupted_prompts = []
for i in range(0, len(prompts), 2):
    corrupted_prompts.append(prompts[i + 1])
    corrupted_prompts.append(prompts[i])
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))

In [10]:
def normalize_patched_logit_diff(patched_logit_diff):
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )

#### Activation patching experiment

If you have only a little run you have to use one ablation method per run.

In [11]:
def patch_head_vector(
    corrupted_head_vector,
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, -1, head_index, :] = clean_cache[hook.name][
        :, -1, head_index, :
    ]
    return corrupted_head_vector

def zero_ablation_patch_head_vector(
    corrupted_head_vector,
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, -1, head_index, :] = 0
    return corrupted_head_vector

def mean_ablation_patch_head_vector(
    corrupted_head_vector,
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, -1, head_index, :] = clean_cache[hook.name][
        :, -1, head_index, :
    ].mean(dim=1, keepdim=True)
    return corrupted_head_vector

patch_head_vector_functions = [patch_head_vector, zero_ablation_patch_head_vector, mean_ablation_patch_head_vector]
patch_head_vector_functions = [patch_head_vector, zero_ablation_patch_head_vector]

In [None]:
patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
patched_head_zero_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
patched_head_mean_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)

patched_head_diffs = [patched_head_z_diff, patched_head_zero_diff, patched_head_mean_diff]
patched_head_diffs = [patched_head_z_diff, patched_head_zero_diff]

In [13]:
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        for function, diff in zip(patch_head_vector_functions, patched_head_diffs):
          hook_fn = partial(function, head_index=head_index, clean_cache=cache)
          patched_logits = model.run_with_hooks(
              corrupted_tokens,
              fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
              return_type="logits",
          )
          patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

          diff[layer, head_index] = normalize_patched_logit_diff(
              patched_logit_diff
          )

In [None]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

In [None]:
imshow(
    patched_head_zero_diff,
    title="Logit Difference From Zero Ablated Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

In [None]:
imshow(
    patched_head_mean_diff,
    title="Logit Difference From Mean Ablated Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

Heads of specified types are hardcoded, full reasoning is contained in [Exploratory_Analysis_Demo.ipynb](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=rJ59h9vXihax).

In [14]:
specified_heads = {
    "previous token heads": [(2, 2), (4, 11)],
    "duplicate token heads": [(0, 1), (3, 0), (0, 10)],
    "induction heads": [(5, 5), (6,9), (5, 8), (5, 9)],
    "negative name mover heads": [(10, 7), (11, 10)],
    "name mover heads": [(9, 9), (9, 6), (10, 0)],
    "s-inhibition heads": [(7, 3), (7, 9), (8, 6), (8, 10)],
    "backup name movers heads": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (11, 9), (11, 3), (9, 7)]
}

I consider only heads where resample ablation produced score bigger than 0.05. That's due to calculation method.

In [37]:
def count_procentage_difference(a, b):
  c = torch.tensor([abs(bi/ai) for bi, ai in zip(b, a) if ai>0.05])
  return int((c.sum()/len(a)*100).item())

As we can see the biggest change is observed in S-inhibition heads followed by name mover heads and name mover heads. This indicates that zero ablation and mean ablation is not good at finding S-inhibition heads and it also means that after zero ablating the S-inhibition head whole avg diff wouldn't change that much due to backup machanisms.

In [None]:
for name, entries in specified_heads.items():
  print(name)
  r = []
  z = []
  m = []
  for entry in entries:
    r.append(patched_head_z_diff[entry].item())
    z.append(patched_head_zero_diff[entry].item())
    m.append(patched_head_mean_diff[entry].item())
  r= torch.tensor(r)
  z = torch.tensor(z)
  m = torch.tensor(m)
  print(f"Procentage error between resample and zero ablation: {count_procentage_difference(r, z)}%")
  print(f"Procentage error between resample and mean ablation: {count_procentage_difference(r, m)}%")
  print()

### Consideration about L1H10

Previous token heads, duplicate token heads, induction heads and negative name mover heads presented 0% differenced but it is partly because of my calculation method. I can see that heads L1H10 and L2H0 are active when using zero or mean output. L2H0 is one of less activated previous token heads. It's hard to determine what is L1H10 head. This head is too early to be negative name mover heads. Lower visualisation also determine that this head isn't previous token heads, duplicate token heads or induction heads.





In [15]:
seq_len = 100
batch_size = 2
original_tokens = torch.randint(
    100, 20000, size=(batch_size, seq_len), device="cpu"
).to(device)
repeated_tokens = einops.repeat(
    original_tokens, "batch seq_len -> batch (2 seq_len)"
).to(device)
repeated_str = model.to_string(repeated_tokens)

In [16]:
induction_heads = detect_head(model, repeated_str, "induction_head")
previous_token_heads = detect_head(model, repeated_str, "previous_token_head")
duplicate_token_heads = detect_head(model, repeated_str, "duplicate_token_head")

In [None]:
imshow(duplicate_token_heads, labels={"x": "Head", "y": "Layer"}, title="Duplicate Token Head Scores  (implementation)")
imshow(previous_token_heads, labels={"x": "Head", "y": "Layer"}, title="Previous Token Head Scores (implementation)")
imshow(induction_heads, labels={"x": "Head", "y": "Layer"}, title="Induction Head Scores (implementation)")

In [20]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [21]:
repeated_tokens=repeated_tokens.flatten()
_, repeated_cache = model.run_with_cache(repeated_tokens)

In [None]:
code = visualize_attention_patterns(
    [59],
    repeated_cache,
    repeated_tokens,
    title="Induction Heads",
    max_width=800,
)
HTML(code)