In [16]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append('/Users/paulnguyen/lre-experiment')
from lre import LREModel

## Load Model and Tokenizer

In [17]:
model_name = "Qwen/Qwen3-0.6B"
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
print(f"Model loaded: {model_name}")

Using device: mps
Model loaded: Qwen/Qwen3-0.6B


## Test Case 1: Simple Template (Subject at Start)

In [18]:
# Simple template where subject is at the beginning
template_simple = "{} is commonly associated with"
subject = "wine"
prompt = template_simple.format(subject)

print(f"Prompt: '{prompt}'")
print(f"\nSubject: '{subject}'")
print("="*80)

Prompt: 'wine is commonly associated with'

Subject: 'wine'


In [19]:
# Tokenize the full prompt
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(f"\nFull prompt tokens: {prompt_tokens}")
print(f"Number of tokens in full prompt: {len(prompt_tokens)}")

# Show what each token decodes to
print("\nToken-by-token breakdown:")
for i, token_id in enumerate(prompt_tokens):
    token_text = tokenizer.decode([token_id])
    print(f"  Position {i}: Token ID {token_id:5d} → '{token_text}'")


Full prompt tokens: [71437, 374, 16626, 5815, 448]
Number of tokens in full prompt: 5

Token-by-token breakdown:
  Position 0: Token ID 71437 → 'wine'
  Position 1: Token ID   374 → ' is'
  Position 2: Token ID 16626 → ' commonly'
  Position 3: Token ID  5815 → ' associated'
  Position 4: Token ID   448 → ' with'


In [20]:
# Tokenize the subject separately
subject_tokens = tokenizer.encode(subject, add_special_tokens=False)
print(f"\nSubject '{subject}' tokens: {subject_tokens}")
print(f"Number of tokens in subject: {len(subject_tokens)}")

# Show subject token breakdown
print("\nSubject token breakdown:")
for i, token_id in enumerate(subject_tokens):
    token_text = tokenizer.decode([token_id])
    print(f"  Token {i}: ID {token_id:5d} → '{token_text}'")


Subject 'wine' tokens: [71437]
Number of tokens in subject: 1

Subject token breakdown:
  Token 0: ID 71437 → 'wine'


In [21]:
# CORRECT APPROACH: Use offset mapping to find subject's last token
# This is what find_subject_last_token() does in lre.py

print("="*80)
print("OFFSET MAPPING APPROACH (Used in lre.py)")
print("="*80)

# Step 1: Tokenize with offset mapping to get character positions
inputs_with_offsets = tokenizer(prompt, return_tensors="pt", return_offsets_mapping=True)
offset_mapping = inputs_with_offsets["offset_mapping"][0]  # Shape: (seq_len, 2)

print(f"\nPrompt: '{prompt}'")
print(f"Subject: '{subject}'")
print(f"\nOffset mapping (character positions for each token):")
for i, (start, end) in enumerate(offset_mapping):
    token_text = tokenizer.decode([prompt_tokens[i]])
    print(f"  Token {i}: chars [{start:2d}, {end:2d}) → '{token_text}'")

# Step 2: Find subject's character position in the prompt (last occurrence)
subject_start_char = prompt.rfind(subject)
subject_end_char = subject_start_char + len(subject)

print(f"\nSubject '{subject}' character positions:")
print(f"  Start: {subject_start_char}")
print(f"  End: {subject_end_char}")
print(f"  Characters [{subject_start_char}, {subject_end_char}): '{prompt[subject_start_char:subject_end_char]}'")

# Step 3: Find which tokens overlap with the subject
subject_last_token = None
print(f"\nFinding tokens that overlap with subject:")
for token_idx, (start, end) in enumerate(offset_mapping):
    # Check if this token overlaps with the subject's character range
    overlaps = start < subject_end_char and end > subject_start_char
    if overlaps:
        subject_last_token = token_idx  # Keep updating - we want the LAST one
        marker = "→ LAST TOKEN OF SUBJECT" if token_idx == len(offset_mapping) - 1 or not any(
            s < subject_end_char and e > subject_start_char 
            for s, e in offset_mapping[token_idx+1:]
        ) else ""
        token_text = tokenizer.decode([prompt_tokens[token_idx]])
        print(f"  Token {token_idx}: chars [{start:2d}, {end:2d}) overlaps! → '{token_text}' {marker}")

