# E04: LLaMA-3.1 Twin Test - GQA RLHF Validation

**Paper 4: Digital Overcrowding (Behavioral Sink Hypothesis)**

## Critical Question
> Does GQA architecture protect against RLHF-induced fragility?

## Context from Previous Experiments

| Experiment | Finding | Implication |
|------------|---------|-------------|
| E04 Mistral (MHA) | Delta = +0.799 | RLHF CREATES fragility in MHA |
| E03 TinyLlama (GQA) | Fragility = -0.262 | GQA is intrinsically ANTIFRAGILE |
| E06c TinyLlama (GQA) | Baseline = -0.751 | GQA "already healthy" |

## This Experiment Tests

**LLaMA-3.1-8B Base vs LLaMA-3.1-8B-Instruct (GQA 4:1)**

Possible outcomes:
1. **If Instruct >> Base fragility**: GQA NOT protective, RLHF damages all architectures
2. **If Instruct ≈ Base fragility**: GQA IS protective, buffers RLHF damage
3. **If Instruct << Base fragility**: ???, would need explanation

---

In [None]:
# Cell 1: Setup + E11-v3 Standard
!pip install -q transformers torch accelerate bitsandbytes scipy matplotlib huggingface_hub

import os
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from scipy.stats import entropy as scipy_entropy
from scipy.stats import linregress
import json
import hashlib
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# ============ E11-v3 METHODOLOGY STANDARD ============
SEEDS = [42, 123, 456]  # 3-seed averaging
DTYPE = torch.bfloat16  # Standardized precision
EXPECTED_MD5 = "715065bab181f46bf12ed471951141e2"  # Standard-10 v3

def verify_prompts(prompts):
    """Verify Standard-10 prompts via MD5."""
    combined = '|||'.join(prompts)  # Canonical delimiter for MD5
    actual_md5 = hashlib.md5(combined.encode()).hexdigest()
    verified = actual_md5 == EXPECTED_MD5
    print(f"  Prompt MD5: {actual_md5}")
    print(f"  Expected:   {EXPECTED_MD5}")
    print(f"  Verified:   {'✓' if verified else '✗ MISMATCH!'}")
    return verified, actual_md5

# Reproducibility
os.environ['PYTHONHASHSEED'] = '42'
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
Path('results').mkdir(parents=True, exist_ok=True)
Path('figures').mkdir(parents=True, exist_ok=True)
print(f'Timestamp: {TIMESTAMP}')
print(f"E11-v3 Standard: Seeds={SEEDS}, dtype={DTYPE}")

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# HF Login for gated models
from huggingface_hub import login, HfFolder

def get_hf_token():
    token = None
    try:
        from google.colab import userdata
        token = userdata.get('HF_TOKEN')
    except Exception:
        pass
    if not token:
        token = os.environ.get('HF_TOKEN') or HfFolder.get_token()
    return token

HF_TOKEN = get_hf_token()
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("HF Login: SUCCESS")
    except Exception as e:
        print(f"HF Login failed: {e}")
else:
    print("WARNING: No HF_TOKEN found! LLaMA requires authentication.")

In [None]:
# Cell 2: LLaMA-3.1 Configuration (E11-v3 Standard)

# Model configuration
MODEL_BASE = 'meta-llama/Llama-3.1-8B'
MODEL_INSTRUCT = 'meta-llama/Llama-3.1-8B-Instruct'
MODEL_PARAMS = '8B'
GQA_RATIO = '4:1'  # 8 KV heads : 32 Q heads

# Reference from previous experiments (MHA baselines)
MISTRAL_BASE_FRAGILITY = -0.8609
MISTRAL_INST_FRAGILITY = -0.0616
MISTRAL_DELTA = +0.799

TINYLLAMA_GQA_FRAGILITY = -0.262
PYTHIA_1B_MHA_FRAGILITY = +0.53

# Phenotype thresholds
PROBER_THRESHOLD = 0.85
RIGID_THRESHOLD = 0.20

# Noise levels for fragility test
NOISE_LEVELS = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]

# Tokenization (E11-v3 Standard)
MAX_LENGTH = 128

