In [1]:
import wandb
import torch
import hydra
import random
import numpy as np
from trl import setup_chat_format
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, TaskType, PeftModel, PeftConfig
from omegaconf import DictConfig, OmegaConf
from trl import (
    SFTTrainer,
    SFTConfig,
    setup_chat_format,
    DPOConfig,
    DPOTrainer
)
from datasets import load_dataset
from peft import PeftModelForCausalLM
from tqdm import tqdm
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
FINE_TUNE_PATH = '/Users/remak/Documents/misc/lora_dpo/experiments/llama_3.2_1B_sft/checkpoint-200'
SEED = 42

device = 'cuda' if torch.cuda.is_available() \
    else 'mps' if torch.mps.is_available() else 'cpu'

print(f'Running on {device}')

Running on mps


In [3]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
tokenizer = AutoTokenizer.from_pretrained(FINE_TUNE_PATH)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(FINE_TUNE_PATH).to(device)

In [4]:
wandb.init(        
    project='lora_mechinterp',
    entity='keram',
    mode='disabled'
)

In [5]:
model.eval()
eval_dataset = load_dataset('tatsu-lab/alpaca', split='train')
prompts = eval_dataset['output']

In [None]:
activation = {}

def get_activation(name):
    def hook(module, input, output):
        activation[name] = output.detach()
    return hook

print('Pasted below are the top-5 (by absolute value) activations from a learned low-rank adapter for a number of input sequences. Specifically, the activations come from x @ A_matrix space within the q_proj matrix (Query matrix of the self-attention mechanism) of the self attention mechanisms in a llama 3.2 model, in the first self-attention layer. Please find an interpretation of the features. NOTE: The features might exist in superposition, as well as encoding different concepts with a negative value than with a positive value. Please analyse the positive and negative values separately. Additionally, below the activations are the top-5 features by lora B_matrix weight magnitude that feed into each of the 32 attention heads. Please try to interpret the attention heads as well.\n')
activations_list = []

layers = model.base_model.layers
LAYER = 0
idx_ext = LAYER
layer = layers[LAYER]

lora_A = layer.self_attn.q_proj.lora_A
hook = lora_A['default'].register_forward_hook(get_activation(f'attn_{idx_ext}'))

try:
    for prompt in prompts[:15]:
        tokens = tokenizer([prompt], padding=True, truncation=True, return_tensors='pt').to(device)
        _ = model(**tokens)
        latents = activation[f'attn_{idx_ext}'].squeeze(0)
        idxs = latents.abs().topk(TOPK, dim=1).indices
        tokens = [tokenizer.decode(t) for t in tokens['input_ids'].tolist()[0]]
        for idx, magnitude, token in zip(idxs, torch.gather(latents, dim=1, index=idxs), tokens):
            magnitudes = '[' + ','.join([str(round(m, 2)) for m in magnitude.tolist()]) + ']'
            magnitudes = []
            for i, m in zip(idx, magnitude.tolist()):
                magnitudes.append(f'f{i}: {str(round(m, 2))}')
            print(f'Token: \"{token}\" → {", ".join(magnitudes)}')
        print()


except Exception as e:
    hook.remove()
    raise e

hook.remove()
# chatgpt code below
layer   = model.base_model.layers[0]
q_proj  = layer.self_attn.q_proj
B_mod   = q_proj.lora_B["default"]
B       = B_mod.weight
n_heads = 32
d_head  = layer.self_attn.head_dim
r       = B.size(1)

B_by_head = B.view(n_heads, d_head, r)

for h in range(n_heads):
    scores = B_by_head[h].norm(dim=0)
    top    = torch.topk(scores, 5)
    feats  = [f"f{int(i)}:{scores[i]:.3f}" for i in top.indices]
    print(f"head {h:02}:  {', '.join(feats)}")


Pasted below are the top-5 (by absolute value) activations from a learned low-rank adapter for a number of input sequences. Specifically, the activations come from x @ A_matrix space within the q_proj matrix (Query matrix of the self-attention mechanism) of the self attention mechanisms in a llama 3.2 model, in the first self-attention layer. Please find an interpretation of the features. NOTE: The features might exist in superposition, as well as encoding different concepts with a negative value than with a positive value. Please analyse the positive and negative values separately. Additionally, below the activations are the top-5 features by lora B_matrix weight magnitude that feed into each of the 32 attention heads. Please try to interpret the attention heads as well.

