# Cross-Model Validation: Llama-3.1-8B Layer Analysis

**Purpose:** Validate whether the phase-structured embedding-output relationship generalizes beyond Pythia.

**Research Question:** Does Llama-3.1-8B show the same three-phase pattern?
1. Positive correlation in early/mid layers
2. Transition zone
3. Negative correlation in late layers

**Method:**
- Same 230 pairs as Pythia experiments
- Layer-wise UA analysis (every 4th layer)
- Pair-level metrics (centroid_asymmetry, etc.)
- Bootstrap CI with n=10,000

**Expected Runtime:** ~2h on A100

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03  
**Model:** Meta-Llama-3.1-8B (Base, not Instruct)

## 1. Setup

In [None]:
# Install dependencies
!pip install -q transformers accelerate torch numpy scipy matplotlib scikit-learn huggingface_hub

In [None]:
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

N_BOOTSTRAP = 10000
CI_LEVEL = 0.95

print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# ========================================
# HUGGINGFACE LOGIN
# ========================================
# Zwei Optionen:

# OPTION 1: Colab Secrets (empfohlen)
# Gehe zu: Colab Menü → Schlüssel-Symbol links → Add secret "HF_TOKEN"
try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
    print("✅ Token aus Colab Secrets geladen")
except:
    HF_TOKEN = None

# OPTION 2: Manuell eingeben falls Secrets nicht funktioniert
if not HF_TOKEN:
    HF_TOKEN = ''  # <-- Hier Token einfügen falls nötig

# OPTION 3: Interaktiver Login
if not HF_TOKEN:
    from huggingface_hub import notebook_login
    notebook_login()
    print("✅ Interaktiv eingeloggt")
else:
    from huggingface_hub import login
    login(token=HF_TOKEN)
    print("✅ Mit Token eingeloggt")

## 2. Load Model

In [None]:
MODEL_NAME = 'meta-llama/Llama-3.1-8B'
MODEL_DISPLAY = 'Llama-3.1-8B'

print(f'Loading {MODEL_DISPLAY}...')
print('NOTE: Using BASE model (not Instruct) for fair comparison with Pythia')

# Token nur übergeben wenn gesetzt (sonst nutzt es den cached Login)
token_arg = HF_TOKEN if HF_TOKEN else True  # True = use cached token

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=token_arg)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    token=token_arg,
    torch_dtype=torch.float16,
    device_map='auto',
    output_hidden_states=True
)

print(f'✅ Model loaded on: {model.device}')
print(f'Layers: {model.config.num_hidden_layers}')

## 3. Load Dataset

In [None]:
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

ALL_PAIRS = []
for cat_name, cat_data in DATASET.items():
    for pair in cat_data['pairs']:
        ALL_PAIRS.append({
            'stmt_a': pair[0],
            'stmt_b': pair[1],
            'category': cat_name
        })

print(f'Categories: {list(DATASET.keys())}')
print(f'Total pairs: {len(ALL_PAIRS)}')

## 4. Core Functions

In [None]:
def get_layer_embedding(text, model, tokenizer, layer_idx):
    '''Get mean-pooled embedding from a specific layer.'''
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states
    
    layer_hidden = hidden_states[layer_idx]
    embedding = layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)
    
    return embedding


def get_output_preference(text_a, text_b, model, tokenizer):
    '''Calculate output preference as NLL(B) - NLL(A).'''
    def get_nll(text):
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs['input_ids'])
            return outputs.loss.item()
    
    return get_nll(text_b) - get_nll(text_a)


def cosine_similarity(a, b):
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10))


def uniformity_score(embeddings):
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized = embeddings / (norms + 1e-10)
    kernel = normalized @ normalized.T
    n = kernel.shape[0]
    idx = np.triu_indices(n, k=1)
    return float(np.mean(kernel[idx]))


