# L* Cross-Heritage Validation (v3 - Sign Change Definition)

**Paper #3: Thermodynamic Constraints in Transformer Architectures**

**Author:** Davide D'Elia

**Date:** 2026-01-06

---

## Purpose

Validate the L* transition point formula using the **correct definition**:

> **L* = Layer where d/dl[Tr(L_F)] changes sign**

This is the inflection point where trace derivative goes from positive to negative.

## v3 Fixes (vs v2)

- **CRITICAL**: Changed L* definition from `argmax(|gradient|)` to `sign_change(gradient)`
- This matches the original calibration methodology

## The Formula

```
L* = L × (0.11 + 0.012×L + 4.9/H)
```

---

In [None]:
# Cell 1: Setup
!pip install -q transformers accelerate scipy seaborn pandas huggingface_hub

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from tqdm.auto import tqdm
import json
import gc
import os
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Configure visualization
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("paper", font_scale=1.2)

# Global timestamp
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"Session timestamp: {TIMESTAMP}")

# HF TOKEN
HF_TOKEN = None
try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
    if HF_TOKEN:
        print(f"HF_TOKEN loaded")
except:
    HF_TOKEN = os.environ.get('HF_TOKEN')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Model Configuration

In [None]:
# Models to test - same as calibration set + cross-heritage
MODELS_TO_TEST = {
    # EleutherAI (calibration heritage)
    "EleutherAI/pythia-160m": {"lab": "EleutherAI", "L": 12, "H": 12, "expected": "DAMPEN", "calibration_L_star": 7},
    "EleutherAI/pythia-410m": {"lab": "EleutherAI", "L": 24, "H": 16, "expected": "DAMPEN", "calibration_L_star": 16},
    # Meta (cross-heritage)
    "facebook/opt-125m": {"lab": "Meta", "L": 12, "H": 12, "expected": "EXPAND", "calibration_L_star": 8},
    "facebook/opt-350m": {"lab": "Meta", "L": 24, "H": 16, "expected": "EXPAND", "calibration_L_star": None},
    # BigScience (ALiBi - cross-heritage)
    "bigscience/bloom-560m": {"lab": "BigScience", "L": 24, "H": 16, "expected": "EXPAND", "calibration_L_star": None},
    # OpenAI (cross-heritage)
    "openai-community/gpt2": {"lab": "OpenAI", "L": 12, "H": 12, "expected": "EXPAND", "calibration_L_star": 9},
}

TEST_PROMPTS = [
    "The capital of France is Paris, which is known for the Eiffel Tower.",
    "In mathematics, the derivative of x squared equals two times x.",
    "Climate change affects global temperatures and weather patterns significantly.",
    "The quick brown fox jumps over the lazy dog near the riverbank.",
    "Once upon a time in a land far away, there lived a wise old king.",
]

def predict_l_star_v3(L, H):
    """L* = L × (0.11 + 0.012×L + 4.9/H)"""
    return L * (0.11 + 0.012 * L + 4.9 / H)

print(f"Models: {len(MODELS_TO_TEST)}, Prompts: {len(TEST_PROMPTS)}")

## 3. Architecture-Aware Functions

In [None]:
def get_architecture_info(model):
    """Detect model architecture type."""
    if hasattr(model, 'gpt_neox'):
        return 'pythia'
    elif hasattr(model, 'model') and hasattr(model.model, 'decoder'):
        return 'opt'
    elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        layer = model.transformer.h[0]
        if hasattr(layer, 'self_attention'):
            return 'bloom'
        elif hasattr(layer, 'attn'):
            return 'gpt2'
    elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
        return 'llama'
    return 'unknown'


def get_layers(model, arch):
    """Get transformer layers based on architecture."""
    if arch == 'pythia':
        return model.gpt_neox.layers
    elif arch == 'opt':
        return model.model.decoder.layers
    elif arch in ['bloom', 'gpt2']:
        return model.transformer.h
    elif arch == 'llama':
        return model.model.layers
    return []


