# Pythia Scaling Analysis: Does the Layer-8 Peak Shift?

**Research Question:**
> Verschiebt sich der Layer-Bereich, in dem |r| maximal ist, mit der ModellgrÃ¶sse?

**Method:** Run layer-wise UA analysis on Pythia family (410M, 1.4B, 6.9B, 12B)

**Hypothesis:** Larger models may show different layer-wise correlation profiles.

**Expected Output:** Clear yes/no on whether the "divergence point" is scale-dependent.

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03  
**Runtime:** ~2-3 hours on A100 (all 4 models)

## 1. Setup

In [None]:
# Install dependencies
!pip install -q transformers accelerate torch numpy scipy matplotlib

In [None]:
import os
import gc
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Models to test (smallest to largest)
MODELS = [
    {"name": "EleutherAI/pythia-410m", "display": "Pythia-410M", "layers": 24},
    {"name": "EleutherAI/pythia-1.4b", "display": "Pythia-1.4B", "layers": 24},
    {"name": "EleutherAI/pythia-6.9b", "display": "Pythia-6.9B", "layers": 32},
    {"name": "EleutherAI/pythia-12b", "display": "Pythia-12B", "layers": 36},
]

print("Models to analyze:")
for m in MODELS:
    print(f"  - {m['display']} ({m['layers']} layers)")

In [None]:
# Load dataset
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

total_pairs = sum(len(cat['pairs']) for cat in DATASET.values())
print(f"Loaded {total_pairs} statement pairs across {len(DATASET)} categories")
for cat_name, cat_data in DATASET.items():
    print(f"  - {cat_name}: {len(cat_data['pairs'])} pairs")

## 2. Core Functions

In [None]:
def get_layer_embeddings(text: str, model, tokenizer, layer_step: int = 4) -> Dict[str, np.ndarray]:
    """
    Extract embeddings from every Nth layer.
    
    Args:
        text: Input text
        model: HuggingFace model
        tokenizer: HuggingFace tokenizer
        layer_step: Sample every Nth layer (default: 4)
    
    Returns:
        Dict with 'layer_X' keys containing mean-pooled embeddings
    """
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states  # Tuple of (n_layers + 1) tensors
        
    n_layers = len(hidden_states) - 1  # Exclude embedding layer
    
    embeddings = {}
    
    # Sample every layer_step layers + final layer
    layer_indices = list(range(0, n_layers + 1, layer_step))
    if n_layers not in layer_indices:
        layer_indices.append(n_layers)
    
    for layer_idx in layer_indices:
        layer_hidden = hidden_states[layer_idx]
        # Mean-pooled (skip BOS token)
        embeddings[f'layer_{layer_idx}'] = layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    return embeddings


def get_output_preference(text_a: str, text_b: str, model, tokenizer) -> float:
    """
    Calculate output preference as NLL(B) - NLL(A).
    Positive = prefers A, Negative = prefers B.
    """
    def get_nll(text):
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            return outputs.loss.item()
    
    nll_a = get_nll(text_a)
    nll_b = get_nll(text_b)
    
    return nll_b - nll_a


def uniformity_score(embeddings: np.ndarray) -> float:
    """
    Calculate average pairwise cosine similarity (uniformity).
    """
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    kernel = normalized @ normalized.T
    n = kernel.shape[0]
    idx = np.triu_indices(n, k=1)
    return float(np.mean(kernel[idx]))

