# Layer-wise UA Analysis: Finding the Point of Divergence

**Purpose:** Test where in the transformer the embedding-output correlation changes.

**Research Question:** Does the correlation r(UA, Output) change across layers?

**Hypotheses:**
1. **Complexity Penalty:** High UA = high output entropy
2. **Late-Stage Corruption:** r flips from positive to negative in late layers
3. **Superposition:** High UA in ambiguous pairs

**Based on:** KV Cache Paper (arXiv:2511.12752) showing early layers encode "topic trajectory" while late layers encode "local discourse."

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03  
**Model:** Pythia-6.9B (strongest effect: r = -0.87)

## Setup

In [None]:
# Install dependencies
!pip install -q transformers accelerate torch numpy scipy matplotlib

In [None]:
import os
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Load Model and Dataset

In [None]:
# Model configuration
MODEL_NAME = "EleutherAI/pythia-6.9b"
MODEL_DISPLAY = "Pythia-6.9B"

print(f"Loading {MODEL_DISPLAY}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True  # IMPORTANT: We need all layer outputs
)

print(f"Model loaded on: {model.device}")
print(f"Number of layers: {model.config.num_hidden_layers}")

In [None]:
# Load dataset
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

total_pairs = sum(len(cat['pairs']) for cat in DATASET.values())
print(f"Loaded {total_pairs} statement pairs across {len(DATASET)} categories")

## Core Functions: Multiple Embedding Strategies

In [None]:
def get_all_layer_embeddings(text: str, model, tokenizer) -> Dict[str, np.ndarray]:
    """
    Extract embeddings using multiple strategies:
    - mean_pooled: Mean over all tokens (skip BOS), final layer
    - last_token: Last non-pad token, final layer
    - layer_X: Mean-pooled embedding from layer X
    
    Returns dict with all embeddings.
    """
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states  # Tuple of (n_layers + 1) tensors
        
    # Number of layers (excluding embedding layer)
    n_layers = len(hidden_states) - 1
    
    # Attention mask for last-token detection
    attention_mask = inputs['attention_mask']
    last_idx = (attention_mask[0].sum() - 1).item()
    
    embeddings = {}
    
    # 1. Mean-pooled, final layer (current method)
    final_hidden = hidden_states[-1]
    embeddings['mean_pooled'] = final_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    # 2. Last-token, final layer (robust against EOS/Pad)
    embeddings['last_token'] = final_hidden[0, last_idx, :].cpu().numpy().astype(np.float32)
    
    # 3. First-token (BOS), final layer
    embeddings['first_token'] = final_hidden[0, 0, :].cpu().numpy().astype(np.float32)
    
    # 4. Layer-by-layer embeddings (every 4th layer)
    for layer_idx in range(0, n_layers + 1, 4):
        layer_hidden = hidden_states[layer_idx]
        # Mean-pooled for this layer
        embeddings[f'layer_{layer_idx}'] = layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    # Make sure we have the final layer
    if f'layer_{n_layers}' not in embeddings:
        embeddings[f'layer_{n_layers}'] = hidden_states[-1][0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    return embeddings


def get_output_preference(text_a: str, text_b: str, model, tokenizer) -> float:
    """
    Calculate output preference as NLL(B) - NLL(A).
    Positive = prefers A, Negative = prefers B.
    """
    def get_nll(text):
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            return outputs.loss.item()
    
    nll_a = get_nll(text_a)
    nll_b = get_nll(text_b)
    
    return nll_b - nll_a


def uniformity_score(embeddings: np.ndarray) -> float:
    """
    Calculate average pairwise cosine similarity (uniformity).
    """
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    kernel = normalized @ normalized.T
    n = kernel.shape[0]
    idx = np.triu_indices(n, k=1)
    return float(np.mean(kernel[idx]))

## Run Layer-wise Analysis

In [None]:
def run_layer_analysis(model, tokenizer, dataset: dict, verbose: bool = True) -> dict:
    """
    Run UA analysis for multiple embedding strategies.
    
    Returns dict with results for each strategy.
    """
    # Collect all embeddings and preferences
    all_embeddings_a = {}  # strategy -> list of embeddings
    all_embeddings_b = {}
    all_preferences = []
    category_indices = []  # Track which category each pair belongs to
    
    total_pairs = sum(len(cat["pairs"]) for cat in dataset.values())
    processed = 0
    
    for category_name, category_data in dataset.items():
        if verbose:
            print(f"\nProcessing: {category_name}")
        
        pairs = category_data["pairs"]
        
        for stmt_a, stmt_b in pairs:
            processed += 1
            if verbose and processed % 20 == 0:
                print(f"  [{processed:03d}/{total_pairs}]")
            
            # Get embeddings for all strategies
            embs_a = get_all_layer_embeddings(stmt_a, model, tokenizer)
            embs_b = get_all_layer_embeddings(stmt_b, model, tokenizer)
            
            # Store embeddings
            for strategy in embs_a.keys():
                if strategy not in all_embeddings_a:
                    all_embeddings_a[strategy] = []
                    all_embeddings_b[strategy] = []
                all_embeddings_a[strategy].append(embs_a[strategy])
                all_embeddings_b[strategy].append(embs_b[strategy])
            
            # Get output preference
            pref = get_output_preference(stmt_a, stmt_b, model, tokenizer)
            all_preferences.append(pref)
            category_indices.append(category_name)
    
    # Calculate UA and correlation for each strategy
    results = {}
    strategies = list(all_embeddings_a.keys())
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"RESULTS: Correlation by Embedding Strategy")
        print(f"{'='*60}")
    
    for strategy in strategies:
        embs_a = np.array(all_embeddings_a[strategy])
        embs_b = np.array(all_embeddings_b[strategy])
        
        # Calculate per-category UA
        category_uas = []
        category_prefs = []
        
        for cat_name in dataset.keys():
            cat_mask = [c == cat_name for c in category_indices]
            cat_embs_a = embs_a[cat_mask]
            cat_embs_b = embs_b[cat_mask]
            cat_prefs = np.array(all_preferences)[cat_mask]
            
            u_a = uniformity_score(cat_embs_a)
            u_b = uniformity_score(cat_embs_b)
            ua = u_a - u_b
            
            category_uas.append(ua)
            category_prefs.append(np.mean(cat_prefs))
        
        # Correlation
        r, p = stats.pearsonr(category_uas, category_prefs)
        
        results[strategy] = {
            'correlation': float(r),
            'p_value': float(p),
            'category_uas': category_uas,
            'category_prefs': category_prefs
        }
        
        if verbose:
            print(f"{strategy:<20} r = {r:+.3f}  (p = {p:.4f})")
    
    return results