def get_W_V(model, arch, layer_idx):
    """Extract W_V matrix for a specific layer."""
    try:
        layers = get_layers(model, arch)
        layer = layers[layer_idx]
        
        if arch == 'pythia':
            qkv = layer.attention.query_key_value.weight.data.float()
            d = qkv.shape[0] // 3
            return qkv[2*d:, :].cpu()
            
        elif arch == 'opt':
            return layer.self_attn.v_proj.weight.data.float().cpu()
            
        elif arch == 'bloom':
            qkv = layer.self_attention.query_key_value.weight.data.float()
            d = qkv.shape[0] // 3
            return qkv[2*d:, :].cpu()
            
        elif arch == 'gpt2':
            c_attn = layer.attn.c_attn.weight.data.float()
            d = c_attn.shape[1] // 3
            return c_attn[:, 2*d:].T.cpu()
            
        elif arch == 'llama':
            return layer.self_attn.v_proj.weight.data.float().cpu()
            
    except Exception as e:
        print(f"    W_V extraction error (layer {layer_idx}): {e}")
    
    return None

print("Architecture-aware functions defined.")

## 4. L* Computation with SIGN CHANGE Definition

**CRITICAL FIX**: L* is defined as the layer where the trace derivative **changes sign** (from positive to negative), NOT the layer of maximum gradient magnitude.

In [None]:
def find_l_star_sign_change(traces):
    """
    Find L* using the CORRECT definition: layer where trace derivative changes sign.
    
    L* = layer where d/dl[Tr(L_F)] goes from positive to negative
    
    This is the inflection point / peak of the trace curve.
    
    Returns:
        L_star: Layer index where sign change occurs (or peak if monotonic)
        method: 'sign_change', 'peak', or 'fallback'
    """
    traces_arr = np.array(traces)
    n_layers = len(traces_arr)
    
    if n_layers < 3:
        return n_layers // 2, 'fallback'
    
    # Compute gradients (first derivative)
    gradients = np.diff(traces_arr)
    
    # Method 1: Find sign change (+ to -)
    # This indicates the peak of the trace curve
    for i in range(len(gradients) - 1):
        if gradients[i] > 0 and gradients[i+1] < 0:
            # Sign change at layer i+1 (0-indexed)
            return i + 1, 'sign_change'
    
    # Method 2: If no sign change, find peak of trace
    # (trace might be monotonically increasing then plateau)
    peak_idx = int(np.argmax(traces_arr))
    if peak_idx > 0 and peak_idx < n_layers - 1:
        return peak_idx, 'peak'
    
    # Method 3: Find where gradient magnitude drops significantly
    # (indicates transition from steep to flat)
    grad_magnitude = np.abs(gradients)
    if len(grad_magnitude) > 1:
        # Find where gradient drops below 50% of max
        max_grad = np.max(grad_magnitude)
        for i in range(len(grad_magnitude)):
            if grad_magnitude[i] >= max_grad * 0.5:
                # Last layer with significant gradient
                last_significant = i
        return last_significant + 1, 'gradient_drop'
    
    # Fallback: midpoint
    return n_layers // 2, 'fallback'


def compute_traces_and_l_star(model, tokenizer, prompt, arch, device='cuda'):
    """
    Compute Sheaf Laplacian trace for each layer and find L* via sign change.
    
    Trace formula: Tr(L_F) = (sum(A) - n) * ||W_V||_F^2
    """
    model.eval()
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    n_tokens = inputs["input_ids"].shape[1]
    
    # Forward pass with attention output
    with torch.no_grad():
        outputs = model(
            **inputs,
            output_attentions=True,
            output_hidden_states=True
        )
    
    attentions = outputs.attentions
    
    if attentions is None:
        print(f"    WARNING: attentions is None!")
        return None, None, None
    
    n_layers = len(attentions)
    traces = []
    
    for layer_idx in range(n_layers):
        attn = attentions[layer_idx]
        if attn is None:
            traces.append(0.0)
            continue
            
        # Average attention over heads
        A = attn[0].float().mean(dim=0).cpu()  # (seq, seq)
        
        # Get W_V
        W_V = get_W_V(model, arch, layer_idx)
        
        # Compute trace: Tr(L_F) = (sum(A) - n) * ||W_V||_F^2
        if W_V is not None:
            A_sum = A.sum().item()
            W_V_frob_sq = (W_V ** 2).sum().item()
            trace = abs((A_sum - n_tokens) * W_V_frob_sq)
        else:
            A_sum = A.sum().item()
            trace = abs(A_sum - n_tokens)
        
        traces.append(trace)
    
    # Find L* using SIGN CHANGE definition
    L_star, method = find_l_star_sign_change(traces)
    
    return traces, L_star, method