Token: "<|begin_of_text|>" → f4: 0.12, f5: -0.11, f10: 0.11, f7: 0.1, f15: -0.07
Token: "1" → f7: -0.5, f11: -0.41, f9: -0.41, f15: 0.4, f4: -0.4
Token: ".E" → f2: -0.5, f13: -0.28, f9: 0.23, f1: -0.16, f7: -0.15
Tok

# ChatGPT Experiments

In [65]:
from contextlib import nullcontext

# ---------- pick layer & convenience ----------
LAYER = 0
layer   = model.base_model.layers[LAYER]
q_proj  = layer.self_attn.q_proj
A_mod   = q_proj.lora_A["default"]        # nn.Linear wrapper
B_mod   = q_proj.lora_B["default"]
scaling = q_proj.scaling['default']

n_heads = 32
d_head  = layer.self_attn.head_dim        # d_model // n_heads
r       = A_mod.weight.size(0)            # rank

# ---------- 1) static head-feature map ----------
B = B_mod.weight     # shape (d_model, r)
# reshape rows into (H, d_head, r)
B_by_head = B.view(n_heads, d_head, r)          # no .T because PEFT stores B as (out, in)

# simple importance score: ℓ₂-norm of each feature’s column inside the head’s rows
B_importance = B_by_head.norm(dim=1)            # (H, r)

# ---------- 2) dynamic hooks ----------
activations = {}

def hook_A(_, __, out):          # out: (batch, seq, r)
    activations["A"] = out.squeeze(0).detach()       # (seq, r)

def hook_Q(_, __, out):          # out: (batch, seq, d_model)
    q = out.squeeze(0).detach()
    activations["Q"] = q.view(q.size(0), n_heads, d_head)  # (seq, H, d_head)


# ---------- 3) one forward pass ----------
model.eval()
with torch.no_grad(), nullcontext():
    hdl_A = A_mod.register_forward_hook(hook_A)
    hdl_Q = q_proj.register_forward_hook(hook_Q)
    for text in prompts[:10]:
        inputs    = tokenizer(text, return_tensors="pt").to(device)
        _      = model(**inputs)
        tok_str = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
        a_tok  = activations["A"]          # (seq, r)
        q_tok  = activations["Q"]          # (seq, H, d_head)

        for pos, tok in enumerate(tok_str):
            a = a_tok[pos]                               # (r,)
            # top-k features by absolute activation
            top_feat = torch.topk(a.abs(), 5)
            feats    = [f"f{int(i)}:{a[i]:+.2f}" for i in top_feat.indices]

            # per-head magnitude *due only to LoRA*
            #   Δq_head = scaling * B_head @ a   →  size (d_head)
            delta_q = (B_by_head @ a) * scaling          # (H, d_head)
            head_mag = delta_q.norm(dim=1)               # (H,)

            head_str = " ".join([f"H{h}:{head_mag[h]:.2f}" for h in range(n_heads)])

            print(f"{tok}  |  {', '.join(feats):40} | {head_str}")

    # remove hooks
        hdl_A.remove(); hdl_Q.remove()


    # ---------- 4) print token-wise diagnostics ----------


<|begin_of_text|>  |  f4:+0.12, f5:-0.11, f10:+0.11, f7:+0.10, f15:-0.07 | H0:0.01 H1:0.00 H2:0.01 H3:0.01 H4:0.00 H5:0.00 H6:0.01 H7:0.01 H8:0.01 H9:0.01 H10:0.00 H11:0.01 H12:0.00 H13:0.00 H14:0.01 H15:0.01 H16:0.00 H17:0.01 H18:0.01 H19:0.00 H20:0.00 H21:0.00 H22:0.01 H23:0.01 H24:0.01 H25:0.00 H26:0.01 H27:0.00 H28:0.01 H29:0.01 H30:0.01 H31:0.01
1  |  f7:-0.50, f11:-0.41, f9:-0.41, f15:+0.40, f4:-0.40 | H0:0.04 H1:0.05 H2:0.02 H3:0.04 H4:0.02 H5:0.02 H6:0.04 H7:0.04 H8:0.03 H9:0.05 H10:0.04 H11:0.03 H12:0.02 H13:0.03 H14:0.05 H15:0.03 H16:0.04 H17:0.05 H18:0.04 H19:0.04 H20:0.03 H21:0.02 H22:0.02 H23:0.03 H24:0.04 H25:0.03 H26:0.06 H27:0.03 H28:0.06 H29:0.04 H30:0.05 H31:0.07
.E  |  f2:-0.50, f13:-0.28, f9:+0.23, f1:-0.16, f7:-0.15 | H0:0.01 H1:0.01 H2:0.01 H3:0.01 H4:0.01 H5:0.01 H6:0.01 H7:0.01 H8:0.01 H9:0.01 H10:0.01 H11:0.01 H12:0.01 H13:0.01 H14:0.01 H15:0.02 H16:0.01 H17:0.00 H18:0.01 H19:0.02 H20:0.00 H21:0.01 H22:0.01 H23:0.01 H24:0.01 H25:0.01 H26:0.01 H27:0.01 H28:0.01 

