In [36]:
import torch
from torch import Tensor

import einops

from typing import Literal
from jaxtyping import Float

from transformer_lens import HookedTransformer, ActivationCache

from ioi_dataset import IOIDataset, format_prompt, make_table


device = "cuda" if torch.cuda.is_available() else "cpu"

# Model, dataset & metric setup

In [37]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [38]:
N = 5
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

make_table(
  colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
  cols = [
    map(format_prompt, clean_dataset.sentences),
    model.to_string(clean_dataset.s_tokenIDs).split(),
    model.to_string(clean_dataset.io_tokenIDs).split(),
    map(format_prompt, corr_dataset.sentences),
  ],
  title = "Sentences from IOI vs ABC distribution",
)

In [39]:
def ave_logit_diff(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]
    s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with torch.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -ioi_metric(logits)
    
# Get clean and corrupt logit differences
with torch.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 3.3131515979766846, Corrupt direction: 1.653100609779358
Clean metric: 1.0, Corrupt metric: 0.0


# Gather the 2 forward and 1 backward passes required to approximate the difference in loss

In [40]:
hook_filter = lambda name: name.endswith("ln1.hook_normalized") or name.endswith("attn.hook_result")


def get_3_caches(model, clean_input, corrupted_input, metric, mode: Literal["node", "edge"]="node"):
    # cache the activations and gradients of the clean inputs
    model.reset_hooks()
    clean_cache = {}

    def forward_cache_hook(act, hook):
        clean_cache[hook.name] = act.detach()

    edge_acdcpp_outgoing_filter = lambda name: name.endswith(("hook_result", "hook_mlp_out", "blocks.0.hook_resid_pre", "hook_q", "hook_k", "hook_v"))
    model.add_hook(hook_filter if mode == "node" else edge_acdcpp_outgoing_filter, forward_cache_hook, "fwd")

    clean_grad_cache = {}

    def backward_cache_hook(act, hook):
        clean_grad_cache[hook.name] = act.detach()

    incoming_ends = ["hook_q_input", "hook_k_input", "hook_v_input", f"blocks.{model.cfg.n_layers-1}.hook_resid_post"]
    if not model.cfg.attn_only:
        incoming_ends.append("hook_mlp_in")
    edge_acdcpp_back_filter = lambda name: name.endswith(tuple(incoming_ends + ["hook_q", "hook_k", "hook_v"]))
    model.add_hook(hook_filter if mode=="node" else edge_acdcpp_back_filter, backward_cache_hook, "bwd")
    value = metric(model(clean_input))


    value.backward()

    # cache the activations of the corrupted inputs
    model.reset_hooks()
    corrupted_cache = {}

    def forward_corrupted_cache_hook(act, hook):
        corrupted_cache[hook.name] = act.detach()

    model.add_hook(hook_filter if mode == "node" else edge_acdcpp_outgoing_filter, forward_corrupted_cache_hook, "fwd")
    model(corrupted_input)
    model.reset_hooks()

    clean_cache = ActivationCache(clean_cache, model)
    corrupted_cache = ActivationCache(corrupted_cache, model)
    clean_grad_cache = ActivationCache(clean_grad_cache, model)
    return clean_cache, corrupted_cache, clean_grad_cache

In [41]:
def split_layers_and_heads(act: Tensor, model: HookedTransformer) -> Tensor:
    return einops.rearrange(act, '(layer head) batch seq d_model -> layer head batch seq d_model',
                            layer=model.cfg.n_layers,
                            head=model.cfg.n_heads)

In [42]:
# get the 2 fwd and 1 bwd caches; cache "normalized" and "result" of attn layers
clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(
    model, 
    clean_dataset.toks,
    corr_dataset.toks,
    metric=negative_ioi_metric,
    mode = "edge",
)

In [44]:
clean_head_act = split_layers_and_heads(clean_cache.stack_head_results(), model=model) # (n_layers, n_heads, batch, seq, d_model)
corr_head_act = split_layers_and_heads(corrupted_cache.stack_head_results(), model=model)

In [23]:
stacked_grad_act = torch.zeros(
    3, # QKV
    model.cfg.n_layers,
    model.cfg.n_heads,
    clean_head_act.shape[-3], # Batch
    clean_head_act.shape[-2], # Seq
    clean_head_act.shape[-1], # D
)

# take the gradient for the qkv input vector
for letter_idx, letter in enumerate("qkv"):
    for layer_idx in range(model.cfg.n_layers):
        stacked_grad_act[letter_idx, layer_idx] = einops.rearrange(clean_grad_cache[f"blocks.{layer_idx}.hook_{letter}_input"], "batch seq n_heads d -> n_heads batch seq d")

In [45]:
results = {}
# compute the edge attribution
for upstream_layer_idx in range(model.cfg.n_layers):
    for upstream_head_idx in range(model.cfg.n_heads):
        for downstream_letter_idx, downstream_letter in enumerate("qkv"):
            for downstream_layer_idx in range(upstream_layer_idx+1, model.cfg.n_layers):
                for downstream_head_idx in range(model.cfg.n_heads):
                    results[
                        (
                            upstream_layer_idx,
                            upstream_head_idx,
                            downstream_letter,
                            downstream_layer_idx,
                            downstream_head_idx,
                        )
                    ] = (stacked_grad_act[downstream_letter_idx, downstream_layer_idx, downstream_head_idx].cpu() * (clean_head_act[upstream_layer_idx, upstream_head_idx] - corr_head_act[upstream_layer_idx, upstream_head_idx]).cpu()).sum()

In [46]:
sorted_results = sorted(results.items(), key=lambda x: x[1].abs(), reverse=True)

In [47]:
print("Top 10 most important edges:")
for i in range(10):
    print(
        f"{sorted_results[i][0][0]}:{sorted_results[i][0][1]} -> {sorted_results[i][0][3]}:{sorted_results[i][0][4]}",
    )

Top 10 most important edges:
5:5 -> 8:6
10:7 -> 11:10
5:5 -> 6:9
9:9 -> 11:10
9:6 -> 11:10
4:11 -> 6:9
9:9 -> 10:7
5:5 -> 7:9
9:6 -> 10:7
9:6 -> 10:0
