In [1]:
## Activation Patching
## Used to understand how different parts of a model contributre to its behavior
## we modify or “patch” the activations of certain model components and observe the impact on model output

import nnsight
from nnsight import CONFIG

from nnsight import LanguageModel,util

llm=LanguageModel("Qwen/Qwen3-0.6B", device_map="cuda")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(llm)

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe

In [3]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
    "After John and Mary went to the store, John gave a bottle of milk to"
)


correct_index = llm.tokenizer(" John")["input_ids"][0] # includes a space
incorrect_index = llm.tokenizer(" Mary")["input_ids"][0] # includes a space

print(f"' John': {correct_index}")
print(f"' Mary': {incorrect_index}")

' John': 3757
' Mary': 10244


In [4]:
N_LAYERS=len(llm.model.layers)

with llm.trace(clean_prompt) as tracer:
    clean_tokens=tracer.invoker.inputs[0][0]['input_ids'][0] 
    clean_hs=[
        llm.model.layers[layer_idx].output[0].save() for layer_idx in range(N_LAYERS)
    ]
    clean_logits=llm.lm_head.output
    clean_logit_diff=(
        clean_logits[0,-1,correct_index] - clean_logits[0,-1,incorrect_index]
    ).save()

In [5]:
with llm.trace(corrupted_prompt) as tracer:
    corrupted_logits=llm.lm_head.output
    
    corrupted_logit_diff=(
        corrupted_logits[0,-1,correct_index] - corrupted_logits[0,-1,incorrect_index]
    ).save()

In [6]:
ioi_patching_results=[]

for layer_idx in range(len(llm.model.layers)):
    _ioi_patching_results=[]
    
    ## Iterate through all tokens
    for token_idx in range(len(clean_tokens)):
        
        ##Patching corrupted run at given layer and token
        with llm.trace(corrupted_prompt) as tracer:
            llm.model.layers[layer_idx].output[0][:,token_idx,:]=clean_hs[layer_idx][:,token_idx,:]
            patched_logits=llm.lm_head.output
            
            patched_logit_diff=(
                patched_logits[0,-1,correct_index] - patched_logits[0,-1,incorrect_index]
                
            )
            patched_result=(patched_logit_diff-corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)
            _ioi_patching_results.append(patched_result.item().save())
    
    ioi_patching_results.append(_ioi_patching_results)            

In [7]:
from nnsight.tracing.graph import Proxy
import plotly.express as px
import plotly.io as pio

def plot_ioi_patching_results(model,
                              ioi_patching_results,
                              x_labels,
                              plot_title="Normalized Logit Difference After Patching Residual Stream on the IOI Task"):

    ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value, Proxy)

    fig = px.imshow(
        ioi_patching_results,
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
        x=x_labels,
        title=plot_title,
    )

    return fig

In [8]:
print(f"Clean logit difference: {clean_logit_diff:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff:.3f}")

clean_decoded_tokens = [llm.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

fig = plot_ioi_patching_results(llm, ioi_patching_results,token_labels,"Patching GPT-2-small Residual Stream on IOI task")
fig.show()

Clean logit difference: 4.330
Corrupted logit difference: -5.813


In [9]:
for i, tok in enumerate(clean_tokens):
    print(i, repr(llm.tokenizer.decode(tok)))


0 'After'
1 ' John'
2 ' and'
3 ' Mary'
4 ' went'
5 ' to'
6 ' the'
7 ' store'
8 ','
9 ' Mary'
10 ' gave'
11 ' a'
12 ' bottle'
13 ' of'
14 ' milk'
15 ' to'
