In [1]:
%pip install ipykernel setuptools transformer_lens plotly circuitsvis nbformat

Note: you may need to restart the kernel to use updated packages.


# Mech Interp introduction

## Transformer lens introduction

### Load and run a HookedTransformer

In [6]:
from transformer_lens import EasyTransformer, EasyTransformerConfig
from tqdm.autonotebook import tqdm as notebook_tqdm
import torch
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.mps.empty_cache()
torch.set_default_device("mps")
cfg = EasyTransformerConfig(
    d_model=768,
    d_head=64,
    n_layers=2,
    n_ctx=128,
    d_mlp=3072,
    d_vocab=300,
    act_fn="gelu",
)
model = EasyTransformer(cfg=cfg)

### Understand basic architecture of the model

In [7]:
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

### Use the model's tokenizer to convert text to tokens, and vice versa

In [8]:
def data_generator(batch_size: int, inc_bos_token: bool=True):
    torch.manual_seed(137)
    while True:
        x = torch.randint(1, model.cfg.d_vocab, (batch_size, model.cfg.n_ctx))
        if inc_bos_token: x[:, 0] = 0
        yield x

def loss_fn(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    # tokens: [batch, position]
    # logits: [batch, position, d_vocab]
    
    tokens = tokens[:, :-1] # ignore last token
    logits = logits[:, 1:]  # ignore first token logit
    
    log_probs = logits.log_softmax(dim=-1) # [batch, position, d_vocab]
    log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    return -log_probs.mean()

In [9]:
# text = """A core belief I have about learning mechanistic interpretability is that you should spend at least a third of your time writing code and playing around with model internals, not just reading papers. MI has great feedback loops, and a large component of the skillset is the practical, empirical skill of being able to write and run experiments easily. Unlike normal machine learning, once you have the basics of MI down, you should be able to run simple experiments on small models within minutes, not hours or days. Further, because the feedback loops are so tight, I don’t think there's a sharp boundary between reading and doing research. If you want to deeply engage with a paper, you should be playing with the model studied, and testing the paper's basic claims."""
# tokens = model.to_tokens(text)
# print(tokens)
# print(model.to_string(tokens))

### Know how to cache activations, and to access activations from the cache

In [12]:
tokens = next(data_generator(1))
logits, cache = model.run_with_cache(tokens)
for k in cache:
    print(f"{k} {cache[k].shape}")
print(logits.shape)

# next_token_batch = logits.argmax(dim=-1)[0]
# batch_decoded = model.tokenizer.batch_decode(next_token_batch)
# result = list(zip(model.to_str_tokens(text), batch_decoded))
# print(result)

hook_embed torch.Size([1, 128, 768])
hook_pos_embed torch.Size([1, 128, 768])
blocks.0.hook_resid_pre torch.Size([1, 128, 768])
blocks.0.ln1.hook_scale torch.Size([1, 128, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 128, 768])
blocks.0.attn.hook_q torch.Size([1, 128, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 128, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 128, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 128, 128])
blocks.0.attn.hook_pattern torch.Size([1, 12, 128, 128])
blocks.0.attn.hook_z torch.Size([1, 128, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 128, 768])
blocks.0.hook_resid_mid torch.Size([1, 128, 768])
blocks.0.ln2.hook_scale torch.Size([1, 128, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 128, 768])
blocks.0.mlp.hook_pre torch.Size([1, 128, 3072])
blocks.0.mlp.hook_post torch.Size([1, 128, 3072])
blocks.0.hook_mlp_out torch.Size([1, 128, 768])
blocks.0.hook_resid_post torch.Size([1, 128, 768])
blocks.1.hook_resid_pre torch.Size([1, 128, 768])
b

### Use circuitsvis to visualise attention heads

In [13]:
from transformer_lens.utils import to_numpy
import plotly.express as px

def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y":yaxis, "x":xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()
    
def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {"color_continuous_scale":"RdBu", "color_continuous_midpoint":0.0, "labels":{"x":xaxis, "y":yaxis}}
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

In [14]:
from circuitsvis.attention import attention_heads, attention_pattern

imshow(cache["attn", 0].mean([0, 1]), title="Layer 0 Attention Pattern", height=500, width=500)
imshow(cache["attn", 1].mean([0, 1]), title="Layer 1 Attention Pattern", height=500, width=500)

In [None]:
### Train the model

In [15]:
batch_size = 8
epochs = 3000
lr = 1e-4
max_grad_norm = 1.0
wd = 0.1
betas = (0.9, 0.95)
optim = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda i: max(i/100, 1.))

data_loader = data_generator(batch_size=batch_size)

losses = []
model.zero_grad()
for epoch in notebook_tqdm(range(epochs)):
    tokens = next(data_loader)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    optim.step()
    optim.zero_grad()
    scheduler.step()
    losses.append(loss.item())
    if epoch % 100 == 0:
        print(f"Epoch: {epoch}: {loss.item()}")