if subject_last_token is None:
    raise ValueError(f"Could not find subject '{subject}' in tokenized prompt")

print(f"\n✓ Subject's last token is at position: {subject_last_token}")
print(f"  Token ID: {prompt_tokens[subject_last_token]}")
print(f"  Token text: '{tokenizer.decode([prompt_tokens[subject_last_token]])}'")

# Store for later use
subject_last_token_pos = subject_last_token

print(f"\n{'='*80}")
print("WHY THIS WORKS:")
print("="*80)
print("✓ Uses character positions from the ACTUAL tokenized prompt")
print("✓ No tokenization mismatch (subject is already in context)")
print("✓ Finds LAST occurrence (important for few-shot templates)")
print("✓ Handles multi-token subjects correctly")

OFFSET MAPPING APPROACH (Used in lre.py)

Prompt: 'wine is commonly associated with'
Subject: 'wine'

Offset mapping (character positions for each token):
  Token 0: chars [ 0,  4) → 'wine'
  Token 1: chars [ 4,  7) → ' is'
  Token 2: chars [ 7, 16) → ' commonly'
  Token 3: chars [16, 27) → ' associated'
  Token 4: chars [27, 32) → ' with'

Subject 'wine' character positions:
  Start: 0
  End: 4
  Characters [0, 4): 'wine'

Finding tokens that overlap with subject:
  Token 0: chars [ 0,  4) overlaps! → 'wine' → LAST TOKEN OF SUBJECT

✓ Subject's last token is at position: 0
  Token ID: 71437
  Token text: 'wine'

WHY THIS WORKS:
✓ Uses character positions from the ACTUAL tokenized prompt
✓ No tokenization mismatch (subject is already in context)
✓ Finds LAST occurrence (important for few-shot templates)
✓ Handles multi-token subjects correctly


## Test Case 2: Few-Shot Template (Subject in Middle)

In [22]:
# Few-shot template with examples before the subject
template_fewshot = "oil is commonly associated with fuel.\njuice is commonly associated with orange.\n{} is commonly associated with"
subject2 = "wine"
prompt2 = template_fewshot.format(subject2)

print(f"Prompt:\n{prompt2}")
print(f"\nSubject: '{subject2}'")
print("="*80)

Prompt:
oil is commonly associated with fuel.
juice is commonly associated with orange.
wine is commonly associated with

Subject: 'wine'


In [23]:
# Tokenize the full prompt
prompt2_tokens = tokenizer.encode(prompt2, add_special_tokens=False)
print(f"\nFull prompt tokens: {prompt2_tokens}")
print(f"Number of tokens in full prompt: {len(prompt2_tokens)}")

# Show what each token decodes to
print("\nToken-by-token breakdown:")
for i, token_id in enumerate(prompt2_tokens):
    token_text = tokenizer.decode([token_id])
    print(f"  Position {i:2d}: Token ID {token_id:5d} → '{token_text}'")


Full prompt tokens: [73813, 374, 16626, 5815, 448, 10416, 624, 8613, 558, 374, 16626, 5815, 448, 18575, 624, 71437, 374, 16626, 5815, 448]
Number of tokens in full prompt: 20

Token-by-token breakdown:
  Position  0: Token ID 73813 → 'oil'
  Position  1: Token ID   374 → ' is'
  Position  2: Token ID 16626 → ' commonly'
  Position  3: Token ID  5815 → ' associated'
  Position  4: Token ID   448 → ' with'
  Position  5: Token ID 10416 → ' fuel'
  Position  6: Token ID   624 → '.
'
  Position  7: Token ID  8613 → 'ju'
  Position  8: Token ID   558 → 'ice'
  Position  9: Token ID   374 → ' is'
  Position 10: Token ID 16626 → ' commonly'
  Position 11: Token ID  5815 → ' associated'
  Position 12: Token ID   448 → ' with'
  Position 13: Token ID 18575 → ' orange'
  Position 14: Token ID   624 → '.