In [None]:
ppl_B_mod  = perplexity(B_loader, patched=True)
ppl_S_mod  = perplexity(S_loader, patched=True)


 99%|█████████▉| 247/250 [28:38<00:07,  2.50s/it]

In [None]:
print(f"Boundary corpus — original PPL: {ppl_B_orig:.2f}  | ablated: {ppl_B_mod:.2f}")
print(f"Science   corpus — original PPL: {ppl_S_orig:.2f}  | ablated: {ppl_S_mod:.2f}")
print("ΔPPL (boundary – science):", (ppl_B_mod-ppl_B_orig) - (ppl_S_mod-ppl_S_orig))

In [None]:
import torch, math
from contextlib import contextmanager, nullcontext
from torch.utils.data import DataLoader
import re

################################################################################
# 0. Preamble – choose layer, heads, corpora                                    #
################################################################################
LAYER   = 0
HEADS   = {26, 31}
FEATURE = 15                               # f15
BS       = 8                               # evaluation batch size


def boundary_ratio(text):
    return sum(text.count(t) for t in [",", ".", "\n", " and ", " or "]) / max(1, len(text))
boundary_corpus = [t for t in prompts if boundary_ratio(t) > 0.02][:2000]

def is_scientific(text):
    long_sent = max(len(s.split()) for s in re.split(r"[.!?]", text)) > 30
    few_commas = text.count(",") < 0.005 * len(text)
    return long_sent and few_commas
science_corpus = [t for t in prompts if is_scientific(t)][:2000]

################################################################################
# 1. Build dataloaders                                                         #
################################################################################
def make_loader(texts):
    def collate(batch):
        tok = tokenizer(batch, return_tensors='pt', padding=True)
        return tok.input_ids.to(device), tok.attention_mask.to(device)
    return DataLoader(texts, batch_size=BS, collate_fn=collate, shuffle=False)

B_loader = make_loader(boundary_corpus)
S_loader = make_loader(science_corpus)

################################################################################
# 2. Pre-compute static tensors                                                #
################################################################################
layer   = model.base_model.layers[LAYER]
q_proj  = layer.self_attn.q_proj
A_mod   = q_proj.lora_A["default"]
B_mod   = q_proj.lora_B["default"]

scaling = q_proj.scaling
n_heads = 32
d_head  = layer.self_attn.head_dim

# column for feature 15, reshaped per head
col15 = B_mod.weight[:, FEATURE].view(n_heads, d_head).to(device)  # (H, d_head)

################################################################################
# 3. Hook that zeroes positive branch for chosen heads                         #
################################################################################
def perturb_q(module, inputs, output):
    """
    output: (batch, seq, d_model) – already contains LoRA contribution
    We subtract   scaling * a15 * B_head_slice   when a15>0 & head in HEADS
    """
    # 3.1  grab a15 (after A) from the *other* hook
    a15 = cache["a15"]                       # (batch, seq)  already on device
    mask_pos = (a15 > 0).unsqueeze(-1)       # (b,s,1)

    # 3.2  reshape q to (b,s,H,d_head)
    b, s, _ = output.shape
    q = output.view(b, s, n_heads, d_head)

    # 3.3  contribution vector for each head h
    for h in HEADS:
        if type(scaling) == dict:
            delta = scaling['default'] * a15.unsqueeze(-1) * col15[h]   # (b,s,d_head)
        else:
            delta = scaling * a15.unsqueeze(-1) * col15[h]   # (b,s,d_head)
        # apply only where a15>0
        q[:,:,h,:] = torch.where(mask_pos, q[:,:,h,:] - delta, q[:,:,h,:])

    # 3.4  write back
    output.copy_(q.view_as(output))

