# E11-T: GQA vs MHA - Territorial Collapse Comparison

**Paper 4: Behavioral Sink Dynamics**

## Motivation

E11 found **NO territorial collapse** in Mistral (MHA architecture):
- Specialization Index: Base 0.7492 → Instruct 0.7806 (+4.2% **INCREASED**)
- Head Correlation: Base 0.2508 → Instruct 0.2194 (-12.5% **DECREASED**)

**E11-T tests whether GQA (Grouped Query Attention) changes this pattern.**

## GQA Architecture

| Architecture | Query Heads | KV Heads | Ratio |
|--------------|-------------|----------|-------|
| MHA (Mistral) | 32 | 32 | 1:1 |
| GQA (LLaMA-3.1) | 32 | 8 | 4:1 |

GQA forces **KV sharing** between query heads:
- 4 query heads share 1 KV head
- This creates inherent "grouping" = potential forced uniformity?
- Hypothesis: GQA might show more "territorial collapse" due to structural constraints

## Comparison Points

1. Does GQA show Specialization Index decrease (unlike MHA)?
2. Is Head Correlation higher in GQA due to KV sharing?
3. Do the 8 KV heads show different specialization patterns?

---

In [None]:
# Cell 1: Setup
!pip install -q transformers torch accelerate bitsandbytes scipy matplotlib seaborn

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from scipy.stats import entropy as scipy_entropy
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist, squareform
import json
import hashlib
import warnings
warnings.filterwarnings('ignore')

import os
from pathlib import Path
from datetime import datetime

# E11-v3 STANDARD: 3-Seed Reproducibility
SEEDS = [42, 123, 456]
os.environ['PYTHONHASHSEED'] = '42'

TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
Path('results').mkdir(parents=True, exist_ok=True)
Path('figures').mkdir(parents=True, exist_ok=True)
print(f"Timestamp: {TIMESTAMP}")
print(f"E11-v3 Standard: 3-seed averaging with {SEEDS}")

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# HF Login for gated models (LLaMA) - REQUIRED!
try:
    from google.colab import userdata
    from huggingface_hub import login
    hf_token = userdata.get('HF_TOKEN')
    if hf_token:
        login(token=hf_token)
        print("HF Login: SUCCESS (required for LLaMA)")
    else:
        print("WARNING: No HF_TOKEN found! LLaMA requires authentication.")
        print("Go to: Runtime → Secrets → Add HF_TOKEN")
except:
    print("Not in Colab - ensure HF_TOKEN is set via huggingface-cli login")

In [None]:
# Cell 2: Configuration

# =============================================================================
# E11-v3 METHODOLOGY STANDARD
# =============================================================================

# LLaMA-3.1 GQA Configuration
GQA_CONFIG = {
    'base': 'meta-llama/Llama-3.1-8B',
    'instruct': 'meta-llama/Llama-3.1-8B-Instruct',
    'params': '8B',
    'num_query_heads': 32,
    'num_kv_heads': 8,       # GQA: 4:1 ratio
    'd_head': 128,
    'num_layers': 32,
    'architecture': 'GQA'
}

# E11 Mistral MHA Reference (from previous run)
E11_MHA_REFERENCE = {
    'model': 'Mistral-7B-v0.3',
    'architecture': 'MHA',
    'base_specialization': 0.7492,
    'instruct_specialization': 0.7806,
    'delta_specialization': +0.0314,
    'base_correlation': 0.2508,
    'instruct_correlation': 0.2194,
    'delta_correlation': -0.0314,
    'verdict': 'C_REFUTED'
}

# E11-v3 Standard Parameters
MAX_LENGTH = 128
DTYPE = torch.bfloat16  # E11-v3: bfloat16 (NOT float16!)
EXPECTED_MD5 = "715065bab181f46bf12ed471951141e2"

# Standard-10 v3 Prompt Set (CANONICAL - DO NOT MODIFY!)
STANDARD_PROMPTS = [
    "What is the capital of France and what is its population?",
    "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Explain step by step.",
    "Calculate 47 multiplied by 23 and show your work.",
    "Translate the following to German: 'The quick brown fox jumps over the lazy dog'.",
    "Write a Python function that checks if a number is prime.",
    "Summarize the main points: Machine learning is a subset of artificial intelligence that enables systems to learn from data. It uses algorithms to identify patterns and make decisions with minimal human intervention.",
    "Statement A: 'All birds can fly.' Statement B: 'Penguins are birds that cannot fly.' Are these statements contradictory? Explain.",
    "What are the safety considerations when using a kitchen knife?",
    "Write a haiku about artificial intelligence.",
    "Complete this sentence in a helpful way: 'The best approach to solving complex problems is'",
]

# Verify prompts haven't been modified
def verify_prompts():
    prompt_string = '|||'.join(STANDARD_PROMPTS)
    actual_md5 = hashlib.md5(prompt_string.encode()).hexdigest()
    return actual_md5, actual_md5 == EXPECTED_MD5