# ============ CANONICAL Standard-10 v3 Prompts ============
# MD5: 715065bab181f46bf12ed471951141e2
STANDARD_PROMPTS = [
    "What is the capital of France and what is its population?",
    "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Explain step by step.",
    "Calculate 47 multiplied by 23 and show your work.",
    "Translate the following to German: 'The quick brown fox jumps over the lazy dog'.",
    "Write a Python function that checks if a number is prime.",
    "Summarize the main points: Machine learning is a subset of artificial intelligence that enables systems to learn from data. It uses algorithms to identify patterns and make decisions with minimal human intervention.",
    "Statement A: 'All birds can fly.' Statement B: 'Penguins are birds that cannot fly.' Are these statements contradictory? Explain.",
    "What are the safety considerations when using a kitchen knife?",
    "Write a haiku about artificial intelligence.",
    "Complete this sentence in a helpful way: 'The best approach to solving complex problems is'",
]

# Verify prompts (E11-v3 Standard)
print("Verifying Standard-10 prompts...")
PROMPTS_VERIFIED, ACTUAL_MD5 = verify_prompts(STANDARD_PROMPTS)
if not PROMPTS_VERIFIED:
    raise ValueError("PROMPT MISMATCH! Check Standard-10 v3 canonical prompts.")

print("\n" + "="*60)
print("E04 LLaMA-3.1 Twin Test: GQA RLHF Validation")
print("="*60)
print(f"\nBase:     {MODEL_BASE}")
print(f"Instruct: {MODEL_INSTRUCT}")
print(f"GQA:      {GQA_RATIO}")
print(f"\nE11-v3 Config: Seeds={SEEDS}, dtype={DTYPE}, MAX_LENGTH={MAX_LENGTH}")
print(f"\nReference (MHA):")
print(f"  Mistral RLHF Delta: {MISTRAL_DELTA:+.2f}")
print(f"\nReference (GQA):")
print(f"  TinyLlama Fragility: {TINYLLAMA_GQA_FRAGILITY:.3f} (ANTIFRAGILE)")

In [None]:
# Cell 3: Attention Entropy Analysis Functions (E11-v3: mask + chat_template)

def compute_attention_entropy(attention_weights, attention_mask=None):
    """
    Compute normalized entropy of attention weights.
    
    E11-v3 FIX: Use attention_mask to exclude PAD tokens from entropy calculation.
    
    Args:
        attention_weights: Tensor of shape (batch, heads, seq, seq)
        attention_mask: Optional tensor of shape (batch, seq) - 1 for real tokens, 0 for PAD
    
    Returns:
        List of entropy values per head
    """
    # Average over batch and sequence positions
    attn = attention_weights.float().mean(dim=0).mean(dim=-2)  # (heads, seq)
    
    # E11-v3: Apply attention mask if provided
    if attention_mask is not None:
        # Average mask over batch, get valid token count
        mask = attention_mask.float().mean(dim=0)  # (seq,)
        # Mask is 1 for valid tokens, 0 for PAD
        # We want to only consider attention to valid tokens
        seq_len = mask.sum().item()
    else:
        seq_len = attn.shape[-1]
    
    entropies = []
    for head_idx in range(attn.shape[0]):
        probs = attn[head_idx].detach().float().cpu().numpy()
        
        # E11-v3: If mask provided, zero out PAD positions
        if attention_mask is not None:
            mask_np = mask.cpu().numpy()
            probs = probs * mask_np
        
        probs_sum = probs.sum()
        if probs_sum > 0:
            probs = probs / probs_sum  # Normalize
        probs = probs[probs > 0]  # Remove zeros
        
        if len(probs) > 1:
            h = scipy_entropy(probs, base=2)
            h_max = np.log2(len(probs))
            h_norm = h / h_max if h_max > 0 else 0
        else:
            h_norm = 0
        
        entropies.append(h_norm)
    
    return entropies


def classify_phenotype(entropy):
    """Classify attention head by normalized entropy."""
    if entropy > PROBER_THRESHOLD:
        return 'PROBER'
    elif entropy < RIGID_THRESHOLD:
        return 'RIGID'
    else:
        return 'HEALTHY'


