In [1]:
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [15]:
import re

template = "The mother tongue of {} is"
candidates = [
    "JRR Tolkien is a native speaker of",
    "The mother tongue of Douglas Adams is",
    "Barack Obama, speaker of"]

# Convert template to regex pattern (escape and replace {})
template_regex = re.escape(template).replace(r"\{\}", r".+")
pattern = re.compile(f"^{template_regex}$")

matches = [pattern.fullmatch(c) for c in candidates]
print(matches)  # True means same structure

[None, <re.Match object; span=(0, 37), match='The mother tongue of Douglas Adams is'>, None]


In [9]:
import torch.nn.functional as F

corrupt_tokens = model.to_tokens(
    ["JRR Tolkien is a native speaker of",
    "The mother tongue of Douglas Adams is",
    "Barack Obama, speaker of"]).to(dtype=torch.float32)
original_tokens = model.to_tokens("The mother tongue of Daniel Radcliffe is").to(dtype=torch.float32)

print(original_tokens.dtype, corrupt_tokens.dtype)

similarities = F.cosine_similarity(original_tokens, corrupt_tokens)

print(similarities.shape)
print(similarities)

torch.float32 torch.float32
torch.Size([3])
tensor([0.8345, 0.6179, 0.7897], device='cuda:0')


In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

embed.W_E
pos_embed.W_pos
blocks.0.attn.W_Q
blocks.0.attn.W_O
blocks.0.attn.b_Q
blocks.0.attn.b_O
blocks.0.attn.W_K
blocks.0.attn.W_V
blocks.0.attn.b_K
blocks.0.attn.b_V
blocks.0.mlp.W_in
blocks.0.mlp.b_in
blocks.0.mlp.W_out
blocks.0.mlp.b_out
blocks.1.attn.W_Q
blocks.1.attn.W_O
blocks.1.attn.b_Q
blocks.1.attn.b_O
blocks.1.attn.W_K
blocks.1.attn.W_V
blocks.1.attn.b_K
blocks.1.attn.b_V
blocks.1.mlp.W_in
blocks.1.mlp.b_in
blocks.1.mlp.W_out
blocks.1.mlp.b_out
blocks.2.attn.W_Q
blocks.2.attn.W_O
blocks.2.attn.b_Q
blocks.2.attn.b_O
blocks.2.attn.W_K
blocks.2.attn.W_V
blocks.2.attn.b_K
blocks.2.attn.b_V
blocks.2.mlp.W_in
blocks.2.mlp.b_in
blocks.2.mlp.W_out
blocks.2.mlp.b_out
blocks.3.attn.W_Q
blocks.3.attn.W_O
blocks.3.attn.b_Q
blocks.3.attn.b_O
blocks.3.attn.W_K
blocks.3.attn.W_V
blocks.3.attn.b_K
blocks.3.attn.b_V
blocks.3.mlp.W_in
blocks.3.mlp.b_in
blocks.3.mlp.W_out
blocks.3.mlp.b_out
blocks.4.attn.W_Q
blocks.4.attn.W_O
blocks.4.attn.b_Q
blocks.4.attn.b_O
blocks.4.attn.W_K
blocks.4.att