# Option A: More Bins (Same Metric, More Power)

**Research Question:**
> With more statistical power (n=20-30 bins instead of n=6 categories), does the mid-layer positive regime become statistically stable?

**Method:**
- Keep the SAME UA metric (group-level uniformity difference)
- Randomly partition 230 pairs into k bins (k=20, 25, 30)
- Compute UA per bin
- Correlate with output preference per bin
- Bootstrap CI with larger n

**Advantage:** No method change - same interpretation as before.

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03  
**Model:** Pythia-6.9B

## 1. Setup

In [None]:
!pip install -q transformers accelerate torch numpy scipy matplotlib

In [None]:
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

N_BOOTSTRAP = 10000
CI_LEVEL = 0.95

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
MODEL_NAME = "EleutherAI/pythia-6.9b"
MODEL_DISPLAY = "Pythia-6.9B"

print(f"Loading {MODEL_DISPLAY}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True
)

print(f"Model loaded. Layers: {model.config.num_hidden_layers}")

In [None]:
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

# Flatten all pairs into a single list
ALL_PAIRS = []
for cat_name, cat_data in DATASET.items():
    for pair in cat_data['pairs']:
        ALL_PAIRS.append({
            'stmt_a': pair[0],
            'stmt_b': pair[1],
            'original_category': cat_name
        })

print(f"Total pairs: {len(ALL_PAIRS)}")

## 2. Core Functions

In [None]:
def get_layer_embedding(text: str, model, tokenizer, layer_idx: int) -> np.ndarray:
    """Get mean-pooled embedding from a specific layer."""
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states
    
    layer_hidden = hidden_states[layer_idx]
    embedding = layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    return embedding


def get_output_preference(text_a: str, text_b: str, model, tokenizer) -> float:
    """Calculate output preference as NLL(B) - NLL(A)."""
    def get_nll(text):
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            return outputs.loss.item()
    
    return get_nll(text_b) - get_nll(text_a)


def uniformity_score(embeddings: np.ndarray) -> float:
    """Calculate average pairwise cosine similarity."""
    if len(embeddings) < 2:
        return 0.0
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    kernel = normalized @ normalized.T
    n = kernel.shape[0]
    idx = np.triu_indices(n, k=1)
    return float(np.mean(kernel[idx]))


def bootstrap_correlation(x: np.ndarray, y: np.ndarray, n_bootstrap: int = 10000,
                          ci_level: float = 0.95) -> Tuple[float, float, float, float]:
    """Compute bootstrap CI for Pearson correlation."""
    n = len(x)
    if n < 3:
        return 0.0, -1.0, 1.0, 1.0
    
    r_observed, p_value = stats.pearsonr(x, y)
    
    bootstrap_rs = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        x_boot = x[idx]
        y_boot = y[idx]
        
        if np.std(x_boot) > 0 and np.std(y_boot) > 0:
            r_boot, _ = stats.pearsonr(x_boot, y_boot)
            bootstrap_rs.append(r_boot)
    
    if len(bootstrap_rs) < 100:
        return float(r_observed), -1.0, 1.0, float(p_value)
    
    bootstrap_rs = np.array(bootstrap_rs)
    alpha = 1 - ci_level
    ci_lower = np.percentile(bootstrap_rs, alpha/2 * 100)
    ci_upper = np.percentile(bootstrap_rs, (1 - alpha/2) * 100)
    
    return float(r_observed), float(ci_lower), float(ci_upper), float(p_value)

## 3. Collect All Embeddings

In [None]:
# Layers to analyze
LAYERS_TO_TEST = [0, 4, 8, 12, 16, 20, 24, 28, 32]

print(f"Collecting embeddings for {len(ALL_PAIRS)} pairs across {len(LAYERS_TO_TEST)} layers...")
print(f"This will take ~30-45 minutes.")

# Store: pair_idx -> layer -> {'emb_a': ..., 'emb_b': ..., 'pref': ...}
pair_data = []

