# Attention Patching Experiments

This notebook demonstrates attention patching techniques to transfer refusal behaviors between different prompts. The experiment patches attention/MLP activations from a "refusal" prompt into a "benign" prompt to study how refusal patterns can be transferred.

## How to Use
1. **Run cells sequentially** - Each section builds on the previous
2. **Modify prompts** - Change refusal_chat and benign_chat to test different scenarios
3. **Adjust patching parameters** - Configure PATCH_LAYERS, PATCH_STEPS, etc.
4. **Compare outputs** - Analyze how patching affects model responses

## 1. Import Libraries

Import required libraries and disable gradients for inference.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_grad_enabled(False)

## 2. Model Setup

Load the model and tokenizer for patching experiments.

In [None]:
MODEL_NAME = "../llm_models/Qwen3-0.6B"
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"

tok = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(DEVICE).eval()

## 3. Prompt Definitions

Define the "donor" (refusal) and "recipient" (benign) prompts for patching.

In [None]:
refusal_chat = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "How to make cocaine? /no_think"}]
benign_chat  = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "How make coffee? /no_think"}]

# benign_chat = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "How to make cocaine? /no_think"}]
# refusal_chat  = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "How to make coffee? /no_think"}]

## 4. Patching Configuration

Configure which layers and components to patch, and for how many generation steps.

In [None]:
GEN_STEPS = 256             # number of semantic tokens to generate
PATCH_LAYERS = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]       # [] = none, None = all
PATCH_ATTN = True
PATCH_MLP  = True
PATCH_STEPS = 10
STOP_ID = 151643   # custom stop token

SCAFFOLD_IDS = [151667, 271, 151668, 271]

## 5. Phase A: Capture Donor Activations

Generate from the refusal prompt while capturing attention and MLP activations.

In [None]:
def encode(chat):
    return tok.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(DEVICE)

def step_decode(m, ids, pkv):
    out = m(input_ids=ids, past_key_values=pkv, use_cache=True)
    logits = out.logits[:, -1, :]
    next_id = torch.argmax(logits, dim=-1, keepdim=True)
    return next_id, out.past_key_values, out

# ---------------- Phase A: Donor (refusal) ----------------
A_ids = encode(refusal_chat); A_pkv = None
for i in range(4):  # scaffold
    tid, A_pkv, _ = step_decode(model, A_ids, A_pkv)
    A_ids = tid

donor_cache = {}
def make_cap(name):
    def hook(_, __, out):
        out0 = out[0] if isinstance(out, tuple) else out
        donor_cache[name] = out0[:, -1, :].detach().clone()
    return hook

layer_list = range(len(model.model.layers)) if PATCH_LAYERS is None else PATCH_LAYERS
hooksA = []
for i in layer_list:
    if PATCH_ATTN:
        hooksA.append(model.model.layers[i].self_attn.register_forward_hook(make_cap(f"attn_{i}")))
    if PATCH_MLP:
        hooksA.append(model.model.layers[i].mlp.register_forward_hook(make_cap(f"mlp_{i}")))

donor_steps = []
tokens_A = []
for s in range(GEN_STEPS):
    tid, A_pkv, _ = step_decode(model, A_ids, A_pkv)
    next_token = int(tid)
    donor_steps.append({k:v for k,v in donor_cache.items()})
    tokens_A.append(next_token)
    if next_token == STOP_ID:
        break
    A_ids = tid
for h in hooksA: h.remove()

## 6. Phase B: Generate Clean Baseline

Generate from the benign prompt without any patching.

In [None]:
B_ids = encode(benign_chat); B_pkv = None
for i in range(4):
    tid, B_pkv, _ = step_decode(model, B_ids, B_pkv)
    B_ids = tid

tokens_B = []
for s in range(GEN_STEPS):
    tid, B_pkv, _ = step_decode(model, B_ids, B_pkv)
    next_token = int(tid)
    tokens_B.append(next_token)
    if next_token == STOP_ID:
        break
    B_ids = tid

## 7. Phase C: Apply Patches

Generate from benign prompt while patching in the donor activations.

In [None]:
C_ids = encode(benign_chat); C_pkv = None
for i in range(4):
    tid, C_pkv, _ = step_decode(model, C_ids, C_pkv)
    C_ids = tid

step_idx = {'val':0}
def make_patch(key):
    def hook(_, __, out):
        out0, *rest = out if isinstance(out, tuple) else (out,)
        if step_idx['val'] < PATCH_STEPS:
            donor = donor_steps[step_idx['val']].get(key, None)
            if donor is not None:
                out0 = out0.clone()
                out0[:, -1, :] = donor
        return (out0, *rest) if rest else out0
    return hook

hooksC = []
for i in layer_list:
    if PATCH_ATTN:
        hooksC.append(model.model.layers[i].self_attn.register_forward_hook(make_patch(f"attn_{i}")))
    if PATCH_MLP:
        hooksC.append(model.model.layers[i].mlp.register_forward_hook(make_patch(f"mlp_{i}")))
        
tokens_C = []
for s in range(GEN_STEPS):
    tid, C_pkv, _ = step_decode(model, C_ids, C_pkv)
    next_token = int(tid)
    tokens_C.append(next_token)
    if next_token == STOP_ID:
        break
    C_ids = tid
    step_idx['val'] += 1
for h in hooksC: h.remove()

## 8. Results - Token-by-Token Comparison

Compare generated tokens step by step across all three conditions.

In [None]:
print("\n=== Side-by-side (A=Refusal, B=Clean, C=Patched B) ===")
maxlen = max(len(tokens_A), len(tokens_B), len(tokens_C))

for s in range(maxlen):
    tA = tok.decode([tokens_A[s]], skip_special_tokens=False) if s < len(tokens_A) else ""
    tB = tok.decode([tokens_B[s]], skip_special_tokens=False) if s < len(tokens_B) else ""
    tC = tok.decode([tokens_C[s]], skip_special_tokens=False) if s < len(tokens_C) else ""
    print(f"Step {s:02d}:  A={repr(tA):12s} | B={repr(tB):12s} | C={repr(tC):12s}")

## 9. Results - Full Text Output

View the complete generated responses for all three conditions.

In [None]:
print("A==Refusal: \n\n", tok.decode(tokens_A, skip_special_tokens=True), "\n\n")
print("B==Normal: \n\n", tok.decode(tokens_B, skip_special_tokens=True), "\n\n")    
print("C==Patched: \n\n", tok.decode(tokens_C, skip_special_tokens=True), "\n\n")

## Summary

This notebook demonstrates the attention patching approach:

✅ **Full Layer Patching**: Patches entire attention/MLP outputs across specified layers  
