# Bootstrap CI for Layer-wise Correlations

**Research Question:**
> Ist das Mid-Layer-Regime statistisch stabil oder nur ein Trend?

**Method:** Bootstrap resampling of category-level data to compute 95% CI for r(UA, Output) per layer.

**Key Test:** Does the CI for mid-layers (Layer 8-12) exclude 0?
- If CI excludes 0 → Statistically distinguishable from no correlation
- If CI includes 0 → Could be noise

**Note:** With n=6 categories, CIs will be wide. This is expected and honest.

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03  
**Model:** Pythia-6.9B (strongest mid-layer effect)

## 1. Setup

In [None]:
# Install dependencies
!pip install -q transformers accelerate torch numpy scipy matplotlib

In [None]:
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

N_BOOTSTRAP = 10000  # Number of bootstrap iterations
CI_LEVEL = 0.95

print(f"Bootstrap iterations: {N_BOOTSTRAP}")
print(f"Confidence level: {CI_LEVEL*100}%")
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Model configuration - using 6.9B (strongest effect)
MODEL_NAME = "EleutherAI/pythia-6.9b"
MODEL_DISPLAY = "Pythia-6.9B"

print(f"Loading {MODEL_DISPLAY}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True
)

print(f"Model loaded. Layers: {model.config.num_hidden_layers}")

In [None]:
# Load dataset
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

CATEGORIES = list(DATASET.keys())
print(f"Categories (n={len(CATEGORIES)}): {CATEGORIES}")

## 2. Core Functions

In [None]:
def get_layer_embeddings(text: str, model, tokenizer, layer_step: int = 4) -> Dict[str, np.ndarray]:
    """Extract embeddings from every Nth layer."""
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states
        
    n_layers = len(hidden_states) - 1
    embeddings = {}
    
    layer_indices = list(range(0, n_layers + 1, layer_step))
    if n_layers not in layer_indices:
        layer_indices.append(n_layers)
    
    for layer_idx in layer_indices:
        layer_hidden = hidden_states[layer_idx]
        embeddings[layer_idx] = layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    return embeddings


def get_output_preference(text_a: str, text_b: str, model, tokenizer) -> float:
    """Calculate output preference as NLL(B) - NLL(A)."""
    def get_nll(text):
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            return outputs.loss.item()
    
    return get_nll(text_b) - get_nll(text_a)


def uniformity_score(embeddings: np.ndarray) -> float:
    """Calculate average pairwise cosine similarity."""
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    kernel = normalized @ normalized.T
    n = kernel.shape[0]
    idx = np.triu_indices(n, k=1)
    return float(np.mean(kernel[idx]))


def bootstrap_correlation(x: np.ndarray, y: np.ndarray, n_bootstrap: int = 10000, 
                          ci_level: float = 0.95) -> Tuple[float, float, float, float]:
    """
    Compute bootstrap CI for Pearson correlation.
    
    Returns:
        (r_observed, ci_lower, ci_upper, p_value)
    """
    n = len(x)
    r_observed, p_value = stats.pearsonr(x, y)
    
    # Bootstrap
    bootstrap_rs = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        x_boot = x[idx]
        y_boot = y[idx]
        
        # Handle edge case: if all values same, correlation undefined
        if np.std(x_boot) > 0 and np.std(y_boot) > 0:
            r_boot, _ = stats.pearsonr(x_boot, y_boot)
            bootstrap_rs.append(r_boot)
    
    bootstrap_rs = np.array(bootstrap_rs)
    
    # Percentile CI
    alpha = 1 - ci_level
    ci_lower = np.percentile(bootstrap_rs, alpha/2 * 100)
    ci_upper = np.percentile(bootstrap_rs, (1 - alpha/2) * 100)
    
    return float(r_observed), float(ci_lower), float(ci_upper), float(p_value)

## 3. Collect Category-Level Data