for i, pair in enumerate(ALL_PAIRS):
    if (i + 1) % 25 == 0:
        print(f"  [{i+1:03d}/{len(ALL_PAIRS)}]")
    
    stmt_a = pair['stmt_a']
    stmt_b = pair['stmt_b']
    
    # Get output preference
    pref = get_output_preference(stmt_a, stmt_b, model, tokenizer)
    
    # Get embeddings for each layer
    layer_embeddings = {}
    for layer_idx in LAYERS_TO_TEST:
        emb_a = get_layer_embedding(stmt_a, model, tokenizer, layer_idx)
        emb_b = get_layer_embedding(stmt_b, model, tokenizer, layer_idx)
        layer_embeddings[layer_idx] = {'emb_a': emb_a, 'emb_b': emb_b}
    
    pair_data.append({
        'pref': pref,
        'original_category': pair['original_category'],
        'layer_embeddings': layer_embeddings
    })

print(f"\nDone! Collected data for {len(pair_data)} pairs.")

## 4. Random Binning Analysis

In [None]:
def analyze_with_k_bins(pair_data: list, layer_idx: int, k: int, n_permutations: int = 100) -> dict:
    """
    Analyze correlation with k random bins.
    
    Returns averaged results over n_permutations random binnings.
    """
    n_pairs = len(pair_data)
    
    all_rs = []
    all_ps = []
    
    for perm in range(n_permutations):
        # Random assignment to bins
        bin_assignments = np.random.randint(0, k, size=n_pairs)
        
        # Compute UA and mean preference per bin
        bin_uas = []
        bin_prefs = []
        
        for bin_idx in range(k):
            bin_mask = bin_assignments == bin_idx
            if bin_mask.sum() < 2:
                continue
            
            # Get embeddings for this bin
            bin_embs_a = []
            bin_embs_b = []
            bin_preferences = []
            
            for i, is_in_bin in enumerate(bin_mask):
                if is_in_bin:
                    bin_embs_a.append(pair_data[i]['layer_embeddings'][layer_idx]['emb_a'])
                    bin_embs_b.append(pair_data[i]['layer_embeddings'][layer_idx]['emb_b'])
                    bin_preferences.append(pair_data[i]['pref'])
            
            bin_embs_a = np.array(bin_embs_a)
            bin_embs_b = np.array(bin_embs_b)
            
            # Compute UA for this bin
            u_a = uniformity_score(bin_embs_a)
            u_b = uniformity_score(bin_embs_b)
            ua = u_a - u_b
            
            bin_uas.append(ua)
            bin_prefs.append(np.mean(bin_preferences))
        
        # Correlation for this permutation
        if len(bin_uas) >= 3:
            r, p = stats.pearsonr(bin_uas, bin_prefs)
            all_rs.append(r)
            all_ps.append(p)
    
    if not all_rs:
        return {'r_mean': 0, 'r_std': 0, 'p_mean': 1, 'n_valid': 0}
    
    return {
        'r_mean': float(np.mean(all_rs)),
        'r_std': float(np.std(all_rs)),
        'p_mean': float(np.mean(all_ps)),
        'n_valid': len(all_rs)
    }

In [None]:
# Test different numbers of bins
K_VALUES = [10, 15, 20, 25, 30]
N_PERMUTATIONS = 200

print(f"Testing k = {K_VALUES} bins with {N_PERMUTATIONS} permutations each...")
print(f"Layers: {LAYERS_TO_TEST}")

results_by_k = {}

for k in K_VALUES:
    print(f"\n--- k = {k} bins ---")
    results_by_k[k] = {}
    
    for layer_idx in LAYERS_TO_TEST:
        result = analyze_with_k_bins(pair_data, layer_idx, k, N_PERMUTATIONS)
        results_by_k[k][layer_idx] = result
        
        r_mean = result['r_mean']
        r_std = result['r_std']
        print(f"  Layer {layer_idx:2d}: r = {r_mean:+.3f} ± {r_std:.3f}")

print("\nDone!")

## 5. Bootstrap CI with Optimal k

In [None]:
# Use k=20 as a good balance (enough bins, enough pairs per bin)
K_OPTIMAL = 20

print(f"Computing Bootstrap CI with k={K_OPTIMAL} bins...")
print(f"This gives n≈{K_OPTIMAL} for correlation (vs n=6 before)")

# Create one stable binning for CI analysis
np.random.seed(RANDOM_SEED)
stable_bins = np.random.randint(0, K_OPTIMAL, size=len(pair_data))

bootstrap_results = {}

