# Look at tokens before response

In [1]:
import torch
import os
import sys
import numpy as np
import pandas as pd

sys.path.append('.')
sys.path.append('..')

In [2]:
model_name = "qwen-3-32b"
layer = 32
base_dir = f"/workspace/{model_name}"

In [3]:
from utils.internals import ProbingModel, ConversationEncoder, ActivationExtractor

pm = ProbingModel("Qwen/Qwen3-32B")
encoder = ConversationEncoder(pm.tokenizer, model_name="Qwen/Qwen3-32B")
extractor = ActivationExtractor(pm, encoder)

config.json:   0%|          | 0.00/728 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 17 files:   0%|          | 0/17 [00:00<?, ?it/s]



model-00002-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00003-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00005-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00007-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00006-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00004-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00008-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00001-of-00017.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]



model-00015-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00014-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00012-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00010-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00009-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00013-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00011-of-00017.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

RuntimeError: Data processing error: CAS service error : IO Error: No space left on device (os error 28)

In [None]:
import json

resp_path = f"{base_dir}/roles_240/responses/aberration.jsonl"

with open(resp_path) as f:
    data = json.loads(f.readline())
    conversation = data['conversation']

print(f"System: {conversation[0]['content'][:80]}...")
print(f"User: {conversation[1]['content']}")
print(f"Assistant: {conversation[2]['content'][:80]}...")

In [None]:
# Tokenize full conversation
full_ids = pm.tokenizer.apply_chat_template(
    conversation, tokenize=True, add_generation_prompt=False
)

# Get response indices (where assistant content starts)
response_indices = encoder.response_indices(conversation, per_turn=True)
first_response_start = response_indices[0][0]

# Get 7 tokens before response
n_tokens_before = 7
before_indices = list(range(first_response_start - n_tokens_before, first_response_start))

# Display what these tokens are
print("7 tokens before response:")
for idx in before_indices:
    token_id = full_ids[idx]
    token_str = pm.tokenizer.decode([token_id])
    print(f"  [{idx}] id={token_id:6d}  {repr(token_str)}")

In [None]:
# Get activations for full conversation at ALL layers
n_layers = len(pm.get_layers())
print(f"Model has {n_layers} layers")

# Extract activations at all layers for the full sequence
all_layer_activations = []  # Will be list of (seq_len, hidden_size) per layer

for layer_idx in range(n_layers):
    layer_acts = extractor.full_conversation(conversation, layer=layer_idx)
    all_layer_activations.append(layer_acts)

# Stack into (n_layers, seq_len, hidden_size)
all_activations = torch.stack(all_layer_activations)
print(f"All activations shape: {all_activations.shape}")

# Extract activations at the 7 positions before response
before_activations = all_activations[:, before_indices, :]  # Shape: (n_layers, 7, hidden_size)
print(f"Before activations shape: {before_activations.shape}")

In [None]:
# Load saved response activation
act_path = f"{base_dir}/roles_240/response_activations/aberration.pt"
saved_acts = torch.load(act_path, weights_only=False)
print(f"Keys: {list(saved_acts.keys())[:5]}")

# Get first response activation (all layers)
first_key = list(saved_acts.keys())[0]
response_act = saved_acts[first_key]  # Shape: (n_layers, hidden_size)
print(f"Response activation shape: {response_act.shape}")

In [None]:
import torch.nn.functional as F

n_layers = before_activations.shape[0]
n_tokens = before_activations.shape[1]

# Compute cosine similarity: (n_layers, 7) - each token vs response at each layer
cos_sims = torch.zeros(n_layers, n_tokens)

for layer_idx in range(n_layers):
    # Get activations at this layer
    before_layer = before_activations[layer_idx]  # (7, hidden_size)
    response_layer = response_act[layer_idx]  # (hidden_size,)

    # Normalize
    before_norm = F.normalize(before_layer.float(), dim=1)  # (7, hidden_size)
    response_norm = F.normalize(response_layer.float().unsqueeze(0), dim=1)  # (1, hidden_size)

    # Cosine similarity
    cos_sims[layer_idx] = (before_norm @ response_norm.T).squeeze()

print(f"Cosine similarities shape: {cos_sims.shape}")

In [None]:
import matplotlib.pyplot as plt

token_labels = [pm.tokenizer.decode([full_ids[idx]]) for idx in before_indices]
# Clean up labels for display
token_labels = [repr(t) for t in token_labels]

plt.figure(figsize=(12, 10))
plt.imshow(cos_sims.numpy(), aspect='auto', cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar(label='Cosine Similarity')
plt.xlabel('Token Position Before Response')
plt.ylabel('Layer')
plt.xticks(range(n_tokens), token_labels, rotation=45, ha='right')
plt.title('Cosine Similarity: Tokens Before Response vs Mean Response Activation')
plt.tight_layout()
plt.show()

In [None]:
# Mean cosine similarity per token (across all layers)
mean_per_token = cos_sims.mean(dim=0)
print("Mean cosine sim per token (across all layers):")
for i, (idx, sim) in enumerate(zip(before_indices, mean_per_token)):
    token_str = pm.tokenizer.decode([full_ids[idx]])
    print(f"  {repr(token_str):20s}  mean_cos_sim = {sim.item():.4f}")

# Mean cosine similarity per layer (across all 7 tokens)
mean_per_layer = cos_sims.mean(dim=1)
print(f"\nLayer with highest mean similarity: {mean_per_layer.argmax().item()}")
print(f"Layer with lowest mean similarity: {mean_per_layer.argmin().item()}")