print(f"Starting layer-wise analysis on {MODEL_DISPLAY}...")
print(f"This will take ~30-45 minutes on A100.\n")

results = run_layer_analysis(model, tokenizer, DATASET)

## Visualize: Layer Correlation Curve

In [None]:
# Extract layer-specific results
layer_results = [(k, v) for k, v in results.items() if k.startswith('layer_')]
layer_results.sort(key=lambda x: int(x[0].split('_')[1]))

layer_nums = [int(k.split('_')[1]) for k, v in layer_results]
layer_corrs = [v['correlation'] for k, v in layer_results]

# Create figure
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Layer correlation curve
ax1 = axes[0]
ax1.plot(layer_nums, layer_corrs, 'o-', linewidth=2, markersize=8, color='blue')
ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('r(UA, Output Preference)', fontsize=12)
ax1.set_title('Correlation by Layer: Point of Divergence', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Mark the divergence point (where r crosses 0 or changes sign)
for i in range(len(layer_corrs) - 1):
    if layer_corrs[i] * layer_corrs[i+1] < 0:  # Sign change
        ax1.axvline(x=(layer_nums[i] + layer_nums[i+1])/2, color='red', linestyle=':', linewidth=2, label='Divergence')
        ax1.legend()

# Plot 2: Compare strategies
ax2 = axes[1]
strategies = ['mean_pooled', 'last_token', 'first_token']
strategy_corrs = [results[s]['correlation'] for s in strategies]
colors = ['blue', 'green', 'orange']

bars = ax2.bar(strategies, strategy_corrs, color=colors, edgecolor='black')
ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax2.set_ylabel('r(UA, Output Preference)', fontsize=12)
ax2.set_title('Correlation by Pooling Strategy', fontsize=14, fontweight='bold')

# Add value labels
for bar, val in zip(bars, strategy_corrs):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
             f'{val:.3f}', ha='center', va='bottom', fontsize=11)

