In [1]:
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import circuitsvis as cv

In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x131331940>

In [4]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [5]:
device = utils.get_device()
device

device(type='mps')

In [6]:
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

OFFICIAL_MODEL_NAMES

['gpt2',
 'gpt2-medium',
 'gpt2-large',
 'gpt2-xl',
 'distilgpt2',
 'facebook/opt-125m',
 'facebook/opt-1.3b',
 'facebook/opt-2.7b',
 'facebook/opt-6.7b',
 'facebook/opt-13b',
 'facebook/opt-30b',
 'facebook/opt-66b',
 'EleutherAI/gpt-neo-125M',
 'EleutherAI/gpt-neo-1.3B',
 'EleutherAI/gpt-neo-2.7B',
 'EleutherAI/gpt-j-6B',
 'EleutherAI/gpt-neox-20b',
 'stanford-crfm/alias-gpt2-small-x21',
 'stanford-crfm/battlestar-gpt2-small-x49',
 'stanford-crfm/caprica-gpt2-small-x81',
 'stanford-crfm/darkmatter-gpt2-small-x343',
 'stanford-crfm/expanse-gpt2-small-x777',
 'stanford-crfm/arwen-gpt2-medium-x21',
 'stanford-crfm/beren-gpt2-medium-x49',
 'stanford-crfm/celebrimbor-gpt2-medium-x81',
 'stanford-crfm/durin-gpt2-medium-x343',
 'stanford-crfm/eowyn-gpt2-medium-x777',
 'EleutherAI/pythia-14m',
 'EleutherAI/pythia-31m',
 'EleutherAI/pythia-70m',
 'EleutherAI/pythia-160m',
 'EleutherAI/pythia-410m',
 'EleutherAI/pythia-1b',
 'EleutherAI/pythia-1.4b',
 'EleutherAI/pythia-2.8b',
 'EleutherAI/pyt

In [7]:
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", device=device)
# model = HookedTransformer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device=device)
# model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", device=device)

Downloading shards: 100%|██████████| 4/4 [06:46<00:00, 101.53s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


In [8]:
text = "one two three one two three one two three"
tokens = model.to_tokens(text)
print(tokens.device)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

mps:0


In [9]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)
str_tokens = model.to_str_tokens(text)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([32, 10, 10])


In [10]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


In [11]:
def plot_head_detection_scores(
    scores: torch.Tensor,
    zmin: float = -1,
    zmax: float = 1,
    xaxis: str = "Head",
    yaxis: str = "Layer",
    title: str = "Head Matches"
) -> None:
    imshow(scores, zmin=zmin, zmax=zmax, xaxis=xaxis, yaxis=yaxis, title=title)

In [12]:
prompts = [
    "one two three one two three one two three",
    "1 2 3 4 5 1 2 3 4 1 2 3 1 2 3 4 5 6 7",
    "green ideas sleep furiously; green ideas don't sleep furiously"
]

In [13]:
from transformer_lens.head_detector import detect_head


head_scores = detect_head(model, prompts, "induction_head", exclude_bos=False, exclude_current_token=False, error_measure="abs")

In [14]:
head_scores

tensor([[1.0128e-01, 5.6669e-02, 9.2258e-02,  ..., 1.1806e-03, 4.6726e-03,
         3.0101e-08],
        [1.0698e-02, 6.3325e-03, 3.1475e-03,  ..., 5.7225e-03, 8.0285e-03,
         1.8843e-02],
        [8.7263e-03, 1.2040e-02, 1.1909e-02,  ..., 2.6295e-02, 3.3771e-03,
         3.7565e-03],
        ...,
        [3.1540e-03, 2.9499e-03, 3.4144e-03,  ..., 2.1878e-02, 8.9275e-02,
         7.3325e-02],
        [2.6566e-02, 7.4857e-02, 3.6811e-02,  ..., 5.4309e-02, 6.4347e-02,
         9.2232e-02],
        [3.5941e-02, 5.4023e-02, 2.6753e-02,  ..., 5.2371e-02, 5.8472e-02,
         2.6142e-02]])

In [None]:
plot_head_detection_scores(head_scores, title="Induction head; average across 3 prompts")

In [16]:
batch_size = 10
seq_len = 50
size = (batch_size, seq_len)
input_tensor = torch.randint(1000, 10000, size)

random_tokens = input_tensor.to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")

In [17]:
# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
def induction_score_hook(
    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

model.run_with_hooks(
    repeated_tokens, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

imshow(induction_score_store, xaxis="Head", yaxis="Layer", title="Induction Score by Head")

In [18]:
induction_head_layer = 8
induction_head_index = 1
size = (1, 20)
input_tensor = torch.randint(1000, 10000, size)

single_random_sequence = input_tensor.to(model.cfg.device)
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")

def visualize_pattern_hook(
    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(repeated_random_sequence), 
            attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
        )
    )

model.run_with_hooks(
    repeated_random_sequence, 
    return_type=None, 
    fwd_hooks=[(
        utils.get_act_name("pattern", induction_head_layer), 
        visualize_pattern_hook
    )]
)