## Key Tutorials
- [Introduction to the Library and Mech Interp](https://arena3-chapter1-transformer-interp.streamlit.app/[1.2]_Intro_to_Mech_Interp)  
- [Demo of Main TransformerLens Features](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Main_Demo.ipynb#scrollTo=vFz9pMahYZJv)

https://github.com/TransformerLensOrg/TransformerLens

In [1]:
import os
DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2
        
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
# if IN_COLAB or not DEVELOPMENT_MODE:
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "notebook_connected"
pio.renderers.default = 'notebook_connected' #1回設定しておけばいい！
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv
# Testing that the library works
cv.examples.hello("keno")

Using renderer: notebook_connected


In [2]:
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
device = utils.get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
input_text = "Hello, my name is Keno." * 5
loss = model(input_text, return_type='loss')
print("model loss: ", loss)

model loss:  tensor(1.0649, device='cuda:0')


In [5]:
input_text_tokens = model.to_tokens(input_text)
gpt2_tokens = model.to_tokens(input_text)
print(input_text_tokens)
gpt2_logits, gpt2_cache = model.run_with_cache(input_text_tokens, remove_batch_dim=True)

tensor([[50256, 15496,    11,   616,  1438,   318,  7148,    78,    13, 15496,
            11,   616,  1438,   318,  7148,    78,    13, 15496,    11,   616,
          1438,   318,  7148,    78,    13, 15496,    11,   616,  1438,   318,
          7148,    78,    13, 15496,    11,   616,  1438,   318,  7148,    78,
            13]], device='cuda:0')


In [6]:
print(type(gpt2_cache))
attention_pattern = gpt2_cache["pattern", 0, "attn"]
# head_index, destination_position, source_position
print(attention_pattern.shape)
gpt2_str_tokens = model.to_str_tokens(input_text)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 41, 41])


In [7]:
print(str(gpt2_cache).replace(',', '\n'))

ActivationCache with keys ['hook_embed'
 'hook_pos_embed'
 'blocks.0.hook_resid_pre'
 'blocks.0.ln1.hook_scale'
 'blocks.0.ln1.hook_normalized'
 'blocks.0.attn.hook_q'
 'blocks.0.attn.hook_k'
 'blocks.0.attn.hook_v'
 'blocks.0.attn.hook_attn_scores'
 'blocks.0.attn.hook_pattern'
 'blocks.0.attn.hook_z'
 'blocks.0.hook_attn_out'
 'blocks.0.hook_resid_mid'
 'blocks.0.ln2.hook_scale'
 'blocks.0.ln2.hook_normalized'
 'blocks.0.mlp.hook_pre'
 'blocks.0.mlp.hook_post'
 'blocks.0.hook_mlp_out'
 'blocks.0.hook_resid_post'
 'blocks.1.hook_resid_pre'
 'blocks.1.ln1.hook_scale'
 'blocks.1.ln1.hook_normalized'
 'blocks.1.attn.hook_q'
 'blocks.1.attn.hook_k'
 'blocks.1.attn.hook_v'
 'blocks.1.attn.hook_attn_scores'
 'blocks.1.attn.hook_pattern'
 'blocks.1.attn.hook_z'
 'blocks.1.hook_attn_out'
 'blocks.1.hook_resid_mid'
 'blocks.1.ln2.hook_scale'
 'blocks.1.ln2.hook_normalized'
 'blocks.1.mlp.hook_pre'
 'blocks.1.mlp.hook_post'
 'blocks.1.hook_mlp_out'
 'blocks.1.hook_resid_post'
 'blocks.2.hook_re

In [8]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


In [9]:
layer_to_ablate = 0
head_index_to_ablate = 8

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(gpt2_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    gpt2_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")

Shape of the value tensor: torch.Size([1, 41, 12, 64])
Original Loss: 1.065
Ablated Loss: 1.210


In [10]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]


clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Logit difference for clean prompt: {clean_logit_diff.item():.3f}")

corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted Logit difference: {corrupted_logit_diff.item():.3f}")

Logit difference for clean prompt: 4.276
Corrupted Logit difference: -2.738


In [11]:
def residual_stream_patching_hook(
        resid_pre: Float[torch.Tensor, "batch pos d_model"],
        hook: HookPoint,
        position: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])

        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

  0%|          | 0/12 [00:00<?, ?it/s]

In [12]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")