px.line(losses, labels={"x": "Epoch", "y": "Loss"})

  0%|          | 2/3000 [00:00<17:31,  2.85it/s]

Epoch: 0: 5.986537933349609


  3%|▎         | 102/3000 [00:16<07:46,  6.21it/s]

Epoch: 100: 3.2244489192962646


  7%|▋         | 202/3000 [00:33<07:31,  6.20it/s]

Epoch: 200: 0.0015244388487190008


 10%|█         | 302/3000 [00:49<07:12,  6.23it/s]

Epoch: 300: 0.00013035595475230366


 13%|█▎        | 402/3000 [01:05<06:57,  6.22it/s]

Epoch: 400: 8.215304660552647e-06


 17%|█▋        | 502/3000 [01:21<06:41,  6.22it/s]

Epoch: 500: 5.797370477012009e-07


 20%|██        | 602/3000 [01:37<06:26,  6.21it/s]

Epoch: 600: 3.121029834574074e-08


 23%|██▎       | 702/3000 [01:53<06:10,  6.21it/s]

Epoch: 700: 1.1615863826364148e-08


 27%|██▋       | 802/3000 [02:10<05:54,  6.21it/s]

Epoch: 800: 1.055987652875956e-08


 30%|███       | 902/3000 [02:26<05:37,  6.22it/s]

Epoch: 900: 1.1029204216583821e-08


 33%|███▎      | 1002/3000 [02:42<05:20,  6.22it/s]

Epoch: 1000: 1.0911872294627756e-08


 37%|███▋      | 1102/3000 [02:58<05:05,  6.21it/s]

Epoch: 1100: 1.4783827495534752e-08


 40%|████      | 1202/3000 [03:14<04:48,  6.23it/s]

Epoch: 1200: 1.5487819027271144e-08


 43%|████▎     | 1302/3000 [03:30<04:33,  6.21it/s]

Epoch: 1300: 1.2437188168235025e-08


 47%|████▋     | 1402/3000 [03:46<04:18,  6.19it/s]

Epoch: 1400: 1.0090547952756879e-08


 50%|█████     | 1502/3000 [04:02<04:01,  6.21it/s]

Epoch: 1500: 1.1381199982452017e-08


 53%|█████▎    | 1602/3000 [04:19<03:44,  6.22it/s]

Epoch: 1600: 1.1967859592232344e-08


 57%|█████▋    | 1702/3000 [04:35<03:29,  6.20it/s]

Epoch: 1700: 1.3493175465839613e-08


 60%|██████    | 1802/3000 [04:51<03:12,  6.21it/s]

Epoch: 1800: 1.1615863826364148e-08


 63%|██████▎   | 1902/3000 [05:07<02:56,  6.22it/s]

Epoch: 1900: 7.626577591679506e-09


 67%|██████▋   | 2002/3000 [05:23<02:40,  6.22it/s]

Epoch: 2000: 6.5705898499857085e-09


 70%|███████   | 2102/3000 [05:39<02:26,  6.12it/s]

Epoch: 2100: 1.1029204216583821e-08


 73%|███████▎  | 2202/3000 [05:55<02:08,  6.23it/s]

Epoch: 2200: 7.661776635359274e-08


 77%|███████▋  | 2302/3000 [06:12<01:54,  6.12it/s]

Epoch: 2300: 3.5316919166916705e-08


 80%|████████  | 2402/3000 [06:28<01:37,  6.14it/s]

Epoch: 2400: 2.3583723418596492e-08


 83%|████████▎ | 2502/3000 [06:45<01:19,  6.24it/s]

Epoch: 2500: 1.4783827495534752e-08


 87%|████████▋ | 2602/3000 [07:01<01:05,  6.12it/s]

Epoch: 2600: 3.519958768904985e-09


 90%|█████████ | 2702/3000 [07:17<00:48,  6.16it/s]

Epoch: 2700: 8.79989769941858e-09


 93%|█████████▎| 2802/3000 [07:33<00:31,  6.20it/s]

Epoch: 2800: 3.6607570308433424e-08


 97%|█████████▋| 2902/3000 [07:50<00:15,  6.21it/s]

Epoch: 2900: 1.1733195748320213e-08


100%|██████████| 3000/3000 [08:06<00:00,  6.17it/s]


In [18]:
big_data_loader = data_generator(10)
big_tokens = next(big_data_loader)
logits, cache = model.run_with_cache(big_tokens)
print(f"Logits: {loss_fn(logits, big_tokens).item()}")

imshow(cache["attn", 0].mean([0, 1]), title="Layer 0 Attention Pattern", height=500, width=500)
imshow(cache["attn", 1].mean([0, 1]), title="Layer 1 Attention Pattern", height=500, width=500)

Logits: 5.2189257360168995e-08