In [None]:
def run_layer_analysis_for_model(model, tokenizer, dataset: dict, model_display: str, 
                                  layer_step: int = 4) -> dict:
    """
    Run complete layer-wise UA analysis for a single model.
    
    Returns dict with layer -> {correlation, p_value}
    """
    # Collect all embeddings and preferences
    all_embeddings_a = {}  # layer -> list of embeddings
    all_embeddings_b = {}
    all_preferences = []
    category_indices = []
    
    total_pairs = sum(len(cat["pairs"]) for cat in dataset.values())
    processed = 0
    
    print(f"\n{'='*60}")
    print(f" Processing: {model_display}")
    print(f"{'='*60}")
    
    for category_name, category_data in dataset.items():
        pairs = category_data["pairs"]
        
        for stmt_a, stmt_b in pairs:
            processed += 1
            if processed % 50 == 0:
                print(f"  [{processed:03d}/{total_pairs}]")
            
            # Get layer embeddings
            embs_a = get_layer_embeddings(stmt_a, model, tokenizer, layer_step)
            embs_b = get_layer_embeddings(stmt_b, model, tokenizer, layer_step)
            
            # Store embeddings
            for layer_key in embs_a.keys():
                if layer_key not in all_embeddings_a:
                    all_embeddings_a[layer_key] = []
                    all_embeddings_b[layer_key] = []
                all_embeddings_a[layer_key].append(embs_a[layer_key])
                all_embeddings_b[layer_key].append(embs_b[layer_key])
            
            # Get output preference
            pref = get_output_preference(stmt_a, stmt_b, model, tokenizer)
            all_preferences.append(pref)
            category_indices.append(category_name)
    
    # Calculate per-layer correlation
    results = {}
    
    for layer_key in sorted(all_embeddings_a.keys(), key=lambda x: int(x.split('_')[1])):
        embs_a = np.array(all_embeddings_a[layer_key])
        embs_b = np.array(all_embeddings_b[layer_key])
        
        # Calculate per-category UA
        category_uas = []
        category_prefs = []
        
        for cat_name in dataset.keys():
            cat_mask = [c == cat_name for c in category_indices]
            cat_embs_a = embs_a[cat_mask]
            cat_embs_b = embs_b[cat_mask]
            cat_prefs = np.array(all_preferences)[cat_mask]
            
            u_a = uniformity_score(cat_embs_a)
            u_b = uniformity_score(cat_embs_b)
            ua = u_a - u_b
            
            category_uas.append(ua)
            category_prefs.append(float(np.mean(cat_prefs)))
        
        # Correlation
        r, p = stats.pearsonr(category_uas, category_prefs)
        
        layer_num = int(layer_key.split('_')[1])
        results[layer_num] = {
            'correlation': float(r),
            'p_value': float(p)
        }
    
    print(f"\n{model_display} - Layer Correlations:")
    for layer_num in sorted(results.keys()):
        r = results[layer_num]['correlation']
        print(f"  Layer {layer_num:2d}: r = {r:+.3f}")
    
    return results

## 3. Run Scaling Analysis

In [None]:
# Store results for all models
all_model_results = {}

for model_config in MODELS:
    model_name = model_config["name"]
    model_display = model_config["display"]
    
    print(f"\n{'#'*70}")
    print(f"# Loading: {model_display}")
    print(f"{'#'*70}")
    
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        output_hidden_states=True
    )
    
    print(f"Model loaded. Layers: {model.config.num_hidden_layers}")
    
    # Run analysis
    results = run_layer_analysis_for_model(model, tokenizer, DATASET, model_display)
    
    # Store results (convert keys to strings for JSON)
    all_model_results[model_display] = {
        "model_name": model_name,
        "n_layers": int(model.config.num_hidden_layers),
        "layer_results": {str(k): v for k, v in results.items()}
    }
    
    # Clear GPU memory
    del model
    del tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"\n{model_display} complete. GPU memory cleared.")

print("\n" + "="*70)
print(" ALL MODELS COMPLETE")
print("="*70)

## 4. Visualization: Layer Curves Comparison

In [None]:
# Create comparison plot
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Colors for each model
colors = ['#2ecc71', '#3498db', '#9b59b6', '#e74c3c']  # green, blue, purple, red
markers = ['o', 's', '^', 'D']

# Plot 1: Absolute layer numbers
ax1 = axes[0]

