# Option B: Pair-Level Analysis (New Metric, n=230)

**Research Question:**
> Can we define a pair-level UA proxy that correlates with output preference at n=230?

**IMPORTANT:** This is a **NEW METRIC**, not the same as category-level UA.

**Method:**
- Define "Local UA" for each pair (A_i, B_i)
- Correlate with output preference
- n=230 gives much more statistical power

**Proposed Local UA Metrics:**
1. **Centroid Distance Asymmetry:** dist(A_i, centroid_A) - dist(B_i, centroid_B)
2. **kNN Density Asymmetry:** density around A_i vs density around B_i
3. **Relative Position:** Where does this pair sit relative to cluster structure?

**Caveat:** Results are NOT directly comparable to category-level UA.

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03  
**Model:** Pythia-6.9B  
**Status:** EXPERIMENTAL

## 1. Setup

In [None]:
!pip install -q transformers accelerate torch numpy scipy matplotlib scikit-learn

In [None]:
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

N_BOOTSTRAP = 10000
CI_LEVEL = 0.95

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
MODEL_NAME = "EleutherAI/pythia-6.9b"
MODEL_DISPLAY = "Pythia-6.9B"

print(f"Loading {MODEL_DISPLAY}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True
)

print(f"Model loaded. Layers: {model.config.num_hidden_layers}")

In [None]:
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

ALL_PAIRS = []
for cat_name, cat_data in DATASET.items():
    for pair in cat_data['pairs']:
        ALL_PAIRS.append({
            'stmt_a': pair[0],
            'stmt_b': pair[1],
            'category': cat_name
        })

print(f"Total pairs: {len(ALL_PAIRS)}")

## 2. Core Functions

In [None]:
def get_layer_embedding(text: str, model, tokenizer, layer_idx: int) -> np.ndarray:
    """Get mean-pooled embedding from a specific layer."""
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states
    
    layer_hidden = hidden_states[layer_idx]
    embedding = layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    return embedding


def get_output_preference(text_a: str, text_b: str, model, tokenizer) -> float:
    """Calculate output preference as NLL(B) - NLL(A)."""
    def get_nll(text):
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            return outputs.loss.item()
    
    return get_nll(text_b) - get_nll(text_a)


def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    """Compute cosine similarity between two vectors."""
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10))


def bootstrap_correlation(x: np.ndarray, y: np.ndarray, n_bootstrap: int = 10000,
                          ci_level: float = 0.95) -> Tuple[float, float, float, float]:
    """Compute bootstrap CI for Pearson correlation."""
    n = len(x)
    r_observed, p_value = stats.pearsonr(x, y)
    
    bootstrap_rs = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        x_boot = x[idx]
        y_boot = y[idx]
        
        if np.std(x_boot) > 0 and np.std(y_boot) > 0:
            r_boot, _ = stats.pearsonr(x_boot, y_boot)
            bootstrap_rs.append(r_boot)
    
    bootstrap_rs = np.array(bootstrap_rs)
    alpha = 1 - ci_level
    ci_lower = np.percentile(bootstrap_rs, alpha/2 * 100)
    ci_upper = np.percentile(bootstrap_rs, (1 - alpha/2) * 100)
    
    return float(r_observed), float(ci_lower), float(ci_upper), float(p_value)

## 3. Collect All Embeddings

In [None]:
LAYERS_TO_TEST = [0, 4, 8, 12, 16, 20, 24, 28, 32]

print(f"Collecting embeddings for {len(ALL_PAIRS)} pairs...")

pair_data = []

for i, pair in enumerate(ALL_PAIRS):
    if (i + 1) % 25 == 0:
        print(f"  [{i+1:03d}/{len(ALL_PAIRS)}]")
    
    stmt_a = pair['stmt_a']
    stmt_b = pair['stmt_b']
    
    pref = get_output_preference(stmt_a, stmt_b, model, tokenizer)
    
    layer_embeddings = {}
    for layer_idx in LAYERS_TO_TEST:
        emb_a = get_layer_embedding(stmt_a, model, tokenizer, layer_idx)
        emb_b = get_layer_embedding(stmt_b, model, tokenizer, layer_idx)
        layer_embeddings[layer_idx] = {'emb_a': emb_a, 'emb_b': emb_b}
    
    pair_data.append({
        'pref': pref,
        'category': pair['category'],
        'layer_embeddings': layer_embeddings
    })

print(f"\nDone! Collected {len(pair_data)} pairs.")

## 4. Define Pair-Level UA Metrics