for layer_idx in LAYERS_TO_TEST:
    # Compute UA and pref per bin
    bin_uas = []
    bin_prefs = []
    
    for bin_idx in range(K_OPTIMAL):
        bin_mask = stable_bins == bin_idx
        if bin_mask.sum() < 2:
            continue
        
        bin_embs_a = [pair_data[i]['layer_embeddings'][layer_idx]['emb_a'] 
                      for i, m in enumerate(bin_mask) if m]
        bin_embs_b = [pair_data[i]['layer_embeddings'][layer_idx]['emb_b'] 
                      for i, m in enumerate(bin_mask) if m]
        bin_preferences = [pair_data[i]['pref'] for i, m in enumerate(bin_mask) if m]
        
        u_a = uniformity_score(np.array(bin_embs_a))
        u_b = uniformity_score(np.array(bin_embs_b))
        
        bin_uas.append(u_a - u_b)
        bin_prefs.append(np.mean(bin_preferences))
    
    # Bootstrap CI
    r, ci_lower, ci_upper, p = bootstrap_correlation(
        np.array(bin_uas), np.array(bin_prefs), N_BOOTSTRAP, CI_LEVEL
    )
    
    includes_zero = ci_lower <= 0 <= ci_upper
    
    bootstrap_results[layer_idx] = {
        'r': r,
        'ci_lower': ci_lower,
        'ci_upper': ci_upper,
        'p_value': p,
        'includes_zero': includes_zero,
        'n_bins': len(bin_uas)
    }

print(f"\n{'='*70}")
print(f" BOOTSTRAP CI RESULTS (k={K_OPTIMAL} bins, n≈{K_OPTIMAL})")
print(f"{'='*70}")
print(f"\n{'Layer':<10} {'r':<10} {'95% CI':<25} {'Includes 0?':<15}")
print("-" * 65)

for layer_idx in LAYERS_TO_TEST:
    res = bootstrap_results[layer_idx]
    ci_str = f"[{res['ci_lower']:+.3f}, {res['ci_upper']:+.3f}]"
    zero_str = "YES" if res['includes_zero'] else "NO ***"
    print(f"Layer {layer_idx:<4} {res['r']:+.3f}     {ci_str:<25} {zero_str}")

## 6. Comparison: n=6 vs n=20

In [None]:
# Original n=6 results (from previous analysis)
original_results = {
    0: {'r': -0.60, 'ci_lower': -1.00, 'ci_upper': 0.95, 'includes_zero': True},
    4: {'r': -0.81, 'ci_lower': -1.00, 'ci_upper': 0.75, 'includes_zero': True},
    8: {'r': 0.47, 'ci_lower': -0.93, 'ci_upper': 0.96, 'includes_zero': True},
    12: {'r': 0.21, 'ci_lower': -0.99, 'ci_upper': 0.91, 'includes_zero': True},
    16: {'r': -0.25, 'ci_lower': -1.00, 'ci_upper': 0.64, 'includes_zero': True},
    20: {'r': -0.25, 'ci_lower': -1.00, 'ci_upper': 0.61, 'includes_zero': True},
    24: {'r': -0.63, 'ci_lower': -1.00, 'ci_upper': 0.03, 'includes_zero': True},
    28: {'r': -0.81, 'ci_lower': -1.00, 'ci_upper': -0.55, 'includes_zero': False},
    32: {'r': -0.87, 'ci_lower': -1.00, 'ci_upper': -0.12, 'includes_zero': False},
}

print(f"\n{'='*80}")
print(f" COMPARISON: n=6 Categories vs n={K_OPTIMAL} Bins")
print(f"{'='*80}")
print(f"\n{'Layer':<8} {'n=6 r':<10} {'n=6 CI':<20} {'n=20 r':<10} {'n=20 CI':<20} {'Change?'}")
print("-" * 85)

for layer_idx in LAYERS_TO_TEST:
    orig = original_results.get(layer_idx, {})
    new = bootstrap_results[layer_idx]
    
    orig_r = orig.get('r', 0)
    orig_ci = f"[{orig.get('ci_lower', -1):+.2f}, {orig.get('ci_upper', 1):+.2f}]"
    orig_inc = orig.get('includes_zero', True)
    
    new_r = new['r']
    new_ci = f"[{new['ci_lower']:+.2f}, {new['ci_upper']:+.2f}]"
    new_inc = new['includes_zero']
    
    # Did significance change?
    if orig_inc and not new_inc:
        change = "NOW SIGNIFICANT!"
    elif not orig_inc and new_inc:
        change = "Lost significance"
    else:
        change = "No change"
    
    print(f"Layer {layer_idx:<3} {orig_r:+.2f}      {orig_ci:<20} {new_r:+.2f}      {new_ci:<20} {change}")