actual_md5, md5_valid = verify_prompts()
if not md5_valid:
    raise ValueError(f"PROMPT INTEGRITY ERROR! Expected {EXPECTED_MD5}, got {actual_md5}")

print(f"E11-T: GQA vs MHA Comparison (E11-v3 Standard)")
print(f"\n=== METHODOLOGY ===")
print(f"Seeds: {SEEDS}")
print(f"MAX_LENGTH: {MAX_LENGTH}")
print(f"dtype: {DTYPE}")
print(f"Prompt MD5: {actual_md5} ({'✅ VALID' if md5_valid else '❌ INVALID'})")
print(f"\n=== GQA MODEL ===")
print(f"Model: {GQA_CONFIG['base']}")
print(f"Query Heads: {GQA_CONFIG['num_query_heads']}, KV Heads: {GQA_CONFIG['num_kv_heads']}")
print(f"GQA Ratio: {GQA_CONFIG['num_query_heads'] // GQA_CONFIG['num_kv_heads']}:1")
print(f"\n=== REFERENCE (MHA) ===")
print(f"Model: {E11_MHA_REFERENCE['model']}")
print(f"Verdict: {E11_MHA_REFERENCE['verdict']}")

In [None]:
# Cell 3: Head Specialization Metrics (same as E11)

