# Role Binding Probe (Speaker Roles)

This notebook builds a two-speaker transcript, extracts activations from Llama 3.1 8B, and runs:
- Role difference vectors (mean Alice minus mean Bob)
- PCA/SVD on speaker-centered activations
- Swap tests to check binding vs surface cues


In [8]:
!pip install -U "transformers>=4.40.0" "accelerate>=0.30.0"

Defaulting to user installation because normal site-packages is not writeable
Collecting transformers>=4.40.0
  Downloading transformers-5.1.0-py3-none-any.whl.metadata (31 kB)
Collecting huggingface-hub<2.0,>=1.3.0 (from transformers>=4.40.0)
  Downloading huggingface_hub-1.4.0-py3-none-any.whl.metadata (13 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers>=4.40.0)
  Downloading tokenizers-0.22.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Collecting typer-slim (from transformers>=4.40.0)
  Downloading typer_slim-0.21.1-py3-none-any.whl.metadata (16 kB)
Collecting hf-xet<2.0.0,>=1.2.0 (from huggingface-hub<2.0,>=1.3.0->transformers>=4.40.0)
  Downloading hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Downloading transformers-5.1.0-py3-none-any.whl (10.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hDownlo

In [9]:
# Core imports
import os
import re
import numpy as np
import torch
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA

from transformers import AutoTokenizer, AutoModelForCausalLM

np.random.seed(7)
torch.manual_seed(7)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cpu


In [10]:
# Optional dependency check (avoid installing into /home on clusters)
import importlib.util
print("accelerate available:", importlib.util.find_spec("accelerate") is not None)

import os

# Fix cluster home quota: redirect caches (edit path if needed)
os.environ.setdefault("ROLE_REP_CACHE_DIR", f"/projects/JeFeSpace/KLM/cache/{os.environ.get('USER','user')}/role-rep")
os.environ.setdefault("HF_HOME", os.path.join(os.environ["ROLE_REP_CACHE_DIR"], "hf"))
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", os.path.join(os.environ["HF_HOME"], "hub"))
os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(os.environ["HF_HOME"], "transformers"))
os.environ.setdefault("TORCH_HOME", os.path.join(os.environ["ROLE_REP_CACHE_DIR"], "torch"))
os.environ.setdefault("XDG_CACHE_HOME", os.path.join(os.environ["ROLE_REP_CACHE_DIR"], "xdg"))

# Avoid widget/progress-bar issues + HF xet
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["DISABLE_TQDM"] = "1"
os.environ["HF_HUB_DISABLE_XET"] = "1"