def bootstrap_correlation(x, y, n_bootstrap=10000, ci_level=0.95):
    n = len(x)
    r_observed, p_value = stats.pearsonr(x, y)
    
    bootstrap_rs = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        x_boot, y_boot = x[idx], y[idx]
        if np.std(x_boot) > 0 and np.std(y_boot) > 0:
            r_boot, _ = stats.pearsonr(x_boot, y_boot)
            bootstrap_rs.append(r_boot)
    
    bootstrap_rs = np.array(bootstrap_rs)
    alpha = 1 - ci_level
    ci_lower = np.percentile(bootstrap_rs, alpha/2 * 100)
    ci_upper = np.percentile(bootstrap_rs, (1 - alpha/2) * 100)
    
    return float(r_observed), float(ci_lower), float(ci_upper), float(p_value)


print('Core functions defined.')

## 5. Collect All Embeddings

In [None]:
N_LAYERS = model.config.num_hidden_layers
LAYERS_TO_TEST = list(range(0, N_LAYERS + 1, 4))
if N_LAYERS not in LAYERS_TO_TEST:
    LAYERS_TO_TEST.append(N_LAYERS)

print(f'Testing layers: {LAYERS_TO_TEST}')
print(f'Collecting embeddings for {len(ALL_PAIRS)} pairs...')
print(f'Estimated time: ~1-2 hours on A100')

pair_data = []
start_time = datetime.now()

for i, pair in enumerate(ALL_PAIRS):
    if (i + 1) % 25 == 0:
        elapsed = (datetime.now() - start_time).total_seconds() / 60
        rate = (i + 1) / elapsed if elapsed > 0 else 0
        eta = (len(ALL_PAIRS) - i - 1) / rate if rate > 0 else 0
        print(f'  [{i+1:03d}/{len(ALL_PAIRS)}] - {elapsed:.1f} min elapsed, ~{eta:.1f} min remaining')
    
    stmt_a = pair['stmt_a']
    stmt_b = pair['stmt_b']
    
    pref = get_output_preference(stmt_a, stmt_b, model, tokenizer)
    
    layer_embeddings = {}
    for layer_idx in LAYERS_TO_TEST:
        emb_a = get_layer_embedding(stmt_a, model, tokenizer, layer_idx)
        emb_b = get_layer_embedding(stmt_b, model, tokenizer, layer_idx)
        layer_embeddings[layer_idx] = {'emb_a': emb_a, 'emb_b': emb_b}
    
    pair_data.append({
        'pref': pref,
        'category': pair['category'],
        'layer_embeddings': layer_embeddings
    })

total_time = (datetime.now() - start_time).total_seconds() / 60
print(f'Done! Collected {len(pair_data)} pairs in {total_time:.1f} minutes.')

## 6. Category-Level UA Analysis (n=6)

In [None]:
print('=' * 80)
print(' CATEGORY-LEVEL UA ANALYSIS (n=6)')
print('=' * 80)

category_results = {}

for layer_idx in LAYERS_TO_TEST:
    category_uas = []
    category_prefs = []
    
    for cat_name in DATASET.keys():
        cat_pairs = [p for p in pair_data if p['category'] == cat_name]
        
        embs_a = np.array([p['layer_embeddings'][layer_idx]['emb_a'] for p in cat_pairs])
        embs_b = np.array([p['layer_embeddings'][layer_idx]['emb_b'] for p in cat_pairs])
        prefs = np.array([p['pref'] for p in cat_pairs])
        
        u_a = uniformity_score(embs_a)
        u_b = uniformity_score(embs_b)
        ua = u_a - u_b
        
        category_uas.append(ua)
        category_prefs.append(np.mean(prefs))
    
    r, p = stats.pearsonr(category_uas, category_prefs)
    
    category_results[layer_idx] = {
        'r': float(r),
        'p_value': float(p)
    }
    
    sig = '*' if p < 0.05 else ''
    print(f'Layer {layer_idx:2d}: r = {r:+.3f} (p = {p:.4f}) {sig}')

## 7. Pair-Level Metrics Analysis (n=230)

