# Finding induction heads

In [2]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part2_intro_to_mech_interp").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference
from part1_transformer_from_scratch.solutions import get_log_probs
import part2_intro_to_mech_interp.tests as tests

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

  warn(f"Failed to load image Python extension: {e}")


## Introducing our toy attention-only model

- It has only attention blocks.
- The positional embeddings are only added to each key and query vector in the attention layers as opposed to the token embeddings (meaning that the residual stream can't directly encode positional information).
    - This turns out to make it way easier for induction heads to form, it happens 2-3x times earlier - see the comparison of two training runs here. (The bump in each curve is the formation of induction heads.)
    - The argument that does this below is positional_embedding_type="shortformer".
- It has no MLP layers, no LayerNorms, and no biases.
- There are separate embed and unembed matrices (i.e. the weights are not tied).

In [3]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

In [4]:
weights_dir = (section_dir / "attn_only_2L_half.pth").resolve()

if not weights_dir.exists():
    url = "https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu"
    output = str(weights_dir)
    gdown.download(url, output)

Downloading...
From (uriginal): https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu
From (redirected): https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu&confirm=t&uuid=6364e0a6-2a96-413d-b8d5-2ac788f322aa
To: /rds/project/rds-0cKEKVse28g/arena/ARENA_2.0/chapter1_transformers/exercises/part2_intro_to_mech_interp/attn_only_2L_half.pth
100%|██████████| 184M/184M [00:04<00:00, 42.0MB/s] 


In [5]:
model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_dir, map_location=device)
model.load_state_dict(pretrained_weights)

Downloading (…)okenizer_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/457k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


<All keys matched successfully>

## Exercise: visualise attention patterns

In [6]:
text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = model.run_with_cache(text, remove_batch_dim=True)

Patterns
- Attend to previous
- Attend to start
- Attend to current

## Exercise: write your own detectors

In [35]:
# At x% of tokens, most focused token is current (and over 0.1)

In [96]:
import torch
from einops import reduce, rearrange, pack

fraction_of_queries_thresh = 0.3
active_thresh = 0.4


def current_attn_detector(cache: ActivationCache) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads

    i.e. all heads for which fraction_of_queries_thresh % of queries attend to the current token primarily (and are above some thresh)
    """
    heads = []
    for layer in range(len(cache.model.blocks)):
        pattern = cache["pattern", layer]
        n_toks = pattern.shape[-1]
        n_heads = pattern.shape[-3]
        highest_weight = torch.sort(pattern, dim=-1, descending=True)[0][:, :, 0]
        current_weight = torch.diagonal(pattern[:,], dim1=1, dim2=2)
        heads += [
            f"{layer}.{int(i)}"
            for i in torch.arange(n_heads, device=pattern.device)[
                (
                    (highest_weight == current_weight)
                    & (highest_weight > active_thresh)
                ).sum(dim=-1)
                / n_toks
                > fraction_of_queries_thresh
            ]
        ]
    return heads


def prev_attn_detector(cache: ActivationCache) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    """
    heads = []
    for layer in range(len(cache.model.blocks)):
        pattern = cache["pattern", layer]
        n_toks = pattern.shape[-1]
        n_heads = pattern.shape[-3]
        highest_weight = torch.sort(pattern, dim=-1, descending=True)[0][:, :, 0]
        current_weight, _ = pack([rearrange(pattern[:,0,0], "b -> b ()"), torch.diagonal(pattern[:,], offset=-1, dim1=1, dim2=2)], "b *")
        heads += [
            f"{layer}.{int(i)}"
            for i in torch.arange(n_heads, device=pattern.device)[
                (
                    (highest_weight == current_weight)
                    & (highest_weight > active_thresh)
                ).sum(dim=-1)
                / n_toks
                > fraction_of_queries_thresh
            ]
        ]
    return heads


def first_attn_detector(cache: ActivationCache) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    """
    heads = []
    for layer in range(len(cache.model.blocks)):
        pattern = cache["pattern", layer]
        n_toks = pattern.shape[-1]
        n_heads = pattern.shape[-3]
        highest_weight = torch.sort(pattern, dim=-1, descending=True)[0][:, :, 0]
        current_weight = pattern[:,:,0]
        heads += [
            f"{layer}.{int(i)}"
            for i in torch.arange(n_heads, device=pattern.device)[
                (
                    (highest_weight == current_weight)
                    & (highest_weight > active_thresh)
                ).sum(dim=-1)
                / n_toks
                > fraction_of_queries_thresh
            ]
        ]
    return heads


print("Heads attending to current token  = ", ", ".join(current_attn_detector(cache)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(cache)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(cache)))


Heads attending to current token  =  0.9, 0.11, 1.6
Heads attending to previous token =  0.7
Heads attending to first token    =  0.3, 1.3, 1.4, 1.8, 1.10


In [9]:
gpt2_str_tokens = model.to_str_tokens(text)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(
    tokens=gpt2_str_tokens, 
    attention=cache["pattern", 0],
    attention_head_names=[f"L0H{i}" for i in range(12)],
))
print("Layer 1 Head Attention Patterns:")
display(cv.attention.attention_patterns(
    tokens=gpt2_str_tokens, 
    attention=cache["pattern", 1],
    attention_head_names=[f"L1H{i}" for i in range(12)],
))

Layer 0 Head Attention Patterns:


Layer 1 Head Attention Patterns:


## Exercise: plot per-token loss on repeated sequence