In [None]:
def collect_category_data(model, tokenizer, dataset: dict) -> Dict[int, Dict]:
    """
    Collect per-category UA and output preference for each layer.
    
    Returns:
        {layer_num: {'uas': [...], 'prefs': [...], 'categories': [...]}}
    """
    # First pass: collect all embeddings
    all_embeddings_a = {}  # layer -> category -> list of embeddings
    all_embeddings_b = {}
    all_preferences = {}  # category -> list of preferences
    
    total_pairs = sum(len(cat["pairs"]) for cat in dataset.values())
    processed = 0
    
    print(f"Collecting embeddings for {total_pairs} pairs...")
    
    for category_name, category_data in dataset.items():
        pairs = category_data["pairs"]
        all_preferences[category_name] = []
        
        for stmt_a, stmt_b in pairs:
            processed += 1
            if processed % 50 == 0:
                print(f"  [{processed:03d}/{total_pairs}]")
            
            # Get embeddings
            embs_a = get_layer_embeddings(stmt_a, model, tokenizer)
            embs_b = get_layer_embeddings(stmt_b, model, tokenizer)
            
            for layer_num, emb in embs_a.items():
                if layer_num not in all_embeddings_a:
                    all_embeddings_a[layer_num] = {cat: [] for cat in dataset.keys()}
                    all_embeddings_b[layer_num] = {cat: [] for cat in dataset.keys()}
                all_embeddings_a[layer_num][category_name].append(emb)
                all_embeddings_b[layer_num][category_name].append(embs_b[layer_num])
            
            # Get output preference
            pref = get_output_preference(stmt_a, stmt_b, model, tokenizer)
            all_preferences[category_name].append(pref)
    
    print("\nComputing per-category UA...")
    
    # Second pass: compute per-category UA for each layer
    layer_data = {}
    
    for layer_num in sorted(all_embeddings_a.keys()):
        uas = []
        prefs = []
        categories = []
        
        for cat_name in dataset.keys():
            cat_embs_a = np.array(all_embeddings_a[layer_num][cat_name])
            cat_embs_b = np.array(all_embeddings_b[layer_num][cat_name])
            
            u_a = uniformity_score(cat_embs_a)
            u_b = uniformity_score(cat_embs_b)
            ua = u_a - u_b
            
            mean_pref = np.mean(all_preferences[cat_name])
            
            uas.append(ua)
            prefs.append(mean_pref)
            categories.append(cat_name)
        
        layer_data[layer_num] = {
            'uas': np.array(uas),
            'prefs': np.array(prefs),
            'categories': categories
        }
    
    return layer_data


print(f"Starting data collection for {MODEL_DISPLAY}...")
layer_data = collect_category_data(model, tokenizer, DATASET)
print(f"\nData collected for {len(layer_data)} layers.")

## 4. Bootstrap CI Analysis

In [None]:
print(f"\n{'='*70}")
print(f" BOOTSTRAP CI ANALYSIS: {MODEL_DISPLAY}")
print(f" n = {len(CATEGORIES)} categories, {N_BOOTSTRAP} bootstrap iterations")
print(f"{'='*70}")

results = {}

print(f"\n{'Layer':<10} {'r':<10} {'95% CI':<25} {'CI includes 0?':<15} {'p-value'}")
print("-" * 75)

for layer_num in sorted(layer_data.keys()):
    data = layer_data[layer_num]
    uas = data['uas']
    prefs = data['prefs']
    
    r, ci_lower, ci_upper, p_value = bootstrap_correlation(
        uas, prefs, n_bootstrap=N_BOOTSTRAP, ci_level=CI_LEVEL
    )
    
    # Check if CI includes 0
    includes_zero = ci_lower <= 0 <= ci_upper
    
    results[layer_num] = {
        'r': r,
        'ci_lower': ci_lower,
        'ci_upper': ci_upper,
        'p_value': p_value,
        'includes_zero': includes_zero
    }
    
    ci_str = f"[{ci_lower:+.3f}, {ci_upper:+.3f}]"
    zero_str = "YES" if includes_zero else "NO ***"
    sig_marker = "*" if p_value < 0.05 else ""
    
    print(f"Layer {layer_num:<4} {r:+.3f}     {ci_str:<25} {zero_str:<15} {p_value:.4f}{sig_marker}")