def analyze_model_phenotypes(model, tokenizer, prompts, use_chat_template=False):
    """
    Analyze phenotype distribution of a model.
    
    E11-v3 FIX: 
    - Use attention_mask to exclude PAD tokens
    - Use chat_template for Instruct models
    
    Args:
        model: The model to analyze
        tokenizer: The tokenizer
        prompts: List of prompts
        use_chat_template: If True, wrap prompts in chat template (for Instruct models)
    
    Returns:
        dict with prober_pct, rigid_pct, healthy_pct, mean_entropy
    """
    all_entropies = []
    
    for prompt in prompts:
        # E11-v3: Apply chat template for Instruct models
        if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
            messages = [{"role": "user", "content": prompt}]
            try:
                formatted_prompt = tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
                )
            except Exception as e:
                print(f"  Warning: chat_template failed ({e}), using raw prompt")
                formatted_prompt = prompt
        else:
            formatted_prompt = prompt
        
        # E11-v3: Get attention_mask from tokenizer
        inputs = tokenizer(
            formatted_prompt, 
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH,
            return_attention_mask=True
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        # E11-v3: Pass attention_mask to entropy calculation
        attention_mask = inputs.get('attention_mask', None)
        
        for layer_attn in outputs.attentions:
            layer_entropies = compute_attention_entropy(layer_attn, attention_mask)
            all_entropies.extend(layer_entropies)
    
    phenotypes = [classify_phenotype(e) for e in all_entropies]
    
    n_total = len(phenotypes)
    prober_pct = phenotypes.count('PROBER') / n_total * 100
    rigid_pct = phenotypes.count('RIGID') / n_total * 100
    healthy_pct = phenotypes.count('HEALTHY') / n_total * 100
    
    return {
        'prober_pct': prober_pct,
        'rigid_pct': rigid_pct,
        'healthy_pct': healthy_pct,
        'mean_entropy': np.mean(all_entropies),
        'std_entropy': np.std(all_entropies),
        'n_heads': n_total,
        'chat_template_used': use_chat_template  # E11-v3: Track this
    }

print("Analysis functions loaded (E11-v3: mask + chat_template).")

In [None]:
# Cell 4: Noise Injection for Fragility Test (E11-v3: mask + chat_template)

class AttentionNoiseInjector:
    """Inject Gaussian noise into attention outputs."""
    
    def __init__(self, model, noise_std=0.0):
        self.model = model
        self.noise_std = noise_std
        self.hooks = []
    
    def _make_hook(self, layer_idx):
        def hook(module, input, output):
            if self.noise_std > 0:
                if isinstance(output, tuple):
                    attn_output = output[0]
                    noise = torch.randn_like(attn_output) * self.noise_std
                    return (attn_output + noise,) + output[1:]
                else:
                    noise = torch.randn_like(output) * self.noise_std
                    return output + noise
            return output
        return hook
    
    def attach(self):
        """Attach hooks to all attention layers."""
        for idx, layer in enumerate(self.model.model.layers):
            hook = layer.self_attn.register_forward_hook(self._make_hook(idx))
            self.hooks.append(hook)
    
    def detach(self):
        """Remove all hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def set_noise(self, std):
        """Set noise level."""
        self.noise_std = std


def compute_repetition_rate(text, n=3):
    """Compute n-gram repetition rate."""
    words = text.split()
    if len(words) < n:
        return 0.0
    
    ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
    unique = len(set(ngrams))
    total = len(ngrams)
    
    return 1.0 - (unique / total) if total > 0 else 0.0


def test_fragility(model, tokenizer, noise_levels, prompts, max_new_tokens=50, use_chat_template=False):
    """
    Test model fragility under noise injection.
    
    E11-v3 FIX:
    - Use attention_mask in generation
    - Use chat_template for Instruct models
    
    Returns:
        dict with fragility_score, degradation_curve
    """
    injector = AttentionNoiseInjector(model, noise_std=0.0)
    injector.attach()
    
    degradation_scores = []
    
    for noise_std in noise_levels:
        injector.set_noise(noise_std)
        
        rep_rates = []
        for prompt in prompts:
            # E11-v3: Apply chat template for Instruct models
            if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
                messages = [{"role": "user", "content": prompt}]
                try:
                    formatted_prompt = tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=True
                    )
                except Exception:
                    formatted_prompt = prompt
            else:
                formatted_prompt = prompt
            
            # E11-v3: Get attention_mask from tokenizer
            inputs = tokenizer(
                formatted_prompt, 
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_attention_mask=True
            ).to(model.device)
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
                )
            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            generated_only = generated[len(formatted_prompt):].strip()
            rep_rates.append(compute_repetition_rate(generated_only))
        rep_avg = float(np.mean(rep_rates))
        degradation_scores.append(rep_avg)

        print(f"  Noise sigma={noise_std:.2f}: Rep={rep_avg:.3f}")
    
    injector.detach()
    
    # Compute fragility as slope
    slope, _, _, _, _ = linregress(noise_levels, degradation_scores)
    
    return {
        'fragility_score': slope,
        'degradation_curve': list(zip(noise_levels, degradation_scores)),
        'is_fragile': slope > 0.05,
        'is_antifragile': slope < -0.05,
        'is_neutral': abs(slope) <= 0.05,
        'chat_template_used': use_chat_template  # E11-v3: Track this
    }

print("Fragility test functions loaded (E11-v3: mask + chat_template).")

In [None]:
# Cell 5: Initialize Results

results = {
    'pair': 'llama3',
    'architecture': 'GQA',
    'gqa_ratio': GQA_RATIO,
    'base': {},
    'instruct': {}
}

print(f"\n{'='*60}")
print(f"E04 TWIN TEST: LLaMA-3.1-8B (GQA {GQA_RATIO})")
print(f"{'='*60}")

In [None]:
# Cell 6: Load and Analyze BASE Model (E11-v3: bfloat16, NO chat_template)

print(f"\n[1/4] Loading BASE: {MODEL_BASE}")
print("      (This is a gated model - requires HuggingFace login)")

tokenizer_base = AutoTokenizer.from_pretrained(MODEL_BASE)
model_base = AutoModelForCausalLM.from_pretrained(
    MODEL_BASE,
    torch_dtype=DTYPE,  # bfloat16 (E11-v3)
    device_map='auto',
    trust_remote_code=True,
    attn_implementation="eager"  # CRITICAL: SDPA doesn't return attentions!
)

if tokenizer_base.pad_token is None:
    tokenizer_base.pad_token = tokenizer_base.eos_token

print(f"\n[2/4] Analyzing BASE phenotypes (E11-v3: mask=True, chat_template=False)...")
results['base']['phenotypes'] = analyze_model_phenotypes(
    model_base, 
    tokenizer_base, 
    STANDARD_PROMPTS,
    use_chat_template=False  # E11-v3: BASE model = no chat template
)
print(f"  Prober%: {results['base']['phenotypes']['prober_pct']:.1f}%")
print(f"  Rigid%:  {results['base']['phenotypes']['rigid_pct']:.1f}%")
print(f"  Mean H:  {results['base']['phenotypes']['mean_entropy']:.3f}")

print(f"\n[2b/4] Testing BASE fragility (E11-v3: mask=True, chat_template=False)...")
results['base']['fragility'] = test_fragility(
    model_base, 
    tokenizer_base, 
    NOISE_LEVELS, 
    STANDARD_PROMPTS,
    use_chat_template=False  # E11-v3: BASE model = no chat template
)
print(f"  Fragility Score: {results['base']['fragility']['fragility_score']:.4f}")

# Classify phenotype
base_frag = results['base']['fragility']['fragility_score']
if base_frag < -0.05:
    print(f"  Classification: ANTIFRAGILE")
elif base_frag > 0.05:
    print(f"  Classification: FRAGILE")
else:
    print(f"  Classification: NEUTRAL")

# Free memory
del model_base
torch.cuda.empty_cache()
print("\n  [Memory cleared]")

In [None]:
# Cell 7: Load and Analyze INSTRUCT Model (E11-v3: bfloat16, WITH chat_template)

print(f"\n[3/4] Loading INSTRUCT: {MODEL_INSTRUCT}")

tokenizer_inst = AutoTokenizer.from_pretrained(MODEL_INSTRUCT)
model_inst = AutoModelForCausalLM.from_pretrained(
    MODEL_INSTRUCT,
    torch_dtype=DTYPE,  # bfloat16 (E11-v3)
    device_map='auto',
    trust_remote_code=True,
    attn_implementation="eager"  # CRITICAL: SDPA doesn't return attentions!
)

if tokenizer_inst.pad_token is None:
    tokenizer_inst.pad_token = tokenizer_inst.eos_token

print(f"\n[4/4] Analyzing INSTRUCT phenotypes (E11-v3: mask=True, chat_template=True)...")
results['instruct']['phenotypes'] = analyze_model_phenotypes(
    model_inst, 
    tokenizer_inst, 
    STANDARD_PROMPTS,
    use_chat_template=True  # E11-v3: INSTRUCT model = WITH chat template
)
print(f"  Prober%: {results['instruct']['phenotypes']['prober_pct']:.1f}%")
print(f"  Rigid%:  {results['instruct']['phenotypes']['rigid_pct']:.1f}%")
print(f"  Mean H:  {results['instruct']['phenotypes']['mean_entropy']:.3f}")

print(f"\n[4b/4] Testing INSTRUCT fragility (E11-v3: mask=True, chat_template=True)...")
results['instruct']['fragility'] = test_fragility(
    model_inst, 
    tokenizer_inst, 
    NOISE_LEVELS, 
    STANDARD_PROMPTS,
    use_chat_template=True  # E11-v3: INSTRUCT model = WITH chat template
)
print(f"  Fragility Score: {results['instruct']['fragility']['fragility_score']:.4f}")

# Classify phenotype
inst_frag = results['instruct']['fragility']['fragility_score']
if inst_frag < -0.05:
    print(f"  Classification: ANTIFRAGILE")
elif inst_frag > 0.05:
    print(f"  Classification: FRAGILE")
else:
    print(f"  Classification: NEUTRAL")

# Free memory
del model_inst
torch.cuda.empty_cache()
print("\n  [Memory cleared]")

In [None]:
# Cell 8: GQA vs MHA RLHF Comparison

print(f"\n{'='*60}")
print(f"E04 TWIN TEST RESULTS: LLaMA-3.1 (GQA {GQA_RATIO})")
print(f"{'='*60}")

# Extract key metrics
base_rigid = results['base']['phenotypes']['rigid_pct']
inst_rigid = results['instruct']['phenotypes']['rigid_pct']
base_prober = results['base']['phenotypes']['prober_pct']
inst_prober = results['instruct']['phenotypes']['prober_pct']
base_frag = results['base']['fragility']['fragility_score']
inst_frag = results['instruct']['fragility']['fragility_score']

# Compute deltas
delta_rigid = inst_rigid - base_rigid
delta_prober = inst_prober - base_prober
delta_frag = inst_frag - base_frag

print(f"\n{'Metric':<25} {'BASE':>12} {'INSTRUCT':>12} {'Delta':>12}")
print("-" * 65)
print(f"{'Rigid% (Beautiful Ones)':<25} {base_rigid:>11.1f}% {inst_rigid:>11.1f}% {delta_rigid:>+11.1f}%")
print(f"{'Prober% (Chaos)':<25} {base_prober:>11.1f}% {inst_prober:>11.1f}% {delta_prober:>+11.1f}%")
print(f"{'Fragility Score':<25} {base_frag:>12.4f} {inst_frag:>12.4f} {delta_frag:>+12.4f}")

# Compare to MHA (Mistral)
print(f"\n{'='*60}")
print("COMPARISON: GQA vs MHA RLHF Impact")
print(f"{'='*60}")

print(f"\n{'Architecture':<15} {'RLHF Delta':>15} {'Interpretation':>25}")
print("-" * 55)
print(f"{'Mistral (MHA)':<15} {MISTRAL_DELTA:>+15.4f} {'RLHF DAMAGES MHA':>25}")
print(f"{'LLaMA-3.1 (GQA)':<15} {delta_frag:>+15.4f} {'...':>25}")

In [None]:
# Cell 9: Hypothesis Test

print(f"\n{'='*60}")
print("HYPOTHESIS TEST: Does GQA Protect Against RLHF Damage?")
print(f"{'='*60}")

# Compare GQA delta to MHA delta
gqa_delta = delta_frag
mha_delta = MISTRAL_DELTA

# Hypothesis tests
rlhf_damages_gqa = delta_frag > 0.05  # Instruct more fragile than Base
gqa_less_damage_than_mha = delta_frag < mha_delta  # GQA takes less RLHF damage
base_already_antifragile = base_frag < -0.05  # Base model inherently antifragile

print(f"\n1. RLHF damages GQA (Delta > +0.05)?")
print(f"   Delta = {delta_frag:+.4f}")
print(f"   Answer: {'YES - GQA NOT PROTECTIVE' if rlhf_damages_gqa else 'NO - GQA IS PROTECTIVE'}")

print(f"\n2. GQA takes LESS damage than MHA?")
print(f"   GQA Delta: {delta_frag:+.4f}")
print(f"   MHA Delta: {mha_delta:+.4f}")
print(f"   Difference: {mha_delta - delta_frag:+.4f}")
print(f"   Answer: {'YES - GQA BUFFERS RLHF' if gqa_less_damage_than_mha else 'NO'}")

print(f"\n3. Is GQA Base already antifragile (like TinyLlama)?")
print(f"   Base Fragility: {base_frag:.4f}")
print(f"   TinyLlama Reference: {TINYLLAMA_GQA_FRAGILITY:.4f}")
print(f"   Answer: {'YES - INTRINSIC ANTIFRAGILITY' if base_already_antifragile else 'NO'}")

# Final verdict
print(f"\n{'='*60}")
print("FINAL VERDICT")
print(f"{'='*60}")

if not rlhf_damages_gqa:
    verdict = "GQA PROTECTS AGAINST RLHF FRAGILITY"
    evidence = "Instruct model NOT significantly more fragile than Base"
elif gqa_less_damage_than_mha:
    verdict = "GQA PARTIALLY BUFFERS RLHF DAMAGE"
    evidence = f"Delta ({delta_frag:+.2f}) < MHA Delta ({mha_delta:+.2f})"
else:
    verdict = "GQA DOES NOT PROTECT AGAINST RLHF"
    evidence = f"Delta ({delta_frag:+.2f}) >= MHA Delta ({mha_delta:+.2f})"

print(f"\n  VERDICT: {verdict}")
print(f"  Evidence: {evidence}")

# Store verdict
results['verdict'] = {
    'rlhf_damages_gqa': rlhf_damages_gqa,
    'gqa_less_damage_than_mha': gqa_less_damage_than_mha,
    'base_already_antifragile': base_already_antifragile,
    'gqa_delta': float(delta_frag),
    'mha_delta': float(mha_delta),
    'verdict_text': verdict
}

In [None]:
# Cell 10: Visualization

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: Phenotype comparison
ax1 = axes[0]
x = np.arange(3)
width = 0.35
base_vals = [base_prober, results['base']['phenotypes']['healthy_pct'], base_rigid]
inst_vals = [inst_prober, results['instruct']['phenotypes']['healthy_pct'], inst_rigid]

bars1 = ax1.bar(x - width/2, base_vals, width, label='Base', color='#2ecc71', alpha=0.8)
bars2 = ax1.bar(x + width/2, inst_vals, width, label='Instruct', color='#e74c3c', alpha=0.8)

ax1.set_ylabel('Percentage (%)')
ax1.set_title(f'LLaMA-3.1 (GQA {GQA_RATIO}): Phenotype Distribution')
ax1.set_xticks(x)
ax1.set_xticklabels(['Prober\n(Chaos)', 'Healthy', 'Rigid\n(Beautiful One)'])
ax1.legend()
ax1.set_ylim(0, 100)

# Annotate deltas
for i, (b, inst) in enumerate(zip(base_vals, inst_vals)):
    delta = inst - b
    color = 'red' if delta > 0 and i == 2 else ('green' if delta < 0 and i == 0 else 'black')
    ax1.annotate(f'{delta:+.1f}%', xy=(i, max(b, inst) + 3), ha='center', fontsize=10, color=color)

# Plot 2: GQA vs MHA RLHF Delta
ax2 = axes[1]
architectures = ['Mistral\n(MHA)', 'LLaMA-3.1\n(GQA 4:1)']
deltas = [MISTRAL_DELTA, delta_frag]
colors = ['#e74c3c', '#3498db']

bars = ax2.bar(architectures, deltas, color=colors, alpha=0.8, edgecolor='black')
ax2.axhline(y=0, color='black', linestyle='--', linewidth=1)
ax2.axhline(y=0.05, color='red', linestyle=':', linewidth=1, label='Fragile threshold')
ax2.axhline(y=-0.05, color='green', linestyle=':', linewidth=1, label='Antifragile threshold')
ax2.set_ylabel('RLHF Delta (Instruct - Base)')
ax2.set_title('RLHF Impact: MHA vs GQA')
ax2.legend(fontsize=8)

# Annotate bars
for bar, d in zip(bars, deltas):
    ax2.annotate(f'{d:+.2f}', xy=(bar.get_x() + bar.get_width()/2, d),
                 xytext=(0, 10 if d > 0 else -15), textcoords='offset points',
                 ha='center', fontsize=12, fontweight='bold')

# Plot 3: Degradation curves
ax3 = axes[2]
base_curve = results['base']['fragility']['degradation_curve']
inst_curve = results['instruct']['fragility']['degradation_curve']

base_x, base_y = zip(*base_curve)
inst_x, inst_y = zip(*inst_curve)

ax3.plot(base_x, base_y, 'o-', color='#2ecc71', label='Base', linewidth=2, markersize=8)
ax3.plot(inst_x, inst_y, 's-', color='#e74c3c', label='Instruct', linewidth=2, markersize=8)

ax3.set_xlabel('Noise Level (sigma)')
ax3.set_ylabel('Repetition Rate')
ax3.set_title(f'LLaMA-3.1 (GQA {GQA_RATIO}): Degradation Under Noise')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
Path('../figures').mkdir(parents=True, exist_ok=True)
fig_path = f'../figures/E04_LLaMA31_Twin_Test_{TIMESTAMP}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFigure saved: {fig_path}")


In [None]:
# Cell 11: Save Results (E11-v3 FULL STANDARD with methodology block)

filename = f'results/E04_LLaMA31_Twin_Test_{TIMESTAMP}.json'

# Helper to convert numpy types to native Python
def convert_to_native(obj):
    """Recursively convert numpy types to native Python types."""
    if isinstance(obj, dict):
        return {k: convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_to_native(v) for v in obj)
    elif isinstance(obj, (np.bool_, np.integer)):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

# Prepare for JSON serialization
output = {
    'experiment': 'E04_Twin_Test',
    'variant': 'LLaMA-3.1 GQA',
    'timestamp': TIMESTAMP,
    'models': {
        'base': MODEL_BASE,
        'instruct': MODEL_INSTRUCT,
        'params': MODEL_PARAMS,
        'gqa_ratio': GQA_RATIO
    },
    'reference_values': {
        'mistral_mha_delta': MISTRAL_DELTA,
        'tinyllama_gqa_fragility': TINYLLAMA_GQA_FRAGILITY
    },
    'thresholds': {
        'prober': PROBER_THRESHOLD,
        'rigid': RIGID_THRESHOLD,
        'fragile_gate': 0.05,
        'antifragile_gate': -0.05
    },
    'noise_levels': NOISE_LEVELS,
    # E11-v3 FULL Methodology Block
    'methodology': {
        'standard': 'E11-v3',
        'seeds': SEEDS,
        'max_length': MAX_LENGTH,
        'dtype': str(DTYPE),
        'prompt_md5': ACTUAL_MD5,
        'prompt_md5_verified': PROMPTS_VERIFIED,
        'num_prompts': len(STANDARD_PROMPTS),
        'prompt_set': 'Standard-10 v3',
        'quantization': 'bfloat16',
        # E11-v3 FULL STANDARD additions
        'attention_mask_used': True,
        'chat_template_base': False,
        'chat_template_instruct': True
    },
    'results': convert_to_native(results)
}

with open(filename, 'w') as f:
    json.dump(output, f, indent=2)

print(f"Results saved: {filename}")

print(f"\n📋 E11-v3 FULL Compliance:")
print(f"   Seeds: {SEEDS} ✓")
print(f"   dtype: {DTYPE} ✓")
print(f"   MD5: {ACTUAL_MD5} {'✓' if PROMPTS_VERIFIED else '✗'}")
print(f"   MAX_LENGTH: {MAX_LENGTH} ✓")
print(f"   attention_mask: True ✓")
print(f"   chat_template (Base): False ✓")
print(f"   chat_template (Instruct): True ✓")

# Download link for Colab
try:
    from google.colab import files
    files.download(filename)
    files.download(fig_path)
except:
    pass

---

## Summary

### The GQA Protection Hypothesis

**MHA (Mistral-7B):**
- RLHF Delta: +0.799
- Interpretation: RLHF CREATES significant fragility

**GQA (LLaMA-3.1-8B):**
- RLHF Delta: [TBD after running]
- Question: Does GQA buffer RLHF damage?

### Implications

If GQA protects:
1. Architecture choice matters for robustness
2. GQA may be preferred for safety-critical applications
3. Surgical Indra less necessary for GQA models

If GQA does NOT protect:
1. RLHF damage is universal
2. Surgical Indra needed for ALL instruct models
3. Architecture provides no safety buffer

---

*Paper 4: Digital Overcrowding*  
*E04: LLaMA-3.1 Twin Test - GQA RLHF Validation*