plt.tight_layout()
plt.savefig('layer_analysis_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPlot saved to: layer_analysis_results.png")

## Summary Table

In [None]:
print("\n" + "="*70)
print(f" LAYER-WISE UA ANALYSIS: {MODEL_DISPLAY}")
print("="*70)

print("\n--- POOLING STRATEGY COMPARISON ---")
print(f"{'Strategy':<20} {'r(UA, Output)':<15} {'Interpretation'}")
print("-" * 60)

for strategy in ['mean_pooled', 'last_token', 'first_token']:
    r = results[strategy]['correlation']
    if r < -0.5:
        interp = "INVERSE (strong negative)"
    elif r < -0.2:
        interp = "Weak negative"
    elif r < 0.2:
        interp = "DECOUPLED (near zero)"
    elif r < 0.5:
        interp = "Weak positive"
    else:
        interp = "ALIGNED (strong positive)"
    print(f"{strategy:<20} {r:+.3f}           {interp}")

print("\n--- LAYER-BY-LAYER CORRELATION ---")
print(f"{'Layer':<10} {'r(UA, Output)':<15}")
print("-" * 30)

for layer, corr in zip(layer_nums, layer_corrs):
    print(f"Layer {layer:<4} {corr:+.3f}")

# Find divergence point
print("\n--- DIVERGENCE ANALYSIS ---")
divergence_found = False
for i in range(len(layer_corrs) - 1):
    if layer_corrs[i] * layer_corrs[i+1] < 0:
        print(f"Sign change between Layer {layer_nums[i]} and Layer {layer_nums[i+1]}")
        print(f"  Layer {layer_nums[i]}: r = {layer_corrs[i]:+.3f}")
        print(f"  Layer {layer_nums[i+1]}: r = {layer_corrs[i+1]:+.3f}")
        divergence_found = True

if not divergence_found:
    if all(c < 0 for c in layer_corrs):
        print("No sign change: Correlation is NEGATIVE across all layers")
    elif all(c > 0 for c in layer_corrs):
        print("No sign change: Correlation is POSITIVE across all layers")
    else:
        print("Pattern unclear - check individual layer values")

## Save Results

In [None]:
# Prepare results for saving
save_results = {
    'model': MODEL_NAME,
    'model_display': MODEL_DISPLAY,
    'timestamp': datetime.now().isoformat(),
    'strategies': {}
}

for strategy, data in results.items():
    save_results['strategies'][strategy] = {
        'correlation': data['correlation'],
        'p_value': data['p_value']
    }

# Save to JSON
output_file = f"layer_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, 'w') as f:
    json.dump(save_results, f, indent=2)

print(f"Results saved to: {output_file}")

# Download
from google.colab import files
files.download(output_file)
files.download('layer_analysis_results.png')

## Summary for Reddit/Discord

In [None]:
# Generate summary text
mean_r = results['mean_pooled']['correlation']
last_r = results['last_token']['correlation']
first_r = results['first_token']['correlation']

summary = f"""
## Layer-wise UA Analysis Results: {MODEL_DISPLAY}

Tested different embedding strategies on the same dataset (230 pairs, 6 categories).

### Pooling Strategy Comparison

| Strategy | r(UA, Output) |
|----------|---------------|
| Mean-pooled (all tokens) | {mean_r:+.3f} |
| Last-token only | {last_r:+.3f} |
| First-token (BOS) | {first_r:+.3f} |

### Layer-by-Layer Correlation

| Layer | r |
|-------|---|
"""

for layer, corr in zip(layer_nums, layer_corrs):
    summary += f"| {layer} | {corr:+.3f} |\n"

# Interpretation
if last_r > mean_r + 0.2:
    interp = "Last-token embeddings correlate MORE with output than mean-pooled."
elif last_r < mean_r - 0.2:
    interp = "Last-token embeddings correlate LESS with output than mean-pooled."
else:
    interp = "Pooling strategy has minimal effect on correlation."

summary += f"\n### Interpretation\n\n{interp}\n"

print(summary)