In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np

## Setup: Load Model and Tokenizer

We'll use GPT-2 for this demo since it's universally accessible and doesn't require authentication.

In [None]:
model_name = "gpt2"  # Using GPT-2 for universal compatibility
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

print("✓ Model loaded successfully")

Using device: mps
Loading gpt2...
✓ Model loaded successfully


## Create a Few-Shot Prompt Template

We'll create a prompt with:
- 3 few-shot examples
- The test subject "wine" at the end

Note: "wine" appears multiple times in the prompt!

In [3]:
# Few-shot examples
examples = [
    ("water", "liquid"),
    ("gold", "metal"),
    ("wine", "red"),  # Note: "wine" appears here in examples!
]

# Template
template = "{} is commonly associated with"

# Test subject
test_subject = "wine"

# Build the full few-shot prompt
few_shot_part = "\n".join([
    template.format(subj) + f" {obj}."
    for subj, obj in examples
])

test_prompt_part = template.format(test_subject)

full_prompt = few_shot_part + "\n" + test_prompt_part

print("Full Prompt:")
print("="*80)
print(full_prompt)
print("="*80)
print(f"\nTest subject: '{test_subject}'")
print(f"Note: '{test_subject}' appears TWICE in this prompt!")
print(f"  - Once in the examples (line 3)")
print(f"  - Once in the test prompt (line 4)")
print(f"\nWe want to extract from the LAST occurrence (line 4)!")

Full Prompt:
water is commonly associated with liquid.
gold is commonly associated with metal.
wine is commonly associated with red.
wine is commonly associated with

Test subject: 'wine'
Note: 'wine' appears TWICE in this prompt!
  - Once in the examples (line 3)
  - Once in the test prompt (line 4)

We want to extract from the LAST occurrence (line 4)!


## Step 1: Tokenize the Prompt

Standard tokenization shows us the token IDs.

In [4]:
# Standard tokenization
prompt_tokens = tokenizer.encode(full_prompt, add_special_tokens=False)

print(f"Token IDs: {prompt_tokens}")
print(f"Total tokens: {len(prompt_tokens)}")
print(f"\nToken-by-token breakdown:")
print("-"*80)

for i, token_id in enumerate(prompt_tokens):
    token_text = tokenizer.decode([token_id])
    print(f"Token {i:2d}: ID {token_id:6d} → '{token_text}'")

Token IDs: [12987, 374, 16626, 5815, 448, 14473, 624, 34537, 374, 16626, 5815, 448, 9317, 624, 71437, 374, 16626, 5815, 448, 2518, 624, 71437, 374, 16626, 5815, 448]
Total tokens: 26

Token-by-token breakdown:
--------------------------------------------------------------------------------
Token  0: ID  12987 → 'water'
Token  1: ID    374 → ' is'
Token  2: ID  16626 → ' commonly'
Token  3: ID   5815 → ' associated'
Token  4: ID    448 → ' with'
Token  5: ID  14473 → ' liquid'
Token  6: ID    624 → '.
'
Token  7: ID  34537 → 'gold'
Token  8: ID    374 → ' is'
Token  9: ID  16626 → ' commonly'
Token 10: ID   5815 → ' associated'
Token 11: ID    448 → ' with'
Token 12: ID   9317 → ' metal'
Token 13: ID    624 → '.
'
Token 14: ID  71437 → 'wine'
Token 15: ID    374 → ' is'
Token 16: ID  16626 → ' commonly'
Token 17: ID   5815 → ' associated'
Token 18: ID    448 → ' with'
Token 19: ID   2518 → ' red'
Token 20: ID    624 → '.
'
Token 21: ID  71437 → 'wine'
Token 22: ID    374 → ' is'
Token 2

## Step 2: Tokenize with Offset Mapping

**This is the key!** Offset mapping tells us the character positions for each token.

In [5]:
# Tokenize with offset mapping
inputs_with_offsets = tokenizer(
    full_prompt, 
    return_tensors="pt", 
    return_offsets_mapping=True
)