import getpass
if not (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")):
    os.environ["HF_TOKEN"] = getpass.getpass("HF token: ")
    os.environ["HUGGINGFACE_HUB_TOKEN"] = os.environ["HF_TOKEN"]

accelerate available: True


In [11]:
# Two-speaker multi-turn transcript (edit freely)
TRANSCRIPT = """Alice: Hey Bob, before we jump in, I wanted to revisit the design proposal.
Bob: Sure, I skimmed it last night; the latency targets look ambitious.
Alice: The client asked for sub-200ms p95; I think batching can get us there.
Bob: Batching helps, but the cache invalidation could become tricky.
Alice: We can constrain invalidation to product-level keys instead of per-user.
Bob: That might reduce precision, though; would that impact personalization?
Alice: Some, but we can re-rank on the client side for the top 10 results.
Bob: Okay, so you are proposing a hybrid: coarse cache, fine rerank.
Alice: Exactly. And we should log enough to measure drift each week.
Bob: Logging is fine, but data retention policy caps at 30 days.
Alice: Right, I can summarize weekly aggregates and delete raw events.
Bob: Great. Also, the new API endpoint needs a version bump.
Alice: v3 seems reasonable; we can keep v2 for a deprecation window.
Bob: Then we need a migration guide; I can draft it.
Alice: Thanks. Another point: the search index rebuild takes 6 hours.
Bob: Maybe we can parallelize by shard and compress the postings lists.
Alice: If we compress too much, we might slow decoding at query time.
Bob: True; we could trade storage for CPU if latency budget allows.
Alice: We'll benchmark both. Also, what about adding synonyms?
Bob: Synonyms help recall, but they increase false positives.
Alice: We'll tune the threshold and evaluate per-category.
Bob: Sounds good. On another note, QA reported flaky tests.
Alice: I saw that; I think the mock clock isn't resetting in CI.
Bob: I can isolate those tests and add a fixture.
Alice: Appreciate it. Lastly, are we aligned on the rollout plan?
Bob: Staged rollout: internal, then 5% external, then 50%.
Alice: And we monitor error rates and rollback if p95 spikes.
Bob: Yes. I'll write the runbook.
Alice: Great, I'll update the proposal and send it today.
Bob: Perfect; I will review as soon as it lands.
Alice: Thanks, Bob.
Bob: Thanks, Alice."""

In [12]:
# Model load (requires HF token with Llama access)
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B"

from transformers import AutoConfig

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

dtype = torch.bfloat16 if device == "cuda" else torch.float32

# Workaround for older Transformers that expect rope_scaling = {type, factor}
config = AutoConfig.from_pretrained(MODEL_ID)
if isinstance(getattr(config, "rope_scaling", None), dict):
    rs = config.rope_scaling
    if "type" not in rs or "factor" not in rs or len(rs) != 2:
        config.rope_scaling = {"type": "linear", "factor": float(rs.get("factor", 1.0))}

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    config=config,
    torch_dtype=dtype,
    device_map="auto" if device == "cuda" else None,
)

# If special tokens were added, resize embeddings
if len(tokenizer) > model.get_input_embeddings().num_embeddings:
    model.resize_token_embeddings(len(tokenizer))

model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