'
  Position 15: Token ID 71437 → 'wine'
  Position 16: Token ID   374 → ' is'
  Position 17: Token ID 16626 → ' commonly'
  Position 18: Token ID  5815 → ' associated'
  Position 19: Token ID  

In [25]:
# Tokenize the subject separately
subject2_tokens = tokenizer.encode(subject2, add_special_tokens=False)
print(f"\nSubject '{subject2}' tokens: {subject2_tokens}")
print(f"Number of tokens in subject: {len(subject2_tokens)}")


Subject 'wine' tokens: [71437]
Number of tokens in subject: 1


In [26]:
# Use offset mapping for few-shot template (shows why this is critical!)

print("="*80)
print("OFFSET MAPPING WITH FEW-SHOT TEMPLATE")
print("="*80)

# Step 1: Tokenize with offset mapping
inputs_with_offsets2 = tokenizer(prompt2, return_tensors="pt", return_offsets_mapping=True)
offset_mapping2 = inputs_with_offsets2["offset_mapping"][0]

print(f"\nPrompt (first 100 chars): {prompt2[:100]}...")
print(f"Subject: '{subject2}'")

# Step 2: Find subject's character position (LAST occurrence!)
subject2_start_char = prompt2.rfind(subject2)
subject2_end_char = subject2_start_char + len(subject2)

print(f"\nSubject '{subject2}' character positions:")
print(f"  Start: {subject2_start_char}")
print(f"  End: {subject2_end_char}")
print(f"  Text: '{prompt2[subject2_start_char:subject2_end_char]}'")

# Show context around subject in the string
context_start = max(0, subject2_start_char - 30)
context_end = min(len(prompt2), subject2_end_char + 30)
context = prompt2[context_start:context_end]
subject_pos_in_context = subject2_start_char - context_start
print(f"\nContext: '{context}'")
print(f"          {' ' * subject_pos_in_context}{'↑' * len(subject2)} (subject here)")

# Step 3: Find overlapping tokens
subject2_last_token = None
print(f"\nFinding tokens that overlap with subject:")
for token_idx, (start, end) in enumerate(offset_mapping2):
    overlaps = start < subject2_end_char and end > subject2_start_char
    if overlaps:
        subject2_last_token = token_idx
        token_text = tokenizer.decode([prompt2_tokens[token_idx]])
        print(f"  Token {token_idx}: chars [{start:3d}, {end:3d}) → '{token_text}'")

if subject2_last_token is None:
    raise ValueError(f"Could not find subject '{subject2}' in tokenized prompt")

print(f"\n→ Subject's last token is at position: {subject2_last_token}")
print(f"→ Token ID: {prompt2_tokens[subject2_last_token]}")
print(f"→ Token text: '{tokenizer.decode([prompt2_tokens[subject2_last_token]])}'")

# Show context around this token
print(f"\nContext around extraction point:")
start = max(0, subject2_last_token - 2)
end = min(len(prompt2_tokens), subject2_last_token + 3)
for j in range(start, end):
    marker = "→" if j == subject2_last_token else "  "
    highlight = "*** EXTRACT HERE ***" if j == subject2_last_token else ""
    token_text = tokenizer.decode([prompt2_tokens[j]])
    char_range = offset_mapping2[j]
    print(f"  {marker} Token {j:2d}: chars [{char_range[0]:3d}, {char_range[1]:3d}) → '{token_text}' {highlight}")

subject2_last_token_pos = subject2_last_token

print(f"\n{'='*80}")
print("CRITICAL: rfind() finds the LAST occurrence!")
print("="*80)
print("In few-shot templates, 'wine' appears multiple times:")
print("  - In examples: 'wine is...red.'")
print("  - In test prompt: 'wine is commonly associated with'")
print("We want the LAST one (in the test prompt), not the examples!")
print("rfind() ensures we get the right occurrence.")

OFFSET MAPPING WITH FEW-SHOT TEMPLATE

Prompt (first 100 chars): oil is commonly associated with fuel.
juice is commonly associated with orange.
wine is commonly ass...
Subject: 'wine'

Subject 'wine' character positions:
  Start: 80
  End: 84
  Text: 'wine'

Context: 'monly associated with orange.
wine is commonly associated with'
                                        ↑↑↑↑ (subject here)