offset_mapping = inputs_with_offsets["offset_mapping"][0]  # Shape: (seq_len, 2)

print("Offset Mapping (character positions for each token):")
print("="*80)
print(f"Format: Token [start_char, end_char) → 'text'")
print("-"*80)

for i, (start, end) in enumerate(offset_mapping):
    token_text = tokenizer.decode([prompt_tokens[i]])
    chars = full_prompt[start:end] if start < len(full_prompt) else ""
    print(f"Token {i:2d}: [{start:3d}, {end:3d}) → '{token_text}'  (chars: '{chars}')")

Offset Mapping (character positions for each token):
Format: Token [start_char, end_char) → 'text'
--------------------------------------------------------------------------------
Token  0: [  0,   5) → 'water'  (chars: 'water')
Token  1: [  5,   8) → ' is'  (chars: ' is')
Token  2: [  8,  17) → ' commonly'  (chars: ' commonly')
Token  3: [ 17,  28) → ' associated'  (chars: ' associated')
Token  4: [ 28,  33) → ' with'  (chars: ' with')
Token  5: [ 33,  40) → ' liquid'  (chars: ' liquid')
Token  6: [ 40,  42) → '.
'  (chars: '.
')
Token  7: [ 42,  46) → 'gold'  (chars: 'gold')
Token  8: [ 46,  49) → ' is'  (chars: ' is')
Token  9: [ 49,  58) → ' commonly'  (chars: ' commonly')
Token 10: [ 58,  69) → ' associated'  (chars: ' associated')
Token 11: [ 69,  74) → ' with'  (chars: ' with')
Token 12: [ 74,  80) → ' metal'  (chars: ' metal')
Token 13: [ 80,  82) → '.
'  (chars: '.
')
Token 14: [ 82,  86) → 'wine'  (chars: 'wine')
Token 15: [ 86,  89) → ' is'  (chars: ' is')
Token 16: [ 89,  9

## Step 3: Find Subject in the String

Use `rfind()` to find the **LAST** occurrence of the subject.

In [6]:
# Find the subject in the prompt string
print(f"Looking for subject: '{test_subject}'")
print("="*80)

# Find ALL occurrences
print("\nAll occurrences of 'wine' in the prompt:")
start_pos = 0
occurrence_num = 1
while True:
    pos = full_prompt.find(test_subject, start_pos)
    if pos == -1:
        break
    end_pos = pos + len(test_subject)
    # Show context
    context_start = max(0, pos - 20)
    context_end = min(len(full_prompt), end_pos + 20)
    context = full_prompt[context_start:context_end]
    print(f"  Occurrence {occurrence_num}: chars [{pos}, {end_pos}) → ...{context}...")
    start_pos = pos + 1
    occurrence_num += 1

# Find the LAST occurrence
subject_start_char = full_prompt.rfind(test_subject)
subject_end_char = subject_start_char + len(test_subject)

print(f"\n→ Using rfind() to get LAST occurrence:")
print(f"  Character range: [{subject_start_char}, {subject_end_char})")
print(f"  Text: '{full_prompt[subject_start_char:subject_end_char]}'")

# Show context
context_start = max(0, subject_start_char - 30)
context_end = min(len(full_prompt), subject_end_char + 30)
context = full_prompt[context_start:context_end]
marker_pos = subject_start_char - context_start

print(f"\nContext around the subject:")
print(f"  '{context}'")
print(f"  {' ' * marker_pos}{'↑' * len(test_subject)}")
print(f"  {' ' * marker_pos}Subject here")

Looking for subject: 'wine'

All occurrences of 'wine' in the prompt:
  Occurrence 1: chars [82, 86) → ...ociated with metal.
wine is commonly associa...
  Occurrence 2: chars [120, 124) → ...ssociated with red.
wine is commonly associa...

→ Using rfind() to get LAST occurrence:
  Character range: [120, 124)
  Text: 'wine'

Context around the subject:
  'commonly associated with red.
wine is commonly associated with'
                                ↑↑↑↑
                                Subject here


## Step 4: Map Character Positions to Token Indices

Find which tokens overlap with the subject's character range.

In [7]:
print("Finding tokens that overlap with subject:")
print("="*80)
print(f"Subject character range: [{subject_start_char}, {subject_end_char})")
print("\nChecking each token:")
print("-"*80)

subject_last_token = None
overlapping_tokens = []

for token_idx, (start, end) in enumerate(offset_mapping):
    # Check if this token overlaps with the subject's character range
    overlaps = start < subject_end_char and end > subject_start_char
    
    token_text = tokenizer.decode([prompt_tokens[token_idx]])
    
    if overlaps:
        subject_last_token = token_idx  # Keep updating - we want the LAST one
        overlapping_tokens.append(token_idx)
        print(f"✓ Token {token_idx:2d}: [{start:3d}, {end:3d}) overlaps! → '{token_text}'")
    else:
        # Only show tokens near the subject
        if abs(start - subject_start_char) < 50:
            print(f"  Token {token_idx:2d}: [{start:3d}, {end:3d}) no overlap → '{token_text}'")

print("-"*80)
print(f"\nOverlapping tokens: {overlapping_tokens}")
print(f"→ Last overlapping token: {subject_last_token}")
print(f"→ Token text: '{tokenizer.decode([prompt_tokens[subject_last_token]])}'")

Finding tokens that overlap with subject:
Subject character range: [120, 124)

Checking each token:
--------------------------------------------------------------------------------
  Token 12: [ 74,  80) no overlap → ' metal'
  Token 13: [ 80,  82) no overlap → '.
'
  Token 14: [ 82,  86) no overlap → 'wine'
  Token 15: [ 86,  89) no overlap → ' is'
  Token 16: [ 89,  98) no overlap → ' commonly'
  Token 17: [ 98, 109) no overlap → ' associated'
  Token 18: [109, 114) no overlap → ' with'
  Token 19: [114, 118) no overlap → ' red'
  Token 20: [118, 120) no overlap → '.
'
✓ Token 21: [120, 124) overlaps! → 'wine'
  Token 22: [124, 127) no overlap → ' is'
  Token 23: [127, 136) no overlap → ' commonly'
  Token 24: [136, 147) no overlap → ' associated'
  Token 25: [147, 152) no overlap → ' with'
--------------------------------------------------------------------------------

Overlapping tokens: [21]
→ Last overlapping token: 21
→ Token text: 'wine'


## Step 5: Visualize the Extraction Point

Show tokens around the extraction point with context.

In [8]:
print("Context around the extraction point:")
print("="*80)

# Show tokens around the extraction point
start_idx = max(0, subject_last_token - 5)
end_idx = min(len(prompt_tokens), subject_last_token + 6)

print(f"\nTokens {start_idx} to {end_idx-1}:")
print("-"*80)

for i in range(start_idx, end_idx):
    token_text = tokenizer.decode([prompt_tokens[i]])
    char_range = offset_mapping[i]
    
    if i == subject_last_token:
        marker = ">>>>"
        highlight = " *** EXTRACT HERE ***"
    elif i in overlapping_tokens:
        marker = "  →"
        highlight = " (part of subject)"
    else:
        marker = "    "
        highlight = ""
    
    print(f"{marker} Token {i:2d}: [{char_range[0]:3d}, {char_range[1]:3d}) → '{token_text}'{highlight}")

print("-"*80)
print(f"\n✓ We extract the hidden state from Token {subject_last_token}")
print(f"  This is the LAST token of the subject 'wine' in the test prompt!")

Context around the extraction point:

Tokens 16 to 25:
--------------------------------------------------------------------------------
     Token 16: [ 89,  98) → ' commonly'
     Token 17: [ 98, 109) → ' associated'
     Token 18: [109, 114) → ' with'
     Token 19: [114, 118) → ' red'
     Token 20: [118, 120) → '.
'
>>>> Token 21: [120, 124) → 'wine' *** EXTRACT HERE ***
     Token 22: [124, 127) → ' is'
     Token 23: [127, 136) → ' commonly'
     Token 24: [136, 147) → ' associated'
     Token 25: [147, 152) → ' with'
--------------------------------------------------------------------------------

✓ We extract the hidden state from Token 21
  This is the LAST token of the subject 'wine' in the test prompt!


## Step 6: Verify with LREModel.find_subject_last_token()

Confirm that the `find_subject_last_token()` method gives the same result.

In [9]:
from lre import LREModel

# Initialize LRE
lre = LREModel(model, tokenizer, device)

# Use the method
lre_position = lre.find_subject_last_token(full_prompt, test_subject)

print("Verification:")
print("="*80)
print(f"Manual calculation:        Token {subject_last_token}")
print(f"LREModel method result:    Token {lre_position}")
print()

if lre_position == subject_last_token:
    print("✓ ✓ ✓ PERFECT MATCH! ✓ ✓ ✓")
    print("\nThe find_subject_last_token() method works correctly!")
else:
    print(f"✗ MISMATCH! Difference: {abs(lre_position - subject_last_token)} tokens")

Loading Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attenti

HFValidationError: Repo id must use alphanumeric chars, '-', '_' or '.'. The name cannot start or end with '-' or '.' and the maximum length is 96: 'Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): Qwen3RMSNorm((1024,), eps=1e-06)
    (rotary_emb): Qwen3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)'.