print("Trace computation with SIGN CHANGE L* definition ready.")

## 5. Run Validation

In [None]:
results = []

print("=" * 80)
print("L* CROSS-HERITAGE VALIDATION (v3 - SIGN CHANGE DEFINITION)")
print("=" * 80)

for model_name, config in tqdm(MODELS_TO_TEST.items(), desc="Models"):
    print(f"\n{'='*70}")
    print(f"Model: {model_name}")
    print(f"Lab: {config['lab']} | L={config['L']} | H={config['H']}")
    if config.get('calibration_L_star'):
        print(f"Calibration L*: {config['calibration_L_star']}")
    
    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            token=HF_TOKEN if HF_TOKEN else None
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model with attention output
        print("  Loading model...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            token=HF_TOKEN if HF_TOKEN else None,
            trust_remote_code=True,
            attn_implementation="eager",
            output_attentions=True,
            output_hidden_states=True
        )
        model.eval()
        
        # Detect architecture
        arch = get_architecture_info(model)
        print(f"  Architecture: {arch}")
        
        # Test prompts
        all_l_stars = []
        all_traces = []
        all_methods = []
        
        for i, prompt in enumerate(TEST_PROMPTS):
            traces, l_star, method = compute_traces_and_l_star(model, tokenizer, prompt, arch)
            if traces is not None:
                all_traces.append(traces)
                all_l_stars.append(l_star)
                all_methods.append(method)
                print(f"    Prompt {i+1}: L* = {l_star} ({method})")
            else:
                print(f"    Prompt {i+1}: FAILED")
        
        if all_l_stars:
            L_star_empirical = np.mean(all_l_stars)
            L_star_std = np.std(all_l_stars)
            L_star_predicted = predict_l_star_v3(config["L"], config["H"])
            error = abs(L_star_predicted - L_star_empirical) / config["L"] * 100
            
            # Most common method
            from collections import Counter
            method_counts = Counter(all_methods)
            dominant_method = method_counts.most_common(1)[0][0]
            
            result = {
                "model": model_name,
                "lab": config["lab"],
                "L": config["L"],
                "H": config["H"],
                "arch": arch,
                "L_star_predicted": float(L_star_predicted),
                "L_star_empirical": float(L_star_empirical),
                "L_star_std": float(L_star_std),
                "L_star_calibration": config.get('calibration_L_star'),
                "error_pct": float(error),
                "detection_method": dominant_method,
                "individual_L_stars": all_l_stars,
            }
            results.append(result)
            
            print(f"\n  RESULTS (Sign Change Definition):")
            print(f"    L* predicted:    {L_star_predicted:.1f}")
            print(f"    L* empirical:    {L_star_empirical:.1f} +/- {L_star_std:.1f}")
            if config.get('calibration_L_star'):
                print(f"    L* calibration:  {config['calibration_L_star']}")
                calib_diff = abs(L_star_empirical - config['calibration_L_star'])
                print(f"    Calib. diff:     {calib_diff:.1f}")
            print(f"    Error:           {error:.1f}%")
            print(f"    Method:          {dominant_method}")
        
        # Cleanup
        del model, tokenizer
        gc.collect()
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"  ERROR: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*80}")
print(f"COMPLETE: {len(results)}/{len(MODELS_TO_TEST)} models")
print("="*80)

## 6. Results Analysis