In [None]:
def compute_pair_metrics(pair_data, layer_idx):
    n_pairs = len(pair_data)
    
    all_embs_a = np.array([p['layer_embeddings'][layer_idx]['emb_a'] for p in pair_data])
    all_embs_b = np.array([p['layer_embeddings'][layer_idx]['emb_b'] for p in pair_data])
    
    centroid_a = all_embs_a.mean(axis=0)
    centroid_b = all_embs_b.mean(axis=0)
    
    metrics = {}
    
    # Centroid Asymmetry
    centroid_dist_a = np.array([cosine_similarity(emb, centroid_a) for emb in all_embs_a])
    centroid_dist_b = np.array([cosine_similarity(emb, centroid_b) for emb in all_embs_b])
    metrics['centroid_asymmetry'] = centroid_dist_a - centroid_dist_b
    
    # Cross-Centroid
    cross_dist_a = np.array([cosine_similarity(emb, centroid_b) for emb in all_embs_a])
    cross_dist_b = np.array([cosine_similarity(emb, centroid_a) for emb in all_embs_b])
    metrics['cross_centroid'] = cross_dist_a - cross_dist_b
    
    # Within-Pair Similarity
    within_pair_sim = np.array([cosine_similarity(all_embs_a[i], all_embs_b[i]) for i in range(n_pairs)])
    metrics['within_pair_sim'] = within_pair_sim
    
    return metrics

In [None]:
all_prefs = np.array([p['pref'] for p in pair_data])

print('=' * 80)
print(f' PAIR-LEVEL ANALYSIS (n={len(pair_data)})')
print('=' * 80)

pair_results = {}

for layer_idx in LAYERS_TO_TEST:
    print(f'\n--- Layer {layer_idx} ---')
    
    metrics = compute_pair_metrics(pair_data, layer_idx)
    layer_results = {}
    
    for metric_name, metric_values in metrics.items():
        r, ci_lower, ci_upper, p = bootstrap_correlation(metric_values, all_prefs, N_BOOTSTRAP, CI_LEVEL)
        
        includes_zero = ci_lower <= 0 <= ci_upper
        sig = '' if includes_zero else '***'
        
        layer_results[metric_name] = {
            'r': r, 'ci_lower': ci_lower, 'ci_upper': ci_upper,
            'p_value': p, 'includes_zero': includes_zero
        }
        
        print(f'  {metric_name:<20} r={r:+.3f}  CI=[{ci_lower:+.3f}, {ci_upper:+.3f}] {sig}')
    
    pair_results[layer_idx] = layer_results

## 8. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Category-level UA
ax1 = axes[0, 0]
cat_rs = [category_results[l]['r'] for l in LAYERS_TO_TEST]
ax1.plot(LAYERS_TO_TEST, cat_rs, 'o-', linewidth=2, markersize=8, color='blue')
ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Layer')
ax1.set_ylabel('r(UA, Output)')
ax1.set_title('Category-Level UA (n=6)', fontweight='bold')
ax1.grid(True, alpha=0.3)

# Plots 2-4: Pair-level metrics
metric_names = ['centroid_asymmetry', 'cross_centroid', 'within_pair_sim']
plot_axes = [axes[0, 1], axes[1, 0], axes[1, 1]]