In [None]:
def compute_pair_metrics(pair_data: list, layer_idx: int) -> Dict[str, np.ndarray]:
    """
    Compute multiple pair-level UA proxies.
    
    Returns dict with different metric arrays (one value per pair).
    """
    n_pairs = len(pair_data)
    
    # Collect all embeddings
    all_embs_a = np.array([p['layer_embeddings'][layer_idx]['emb_a'] for p in pair_data])
    all_embs_b = np.array([p['layer_embeddings'][layer_idx]['emb_b'] for p in pair_data])
    
    # Compute centroids
    centroid_a = all_embs_a.mean(axis=0)
    centroid_b = all_embs_b.mean(axis=0)
    
    metrics = {}
    
    # ----- Metric 1: Centroid Distance Asymmetry -----
    # How far is this A from all A's centroid vs this B from all B's centroid?
    # High value = A is more "typical" of its side than B is of its side
    centroid_dist_a = np.array([cosine_similarity(emb, centroid_a) for emb in all_embs_a])
    centroid_dist_b = np.array([cosine_similarity(emb, centroid_b) for emb in all_embs_b])
    metrics['centroid_asymmetry'] = centroid_dist_a - centroid_dist_b
    
    # ----- Metric 2: Cross-Centroid Distance -----
    # How far is A from B's centroid vs B from A's centroid?
    cross_dist_a = np.array([cosine_similarity(emb, centroid_b) for emb in all_embs_a])
    cross_dist_b = np.array([cosine_similarity(emb, centroid_a) for emb in all_embs_b])
    metrics['cross_centroid'] = cross_dist_a - cross_dist_b
    
    # ----- Metric 3: Within-Pair Similarity -----
    # How similar is A to B within this pair?
    # Low similarity might indicate more "distinctive" pairs
    within_pair_sim = np.array([cosine_similarity(all_embs_a[i], all_embs_b[i]) 
                                 for i in range(n_pairs)])
    metrics['within_pair_sim'] = within_pair_sim
    
    # ----- Metric 4: kNN Density Asymmetry -----
    # Is A in a denser region than B?
    k = min(10, n_pairs - 1)
    
    nn_a = NearestNeighbors(n_neighbors=k, metric='cosine')
    nn_a.fit(all_embs_a)
    distances_a, _ = nn_a.kneighbors(all_embs_a)
    density_a = 1 / (distances_a.mean(axis=1) + 1e-10)
    
    nn_b = NearestNeighbors(n_neighbors=k, metric='cosine')
    nn_b.fit(all_embs_b)
    distances_b, _ = nn_b.kneighbors(all_embs_b)
    density_b = 1 / (distances_b.mean(axis=1) + 1e-10)
    
    # Normalize densities
    density_a_norm = (density_a - density_a.mean()) / (density_a.std() + 1e-10)
    density_b_norm = (density_b - density_b.mean()) / (density_b.std() + 1e-10)
    
    metrics['knn_density_asymmetry'] = density_a_norm - density_b_norm
    
    # ----- Metric 5: Margin to Decision Boundary -----
    # Conceptual: How far is each embedding from the "middle" between centroids?
    midpoint = (centroid_a + centroid_b) / 2
    dist_to_mid_a = np.array([np.linalg.norm(emb - midpoint) for emb in all_embs_a])
    dist_to_mid_b = np.array([np.linalg.norm(emb - midpoint) for emb in all_embs_b])
    metrics['margin_asymmetry'] = dist_to_mid_a - dist_to_mid_b
    
    return metrics


print("Pair-level metrics defined:")
print("  1. centroid_asymmetry: sim(A, centroid_A) - sim(B, centroid_B)")
print("  2. cross_centroid: sim(A, centroid_B) - sim(B, centroid_A)")
print("  3. within_pair_sim: sim(A, B) for each pair")
print("  4. knn_density_asymmetry: normalized kNN density difference")
print("  5. margin_asymmetry: distance to midpoint difference")

## 5. Pair-Level Correlation Analysis

In [None]:
# Get output preferences
all_prefs = np.array([p['pref'] for p in pair_data])

print(f"\n{'='*80}")
print(f" PAIR-LEVEL ANALYSIS (n={len(pair_data)})")
print(f" NOTE: This is a DIFFERENT metric than category-level UA")
print(f"{'='*80}")

results = {}

for layer_idx in LAYERS_TO_TEST:
    print(f"\n--- Layer {layer_idx} ---")
    
    metrics = compute_pair_metrics(pair_data, layer_idx)
    layer_results = {}
    
    for metric_name, metric_values in metrics.items():
        r, ci_lower, ci_upper, p = bootstrap_correlation(
            metric_values, all_prefs, N_BOOTSTRAP, CI_LEVEL
        )
        
        includes_zero = ci_lower <= 0 <= ci_upper
        sig_marker = "" if includes_zero else "***"
        
        layer_results[metric_name] = {
            'r': r,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'p_value': p,
            'includes_zero': includes_zero
        }
        
        print(f"  {metric_name:<25} r={r:+.3f}  CI=[{ci_lower:+.3f}, {ci_upper:+.3f}] {sig_marker}")
    
    results[layer_idx] = layer_results

## 6. Find Best Metric

In [None]:
print(f"\n{'='*80}")
print(f" BEST PAIR-LEVEL METRIC BY LAYER")
print(f"{'='*80}")

# For each layer, find metric with highest |r| that excludes 0
print(f"\n{'Layer':<8} {'Best Metric':<25} {'r':<10} {'CI excludes 0?'}")
print("-" * 65)

best_metrics_summary = []