In [None]:
if results:
    df = pd.DataFrame(results)
    
    print("\n" + "="*80)
    print("CROSS-HERITAGE L* VALIDATION RESULTS (SIGN CHANGE DEFINITION)")
    print("="*80)
    
    # Display key columns
    display_cols = ['model', 'lab', 'L', 'H', 'L_star_predicted', 'L_star_empirical', 
                    'L_star_calibration', 'error_pct', 'detection_method']
    print(df[display_cols].to_string(index=False))
    
    # Calibration consistency check
    print("\n" + "="*80)
    print("CALIBRATION CONSISTENCY CHECK")
    print("="*80)
    
    for _, row in df.iterrows():
        if row['L_star_calibration'] is not None:
            diff = abs(row['L_star_empirical'] - row['L_star_calibration'])
            status = "MATCH" if diff < 2 else "MISMATCH"
            print(f"{row['model'].split('/')[-1]:20} Empirical={row['L_star_empirical']:.1f} Calib={row['L_star_calibration']} Diff={diff:.1f} [{status}]")
    
    # Summary by lab
    print("\n" + "="*80)
    print("SUMMARY BY LAB")
    print("="*80)
    
    for lab in sorted(df['lab'].unique()):
        lab_df = df[df['lab'] == lab]
        errors = lab_df['error_pct'].values
        print(f"\n{lab}: n={len(lab_df)}, MAPE={np.mean(errors):.1f}%")
    
    # Overall
    overall_mape = df['error_pct'].mean()
    print(f"\n{'='*80}")
    print(f"OVERALL MAPE: {overall_mape:.1f}%")
    print("="*80)
    
    # Compare with v2 (argmax) results
    print(f"\n{'='*80}")
    print("COMPARISON: Sign Change vs Argmax Gradient")
    print("="*80)
    print("v2 (argmax):      15.7% MAPE")
    print(f"v3 (sign change): {overall_mape:.1f}% MAPE")
    improvement = 15.7 - overall_mape
    print(f"Improvement:      {improvement:+.1f}pp")
else:
    print("No results!")

## 7. Visualization