for ax, metric_name in zip(plot_axes, metric_names):
    rs = [pair_results[l][metric_name]['r'] for l in LAYERS_TO_TEST]
    ci_lowers = [pair_results[l][metric_name]['ci_lower'] for l in LAYERS_TO_TEST]
    ci_uppers = [pair_results[l][metric_name]['ci_upper'] for l in LAYERS_TO_TEST]
    
    yerr_lower = [r - ci_l for r, ci_l in zip(rs, ci_lowers)]
    yerr_upper = [ci_u - r for r, ci_u in zip(rs, ci_uppers)]
    
    ax.errorbar(LAYERS_TO_TEST, rs, yerr=[yerr_lower, yerr_upper],
                fmt='o-', capsize=5, capthick=2, markersize=8, color='blue', alpha=0.7)
    
    for l, r in zip(LAYERS_TO_TEST, rs):
        if not pair_results[l][metric_name]['includes_zero']:
            ax.scatter([l], [r], color='red', s=150, zorder=5, marker='*')
    
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel('Layer')
    ax.set_ylabel('r')
    ax.set_title(f'{metric_name} (n=230)', fontweight='bold')
    ax.grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_DISPLAY}: Layer-wise Correlation Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('llama3_layer_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print('Plot saved to: llama3_layer_analysis.png')

## 9. Comparison with Pythia

In [None]:
# Pythia-6.9B reference values
PYTHIA_CENTROID = {
    0: +0.46, 4: +0.41, 8: +0.46, 12: +0.41, 16: +0.35,
    20: +0.32, 24: +0.21, 28: +0.14, 32: -0.34
}

print('=' * 80)
print(' CROSS-MODEL COMPARISON: Llama-3-8B vs Pythia-6.9B')
print('=' * 80)

print(f'\nLayer    Pythia-6.9B    Llama-3-8B    Same Sign?')
print('-' * 50)

same_sign_count = 0
for layer_idx in LAYERS_TO_TEST:
    pythia_r = PYTHIA_CENTROID.get(layer_idx, None)
    llama_r = pair_results[layer_idx]['centroid_asymmetry']['r']
    
    if pythia_r is not None:
        same_sign = (pythia_r * llama_r) > 0
        if same_sign:
            same_sign_count += 1
        sign_str = 'YES' if same_sign else 'NO'
        print(f'Layer {layer_idx:<3} {pythia_r:+.3f}          {llama_r:+.3f}          {sign_str}')

print(f'\nSame sign: {same_sign_count}/{len(PYTHIA_CENTROID)} layers')

# Phase structure analysis
print('\n--- Phase Structure Analysis ---')
early_layers = [0, 4, 8]
mid_layers = [12, 16, 20]
late_layers = [24, 28, 32]

early_mean = np.mean([pair_results[l]['centroid_asymmetry']['r'] for l in early_layers])
mid_mean = np.mean([pair_results[l]['centroid_asymmetry']['r'] for l in mid_layers])
late_mean = np.mean([pair_results[l]['centroid_asymmetry']['r'] for l in late_layers])

print(f'Early layers (0-8):   mean r = {early_mean:+.3f}')
print(f'Mid layers (12-20):   mean r = {mid_mean:+.3f}')
print(f'Late layers (24-32):  mean r = {late_mean:+.3f}')

# Determine pattern
if early_mean > 0 and mid_mean > 0 and late_mean < 0:
    pattern = 'MATCHES PYTHIA: positive early/mid, negative late'
elif early_mean > 0 and late_mean < 0:
    pattern = 'PARTIAL MATCH: positive early, negative late'
elif late_mean < 0:
    pattern = 'PARTIAL: late inversion confirmed'
else:
    pattern = 'DIFFERENT PATTERN'

print(f'\n>>> {pattern} <<<')

## 10. Save Results

In [None]:
save_data = {
    'timestamp': datetime.now().isoformat(),
    'model': MODEL_NAME,
    'model_display': MODEL_DISPLAY,
    'n_layers': N_LAYERS,
    'layers_tested': LAYERS_TO_TEST,
    'n_pairs': len(pair_data),
    'n_bootstrap': N_BOOTSTRAP,
    'category_level_results': {str(k): v for k, v in category_results.items()},
    'pair_level_results': {
        str(k): {m: dict(v) for m, v in layer_res.items()}
        for k, layer_res in pair_results.items()
    },
    'phase_structure': {
        'early_mean': float(early_mean),
        'mid_mean': float(mid_mean),
        'late_mean': float(late_mean),
        'pattern': pattern
    },
    'pythia_comparison': {
        'same_sign_layers': same_sign_count,
        'total_compared': len(PYTHIA_CENTROID)
    }
}

output_file = f'llama3_cross_validation_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
with open(output_file, 'w') as f:
    json.dump(save_data, f, indent=2)

print(f'Results saved to: {output_file}')

from google.colab import files
files.download(output_file)
files.download('llama3_layer_analysis.png')

## 11. Summary

In [None]:
print('#' * 80)
print(f'# CROSS-MODEL VALIDATION SUMMARY: {MODEL_DISPLAY}')
print('#' * 80)

print(f'''
Pattern: {pattern}

Phase Structure (centroid_asymmetry):
  Early (0-8):   {early_mean:+.3f}
  Mid (12-20):   {mid_mean:+.3f}
  Late (24-32):  {late_mean:+.3f}

Pythia Comparison: {same_sign_count}/{len(PYTHIA_CENTROID)} layers same sign

Generated: {datetime.now().isoformat()}
''')