Finding tokens that overlap with subject:
  Token 15: chars [ 80,  84) → 'wine'

→ Subject's last token is at position: 15
→ Token ID: 71437
→ Token text: 'wine'

Context around extraction point:
     Token 13: chars [ 71,  78) → ' orange' 
     Token 14: chars [ 78,  80) → '.
' 
  → Token 15: chars [ 80,  84) → 'wine' *** EXTRACT HERE ***
     Token 16: chars [ 84,  87) → ' is' 
     Token 17: chars [ 87,  96) → ' commonly' 

CRITICAL: rfind() finds the LAST occurrence!
In few-shot templates, 'wine' appears multiple times:
  - In examples: 'wine is...red.'
  - In test prompt: 'wine is commonly ass

## Test Case 3: Using the LRE Class

In [27]:
# Test that lre.find_subject_last_token() gives the same result

lre = LREModel(model, tokenizer, device)

# Test with simple template
layer_name = "model.layers.10"
print(f"Testing lre.find_subject_last_token() method")
print(f"="*80)

# Get the position using the LRE method
lre_position = lre.find_subject_last_token(prompt, subject)

print(f"\nPrompt: '{prompt}'")
print(f"Subject: '{subject}'")
print(f"\nManual calculation: position {subject_last_token_pos}")
print(f"LRE method result:  position {lre_position}")

if lre_position == subject_last_token_pos:
    print(f"\n✓ MATCH! Both methods found the same position.")
else:
    print(f"\n✗ MISMATCH! Different positions found.")
    print(f"  Difference: {abs(lre_position - subject_last_token_pos)} tokens")

# Test with few-shot template too
print(f"\n{'='*80}")
print("Testing with few-shot template:")
print(f"{'='*80}")

lre_position2 = lre.find_subject_last_token(prompt2, subject2)

print(f"\nPrompt (first 100 chars): {prompt2[:100]}...")
print(f"Subject: '{subject2}'")
print(f"\nManual calculation: position {subject2_last_token_pos}")
print(f"LRE method result:  position {lre_position2}")

if lre_position2 == subject2_last_token_pos:
    print(f"\n✓ MATCH! Both methods found the same position.")
else:
    print(f"\n✗ MISMATCH! Different positions found.")
    print(f"  Difference: {abs(lre_position2 - subject2_last_token_pos)} tokens")

print(f"\n{'='*80}")
print("This is exactly how train_lre() and evaluate() work internally!")
print(f"{'='*80}")

Loading Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attenti

HFValidationError: Repo id must use alphanumeric chars, '-', '_' or '.'. The name cannot start or end with '-' or '.' and the maximum length is 96: 'Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): Qwen3RMSNorm((1024,), eps=1e-06)
    (rotary_emb): Qwen3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)'.

In [28]:
# Extract hidden state using the correct position

print("="*80)
print("EXTRACTING HIDDEN STATE")
print("="*80)

hidden_state = lre.get_hidden_state(prompt, layer_name, token_position=subject_last_token_pos)

print(f"\nExtracted from layer: {layer_name}")
print(f"Token position: {subject_last_token_pos}")
print(f"Token: '{tokenizer.decode([prompt_tokens[subject_last_token_pos]])}'")
print(f"\n✓ Hidden state shape: {hidden_state.shape}")
print(f"  Hidden state norm: {np.linalg.norm(hidden_state):.4f}")
print(f"  First 5 values: {hidden_state[:5]}")

EXTRACTING HIDDEN STATE


NameError: name 'lre' is not defined

## Test Case 4: Compare with Template End Extraction

In [29]:
# Extract from template end (last token of prompt)
last_token_pos = len(prompt_tokens) - 1
hidden_state_template_end = lre.get_hidden_state(prompt, layer_name, token_position="last")

print(f"Extraction from TEMPLATE END (position {last_token_pos}):")
print(f"  Token: '{tokenizer.decode([prompt_tokens[last_token_pos]])}'")
print(f"  Hidden state norm: {torch.norm(hidden_state_template_end).item():.4f}")
print(f"\nExtraction from SUBJECT END (position {subject_last_token_pos}):")
print(f"  Token: '{tokenizer.decode([prompt_tokens[subject_last_token_pos]])}'")
print(f"  Hidden state norm: {torch.norm(hidden_state).item():.4f}")