## 5. Visualization

In [None]:
# Create visualization with error bars
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

layers = sorted(results.keys())
rs = [results[l]['r'] for l in layers]
ci_lowers = [results[l]['ci_lower'] for l in layers]
ci_uppers = [results[l]['ci_upper'] for l in layers]
includes_zeros = [results[l]['includes_zero'] for l in layers]

# Error bar heights (asymmetric)
yerr_lower = [r - ci_l for r, ci_l in zip(rs, ci_lowers)]
yerr_upper = [ci_u - r for r, ci_u in zip(rs, ci_uppers)]

# Plot 1: Layer correlation with CI
ax1 = axes[0]

# Color based on whether CI includes 0
colors = ['green' if not inc else 'gray' for inc in includes_zeros]

ax1.errorbar(layers, rs, yerr=[yerr_lower, yerr_upper], 
             fmt='o', capsize=5, capthick=2, markersize=10,
             color='blue', ecolor='blue', alpha=0.7)

# Highlight points where CI excludes 0
for i, (layer, r, inc) in enumerate(zip(layers, rs, includes_zeros)):
    if not inc:
        ax1.scatter([layer], [r], color='red', s=150, zorder=5, marker='*')

ax1.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=2)
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('r(UA, Output Preference)', fontsize=12)
ax1.set_title(f'{MODEL_DISPLAY}: Layer Correlation with 95% Bootstrap CI\n(Red stars = CI excludes 0)', 
              fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(-1.2, 1.2)

# Plot 2: CI width by layer
ax2 = axes[1]

ci_widths = [ci_u - ci_l for ci_l, ci_u in zip(ci_lowers, ci_uppers)]
bars = ax2.bar(layers, ci_widths, color='steelblue', edgecolor='black', alpha=0.7)

ax2.set_xlabel('Layer', fontsize=12)
ax2.set_ylabel('CI Width', fontsize=12)
ax2.set_title('95% CI Width by Layer\n(Wider = more uncertainty)', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

# Add mean CI width line
mean_width = np.mean(ci_widths)
ax2.axhline(y=mean_width, color='red', linestyle='--', label=f'Mean: {mean_width:.2f}')
ax2.legend()

plt.tight_layout()
plt.savefig('bootstrap_ci_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPlot saved to: bootstrap_ci_results.png")

## 6. Summary: Is the Mid-Layer Regime Statistically Stable?

In [None]:
print("\n" + "#"*70)
print("# KEY QUESTION: Is the mid-layer positive regime statistically stable?")
print("#"*70)

# Identify mid-layers (roughly 25-50% of depth)
n_layers = model.config.num_hidden_layers
mid_layer_range = (int(n_layers * 0.2), int(n_layers * 0.5))

print(f"\nModel depth: {n_layers} layers")
print(f"Mid-layer range: Layer {mid_layer_range[0]} to {mid_layer_range[1]} (~20-50% depth)")

mid_layers = [l for l in layers if mid_layer_range[0] <= l <= mid_layer_range[1]]

print(f"\nMid-layer analysis:")
mid_layer_stable = []

for layer in mid_layers:
    r = results[layer]['r']
    ci_l = results[layer]['ci_lower']
    ci_u = results[layer]['ci_upper']
    inc_zero = results[layer]['includes_zero']
    
    status = "INCLUDES 0 (not stable)" if inc_zero else "EXCLUDES 0 (stable)"
    print(f"  Layer {layer}: r = {r:+.3f}, CI = [{ci_l:+.3f}, {ci_u:+.3f}] → {status}")
    
    if r > 0:  # Positive correlation
        mid_layer_stable.append(not inc_zero)

# Late layers
late_layers = [l for l in layers if l >= int(n_layers * 0.75)]

print(f"\nLate-layer analysis (>75% depth):")
late_layer_stable = []

for layer in late_layers:
    r = results[layer]['r']
    ci_l = results[layer]['ci_lower']
    ci_u = results[layer]['ci_upper']
    inc_zero = results[layer]['includes_zero']
    
    status = "INCLUDES 0" if inc_zero else "EXCLUDES 0 (significant)"
    print(f"  Layer {layer}: r = {r:+.3f}, CI = [{ci_l:+.3f}, {ci_u:+.3f}] → {status}")
    
    if r < 0:  # Negative correlation
        late_layer_stable.append(not inc_zero)

# Final answer
print("\n" + "="*70)
print(" ANSWER")
print("="*70)

any_mid_stable = any(mid_layer_stable) if mid_layer_stable else False
any_late_stable = any(late_layer_stable) if late_layer_stable else False

if any_mid_stable:
    print("\n>>> Mid-layer positive regime: STATISTICALLY DISTINGUISHABLE from 0")
    print("    At least one mid-layer has CI that excludes 0.")
else:
    print("\n>>> Mid-layer positive regime: NOT statistically stable")
    print("    All mid-layer CIs include 0 - could be noise.")
    print("    NOTE: With n=6 categories, wide CIs are expected.")

if any_late_stable:
    print("\n>>> Late-layer negative regime: STATISTICALLY SIGNIFICANT")
    print("    CI excludes 0 - this is a robust finding.")
else:
    print("\n>>> Late-layer negative regime: Not statistically significant")

## 7. Save Results

In [None]:
# Prepare results for saving
save_data = {
    'timestamp': datetime.now().isoformat(),
    'model': MODEL_NAME,
    'model_display': MODEL_DISPLAY,
    'n_categories': len(CATEGORIES),
    'n_bootstrap': N_BOOTSTRAP,
    'ci_level': CI_LEVEL,
    'layer_results': {str(k): v for k, v in results.items()},
    'summary': {
        'mid_layer_range': list(mid_layer_range),
        'mid_layers_tested': mid_layers,
        'any_mid_layer_stable': bool(any_mid_stable),
        'late_layers_tested': late_layers,
        'any_late_layer_stable': bool(any_late_stable)
    }
}

# Save to JSON
output_file = f"bootstrap_ci_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, 'w') as f:
    json.dump(save_data, f, indent=2)

print(f"Results saved to: {output_file}")

# Download
from google.colab import files
files.download(output_file)
files.download('bootstrap_ci_results.png')

## 8. Interpretation

In [None]:
interpretation = f"""
## Bootstrap CI Analysis: {MODEL_DISPLAY}

### Method
- n = {len(CATEGORIES)} categories
- {N_BOOTSTRAP} bootstrap iterations
- {CI_LEVEL*100}% confidence intervals

### Results by Layer

| Layer | r | 95% CI | CI includes 0? |
|-------|---|--------|----------------|
"""

for layer in sorted(results.keys()):
    r = results[layer]['r']
    ci_l = results[layer]['ci_lower']
    ci_u = results[layer]['ci_upper']
    inc = "Yes" if results[layer]['includes_zero'] else "**No**"
    interpretation += f"| {layer} | {r:+.3f} | [{ci_l:+.3f}, {ci_u:+.3f}] | {inc} |\n"

interpretation += f"""
### Key Finding

"""

if any_mid_stable:
    interpretation += """The mid-layer positive regime is **statistically distinguishable from 0**.
This is not just noise - there is evidence for a qualitatively different regime in mid-layers.
"""
else:
    interpretation += """The mid-layer positive regime **cannot be distinguished from 0** with n=6 categories.
This doesn't mean it's not real - just that we lack statistical power to confirm it.
Wide CIs are expected with small n.
"""

if any_late_stable:
    interpretation += """\nThe late-layer negative regime **is statistically significant** (CI excludes 0).
This is the most robust finding.
"""

interpretation += """
### Honest Assessment

With n=6 categories, bootstrap CIs will be wide. This is **honest uncertainty**.
The pattern (positive mid-layer, negative late-layer) is **consistent** across models,
but statistical significance is limited by small sample size.

**Recommendation:** Pair-level analysis (n=230) would provide more statistical power.
"""

print(interpretation)