for idx, (model_display, data) in enumerate(all_model_results.items()):
    layer_results = data["layer_results"]
    layers = sorted([int(k) for k in layer_results.keys()])
    correlations = [layer_results[str(l)]['correlation'] for l in layers]
    
    ax1.plot(layers, correlations, f'{markers[idx]}-', 
             color=colors[idx], linewidth=2, markersize=8, 
             label=model_display, alpha=0.8)

ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Layer Number', fontsize=12)
ax1.set_ylabel('r(UA, Output Preference)', fontsize=12)
ax1.set_title('Layer Correlation by Model Size (Absolute)', fontsize=14, fontweight='bold')
ax1.legend(loc='lower left')
ax1.grid(True, alpha=0.3)

# Plot 2: Normalized layer position (0-100%)
ax2 = axes[1]

for idx, (model_display, data) in enumerate(all_model_results.items()):
    layer_results = data["layer_results"]
    n_layers = data["n_layers"]
    layers = sorted([int(k) for k in layer_results.keys()])
    correlations = [layer_results[str(l)]['correlation'] for l in layers]
    
    # Normalize to percentage
    normalized_layers = [l / n_layers * 100 for l in layers]
    
    ax2.plot(normalized_layers, correlations, f'{markers[idx]}-', 
             color=colors[idx], linewidth=2, markersize=8, 
             label=model_display, alpha=0.8)

ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Layer Position (% of total depth)', fontsize=12)
ax2.set_ylabel('r(UA, Output Preference)', fontsize=12)
ax2.set_title('Layer Correlation by Model Size (Normalized)', fontsize=14, fontweight='bold')
ax2.legend(loc='lower left')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('pythia_scaling_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPlot saved to: pythia_scaling_comparison.png")

## 5. Analysis: Does the Peak Shift?

In [None]:
print("\n" + "="*70)
print(" SCALING ANALYSIS: Does the Divergence Point Shift?")
print("="*70)

# For each model, find:
# 1. Max positive r (and which layer)
# 2. Max negative r (and which layer)
# 3. Sign changes (divergence points)

summary_data = []

for model_display, data in all_model_results.items():
    layer_results = data["layer_results"]
    n_layers = data["n_layers"]
    layers = sorted([int(k) for k in layer_results.keys()])
    correlations = [layer_results[str(l)]['correlation'] for l in layers]
    
    print(f"\n--- {model_display} ({n_layers} layers) ---")
    
    # Max positive
    max_pos_r = max(correlations)
    max_pos_layer = layers[correlations.index(max_pos_r)]
    max_pos_pct = max_pos_layer / n_layers * 100
    
    # Max negative (most negative)
    max_neg_r = min(correlations)
    max_neg_layer = layers[correlations.index(max_neg_r)]
    max_neg_pct = max_neg_layer / n_layers * 100
    
    # Final layer
    final_layer = max(layers)
    final_r = layer_results[str(final_layer)]['correlation']
    
    print(f"  Max positive: r = {max_pos_r:+.3f} at Layer {max_pos_layer} ({max_pos_pct:.0f}%)")
    print(f"  Max negative: r = {max_neg_r:+.3f} at Layer {max_neg_layer} ({max_neg_pct:.0f}%)")
    print(f"  Final layer:  r = {final_r:+.3f}")
    
    # Find sign changes
    sign_changes = []
    for i in range(len(correlations) - 1):
        if correlations[i] * correlations[i+1] < 0:
            midpoint = (layers[i] + layers[i+1]) / 2
            midpoint_pct = midpoint / n_layers * 100
            sign_changes.append({
                'layer_before': int(layers[i]),
                'layer_after': int(layers[i+1]),
                'midpoint_pct': float(midpoint_pct)
            })
    
    if sign_changes:
        print(f"  Sign changes:")
        for sc in sign_changes:
            print(f"    Between Layer {sc['layer_before']} and {sc['layer_after']} (~{sc['midpoint_pct']:.0f}% depth)")
    else:
        print(f"  No sign changes detected")
    
    summary_data.append({
        'model': model_display,
        'n_layers': int(n_layers),
        'max_pos_r': float(max_pos_r),
        'max_pos_layer': int(max_pos_layer),
        'max_pos_pct': float(max_pos_pct),
        'max_neg_r': float(max_neg_r),
        'max_neg_layer': int(max_neg_layer),
        'max_neg_pct': float(max_neg_pct),
        'final_r': float(final_r),
        'sign_changes': sign_changes
    })

# Summary table
print("\n" + "="*70)
print(" SUMMARY TABLE")
print("="*70)
print(f"{'Model':<15} {'Layers':<8} {'Peak r':<12} {'Peak @':<10} {'Final r':<10}")
print("-" * 60)
for s in summary_data:
    print(f"{s['model']:<15} {s['n_layers']:<8} {s['max_pos_r']:+.3f}       {s['max_pos_pct']:.0f}%       {s['final_r']:+.3f}")

In [None]:
# Answer the research question
print("\n" + "#"*70)
print("# RESEARCH QUESTION ANSWER")
print("#"*70)

# Extract peak positions
peak_positions = [(s['model'], s['max_pos_pct']) for s in summary_data]

# Check if peaks are at similar normalized positions
peak_pcts = [p[1] for p in peak_positions]
peak_std = float(np.std(peak_pcts))
peak_mean = float(np.mean(peak_pcts))

print(f"\nPeak positive correlation positions (% of model depth):")
for model, pct in peak_positions:
    print(f"  {model}: {pct:.0f}%")

print(f"\nMean: {peak_mean:.1f}%")
print(f"Std:  {peak_std:.1f}%")

if peak_std < 10:
    print(f"\n>>> ANSWER: The peak position is STABLE across scales (~{peak_mean:.0f}% depth)")
    print(f"    This suggests an architectural constant, not emergent behavior.")
    peaks_stable = True
else:
    print(f"\n>>> ANSWER: The peak position SHIFTS with model size")
    print(f"    This suggests scale-dependent processing regimes.")
    peaks_stable = False

## 6. Save Results

In [None]:
# Prepare results for saving (all values explicitly converted to Python types)
save_data = {
    'timestamp': datetime.now().isoformat(),
    'research_question': 'Does the layer correlation peak shift with model size?',
    'models': all_model_results,
    'summary': summary_data,
    'peak_analysis': {
        'mean_peak_pct': float(peak_mean),
        'std_peak_pct': float(peak_std),
        'peaks_stable': bool(peaks_stable)
    }
}

# Save to JSON
output_file = f"pythia_scaling_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, 'w') as f:
    json.dump(save_data, f, indent=2)

print(f"Results saved to: {output_file}")

# Download
from google.colab import files
files.download(output_file)
files.download('pythia_scaling_comparison.png')

## 7. Interpretation Template

In [None]:
# Generate interpretation text
interpretation = f"""
## Pythia Scaling Analysis: Layer Correlation Across Model Sizes

### Research Question
> Does the layer-wise correlation pattern (where r flips from positive to negative) 
> shift with model size?

### Models Tested
"""

for s in summary_data:
    interpretation += f"- **{s['model']}** ({s['n_layers']} layers): Peak r = {s['max_pos_r']:+.3f} at {s['max_pos_pct']:.0f}% depth\n"

interpretation += f"""
### Key Finding

Peak position variability: {peak_std:.1f}% (std across models)

"""

if peaks_stable:
    interpretation += f"""
**The peak position is STABLE** across model sizes (~{peak_mean:.0f}% depth).

This suggests:
- The layer-wise correlation pattern is an architectural property, not emergent
- The "mid-layer regime" appears at a consistent relative depth
- Scale does not fundamentally change where the correlation sign flips
"""
else:
    interpretation += f"""
**The peak position SHIFTS** with model size.

This suggests:
- The correlation pattern is scale-dependent
- Larger models may develop different processing regimes
- Emergent behavior in layer specialization
"""

print(interpretation)