# Check if they're different
difference = torch.norm(hidden_state - hidden_state_template_end).item()
print(f"\nDifference between the two: {difference:.4f}")
if difference > 0.001:
    print("✓ The two extraction points give DIFFERENT hidden states (as expected!)")
else:
    print("⚠ The two extraction points give the SAME hidden state")

NameError: name 'lre' is not defined

## Test Case 5: Multi-Token Subject

In [30]:
# Test with a multi-token subject
subject_multi = "New York"
prompt_multi = template_simple.format(subject_multi)

print(f"Prompt: '{prompt_multi}'")
print(f"Subject: '{subject_multi}'")
print("="*80)

# Tokenize
prompt_multi_tokens = tokenizer.encode(prompt_multi, add_special_tokens=False)
subject_multi_tokens = tokenizer.encode(subject_multi, add_special_tokens=False)

print(f"\nFull prompt tokens: {prompt_multi_tokens}")
print(f"Subject tokens: {subject_multi_tokens}")
print(f"Subject has {len(subject_multi_tokens)} tokens")

# Find subject
subject_multi_len = len(subject_multi_tokens)
for i in range(len(prompt_multi_tokens) - subject_multi_len + 1):
    if prompt_multi_tokens[i:i+subject_multi_len] == subject_multi_tokens:
        subject_multi_start = i
        subject_multi_last = i + subject_multi_len - 1
        print(f"\n✓ Subject spans positions {subject_multi_start} to {subject_multi_last}")
        print(f"→ Extracting from position {subject_multi_last} (last token of subject)")
        print(f"→ Token text: '{tokenizer.decode([prompt_multi_tokens[subject_multi_last]])}'")
        break

Prompt: 'New York is commonly associated with'
Subject: 'New York'

Full prompt tokens: [3564, 4261, 374, 16626, 5815, 448]
Subject tokens: [3564, 4261]
Subject has 2 tokens

✓ Subject spans positions 0 to 1
→ Extracting from position 1 (last token of subject)
→ Token text: ' York'


# Summary

print("="*80)
print("COMPLETE WORKFLOW SUMMARY")
print("="*80)

print("\n1. OFFSET MAPPING APPROACH (find_subject_last_token):")
print("   ✓ Tokenize prompt with return_offsets_mapping=True")
print("   ✓ Get character positions: [(0,4), (4,7), ...]")
print("   ✓ Find subject in string: prompt.rfind(subject)")
print("   ✓ Map character positions to token indices")
print("   ✓ Return last token that overlaps with subject")

print("\n2. WHY IT WORKS:")
print("   ✓ No tokenization mismatch (uses actual tokenized prompt)")
print("   ✓ Context-aware (subject already in its final tokenized form)")
print("   ✓ Finds LAST occurrence (critical for few-shot templates)")
print("   ✓ Handles multi-token subjects correctly")

print("\n3. HOW train_lre() USES IT:")
print("   For each training sample:")
print("     a. Format prompt: template.format(subject)")
print("     b. Find position: find_subject_last_token(prompt, subject)")
print("     c. Extract hidden state: get_hidden_state(prompt, layer, position)")
print("     d. Get target: embedding(object)")
print("     e. Fit: W * h_subject + b ≈ embedding(object)")

print("\n4. HOW evaluate() USES IT:")
print("   For each test sample:")
print("     a. Format prompt: template.format(subject)")
print("     b. Find position: find_subject_last_token(prompt, subject)")
print("     c. Extract hidden state: get_hidden_state(prompt, layer, position)")
print("     d. Predict: z_pred = W * h_subject + b")
print("     e. Decode: argmax(z_pred @ embedding_matrix.T)")
print("     f. Compare with expected object")

print("\n5. KEY ADVANTAGE OVER TEMPLATE SPLITTING:")
print("   OLD: Split template at {}, count tokens before/after")
print("     → Can fail if tokenizer treats parts differently")
print("   NEW: Use character offsets from actual tokenization")
print("     → Always correct because it uses the real tokens")

print("\n" + "="*80)
print("This notebook demonstrates the EXACT logic used in lre/lre.py!")
print("="*80)