for layer_idx in LAYERS_TO_TEST:
    layer_results = results[layer_idx]
    
    # Find best (highest |r| with CI excluding 0)
    best_metric = None
    best_r = 0
    best_excludes_zero = False
    
    for metric_name, res in layer_results.items():
        if abs(res['r']) > abs(best_r):
            best_r = res['r']
            best_metric = metric_name
            best_excludes_zero = not res['includes_zero']
    
    status = "YES ***" if best_excludes_zero else "no"
    print(f"Layer {layer_idx:<3} {best_metric:<25} {best_r:+.3f}     {status}")
    
    best_metrics_summary.append({
        'layer': layer_idx,
        'best_metric': best_metric,
        'r': best_r,
        'excludes_zero': best_excludes_zero
    })

## 7. Visualization

In [None]:
# Plot correlation by layer for each metric
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

metric_names = list(results[LAYERS_TO_TEST[0]].keys())

for idx, metric_name in enumerate(metric_names):
    ax = axes[idx]
    
    rs = [results[l][metric_name]['r'] for l in LAYERS_TO_TEST]
    ci_lowers = [results[l][metric_name]['ci_lower'] for l in LAYERS_TO_TEST]
    ci_uppers = [results[l][metric_name]['ci_upper'] for l in LAYERS_TO_TEST]
    
    yerr_lower = [r - ci_l for r, ci_l in zip(rs, ci_lowers)]
    yerr_upper = [ci_u - r for r, ci_u in zip(rs, ci_uppers)]
    
    ax.errorbar(LAYERS_TO_TEST, rs, yerr=[yerr_lower, yerr_upper],
                fmt='o-', capsize=5, capthick=2, markersize=8,
                color='blue', ecolor='blue', alpha=0.7)
    
    # Highlight significant layers
    for l, r in zip(LAYERS_TO_TEST, rs):
        if not results[l][metric_name]['includes_zero']:
            ax.scatter([l], [r], color='red', s=150, zorder=5, marker='*')
    
    ax.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax.set_xlabel('Layer')
    ax.set_ylabel('r')
    ax.set_title(metric_name, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.5, 0.5)

# Hide extra subplot
if len(metric_names) < len(axes):
    axes[-1].axis('off')

plt.suptitle(f'Pair-Level Correlations (n={len(pair_data)}) - Red stars = significant',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('option_b_pair_level.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPlot saved to: option_b_pair_level.png")

## 8. Summary & Interpretation

In [None]:
print("\n" + "#"*80)
print("# OPTION B SUMMARY: Pair-Level Analysis (EXPERIMENTAL)")
print("#"*80)

# Count significant findings
total_tests = len(LAYERS_TO_TEST) * len(metric_names)
sig_tests = sum(1 for l in results for m in results[l] if not results[l][m]['includes_zero'])

print(f"\nTotal tests: {total_tests}")
print(f"Significant (CI excludes 0): {sig_tests}")
print(f"Significance rate: {sig_tests/total_tests*100:.1f}%")

print(f"\n--- Key Observations ---")

# Check if mid-layers are different
mid_layers = [8, 12]
late_layers = [28, 32]

mid_sig = [(l, m) for l in mid_layers for m in results[l] if not results[l][m]['includes_zero']]
late_sig = [(l, m) for l in late_layers for m in results[l] if not results[l][m]['includes_zero']]

print(f"\nMid-layer (8, 12) significant correlations: {len(mid_sig)}")
for l, m in mid_sig:
    print(f"  Layer {l}, {m}: r = {results[l][m]['r']:+.3f}")

print(f"\nLate-layer (28, 32) significant correlations: {len(late_sig)}")
for l, m in late_sig:
    print(f"  Layer {l}, {m}: r = {results[l][m]['r']:+.3f}")

print(f"\n--- IMPORTANT CAVEAT ---")
print(f"These pair-level metrics are NOT the same as category-level UA.")
print(f"They measure different properties:")
print(f"  - Category-level UA: Group geometry (uniformity within clusters)")
print(f"  - Pair-level metrics: Individual position relative to clusters")
print(f"")
print(f"Pair-level results should be interpreted as:")
print(f"  'Individual embedding position correlates with output preference'")
print(f"NOT as:")
print(f"  'The same UA metric works at pair level'")

In [None]:
# Save results
save_data = {
    'timestamp': datetime.now().isoformat(),
    'model': MODEL_NAME,
    'method': 'Option B: Pair-Level Analysis (EXPERIMENTAL)',
    'n_pairs': len(pair_data),
    'n_bootstrap': N_BOOTSTRAP,
    'metrics_tested': metric_names,
    'results': {str(k): {m: v for m, v in layer_res.items()} 
                for k, layer_res in results.items()},
    'best_per_layer': best_metrics_summary,
    'summary': {
        'total_tests': total_tests,
        'significant_tests': sig_tests,
        'mid_layer_significant': len(mid_sig),
        'late_layer_significant': len(late_sig)
    },
    'caveat': 'Pair-level metrics are NOT the same as category-level UA. Different interpretation required.'
}

output_file = f"option_b_pair_level_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, 'w') as f:
    json.dump(save_data, f, indent=2)

print(f"\nResults saved to: {output_file}")

from google.colab import files
files.download(output_file)
files.download('option_b_pair_level.png')