In [None]:
if results:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    lab_colors = {
        "EleutherAI": "#E74C3C",
        "Meta": "#3498DB",
        "BigScience": "#27AE60",
        "OpenAI": "#9B59B6"
    }
    
    # Plot 1: Predicted vs Empirical
    ax1 = axes[0]
    for r in results:
        color = lab_colors.get(r['lab'], 'gray')
        ax1.scatter(r['L_star_predicted'], r['L_star_empirical'],
                   c=color, s=150, alpha=0.8, edgecolors='white', linewidths=2)
        ax1.annotate(r['model'].split('/')[-1],
                    (r['L_star_predicted'], r['L_star_empirical']),
                    fontsize=8, xytext=(5, 5), textcoords='offset points')
    
    max_val = max(max(r['L_star_predicted'] for r in results),
                  max(r['L_star_empirical'] for r in results))
    ax1.plot([0, max_val*1.1], [0, max_val*1.1], 'k--', alpha=0.5, label='Perfect')
    ax1.set_xlabel('L* Predicted', fontsize=12)
    ax1.set_ylabel('L* Empirical (Sign Change)', fontsize=12)
    ax1.set_title('L* Validation (Sign Change Definition)', fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Plot 2: Error by Lab
    ax2 = axes[1]
    labs = sorted(set(r['lab'] for r in results))
    for i, lab in enumerate(labs):
        lab_results = [r for r in results if r['lab'] == lab]
        errors = [r['error_pct'] for r in lab_results]
        ax2.bar(i, np.mean(errors), color=lab_colors.get(lab, 'gray'), alpha=0.7)
        ax2.scatter([i]*len(errors), errors, c='black', s=50, zorder=5)
    
    ax2.axhline(y=10, color='green', linestyle='--', label='10% threshold')
    ax2.axhline(y=15, color='orange', linestyle='--', label='15% threshold')
    ax2.set_xticks(range(len(labs)))
    ax2.set_xticklabels(labs, rotation=45, ha='right')
    ax2.set_ylabel('Error (%)', fontsize=12)
    ax2.set_title('L* Formula Error by Lab', fontsize=14)
    ax2.legend()
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Plot 3: Empirical vs Calibration (where available)
    ax3 = axes[2]
    calib_results = [r for r in results if r['L_star_calibration'] is not None]
    if calib_results:
        for r in calib_results:
            color = lab_colors.get(r['lab'], 'gray')
            ax3.scatter(r['L_star_calibration'], r['L_star_empirical'],
                       c=color, s=150, alpha=0.8, edgecolors='white', linewidths=2)
            ax3.annotate(r['model'].split('/')[-1],
                        (r['L_star_calibration'], r['L_star_empirical']),
                        fontsize=8, xytext=(5, 5), textcoords='offset points')
        
        max_calib = max(max(r['L_star_calibration'] for r in calib_results),
                        max(r['L_star_empirical'] for r in calib_results))
        ax3.plot([0, max_calib*1.1], [0, max_calib*1.1], 'k--', alpha=0.5)
        ax3.set_xlabel('L* Calibration (Original)', fontsize=12)
        ax3.set_ylabel('L* Empirical (This Run)', fontsize=12)
        ax3.set_title('Calibration Consistency', fontsize=14)
        ax3.grid(True, alpha=0.3)
    else:
        ax3.text(0.5, 0.5, 'No calibration data', ha='center', va='center', fontsize=14)
        ax3.set_title('Calibration Consistency', fontsize=14)
    
    plt.tight_layout()
    PNG_FILE = f"l_star_sign_change_{TIMESTAMP}.png"
    plt.savefig(PNG_FILE, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {PNG_FILE}")

## 8. Final Verdict

In [None]:
def make_serializable(obj):
    """Convert numpy types to Python native types for JSON serialization."""
    if isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_serializable(v) for v in obj]
    elif isinstance(obj, (np.integer, np.floating)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    return obj


if results:
    overall_mape = np.mean([r['error_pct'] for r in results])
    
    # Check calibration consistency
    calib_matches = 0
    calib_total = 0
    for r in results:
        if r['L_star_calibration'] is not None:
            calib_total += 1
            if abs(r['L_star_empirical'] - r['L_star_calibration']) < 2:
                calib_matches += 1
    
    if overall_mape < 10:
        verdict = "FORMULA VALIDATED"
    elif overall_mape < 15:
        verdict = "PARTIAL GENERALIZATION"
    else:
        verdict = "CALIBRATION NEEDED"
    
    print("\n" + "#"*80)
    print("#" + " "*30 + "FINAL VERDICT" + " "*33 + "#")
    print("#"*80)
    print(f"""
    L* Definition: SIGN CHANGE (d/dl[Tr] changes from + to -)
    
    Formula: L* = L × (0.11 + 0.012×L + 4.9/H)
    
    Models tested:       {len(results)}
    Labs tested:         {len(set(r['lab'] for r in results))}
    Overall MAPE:        {overall_mape:.1f}%
    
    Calibration check:   {calib_matches}/{calib_total} matches (tolerance ±2)
    
    v2 (argmax):         15.7% MAPE
    v3 (sign change):    {overall_mape:.1f}% MAPE
    
    VERDICT: {verdict}
    """)
    print("#"*80)
    
    # Save
    output = {
        "experiment": "L* Cross-Heritage Validation v3 (Sign Change)",
        "timestamp": TIMESTAMP,
        "l_star_definition": "Layer where d/dl[Tr(L_F)] changes sign (+ to -)",
        "formula": "L* = L × (0.11 + 0.012×L + 4.9/H)",
        "n_models": len(results),
        "overall_mape": float(overall_mape),
        "calibration_consistency": f"{calib_matches}/{calib_total}",
        "comparison_v2_argmax": 15.7,
        "improvement_pp": float(15.7 - overall_mape),
        "verdict": verdict,
        "results": make_serializable(results)
    }
    
    JSON_FILE = f"l_star_sign_change_{TIMESTAMP}.json"
    with open(JSON_FILE, 'w') as f:
        json.dump(output, f, indent=2)
    print(f"\nSaved: {JSON_FILE}")

In [None]:
# Download
try:
    from google.colab import files
    if 'JSON_FILE' in dir():
        files.download(JSON_FILE)
    if 'PNG_FILE' in dir():
        files.download(PNG_FILE)
    print("Downloads started!")
except:
    print("Files saved locally.")