ValueError: `rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {'factor': 8.0, 'low_freq_factor': 1.0, 'high_freq_factor': 4.0, 'original_max_position_embeddings': 8192, 'rope_type': 'llama3'}

In [None]:
def parse_turns(text, speakers=("Alice", "Bob")):
    turns = []
    cursor = 0
    pattern = re.compile(rf"^({'|'.join(speakers)}):\\s*")
    for line in text.splitlines():
        line_start = cursor
        line_end = cursor + len(line)
        match = pattern.match(line)
        if match:
            speaker = match.group(1)
            content_start = line_start + match.end()
        else:
            speaker = "UNKNOWN"
            content_start = line_start
        turns.append(
            {
                "speaker": speaker,
                "line_start": line_start,
                "line_end": line_end,
                "content_start": content_start,
            }
        )
        cursor = line_end + 1
    return turns


def tokenize_with_offsets(text):
    enc = tokenizer(
        text,
        return_offsets_mapping=True,
        add_special_tokens=False,
    )
    input_ids = torch.tensor([enc["input_ids"]])
    offsets = enc["offset_mapping"]
    return input_ids, offsets


def label_tokens_by_turn(offsets, turns):
    token_turn_idx = []
    turn_idx = 0
    for start, end in offsets:
        while turn_idx < len(turns) and start >= turns[turn_idx]["line_end"]:
            turn_idx += 1
        if turn_idx >= len(turns):
            token_turn_idx.append(None)
        else:
            token_turn_idx.append(turn_idx)
    token_speakers = [turns[i]["speaker"] if i is not None else "UNKNOWN" for i in token_turn_idx]
    return token_turn_idx, token_speakers


@torch.no_grad()
def extract_hidden_states(text):
    input_ids, offsets = tokenize_with_offsets(text)
    input_ids = input_ids.to(model.device)
    outputs = model(input_ids, output_hidden_states=True)
    # hidden_states: tuple (embeddings + each layer)
    hidden_states = [hs[0].float().cpu().numpy() for hs in outputs.hidden_states]
    return hidden_states, offsets

In [None]:
# Extract activations for the main transcript
turns = parse_turns(TRANSCRIPT)
hidden_states, offsets = extract_hidden_states(TRANSCRIPT)
token_turn_idx, token_speakers = label_tokens_by_turn(offsets, turns)

print("tokens:", len(token_speakers), "layers:", len(hidden_states))

In [None]:
# Role difference vectors per layer (mean Alice minus mean Bob)
def compute_role_diff_vectors(hidden_states, token_speakers):
    role_diffs = []
    for hs in hidden_states:
        hs = np.asarray(hs)
        mask_a = np.array([s == "Alice" for s in token_speakers])
        mask_b = np.array([s == "Bob" for s in token_speakers])
        mean_a = hs[mask_a].mean(axis=0)
        mean_b = hs[mask_b].mean(axis=0)
        role_diffs.append(mean_a - mean_b)
    return role_diffs


role_diffs = compute_role_diff_vectors(hidden_states, token_speakers)
role_norms = [np.linalg.norm(v) for v in role_diffs]

plt.figure(figsize=(6, 3))
plt.plot(role_norms, marker="o")
plt.title("Role diff norm by layer")
plt.xlabel("Layer (0 = embeddings)")
plt.ylabel("L2 norm")
plt.show()

In [None]:
# PCA on speaker-centered activations (choose a layer)
LAYER_FOR_PCA = -1

hs = np.asarray(hidden_states[LAYER_FOR_PCA])
mask_a = np.array([s == "Alice" for s in token_speakers])
mask_b = np.array([s == "Bob" for s in token_speakers])

mean_a = hs[mask_a].mean(axis=0)
mean_b = hs[mask_b].mean(axis=0)

# Speaker-centered
hs_centered = hs.copy()
hs_centered[mask_a] -= mean_a
hs_centered[mask_b] -= mean_b

pca = PCA(n_components=2)
proj_raw = pca.fit_transform(hs)

pca_centered = PCA(n_components=2)
proj_centered = pca_centered.fit_transform(hs_centered)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.scatter(proj_raw[:, 0], proj_raw[:, 1], c=mask_a, s=8, alpha=0.6)
plt.title("PCA (raw)")

plt.subplot(1, 2, 2)
plt.scatter(proj_centered[:, 0], proj_centered[:, 1], c=mask_a, s=8, alpha=0.6)
plt.title("PCA (speaker-centered)")
plt.show()

print("Raw PCA explained variance:", pca.explained_variance_ratio_)
print("Centered PCA explained variance:", pca_centered.explained_variance_ratio_)

In [None]:
# Swap test: flip speaker tags but keep content order
def swap_speaker_tags(text):
    tmp = text.replace("Alice:", "__TEMP__")
    tmp = tmp.replace("Bob:", "Alice:")
    return tmp.replace("__TEMP__", "Bob:")


SWAPPED = swap_speaker_tags(TRANSCRIPT)

# Role vector from original transcript (use chosen layer)
layer_idx = LAYER_FOR_PCA
role_vec = role_diffs[layer_idx]
role_vec = role_vec / (np.linalg.norm(role_vec) + 1e-8)

# Hidden states for swapped transcript
swapped_turns = parse_turns(SWAPPED)
swapped_states, swapped_offsets = extract_hidden_states(SWAPPED)
_, swapped_tag_speakers = label_tokens_by_turn(swapped_offsets, swapped_turns)

# Content speakers come from original turn order
content_speakers = [turns[i]["speaker"] if i is not None else "UNKNOWN" for i in token_turn_idx]

swapped_hs = np.asarray(swapped_states[layer_idx])
proj_swapped = swapped_hs @ role_vec

def mean_projection_by_label(proj, labels, label_name):
    labels = np.array(labels)
    mean_a = proj[labels == "Alice"].mean()
    mean_b = proj[labels == "Bob"].mean()
    print(f"{label_name} mean projection - Alice: {mean_a:.4f}, Bob: {mean_b:.4f}")


mean_projection_by_label(proj_swapped, swapped_tag_speakers, "Swapped tag")
mean_projection_by_label(proj_swapped, content_speakers, "Content label")