def grab_a(_, __, out):          # out: (batch, seq, r)
    cache["a15"] = out[..., FEATURE].detach()   # keep only feature 15

################################################################################
# 4. Context manager to enable / disable the patch                             #
################################################################################
@contextmanager
def f15_positive_ablation():
    hA = A_mod.register_forward_hook(grab_a)
    hQ = q_proj.register_forward_hook(perturb_q)
    try:
        yield
    finally:
        hA.remove(); hQ.remove(); cache.clear()

################################################################################
# 5. Perplexity helper                                                         #
################################################################################
def perplexity(loader, patched=False):
    total_nll, total_tokens = 0.0, 0
    ctx = f15_positive_ablation() if patched else nullcontext()
    model.eval()
    with torch.no_grad(), ctx:
        for ids, mask in tqdm(loader):
            out = model(input_ids=ids, attention_mask=mask, labels=ids)
            nll = out.loss.item() * mask.sum().item()
            total_nll     += nll
            total_tokens  += mask.sum().item()
    return math.exp(total_nll / total_tokens)

################################################################################
# 6. Run evaluation                                                            #
################################################################################
cache = {}

ppl_B_orig = perplexity(B_loader, patched=False)
ppl_S_orig = perplexity(S_loader, patched=False)
ppl_B_mod  = perplexity(B_loader, patched=True)
ppl_S_mod  = perplexity(S_loader, patched=True)

print(f"Boundary corpus — original PPL: {ppl_B_orig:.2f}  | ablated: {ppl_B_mod:.2f}")
print(f"Science   corpus — original PPL: {ppl_S_orig:.2f}  | ablated: {ppl_S_mod:.2f}")
print("ΔPPL (boundary – science):", (ppl_B_mod-ppl_B_orig) - (ppl_S_mod-ppl_S_orig))


In [None]:

for prompt in prompts[:50]:
    tokens = tokenizer([prompt], padding=True, truncation=True, return_tensors='pt').to(device)
    _ = model(**tokens)
    latents = activation['lm_head.lora_A.default'].squeeze(0)
    idxs = latents.abs().topk(3, dim=1).indices
    print('The input sequence is:', prompt)
    tokens = [tokenizer.decode(t) for t in tokens['input_ids'].tolist()[0]]
    for idx, magnitude, token in zip(idxs, torch.gather(latents, dim=1, index=idxs), tokens):
        magnitudes = '[' + ','.join([str(round(m, 2)) for m in magnitude.tolist()]) + ']'
        magnitudes = []
        for i, m in zip(idx, magnitude.tolist()):
            magnitudes.append(f'f{i}: {str(round(m, 2))}')
        print(f'Token: \"{token}\" → {", ".join(magnitudes)}')
    print()


Below is a list of tokens from an input sequence, each followed by the top-5 (by absolute value) activations from a learned low-rank adapter. The activations reflect what features the adapter responds to in each token. Please find an interpretation of the features.

The input sequence is: 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. 
2. Exercise regularly to keep your body active and strong. 
3. Get enough sleep and maintain a consistent sleep schedule.
Token: "<|begin_of_text|>" → f0: -47.5, f8: 35.75, f17: -33.76
Token: "1" → f29: -14.54, f18: -10.36, f15: -9.38
Token: ".E" → f22: -10.27, f7: -8.24, f20: 7.61
Token: "at" → f22: -12.59, f29: -6.23, f15: -5.82
Token: " a" → f2: -7.01, f24: 6.47, f27: 4.84
Token: " balanced" → f22: -7.87, f17: 6.68, f21: -6.23
Token: " diet" → f22: -14.99, f21: -10.97, f13: 10.88
Token: " and" → f22: -8.12, f15: -6.2, f2: -6.17
Token: " make" → f15: -6.66, f22: -6.48, f26: 6.02
Token: " sure" → f22: -12.75, f13: 7.31, 