## Summary: How It All Works Together

### The Problem:
- Subject "wine" appears MULTIPLE times in few-shot prompts
- We need to extract from the LAST occurrence (the test prompt, not examples)
- Separately tokenizing the subject can give different tokens due to context

### The Solution (Offset Mapping):
1. **Tokenize with offsets**: Get character positions for each token
2. **Find subject in string**: Use `rfind()` to get LAST occurrence
3. **Map chars to tokens**: Find which tokens overlap with subject
4. **Return last token**: The last overlapping token is what we want!

### Why It's Robust:
- ✓ Uses the ACTUAL tokenized prompt (no tokenization mismatch)
- ✓ Character positions are ground truth (unambiguous)
- ✓ Finds LAST occurrence automatically (critical for few-shot)
- ✓ Handles multi-token subjects correctly (maps all overlapping tokens)
- ✓ Context-aware (subject is already in its final tokenized form)

### Used by:
- `train_lre()`: Extracts subject hidden states for training
- `evaluate()`: Extracts subject hidden states for prediction
- Both use `extract_from="subject_end"` by default

In [11]:
from data_utils import create_loo_prompt

# LOO (Leave-One-Out) Template
# In LOO, we REMOVE the test example from the few-shot examples

print("Original Few-Shot Examples:")
print("="*80)
for subj, obj in examples:
    print(f"  {subj} → {obj}")

print(f"\nTest subject: {test_subject}")
print(f"Test subject matches example: {test_subject in [subj for subj, obj in examples]}")

# Create LOO prompt using data_utils
loo_full_prompt = create_loo_prompt(examples, test_subject, template)

# Extract LOO examples for display
loo_examples = [(subj, obj) for subj, obj in examples if subj != test_subject]

print(f"\nLOO Examples (removed '{test_subject}'):")
print("="*80)
for subj, obj in loo_examples:
    print(f"  {subj} → {obj}")

print("\nLOO Full Prompt:")
print("="*80)
print(loo_full_prompt)
print("="*80)

print(f"\nKey Difference:")
print(f"  Original prompt: '{test_subject}' appears {full_prompt.count(test_subject)} times")
print(f"  LOO prompt:      '{test_subject}' appears {loo_full_prompt.count(test_subject)} time(s)")
print(f"\n✓ LOO ensures the model only sees '{test_subject}' in the test position!")
print(f"  This prevents the model from simply copying from the examples.")


ImportError: cannot import name 'create_loo_prompt' from 'data_utils' (/Users/paulnguyen/lre-experiment/data_utils.py)