def extract_head_activations(model, tokenizer, prompts, max_length=128):
    """
    Extract per-head activation patterns across prompts.
    
    For GQA models:
    - attention output is (batch, num_query_heads, seq, seq)
    - but internally KV are shared (4 query heads : 1 KV head)
    - We measure query head behavior (what model actually computes)
    """
    all_attention_patterns = []
    
    for prompt in prompts:
        inputs = tokenizer(
            prompt, 
            return_tensors='pt',
            max_length=max_length,
            truncation=True,
            padding='max_length'
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
        
        # Stack attention patterns: (num_layers, num_heads, seq, seq)
        attn_stack = torch.stack([a.squeeze(0) for a in outputs.attentions], dim=0)
        all_attention_patterns.append(attn_stack.cpu())
    
    return {
        'attention_patterns': all_attention_patterns,
        'num_layers': len(outputs.attentions),
        'num_heads': outputs.attentions[0].shape[1]  # Query heads
    }


def compute_head_entropy_profiles(attention_patterns):
    """Compute normalized entropy for each head across prompts."""
    num_prompts = len(attention_patterns)
    num_layers = attention_patterns[0].shape[0]
    num_heads = attention_patterns[0].shape[1]
    
    all_entropies = np.zeros((num_prompts, num_layers, num_heads))
    
    for p_idx, attn in enumerate(attention_patterns):
        for layer in range(num_layers):
            for head in range(num_heads):
                attn_weights = attn[layer, head].mean(dim=0).float().cpu().numpy()
                attn_weights = attn_weights / attn_weights.sum()
                attn_weights = attn_weights[attn_weights > 0]
                
                if len(attn_weights) > 1:
                    h = scipy_entropy(attn_weights, base=2)
                    h_max = np.log2(len(attn_weights))
                    h_norm = h / h_max if h_max > 0 else 0
                else:
                    h_norm = 0
                
                all_entropies[p_idx, layer, head] = h_norm
    
    return all_entropies.mean(axis=0)


def compute_specialization_metrics(head_entropies):
    """Compute metrics for territorial collapse / specialization loss."""
    num_layers, num_heads = head_entropies.shape
    
    # 1. Head Variance per Layer
    layer_variances = np.var(head_entropies, axis=1)
    mean_variance = float(np.mean(layer_variances))
    
    # 2. Inter-Head Correlation
    head_profiles = head_entropies.T  # (num_heads, num_layers)
    head_corr_matrix = np.corrcoef(head_profiles)
    upper_tri = head_corr_matrix[np.triu_indices(num_heads, k=1)]
    mean_head_correlation = float(np.nanmean(upper_tri))
    
    # 3. Specialization Index = 1 - mean_correlation
    specialization_index = 1.0 - mean_head_correlation
    
    # 4. Effective Number of Heads (participation ratio)
    head_contributions = np.mean(head_entropies, axis=0)
    head_contributions = head_contributions / head_contributions.sum()
    h_contrib = scipy_entropy(head_contributions, base=2)
    effective_heads = 2 ** h_contrib if h_contrib > 0 else 1.0
    effective_ratio = effective_heads / num_heads
    
    # 5. Layer-wise specialization
    third = num_layers // 3
    early_var = float(np.mean(layer_variances[:third]))
    middle_var = float(np.mean(layer_variances[third:2*third]))
    late_var = float(np.mean(layer_variances[2*third:]))
    
    return {
        'mean_head_variance': mean_variance,
        'mean_head_correlation': mean_head_correlation,
        'specialization_index': specialization_index,
        'effective_heads': float(effective_heads),
        'effective_ratio': float(effective_ratio),
        'layer_variances': layer_variances.tolist(),
        'early_variance': early_var,
        'middle_variance': middle_var,
        'late_variance': late_var,
        'head_correlation_matrix': head_corr_matrix.tolist(),
        'num_layers': num_layers,
        'num_heads': num_heads
    }


def compute_gqa_group_metrics(head_entropies, num_kv_heads=8):
    """
    GQA-specific: Compute specialization within KV groups.
    
    In LLaMA-3.1 GQA (4:1 ratio):
    - Query heads 0-3 share KV head 0
    - Query heads 4-7 share KV head 1
    - etc.
    
    This measures: Do query heads within the same KV group specialize differently?
    """
    num_layers, num_query_heads = head_entropies.shape
    group_size = num_query_heads // num_kv_heads  # 4 for LLaMA-3.1
    
    # Within-group variance (heads sharing same KV)
    within_group_vars = []
    for g in range(num_kv_heads):
        start = g * group_size
        end = start + group_size
        group_entropies = head_entropies[:, start:end]  # (num_layers, group_size)
        group_var = np.var(group_entropies, axis=1).mean()  # Mean variance across layers
        within_group_vars.append(group_var)
    
    mean_within_group_var = float(np.mean(within_group_vars))
    
    # Between-group variance (comparing KV group means)
    group_means = []
    for g in range(num_kv_heads):
        start = g * group_size
        end = start + group_size
        group_mean = head_entropies[:, start:end].mean(axis=1)  # (num_layers,)
        group_means.append(group_mean)
    
    group_means = np.array(group_means).T  # (num_layers, num_kv_heads)
    between_group_vars = np.var(group_means, axis=1)  # (num_layers,)
    mean_between_group_var = float(np.mean(between_group_vars))
    
    # Ratio: Within-group / Between-group
    # Low ratio = groups are internally homogeneous but externally different (good specialization)
    # High ratio = groups are internally heterogeneous (KV sharing doesn't constrain)
    var_ratio = mean_within_group_var / mean_between_group_var if mean_between_group_var > 0 else np.inf
    
    return {
        'within_group_variance': mean_within_group_var,
        'between_group_variance': mean_between_group_var,
        'within_between_ratio': float(var_ratio),
        'group_size': group_size,
        'num_kv_groups': num_kv_heads,
        'per_group_variances': within_group_vars
    }

print("Specialization metrics functions loaded (with GQA extensions).")

In [None]:
# Cell 4: Load and Analyze BASE Model (LLaMA-3.1-8B) with 3-Seed Averaging

results = {'pair': 'llama31_gqa', 'base': {}, 'instruct': {}, 'config': GQA_CONFIG}
seed_results_base = []

print(f"\n{'='*60}")
print(f"E11-T: GQA TERRITORIAL COLLAPSE - LLAMA-3.1-8B (E11-v3)")
print(f"{'='*60}")

print(f"\n[1/4] Loading BASE: {GQA_CONFIG['base']}")

tokenizer_base = AutoTokenizer.from_pretrained(GQA_CONFIG['base'])
model_base = AutoModelForCausalLM.from_pretrained(
    GQA_CONFIG['base'],
    torch_dtype=DTYPE,  # E11-v3: bfloat16
    device_map='auto',
    trust_remote_code=True,
    attn_implementation="eager"  # CRITICAL: SDPA doesn't return attentions!
)

# CRITICAL: Set eval mode to disable dropout
model_base.eval()

if tokenizer_base.pad_token is None:
    tokenizer_base.pad_token = tokenizer_base.eos_token

print(f"\n[2/4] Extracting BASE head activations (3-seed average)...")

for seed in SEEDS:
    print(f"\n  --- Seed {seed} ---")
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    base_activations = extract_head_activations(model_base, tokenizer_base, STANDARD_PROMPTS, max_length=MAX_LENGTH)
    base_entropies = compute_head_entropy_profiles(base_activations['attention_patterns'])
    spec_metrics = compute_specialization_metrics(base_entropies)
    gqa_metrics = compute_gqa_group_metrics(base_entropies, num_kv_heads=GQA_CONFIG['num_kv_heads'])
    
    seed_results_base.append({
        'seed': seed,
        'specialization': spec_metrics,
        'gqa_metrics': gqa_metrics,
        'entropies': base_entropies.tolist()
    })
    print(f"  Specialization Index: {spec_metrics['specialization_index']:.4f}")
    print(f"  Mean Head Correlation: {spec_metrics['mean_head_correlation']:.4f}")

# Aggregate across seeds (mean)
print(f"\n  Computing 3-seed average...")
avg_si = np.mean([r['specialization']['specialization_index'] for r in seed_results_base])
avg_corr = np.mean([r['specialization']['mean_head_correlation'] for r in seed_results_base])
avg_var = np.mean([r['specialization']['mean_head_variance'] for r in seed_results_base])
avg_gqa_ratio = np.mean([r['gqa_metrics']['within_between_ratio'] for r in seed_results_base])

# Store aggregated results
results['base']['specialization'] = {
    'specialization_index': float(avg_si),
    'mean_head_correlation': float(avg_corr),
    'mean_head_variance': float(avg_var),
    'effective_heads': float(np.mean([r['specialization']['effective_heads'] for r in seed_results_base])),
    'effective_ratio': float(np.mean([r['specialization']['effective_ratio'] for r in seed_results_base])),
    'layer_variances': seed_results_base[0]['specialization']['layer_variances'],  # Use first seed for structure
    'early_variance': float(np.mean([r['specialization']['early_variance'] for r in seed_results_base])),
    'middle_variance': float(np.mean([r['specialization']['middle_variance'] for r in seed_results_base])),
    'late_variance': float(np.mean([r['specialization']['late_variance'] for r in seed_results_base])),
    'head_correlation_matrix': seed_results_base[0]['specialization']['head_correlation_matrix'],
    'num_layers': seed_results_base[0]['specialization']['num_layers'],
    'num_heads': seed_results_base[0]['specialization']['num_heads']
}
results['base']['gqa_metrics'] = {
    'within_group_variance': float(np.mean([r['gqa_metrics']['within_group_variance'] for r in seed_results_base])),
    'between_group_variance': float(np.mean([r['gqa_metrics']['between_group_variance'] for r in seed_results_base])),
    'within_between_ratio': float(avg_gqa_ratio),
    'group_size': seed_results_base[0]['gqa_metrics']['group_size'],
    'num_kv_groups': seed_results_base[0]['gqa_metrics']['num_kv_groups']
}
results['base']['seed_results'] = seed_results_base

print(f"\n  === BASE AGGREGATED (3-seed) ===")
print(f"  Specialization Index: {avg_si:.4f}")
print(f"  Mean Head Correlation: {avg_corr:.4f}")
print(f"  GQA Within/Between Ratio: {avg_gqa_ratio:.4f}")

# Free memory
del model_base
torch.cuda.empty_cache()
print("\n  [Memory cleared]")

In [None]:
# Cell 5: Load and Analyze INSTRUCT Model (LLaMA-3.1-8B-Instruct) with 3-Seed Averaging

seed_results_inst = []

print(f"\n[3/4] Loading INSTRUCT: {GQA_CONFIG['instruct']}")

tokenizer_inst = AutoTokenizer.from_pretrained(GQA_CONFIG['instruct'])
model_inst = AutoModelForCausalLM.from_pretrained(
    GQA_CONFIG['instruct'],
    torch_dtype=DTYPE,  # E11-v3: bfloat16
    device_map='auto',
    trust_remote_code=True,
    attn_implementation="eager"  # CRITICAL: SDPA doesn't return attentions!
)

# CRITICAL: Set eval mode to disable dropout
model_inst.eval()

if tokenizer_inst.pad_token is None:
    tokenizer_inst.pad_token = tokenizer_inst.eos_token

print(f"\n[4/4] Extracting INSTRUCT head activations (3-seed average)...")

for seed in SEEDS:
    print(f"\n  --- Seed {seed} ---")
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    inst_activations = extract_head_activations(model_inst, tokenizer_inst, STANDARD_PROMPTS, max_length=MAX_LENGTH)
    inst_entropies = compute_head_entropy_profiles(inst_activations['attention_patterns'])
    spec_metrics = compute_specialization_metrics(inst_entropies)
    gqa_metrics = compute_gqa_group_metrics(inst_entropies, num_kv_heads=GQA_CONFIG['num_kv_heads'])
    
    seed_results_inst.append({
        'seed': seed,
        'specialization': spec_metrics,
        'gqa_metrics': gqa_metrics,
        'entropies': inst_entropies.tolist()
    })
    print(f"  Specialization Index: {spec_metrics['specialization_index']:.4f}")
    print(f"  Mean Head Correlation: {spec_metrics['mean_head_correlation']:.4f}")

# Aggregate across seeds (mean)
print(f"\n  Computing 3-seed average...")
avg_si_inst = np.mean([r['specialization']['specialization_index'] for r in seed_results_inst])
avg_corr_inst = np.mean([r['specialization']['mean_head_correlation'] for r in seed_results_inst])
avg_var_inst = np.mean([r['specialization']['mean_head_variance'] for r in seed_results_inst])
avg_gqa_ratio_inst = np.mean([r['gqa_metrics']['within_between_ratio'] for r in seed_results_inst])

# Store aggregated results
results['instruct']['specialization'] = {
    'specialization_index': float(avg_si_inst),
    'mean_head_correlation': float(avg_corr_inst),
    'mean_head_variance': float(avg_var_inst),
    'effective_heads': float(np.mean([r['specialization']['effective_heads'] for r in seed_results_inst])),
    'effective_ratio': float(np.mean([r['specialization']['effective_ratio'] for r in seed_results_inst])),
    'layer_variances': seed_results_inst[0]['specialization']['layer_variances'],
    'early_variance': float(np.mean([r['specialization']['early_variance'] for r in seed_results_inst])),
    'middle_variance': float(np.mean([r['specialization']['middle_variance'] for r in seed_results_inst])),
    'late_variance': float(np.mean([r['specialization']['late_variance'] for r in seed_results_inst])),
    'head_correlation_matrix': seed_results_inst[0]['specialization']['head_correlation_matrix'],
    'num_layers': seed_results_inst[0]['specialization']['num_layers'],
    'num_heads': seed_results_inst[0]['specialization']['num_heads']
}
results['instruct']['gqa_metrics'] = {
    'within_group_variance': float(np.mean([r['gqa_metrics']['within_group_variance'] for r in seed_results_inst])),
    'between_group_variance': float(np.mean([r['gqa_metrics']['between_group_variance'] for r in seed_results_inst])),
    'within_between_ratio': float(avg_gqa_ratio_inst),
    'group_size': seed_results_inst[0]['gqa_metrics']['group_size'],
    'num_kv_groups': seed_results_inst[0]['gqa_metrics']['num_kv_groups']
}
results['instruct']['seed_results'] = seed_results_inst

print(f"\n  === INSTRUCT AGGREGATED (3-seed) ===")
print(f"  Specialization Index: {avg_si_inst:.4f}")
print(f"  Mean Head Correlation: {avg_corr_inst:.4f}")
print(f"  GQA Within/Between Ratio: {avg_gqa_ratio_inst:.4f}")

# Free memory
del model_inst
torch.cuda.empty_cache()
print("\n  [Memory cleared]")

In [None]:
# Cell 6: Hypothesis Test - GQA Territorial Collapse

print(f"\n{'='*70}")
print(f"E11-T: GQA TERRITORIAL COLLAPSE RESULTS")
print(f"{'='*70}")

# Extract key metrics
base_spec = results['base']['specialization']
inst_spec = results['instruct']['specialization']
base_gqa = results['base']['gqa_metrics']
inst_gqa = results['instruct']['gqa_metrics']

# Core metrics
base_si = base_spec['specialization_index']
inst_si = inst_spec['specialization_index']
delta_si = inst_si - base_si

base_corr = base_spec['mean_head_correlation']
inst_corr = inst_spec['mean_head_correlation']
delta_corr = inst_corr - base_corr

base_var = base_spec['mean_head_variance']
inst_var = inst_spec['mean_head_variance']
delta_var = inst_var - base_var

# GQA-specific
base_gqa_ratio = base_gqa['within_between_ratio']
inst_gqa_ratio = inst_gqa['within_between_ratio']
delta_gqa_ratio = inst_gqa_ratio - base_gqa_ratio

print(f"\n{'Metric':<35} {'BASE':>12} {'INSTRUCT':>12} {'Delta':>12}")
print("-" * 75)
print(f"{'Specialization Index':<35} {base_si:>12.4f} {inst_si:>12.4f} {delta_si:>+12.4f}")
print(f"{'Mean Head Correlation':<35} {base_corr:>12.4f} {inst_corr:>12.4f} {delta_corr:>+12.4f}")
print(f"{'Mean Head Variance':<35} {base_var:>12.6f} {inst_var:>12.6f} {delta_var:>+12.6f}")
print(f"{'GQA Within/Between Ratio':<35} {base_gqa_ratio:>12.4f} {inst_gqa_ratio:>12.4f} {delta_gqa_ratio:>+12.4f}")

# Hypothesis Test
print(f"\n{'='*70}")
print("HYPOTHESIS TEST: Does RLHF cause TERRITORIAL COLLAPSE in GQA?")
print(f"{'='*70}")

collapse_1 = delta_si < 0  # Specialization decreased
collapse_2 = delta_corr > 0  # Correlation increased
collapse_3 = delta_var < 0  # Variance decreased

print(f"\n  [1] Specialization decreased:    {'YES' if collapse_1 else 'NO'} ({delta_si:+.4f})")
print(f"  [2] Head correlation increased:  {'YES' if collapse_2 else 'NO'} ({delta_corr:+.4f})")
print(f"  [3] Head variance decreased:     {'YES' if collapse_3 else 'NO'} ({delta_var:+.6f})")

collapse_count = sum([collapse_1, collapse_2, collapse_3])

print(f"\n{'='*70}")
if collapse_count >= 2:
    verdict = "A_CONFIRMED"
    print(f"VERDICT: {verdict}")
    print("RLHF causes TERRITORIAL COLLAPSE in GQA architecture!")
elif collapse_count == 1:
    verdict = "B_PARTIAL"
    print(f"VERDICT: {verdict}")
    print("Partial evidence for territorial collapse in GQA.")
else:
    verdict = "C_REFUTED"
    print(f"VERDICT: {verdict}")
    print("No evidence for territorial collapse - GQA also preserves specialization.")
print(f"{'='*70}")

# Store verdict
results['verdict'] = {
    'code': verdict,
    'specialization_decreased': collapse_1,
    'correlation_increased': collapse_2,
    'variance_decreased': collapse_3,
    'delta_specialization': delta_si,
    'delta_correlation': delta_corr,
    'delta_variance': delta_var,
    'delta_gqa_ratio': delta_gqa_ratio
}

In [None]:
# Cell 7: GQA vs MHA Comparison

print(f"\n{'='*70}")
print("GQA vs MHA ARCHITECTURE COMPARISON")
print(f"{'='*70}")

# Build comparison table
print(f"\n{'Metric':<30} {'MHA (Mistral)':>16} {'GQA (LLaMA)':>16} {'Same Dir?':>12}")
print("-" * 80)

# Specialization Index Delta
mha_delta_si = E11_MHA_REFERENCE['delta_specialization']
gqa_delta_si = delta_si
same_dir_si = (mha_delta_si > 0) == (gqa_delta_si > 0)
print(f"{'Δ Specialization Index':<30} {mha_delta_si:>+16.4f} {gqa_delta_si:>+16.4f} {'YES' if same_dir_si else 'NO':>12}")

# Correlation Delta  
mha_delta_corr = E11_MHA_REFERENCE['delta_correlation']
gqa_delta_corr = delta_corr
same_dir_corr = (mha_delta_corr > 0) == (gqa_delta_corr > 0)
print(f"{'Δ Head Correlation':<30} {mha_delta_corr:>+16.4f} {gqa_delta_corr:>+16.4f} {'YES' if same_dir_corr else 'NO':>12}")

# Verdicts
print(f"\n{'Verdict':<30} {E11_MHA_REFERENCE['verdict']:>16} {verdict:>16}")

# Interpretation
print(f"\n{'='*70}")
print("INTERPRETATION:")
print(f"{'='*70}")

if same_dir_si and same_dir_corr:
    print("\n  ARCHITECTURE-INVARIANT: Both MHA and GQA show same RLHF effect direction.")
    print("  → Territorial collapse (or lack thereof) is RLHF-intrinsic, not architecture-specific.")
else:
    print("\n  ARCHITECTURE-DEPENDENT: MHA and GQA show different RLHF effects.")
    print("  → GQA's KV sharing creates distinct specialization dynamics.")

# Store comparison
results['mha_comparison'] = {
    'mha_model': E11_MHA_REFERENCE['model'],
    'mha_verdict': E11_MHA_REFERENCE['verdict'],
    'mha_delta_specialization': mha_delta_si,
    'mha_delta_correlation': mha_delta_corr,
    'gqa_model': 'LLaMA-3.1-8B',
    'gqa_verdict': verdict,
    'gqa_delta_specialization': gqa_delta_si,
    'gqa_delta_correlation': gqa_delta_corr,
    'same_direction_specialization': same_dir_si,
    'same_direction_correlation': same_dir_corr,
    'architecture_invariant': same_dir_si and same_dir_corr
}

In [None]:
# Cell 8: Visualization

fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Plot 1: GQA vs MHA Specialization Comparison
ax1 = axes[0, 0]
archs = ['MHA (Mistral)', 'GQA (LLaMA)']
base_vals = [E11_MHA_REFERENCE['base_specialization'], base_si]
inst_vals = [E11_MHA_REFERENCE['instruct_specialization'], inst_si]

x = np.arange(len(archs))
width = 0.35
bars1 = ax1.bar(x - width/2, base_vals, width, label='Base', color='#2ecc71', alpha=0.8)
bars2 = ax1.bar(x + width/2, inst_vals, width, label='Instruct', color='#e74c3c', alpha=0.8)
ax1.set_ylabel('Specialization Index')
ax1.set_title('GQA vs MHA: Specialization Index\n(Higher = More Unique Roles)')
ax1.set_xticks(x)
ax1.set_xticklabels(archs)
ax1.legend()
ax1.set_ylim(0, 1)
# Annotate deltas
for i, (b, inst) in enumerate(zip(base_vals, inst_vals)):
    delta = inst - b
    ax1.annotate(f'Δ={delta:+.4f}', xy=(i, max(b, inst) + 0.03), ha='center', fontsize=10,
                 color='green' if delta > 0 else 'red', fontweight='bold')

# Plot 2: GQA Specialization (Base vs Instruct)
ax2 = axes[0, 1]
models = ['Base', 'Instruct']
si_vals = [base_si, inst_si]
colors = ['#2ecc71', '#e74c3c']
bars = ax2.bar(models, si_vals, color=colors, alpha=0.8, edgecolor='black')
ax2.set_ylabel('Specialization Index')
ax2.set_title(f'LLaMA-3.1-8B (GQA): Specialization\nΔ = {delta_si:+.4f}')
ax2.set_ylim(0, 1)
for bar, val in zip(bars, si_vals):
    ax2.annotate(f'{val:.4f}', xy=(bar.get_x() + bar.get_width()/2, val),
                 xytext=(0, 5), textcoords='offset points', ha='center', fontsize=12)

# Plot 3: GQA Group Analysis (Within vs Between)
ax3 = axes[0, 2]
metrics = ['Within-Group\nVariance', 'Between-Group\nVariance', 'W/B Ratio']
base_gqa_vals = [base_gqa['within_group_variance'] * 1000,  # Scale for visibility
                 base_gqa['between_group_variance'] * 1000,
                 base_gqa['within_between_ratio']]
inst_gqa_vals = [inst_gqa['within_group_variance'] * 1000,
                 inst_gqa['between_group_variance'] * 1000,
                 inst_gqa['within_between_ratio']]

x = np.arange(len(metrics))
bars1 = ax3.bar(x - width/2, base_gqa_vals, width, label='Base', color='#2ecc71', alpha=0.8)
bars2 = ax3.bar(x + width/2, inst_gqa_vals, width, label='Instruct', color='#e74c3c', alpha=0.8)
ax3.set_ylabel('Value (×1000 for variance)')
ax3.set_title('GQA Group Analysis\n(KV Sharing Effect)')
ax3.set_xticks(x)
ax3.set_xticklabels(metrics)
ax3.legend()

# Plot 4: Layer-wise Variance Comparison
ax4 = axes[1, 0]
base_layer_var = base_spec['layer_variances']
inst_layer_var = inst_spec['layer_variances']
layers = range(len(base_layer_var))
ax4.plot(layers, base_layer_var, 'o-', color='#2ecc71', label='Base', linewidth=2, markersize=3)
ax4.plot(layers, inst_layer_var, 's-', color='#e74c3c', label='Instruct', linewidth=2, markersize=3)
ax4.set_xlabel('Layer')
ax4.set_ylabel('Head Variance')
ax4.set_title('LLaMA-3.1 (GQA): Layer-wise Head Variance')
ax4.legend()
ax4.grid(True, alpha=0.3)
# Mark L* region
num_layers = len(base_layer_var)
third = num_layers // 3
ax4.axvspan(third, 2*third, alpha=0.2, color='yellow', label='L* Region')

# Plot 5: Head Correlation Heatmap (Base)
ax5 = axes[1, 1]
base_corr_matrix = np.array(base_spec['head_correlation_matrix'])
sns.heatmap(base_corr_matrix, cmap='RdBu_r', center=0, vmin=-1, vmax=1,
            ax=ax5, cbar_kws={'label': 'Correlation'})
ax5.set_title('GQA BASE: Head Correlation Matrix')
ax5.set_xlabel('Head')
ax5.set_ylabel('Head')
# Mark GQA groups
for g in range(GQA_CONFIG['num_kv_heads']):
    start = g * 4
    ax5.axhline(y=start, color='black', linewidth=0.5, alpha=0.5)
    ax5.axvline(x=start, color='black', linewidth=0.5, alpha=0.5)

# Plot 6: Head Correlation Heatmap (Instruct)
ax6 = axes[1, 2]
inst_corr_matrix = np.array(inst_spec['head_correlation_matrix'])
sns.heatmap(inst_corr_matrix, cmap='RdBu_r', center=0, vmin=-1, vmax=1,
            ax=ax6, cbar_kws={'label': 'Correlation'})
ax6.set_title('GQA INSTRUCT: Head Correlation Matrix')
ax6.set_xlabel('Head')
ax6.set_ylabel('Head')
# Mark GQA groups
for g in range(GQA_CONFIG['num_kv_heads']):
    start = g * 4
    ax6.axhline(y=start, color='black', linewidth=0.5, alpha=0.5)
    ax6.axvline(x=start, color='black', linewidth=0.5, alpha=0.5)

plt.tight_layout()
fig_path = f'figures/E11T_gqa_comparison_{TIMESTAMP}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFigure saved: {fig_path}")

In [None]:
# Cell 9: Save Results with E11-v3 Methodology Block

def convert_to_native(obj):
    """Recursively convert numpy types to native Python types."""
    if isinstance(obj, dict):
        return {k: convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_to_native(v) for v in obj)
    elif isinstance(obj, (np.bool_, np.integer)):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

filename = f'results/E11T_gqa_comparison_{TIMESTAMP}.json'

output = {
    'experiment': 'E11-T_GQA_Comparison',
    'timestamp': TIMESTAMP,
    'model': 'LLaMA-3.1-8B',
    'architecture': 'GQA',
    'config': GQA_CONFIG,
    
    # E11-v3 METHODOLOGY BLOCK (REQUIRED!)
    'methodology': {
        'standard': 'E11-v3',
        'seeds': SEEDS,
        'max_length': MAX_LENGTH,
        'dtype': str(DTYPE),
        'prompt_md5': actual_md5,
        'num_prompts': len(STANDARD_PROMPTS),
        'quantization': 'NONE (Full Precision bfloat16)',
        'use_chat_template': False
    },
    
    'prompt_set': 'Standard-10 v3',
    'hypothesis': 'GQA architecture affects territorial collapse differently than MHA',
    'mha_reference': E11_MHA_REFERENCE,
    'results': convert_to_native(results),
    
    # Runtime info
    'runtime': {
        'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
        'gpu_memory_gb': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0,
        'dtype': str(DTYPE)
    }
}

with open(filename, 'w') as f:
    json.dump(output, f, indent=2)

print(f"Results saved: {filename}")
print(f"\n=== E11-v3 METHODOLOGY COMPLIANCE ===")
print(f"  Seeds: {SEEDS} ✅")
print(f"  MAX_LENGTH: {MAX_LENGTH} ✅")
print(f"  dtype: {DTYPE} ✅")
print(f"  Prompt MD5: {actual_md5} ✅")
print(f"  Quantization: Full Precision ✅")

# Download link for Colab
try:
    from google.colab import files
    files.download(filename)
    files.download(fig_path)
except:
    pass

---

## Summary

### E11-T: GQA vs MHA - Territorial Collapse Comparison

**Question:** Does GQA (Grouped Query Attention) architecture change how RLHF affects head specialization?

**Background:**
- E11 found NO territorial collapse in Mistral (MHA): RLHF INCREASED specialization
- GQA forces KV sharing (4 query heads : 1 KV head)
- This structural constraint might create different specialization dynamics

**GQA-Specific Metrics:**
- **Within-Group Variance**: How different are query heads sharing the same KV?
- **Between-Group Variance**: How different are the 8 KV groups from each other?
- **Within/Between Ratio**: Low = groups internally similar but externally different

**Key Question:**
> Is the "efficiency-fragility trade-off" (RLHF removes redundancy) architecture-invariant?

---

*Paper 4: Behavioral Sink Dynamics*
*E11-T: GQA vs MHA - Territorial Collapse Comparison*

In [None]:
# Cell 11: Artifact Log

artifact_entry = {
    'experiment': 'E11-T',
    'timestamp': TIMESTAMP,
    'model': 'LLaMA-3.1-8B',
    'architecture': 'GQA',
    'verdict': results['verdict']['code'],
    'base_specialization': base_si,
    'instruct_specialization': inst_si,
    'delta_specialization': delta_si,
    'base_correlation': base_corr,
    'instruct_correlation': inst_corr,
    'delta_correlation': delta_corr,
    'architecture_invariant': results['mha_comparison']['architecture_invariant'],
    'mha_verdict': E11_MHA_REFERENCE['verdict'],
    'prompt_count': len(STANDARD_PROMPTS),
    'files': {
        'results': filename,
        'figure': fig_path
    }
}

artifact_log = f'results/E11T_artifact_log.jsonl'
with open(artifact_log, 'a') as f:
    f.write(json.dumps(artifact_entry) + '\n')

print(f"Artifact log appended: {artifact_log}")
print(f"\nEntry: {json.dumps(artifact_entry, indent=2)}")

In [None]:
# ============================================================================
# AUTO-DOWNLOAD RESULTS (Colab only)
# ============================================================================
import glob
import shutil

def auto_download_results():
    try:
        from google.colab import files
    except ImportError:
        print('Not in Colab - skipping auto-download')
        return
    
    print('=' * 60)
    print('AUTO-DOWNLOADING RESULTS...')
    print('=' * 60)
    
    # Find all result files
    json_files = glob.glob('results/*.json') + glob.glob('figures/*.json')
    png_files = glob.glob('results/*.png') + glob.glob('figures/*.png')
    all_files = json_files + png_files
    
    if not all_files:
        print('WARNING: No result files found!')
        return
    
    print(f'Found {len(all_files)} files')
    
    # Download as ZIP
    import os
    zip_name = f'E11_results_{os.path.basename(os.getcwd())}'
    
    # Create combined folder
    os.makedirs('download_package', exist_ok=True)
    for f in all_files:
        shutil.copy(f, 'download_package/')
    
    shutil.make_archive(zip_name, 'zip', 'download_package')
    print(f'Downloading: {zip_name}.zip')
    files.download(f'{zip_name}.zip')
    print('DOWNLOAD COMPLETE!')

auto_download_results()