## 7. Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

layers = LAYERS_TO_TEST

# Plot 1: Comparison of r values
ax1 = axes[0]

orig_rs = [original_results.get(l, {}).get('r', 0) for l in layers]
new_rs = [bootstrap_results[l]['r'] for l in layers]

x = np.arange(len(layers))
width = 0.35

bars1 = ax1.bar(x - width/2, orig_rs, width, label='n=6 (original)', color='blue', alpha=0.7)
bars2 = ax1.bar(x + width/2, new_rs, width, label=f'n={K_OPTIMAL} (bins)', color='green', alpha=0.7)

ax1.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('r(UA, Output)', fontsize=12)
ax1.set_title('Correlation: n=6 Categories vs n=20 Bins', fontsize=14, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(layers)
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Plot 2: CI width comparison
ax2 = axes[1]

orig_widths = [original_results.get(l, {}).get('ci_upper', 1) - original_results.get(l, {}).get('ci_lower', -1) for l in layers]
new_widths = [bootstrap_results[l]['ci_upper'] - bootstrap_results[l]['ci_lower'] for l in layers]

bars1 = ax2.bar(x - width/2, orig_widths, width, label='n=6 (original)', color='blue', alpha=0.7)
bars2 = ax2.bar(x + width/2, new_widths, width, label=f'n={K_OPTIMAL} (bins)', color='green', alpha=0.7)

ax2.set_xlabel('Layer', fontsize=12)
ax2.set_ylabel('CI Width', fontsize=12)
ax2.set_title('CI Width: More Bins = Narrower CI?', fontsize=14, fontweight='bold')
ax2.set_xticks(x)
ax2.set_xticklabels(layers)
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('option_a_more_bins.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPlot saved to: option_a_more_bins.png")

## 8. Summary

In [None]:
print("\n" + "#"*70)
print("# OPTION A SUMMARY: More Bins (Same Metric)")
print("#"*70)

# Count significant layers
orig_sig = sum(1 for l in layers if not original_results.get(l, {}).get('includes_zero', True))
new_sig = sum(1 for l in layers if not bootstrap_results[l]['includes_zero'])

print(f"\nSignificant layers (CI excludes 0):")
print(f"  n=6 (original): {orig_sig} layers")
print(f"  n={K_OPTIMAL} (bins):    {new_sig} layers")

# Check mid-layers specifically
mid_layers = [8, 12]
mid_now_sig = [l for l in mid_layers if not bootstrap_results[l]['includes_zero']]

print(f"\nMid-layer regime (Layer 8, 12):")
if mid_now_sig:
    print(f"  NOW SIGNIFICANT: Layers {mid_now_sig}")
    print(f"  >>> The mid-layer positive regime IS statistically stable with more power!")
else:
    print(f"  Still not significant")
    print(f"  >>> The mid-layer regime remains a trend, even with n={K_OPTIMAL}")

# Average CI width reduction
orig_mean_width = np.mean(orig_widths)
new_mean_width = np.mean(new_widths)
width_reduction = (orig_mean_width - new_mean_width) / orig_mean_width * 100

print(f"\nCI Width:")
print(f"  n=6 mean width:  {orig_mean_width:.2f}")
print(f"  n={K_OPTIMAL} mean width: {new_mean_width:.2f}")
print(f"  Reduction: {width_reduction:.1f}%")

In [None]:
# Save results
save_data = {
    'timestamp': datetime.now().isoformat(),
    'model': MODEL_NAME,
    'method': 'Option A: Random Bins (same metric)',
    'k_bins': K_OPTIMAL,
    'n_bootstrap': N_BOOTSTRAP,
    'n_permutations': N_PERMUTATIONS,
    'bootstrap_results': {str(k): v for k, v in bootstrap_results.items()},
    'comparison': {
        'original_n': 6,
        'new_n': K_OPTIMAL,
        'original_sig_layers': orig_sig,
        'new_sig_layers': new_sig,
        'mid_layer_now_significant': bool(mid_now_sig),
        'ci_width_reduction_pct': float(width_reduction)
    }
}

output_file = f"option_a_bins_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, 'w') as f:
    json.dump(save_data, f, indent=2)

print(f"\nResults saved to: {output_file}")

from google.colab import files
files.download(output_file)
files.download('option_a_more_bins.png')