# E11: Falcon-40B MQA Validation (n=2 for MQA Architecture)

**Paper 4: Behavioral Sink Dynamics**

## Purpose

This notebook validates the MQA architecture claim by testing **Falcon-40B**:

> "Is the 'Pre-Collapsed' pattern observed in Falcon-7B consistent across scale?"

**Gap Being Closed:**
- Current: MQA claim based on n=1 (only Falcon-7B)
- After: MQA claim based on n=2 (Falcon-7B + Falcon-40B)

## Hypothesis

**H0 (Architecture-Determined):** Falcon-40B shows same "Pre-Collapsed" pattern as Falcon-7B
- Expected: SI ~0.12, Base correlation ~0.88, minimal SI change under alignment

**H1 (Scale-Dependent):** Falcon-40B behaves differently due to larger scale
- Would indicate MQA pattern is size-dependent

## Model Pair

| Role | Model | Notes |
|------|-------|-------|
| Base | tiiuae/falcon-40b | MQA (128 Q-heads, 1 KV-head per layer) |
| Instruct | tiiuae/falcon-40b-instruct | SFT-only (no RLHF!) |

## Methodology: E11-v3 Standard

| Parameter | Value |
|-----------|-------|
| **Seeds** | 42, 123, 456 |
| **Prompts** | Standard-10 v3 (MD5: `715065bab181f46bf12ed471951141e2`) |
| **MAX_LENGTH** | 128 |
| **dtype** | bfloat16 (8-bit for 40B) |
| **Sanity Check** | Required before analysis |

---

In [None]:
# Cell 1: Setup, Dependencies, and RESOURCE CHECK (DISK MITIGATION)
!pip install -q transformers torch accelerate bitsandbytes scipy matplotlib seaborn psutil

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from scipy.stats import entropy as scipy_entropy
import json
import warnings
import gc
import shutil
import psutil
import os
warnings.filterwarnings('ignore')

from pathlib import Path
from datetime import datetime

# Deterministic seeds for reproducibility
SEEDS = [42, 123, 456]
os.environ['PYTHONHASHSEED'] = '42'
torch.manual_seed(42)
np.random.seed(42)

TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
Path('results').mkdir(parents=True, exist_ok=True)
Path('figures').mkdir(parents=True, exist_ok=True)
Path('offload').mkdir(parents=True, exist_ok=True)

# ========================================
# AGGRESSIVE DISK CLEANUP FUNCTIONS (from E08b_Gemma_Ladder)
# ========================================
def get_disk_free_gb():
    """Get free disk space in GB."""
    disk_path = '/content' if os.path.exists('/content') else '/'
    return shutil.disk_usage(disk_path).free / 1e9

def clear_hf_cache(model_name=None):
    """
    Clear HuggingFace cache.
    If model_name provided, only clear that model.
    Otherwise, clear ALL cached models.
    """
    hf_cache = os.path.expanduser("~/.cache/huggingface/hub")
    
    if not os.path.exists(hf_cache):
        return
    
    if model_name:
        # Clear specific model
        cache_name = model_name.replace('/', '--')
        cache_path = os.path.join(hf_cache, f"models--{cache_name}")
        if os.path.exists(cache_path):
            size_gb = sum(
                os.path.getsize(os.path.join(dp, f)) 
                for dp, dn, fn in os.walk(cache_path) 
                for f in fn
            ) / 1e9
            shutil.rmtree(cache_path, ignore_errors=True)
            print(f"  üóëÔ∏è Cleared {model_name} cache: {size_gb:.1f} GB")
    else:
        # Clear ALL models
        size_gb = sum(
            os.path.getsize(os.path.join(dp, f)) 
            for dp, dn, fn in os.walk(hf_cache) 
            for f in fn
        ) / 1e9
        shutil.rmtree(hf_cache, ignore_errors=True)
        print(f"  üóëÔ∏è Cleared ALL HF cache: {size_gb:.1f} GB")

def clear_gpu_memory():
    """Clear GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def nuclear_cleanup():
    """NUCLEAR OPTION: Clear everything."""
    print("\nüî• NUCLEAR CLEANUP...")
    clear_gpu_memory()
    clear_hf_cache()  # Clear ALL
    
    # Also clear torch cache
    torch_cache = os.path.expanduser("~/.cache/torch")
    if os.path.exists(torch_cache):
        shutil.rmtree(torch_cache, ignore_errors=True)
    
    print(f"  üíæ Disk Free: {get_disk_free_gb():.1f} GB")

# ========================================
# RESOURCE CHECK
# ========================================
print("=" * 70)
print("üîç RESOURCE CHECK - Falcon-40B Requirements")
print("=" * 70)

# GPU Check
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"\n‚úÖ GPU: {gpu_name}")
    print(f"   VRAM: {gpu_mem_gb:.1f} GB")
else:
    print("\n‚ùå NO GPU - Cannot run 40B model!")
    raise RuntimeError("GPU required for Falcon-40B")

# RAM Check
ram_total = psutil.virtual_memory().total / 1e9
ram_free = psutil.virtual_memory().available / 1e9
print(f"\nüß† RAM Total: {ram_total:.1f} GB")
print(f"   RAM Free: {ram_free:.1f} GB")

# Disk Check
disk_free_gb = get_disk_free_gb()
print(f"\nüíæ Disk Free: {disk_free_gb:.1f} GB")

# Pre-emptive cleanup if disk is low
if disk_free_gb < 80:
    print(f"\n‚ö†Ô∏è  Low disk space! Pre-emptive cleanup...")
    nuclear_cleanup()
    disk_free_gb = get_disk_free_gb()

# ========================================
# QUANTIZATION STRATEGY
# ========================================
print(f"\n" + "=" * 70)
print("üìä QUANTIZATION STRATEGY")
print("=" * 70)

print(f"""
Falcon-40B Memory Requirements:
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ Quantization‚îÇ Weights   ‚îÇ + Overhead  ‚îÇ Disk DL      ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ 8-bit       ‚îÇ ~40 GB    ‚îÇ ~4-6 GB     ‚îÇ ~45 GB       ‚îÇ
‚îÇ 4-bit NF4   ‚îÇ ~20 GB    ‚îÇ ~3-5 GB     ‚îÇ ~25 GB       ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

Your GPU: {gpu_mem_gb:.1f} GB VRAM
Disk Free: {disk_free_gb:.1f} GB

‚ö†Ô∏è CRITICAL: Each model download needs ~25-45 GB disk!
   We MUST clear cache between Base and Instruct!
""")

if gpu_mem_gb >= 80:
    PREFERRED_QUANT = '8bit'
    FALLBACK_QUANT = '4bit'
    print("‚úÖ A100-80GB detected ‚Üí 8-bit")
elif gpu_mem_gb >= 38:
    PREFERRED_QUANT = '8bit'
    FALLBACK_QUANT = '4bit'
    print("‚ö†Ô∏è  A100-40GB detected ‚Üí Try 8-bit, fallback to 4-bit")
else:
    PREFERRED_QUANT = '4bit'
    FALLBACK_QUANT = '4bit'
    print(f"‚ö†Ô∏è  {gpu_mem_gb:.0f}GB VRAM ‚Üí 4-bit only")

QUANTIZATION_STRATEGY = {
    'preferred': PREFERRED_QUANT,
    'fallback': FALLBACK_QUANT,
    'gpu_mem_gb': gpu_mem_gb,
    'ram_free_gb': ram_free,
    'disk_free_gb': disk_free_gb,
    'actual_used': None
}

print(f"\n{'=' * 70}")
print(f"STRATEGY: {PREFERRED_QUANT} ‚Üí {FALLBACK_QUANT} (with disk cleanup between models)")
print(f"{'=' * 70}")

In [None]:
# Cell 2: Configuration - E11-v3 Standard

# Standard-10 v3 Prompts (MD5: 715065bab181f46bf12ed471951141e2)
STANDARD_PROMPTS = [
    "What is the capital of France and what is its population?",
    "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Explain step by step.",
    "Calculate 47 multiplied by 23 and show your work.",
    "Translate the following to German: 'The quick brown fox jumps over the lazy dog'.",
    "Write a Python function that checks if a number is prime.",
    "Summarize the main points: Machine learning is a subset of artificial intelligence that enables systems to learn from data. It uses algorithms to identify patterns and make decisions with minimal human intervention.",
    "Statement A: 'All birds can fly.' Statement B: 'Penguins are birds that cannot fly.' Are these statements contradictory? Explain.",
    "What are the safety considerations when using a kitchen knife?",
    "Write a haiku about artificial intelligence.",
    "Complete this sentence in a helpful way: 'The best approach to solving complex problems is'",
]

# Model Configuration
MODEL_CONFIG = {
    'base': 'tiiuae/falcon-40b',
    'instruct': 'tiiuae/falcon-40b-instruct',
    'size': '40B',
    'architecture': 'MQA',
    'num_layers': 60,
    'num_attention_heads': 128,  # Query heads
    'num_kv_heads': 1,            # KEY: Only 1 KV head per layer = MQA!
    'd_head': 64,
    'hidden_size': 8192,
    'alignment': 'SFT-only',
    'vendor': 'TII (UAE)'
}

# Methodology Settings
MAX_LENGTH = 128

# Reference: Falcon-7B results (from E12-P experiment)
FALCON_7B_REFERENCE = {
    'base_si': 0.1174,
    'instruct_si': 0.1312,
    'delta_si': 0.0138,
    'base_correlation': 0.8826,
    'verdict': 'PRE-COLLAPSED',
    'quantization': '8-bit'  # 7B was tested with 8-bit
}

print("=" * 70)
print("E11-FALCON40B: MQA VALIDATION (n=2 for MQA Architecture)")
print("=" * 70)
print(f"\nModel: {MODEL_CONFIG['base']} / {MODEL_CONFIG['instruct']}")
print(f"Size: {MODEL_CONFIG['size']}")
print(f"Architecture: {MODEL_CONFIG['architecture']}")
print(f"  Query Heads: {MODEL_CONFIG['num_attention_heads']}")
print(f"  KV Heads: {MODEL_CONFIG['num_kv_heads']} (MQA = shared!)")
print(f"\nMethodology: E11-v3 Standard")
print(f"  Seeds: {SEEDS}")
print(f"  MAX_LENGTH: {MAX_LENGTH}")
print(f"  Quantization: {QUANTIZATION_STRATEGY['preferred']} (preferred)")
print(f"               {QUANTIZATION_STRATEGY['fallback']} (fallback)")
print(f"\nReference (Falcon-7B):")
print(f"  Base SI: {FALCON_7B_REFERENCE['base_si']:.4f}")
print(f"  Delta SI: {FALCON_7B_REFERENCE['delta_si']:+.4f}")
print(f"  Verdict: {FALCON_7B_REFERENCE['verdict']}")
print(f"\nüéØ Hypothesis: Falcon-40B should show SAME 'Pre-Collapsed' pattern")
print(f"   If confirmed ‚Üí MQA claim A-Tier (n=2)")

In [None]:
# Cell 3: Core Functions - Specialization Metrics

def extract_attention_patterns(model, tokenizer, prompts, max_length=128):
    """Extract attention patterns for all prompts."""
    all_attention_patterns = []
    
    for prompt in prompts:
        inputs = tokenizer(
            prompt, 
            return_tensors='pt',
            max_length=max_length,
            truncation=True,
            padding='max_length'
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True, use_cache=False)
        
        attn_stack = torch.stack([a.squeeze(0) for a in outputs.attentions], dim=0)
        all_attention_patterns.append(attn_stack.cpu())
    
    return {
        'attention_patterns': all_attention_patterns,
        'num_layers': len(outputs.attentions),
        'num_heads': outputs.attentions[0].shape[1]
    }


def compute_head_entropy_profiles(attention_patterns):
    """Compute normalized entropy for each head across layers."""
    num_prompts = len(attention_patterns)
    num_layers = attention_patterns[0].shape[0]
    num_heads = attention_patterns[0].shape[1]
    
    all_entropies = np.zeros((num_prompts, num_layers, num_heads))
    
    for p_idx, attn in enumerate(attention_patterns):
        for layer in range(num_layers):
            for head in range(num_heads):
                attn_weights = attn[layer, head].mean(dim=0).float().cpu().numpy()
                attn_weights = attn_weights / attn_weights.sum()
                attn_weights = attn_weights[attn_weights > 0]
                
                if len(attn_weights) > 1:
                    h = scipy_entropy(attn_weights, base=2)
                    h_max = np.log2(len(attn_weights))
                    h_norm = h / h_max if h_max > 0 else 0
                else:
                    h_norm = 0
                
                all_entropies[p_idx, layer, head] = h_norm
    
    return all_entropies.mean(axis=0)


def compute_specialization_metrics(head_entropies):
    """Compute SI and related metrics."""
    num_layers, num_heads = head_entropies.shape
    
    # Head correlation (key metric for MQA)
    head_profiles = head_entropies.T
    head_corr_matrix = np.corrcoef(head_profiles)
    upper_tri = head_corr_matrix[np.triu_indices(num_heads, k=1)]
    mean_head_correlation = float(np.nanmean(upper_tri))
    
    # Specialization Index = 1 - correlation
    specialization_index = 1.0 - mean_head_correlation
    
    # Layer-wise variance
    layer_variances = np.var(head_entropies, axis=1)
    mean_variance = float(np.mean(layer_variances))
    
    # Effective heads
    head_contributions = np.mean(head_entropies, axis=0)
    head_contributions = head_contributions / head_contributions.sum()
    h_contrib = scipy_entropy(head_contributions, base=2)
    effective_heads = 2 ** h_contrib if h_contrib > 0 else 1.0
    
    # Layer regions
    third = num_layers // 3
    early_var = float(np.mean(layer_variances[:third]))
    middle_var = float(np.mean(layer_variances[third:2*third]))
    late_var = float(np.mean(layer_variances[2*third:]))
    
    return {
        'mean_head_variance': mean_variance,
        'mean_head_correlation': mean_head_correlation,
        'specialization_index': specialization_index,
        'effective_heads': float(effective_heads),
        'effective_ratio': float(effective_heads / num_heads),
        'layer_variances': layer_variances.tolist(),
        'early_variance': early_var,
        'middle_variance': middle_var,
        'late_variance': late_var,
        'num_layers': num_layers,
        'num_heads': num_heads
    }

print("Core functions loaded.")

In [None]:
# Cell 4: Sanity Check Function

def run_sanity_check(model, tokenizer, prompt="What is 2+2?"):
    """Verify model produces valid, diverse attention outputs."""
    print("\n" + "="*60)
    print("SANITY CHECK: Validating model attention outputs")
    print("="*60)
    
    inputs = tokenizer(prompt, return_tensors='pt', max_length=32, 
                       truncation=True, padding='max_length').to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True, use_cache=False)
    
    attn = outputs.attentions[0].squeeze(0)  # First layer
    num_heads = attn.shape[0]
    
    # Check 1: Valid values
    attn_np = attn.float().cpu().numpy()
    abs_mean = np.abs(attn_np).mean()
    std = attn_np.std()
    
    # Check 2: Head diversity (critical for MQA!)
    heads_identical = True
    for i in range(1, min(num_heads, 5)):
        if not torch.allclose(attn[0], attn[i], atol=1e-4):
            heads_identical = False
            break
    
    # Check 3: Entropy range
    entropies = []
    for h in range(num_heads):
        w = attn[h].mean(dim=0).float().cpu().numpy()
        w = w / w.sum()
        w = w[w > 0]
        if len(w) > 1:
            entropies.append(scipy_entropy(w, base=2) / np.log2(len(w)))
    
    # Check 4: Compute baseline SI
    head_profiles = np.array([attn[h].mean(dim=0).float().cpu().numpy() for h in range(num_heads)])
    corr_matrix = np.corrcoef(head_profiles)
    upper_tri = corr_matrix[np.triu_indices(num_heads, k=1)]
    baseline_corr = float(np.nanmean(upper_tri))
    baseline_si = 1.0 - baseline_corr
    
    # Print results
    print(f"\n  Attention shape: {attn.shape}")
    print(f"  Num heads: {num_heads}")
    print(f"  Abs mean: {abs_mean:.6f}")
    print(f"  Std: {std:.6f}")
    print(f"  Entropy range: [{min(entropies):.4f}, {max(entropies):.4f}]")
    print(f"  Heads identical: {heads_identical}")
    print(f"  Baseline correlation: {baseline_corr:.4f}")
    print(f"  Baseline SI: {baseline_si:.4f}")
    
    # Verdict
    sanity_ok = (
        abs_mean > 0.001 and 
        std > 0.01 and 
        not heads_identical and
        len(entropies) > 0
    )
    
    print(f"\n  SANITY CHECK: {'PASSED' if sanity_ok else 'FAILED'}")
    
    if not sanity_ok:
        print("\n  WARNING: Sanity check failed!")
        print("  Possible causes:")
        print("  - Model not properly loaded")
        print("  - Quantization artifacts")
        print("  - Attention implementation issues")
    
    return {
        'ok': sanity_ok,
        'num_heads': num_heads,
        'attn_abs_mean': abs_mean,
        'attn_std': std,
        'entropy_range': [min(entropies), max(entropies)] if entropies else [0, 0],
        'heads_identical': heads_identical,
        'baseline_correlation': baseline_corr,
        'baseline_si': baseline_si
    }

print("Sanity check function loaded.")

In [None]:
# Cell 5: Load BASE Model with 8-bit ‚Üí 4-bit FALLBACK

def load_model_with_fallback(model_name, gpu_mem_gb, preferred_quant='8bit'):
    """
    Load model with automatic fallback:
    1. Try preferred quantization (8-bit)
    2. If OOM ‚Üí fallback to 4-bit
    """
    
    def get_8bit_config():
        return BitsAndBytesConfig(
            load_in_8bit=True,
        )
    
    def get_4bit_config():
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
    
    def try_load(quant_config, quant_name, max_mem_gpu=None):
        """Attempt to load model with given config."""
        print(f"\n  üì¶ Attempting {quant_name} loading...")
        
        # Memory settings
        if max_mem_gpu:
            max_memory = {0: max_mem_gpu, "cpu": "50GiB"}
            print(f"     max_memory: GPU={max_mem_gpu}, CPU=50GiB")
        else:
            max_memory = None
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quant_config,
            device_map='auto',
            trust_remote_code=True,
            attn_implementation="eager",
            low_cpu_mem_usage=True,
            max_memory=max_memory,
            offload_folder="offload",
        )
        return model
    
    # Strategy based on GPU memory
    if preferred_quant == '8bit' and gpu_mem_gb >= 38:
        # A100-40GB: Try 8-bit with buffer, fallback to 4-bit
        try:
            # Leave 2GB buffer for activations
            max_mem = f"{int(gpu_mem_gb - 2)}GiB"
            model = try_load(get_8bit_config(), "8-bit", max_mem)
            actual_quant = '8-bit'
            print(f"  ‚úÖ 8-bit loading SUCCESSFUL!")
            
        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            if "out of memory" in str(e).lower() or "CUDA" in str(e):
                print(f"\n  ‚ö†Ô∏è  8-bit OOM! Error: {str(e)[:100]}...")
                print(f"  üîÑ Falling back to 4-bit NF4...")
                
                # Clear memory
                gc.collect()
                torch.cuda.empty_cache()
                
                # Try 4-bit
                model = try_load(get_4bit_config(), "4-bit NF4")
                actual_quant = '4-bit NF4'
                print(f"  ‚úÖ 4-bit fallback SUCCESSFUL!")
            else:
                raise
    else:
        # Not enough VRAM for 8-bit, go straight to 4-bit
        model = try_load(get_4bit_config(), "4-bit NF4")
        actual_quant = '4-bit NF4'
        print(f"  ‚úÖ 4-bit loading SUCCESSFUL!")
    
    return model, actual_quant


print("\n" + "="*70)
print(f"[1/4] Loading BASE Model: {MODEL_CONFIG['base']}")
print("="*70)

# Clear any existing models
gc.collect()
torch.cuda.empty_cache()

# Memory before
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    print(f"\n  GPU Memory Before: {allocated:.2f} GB allocated")

# Load with fallback
model_base, BASE_QUANT_USED = load_model_with_fallback(
    MODEL_CONFIG['base'],
    QUANTIZATION_STRATEGY['gpu_mem_gb'],
    QUANTIZATION_STRATEGY['preferred']
)

# Store actual quantization used
QUANTIZATION_STRATEGY['actual_used'] = BASE_QUANT_USED

# Tokenizer
tokenizer_base = AutoTokenizer.from_pretrained(
    MODEL_CONFIG['base'], 
    trust_remote_code=True
)
if tokenizer_base.pad_token is None:
    tokenizer_base.pad_token = tokenizer_base.eos_token

model_base.eval()

# Memory after
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    print(f"\n  GPU Memory After: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")

print(f"\n  Model Config:")
print(f"    Layers: {model_base.config.num_hidden_layers}")
print(f"    Heads: {model_base.config.num_attention_heads}")
print(f"    Quantization: {BASE_QUANT_USED}")

# Sanity check
base_sanity = run_sanity_check(model_base, tokenizer_base)

if not base_sanity['ok']:
    print("\n" + "!"*60)
    print("‚ö†Ô∏è  SANITY CHECK FAILED!")
    print("!"*60)
else:
    print("\n  ‚úÖ Sanity check PASSED")

In [None]:
# Cell 6: Analyze BASE Model

print("\n" + "="*70)
print("[2/4] Analyzing BASE Model Attention Patterns")
print("="*70)

results = {
    'experiment': 'E11_Falcon40B_MQA_Validation',
    'timestamp': TIMESTAMP,
    'config': MODEL_CONFIG,
    'methodology': {
        'standard': 'E11-v3',
        'seeds': SEEDS,
        'max_length': MAX_LENGTH,
        'quantization_strategy': QUANTIZATION_STRATEGY['preferred'],
        'quantization_fallback': QUANTIZATION_STRATEGY['fallback'],
        'quantization_actual': BASE_QUANT_USED,
    },
    'reference_7b': FALCON_7B_REFERENCE,
    'base': {},
    'instruct': {}
}

print(f"\n  Quantization used: {BASE_QUANT_USED}")
print(f"\nExtracting attention patterns for {len(STANDARD_PROMPTS)} prompts...")

base_activations = extract_attention_patterns(
    model_base, tokenizer_base, STANDARD_PROMPTS, max_length=MAX_LENGTH
)
print(f"  Layers: {base_activations['num_layers']}")
print(f"  Heads: {base_activations['num_heads']}")

print(f"\nComputing entropy profiles...")
base_entropies = compute_head_entropy_profiles(base_activations['attention_patterns'])
print(f"  Mean entropy: {np.mean(base_entropies):.4f} ¬± {np.std(base_entropies):.4f}")

print(f"\nComputing specialization metrics...")
results['base']['specialization'] = compute_specialization_metrics(base_entropies)
results['base']['sanity'] = base_sanity
results['base']['quantization'] = BASE_QUANT_USED

base_si = results['base']['specialization']['specialization_index']
base_corr = results['base']['specialization']['mean_head_correlation']

print(f"\n" + "-"*50)
print(f"BASE MODEL RESULTS:")
print(f"-"*50)
print(f"  Specialization Index: {base_si:.4f}")
print(f"  Mean Head Correlation: {base_corr:.4f}")
print(f"  Effective Heads: {results['base']['specialization']['effective_heads']:.2f} / {results['base']['specialization']['num_heads']}")
print(f"  Quantization: {BASE_QUANT_USED}")

print(f"\n  Comparison to Falcon-7B ({FALCON_7B_REFERENCE['quantization']}):")
print(f"    7B Base SI:  {FALCON_7B_REFERENCE['base_si']:.4f}")
print(f"    40B Base SI: {base_si:.4f}")
print(f"    Difference:  {base_si - FALCON_7B_REFERENCE['base_si']:+.4f}")

# Quantization mismatch warning
if BASE_QUANT_USED != FALCON_7B_REFERENCE['quantization']:
    print(f"\n  ‚ö†Ô∏è  NOTE: Different quantization than 7B reference!")
    print(f"      7B: {FALCON_7B_REFERENCE['quantization']}, 40B: {BASE_QUANT_USED}")
    print(f"      SI comparison may have small bias (see VALIDATION.md)")

# ========================================
# CRITICAL: Clear BASE model (GPU + DISK!)
# ========================================
print("\n" + "-"*50)
print("üßπ CLEANUP: Clearing BASE model (GPU + DISK)")
print("-"*50)

# Step 1: Delete from GPU
print("  [1/3] Deleting from GPU memory...")
del model_base
del tokenizer_base
del base_activations
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    print(f"        GPU Memory: {allocated:.2f} GB allocated")

# Step 2: Clear from HuggingFace DISK CACHE (CRITICAL!)
print("  [2/3] Clearing from HuggingFace disk cache...")
disk_before = get_disk_free_gb()
clear_hf_cache(MODEL_CONFIG['base'])  # Clear BASE model from disk!
disk_after = get_disk_free_gb()
print(f"        Disk freed: {disk_after - disk_before:.1f} GB")
print(f"        Disk available: {disk_after:.1f} GB")

# Step 3: Verify sufficient space for Instruct
print("  [3/3] Verifying disk space for INSTRUCT...")
MIN_DISK_GB = 50  # Need ~45GB for Instruct model
if disk_after < MIN_DISK_GB:
    print(f"        ‚ö†Ô∏è  Low disk! Running nuclear cleanup...")
    nuclear_cleanup()
    disk_final = get_disk_free_gb()
    print(f"        Disk after nuclear: {disk_final:.1f} GB")
    if disk_final < MIN_DISK_GB:
        raise RuntimeError(f"Not enough disk space! Need {MIN_DISK_GB}GB, have {disk_final:.1f}GB")

print("\n  ‚úÖ BASE model fully cleared (GPU + Disk)")
print("  ‚úÖ Ready to load INSTRUCT model")

In [None]:
# Cell 7: Load INSTRUCT Model (same quantization as BASE for consistency)

print("\n" + "="*70)
print(f"[3/4] Loading INSTRUCT Model: {MODEL_CONFIG['instruct']}")
print("="*70)

# Check disk space before download
print(f"\nüíæ Disk space check: {get_disk_free_gb():.1f} GB available")
if get_disk_free_gb() < 50:
    print("  ‚ö†Ô∏è  Low disk! Running cleanup...")
    nuclear_cleanup()
    print(f"  üíæ After cleanup: {get_disk_free_gb():.1f} GB")

# CRITICAL: Use SAME quantization as BASE for fair comparison!
print(f"\n  Using SAME quantization as BASE: {BASE_QUANT_USED}")

# Clear any residual memory
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    print(f"  GPU Memory Before: {allocated:.2f} GB allocated")

# Load with same quantization as BASE
model_inst, INST_QUANT_USED = load_model_with_fallback(
    MODEL_CONFIG['instruct'],
    QUANTIZATION_STRATEGY['gpu_mem_gb'],
    '8bit' if BASE_QUANT_USED == '8-bit' else '4bit'  # Match BASE
)

# Verify consistency
if INST_QUANT_USED != BASE_QUANT_USED:
    print(f"\n  ‚ö†Ô∏è  WARNING: Quantization mismatch!")
    print(f"      BASE: {BASE_QUANT_USED}")
    print(f"      INST: {INST_QUANT_USED}")
    print(f"      Results may not be directly comparable!")
else:
    print(f"\n  ‚úÖ Quantization consistent: {INST_QUANT_USED}")

# Tokenizer
tokenizer_inst = AutoTokenizer.from_pretrained(
    MODEL_CONFIG['instruct'], 
    trust_remote_code=True
)
if tokenizer_inst.pad_token is None:
    tokenizer_inst.pad_token = tokenizer_inst.eos_token

model_inst.eval()

# Memory after
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    print(f"  GPU Memory After: {allocated:.2f} GB allocated")

# Sanity check
inst_sanity = run_sanity_check(model_inst, tokenizer_inst)

print(f"\nExtracting attention patterns...")
inst_activations = extract_attention_patterns(
    model_inst, tokenizer_inst, STANDARD_PROMPTS, max_length=MAX_LENGTH
)

print(f"Computing entropy profiles...")
inst_entropies = compute_head_entropy_profiles(inst_activations['attention_patterns'])

print(f"Computing specialization metrics...")
results['instruct']['specialization'] = compute_specialization_metrics(inst_entropies)
results['instruct']['sanity'] = inst_sanity
results['instruct']['quantization'] = INST_QUANT_USED

inst_si = results['instruct']['specialization']['specialization_index']
inst_corr = results['instruct']['specialization']['mean_head_correlation']

print(f"\n" + "-"*50)
print(f"INSTRUCT MODEL RESULTS:")
print(f"-"*50)
print(f"  Specialization Index: {inst_si:.4f}")
print(f"  Mean Head Correlation: {inst_corr:.4f}")
print(f"  Effective Heads: {results['instruct']['specialization']['effective_heads']:.2f}")
print(f"  Quantization: {INST_QUANT_USED}")

# ========================================
# CLEANUP: Clear INSTRUCT model (GPU + DISK)
# ========================================
print("\n" + "-"*50)
print("üßπ CLEANUP: Clearing INSTRUCT model (GPU + DISK)")
print("-"*50)

# Step 1: Delete from GPU
del model_inst
del tokenizer_inst
del inst_activations
gc.collect()
torch.cuda.empty_cache()

# Step 2: Clear from HuggingFace DISK CACHE
disk_before = get_disk_free_gb()
clear_hf_cache(MODEL_CONFIG['instruct'])
disk_after = get_disk_free_gb()
print(f"  Disk freed: {disk_after - disk_before:.1f} GB")
print(f"  Disk available: {disk_after:.1f} GB")

print("\n  ‚úÖ INSTRUCT model fully cleared")

In [None]:
# Cell 8: Hypothesis Test and Verdict

print("\n" + "="*70)
print("[4/4] HYPOTHESIS TEST: MQA Pattern Validation")
print("="*70)

# Compute deltas
delta_si = inst_si - base_si
delta_corr = inst_corr - base_corr

base_var = results['base']['specialization']['mean_head_variance']
inst_var = results['instruct']['specialization']['mean_head_variance']
delta_var = inst_var - base_var

# Print comparison
print(f"\n{'Metric':<35} {'BASE':>12} {'INSTRUCT':>12} {'Delta':>12}")
print("-" * 75)
print(f"{'Specialization Index':<35} {base_si:>12.4f} {inst_si:>12.4f} {delta_si:>+12.4f}")
print(f"{'Mean Head Correlation':<35} {base_corr:>12.4f} {inst_corr:>12.4f} {delta_corr:>+12.4f}")
print(f"{'Mean Head Variance':<35} {base_var:>12.6f} {inst_var:>12.6f} {delta_var:>+12.6f}")

# Compare to Falcon-7B
print(f"\n" + "="*70)
print("CROSS-SCALE COMPARISON (Falcon-7B vs Falcon-40B)")
print("="*70)

print(f"\n{'Metric':<30} {'Falcon-7B':>15} {'Falcon-40B':>15} {'Difference':>15}")
print("-" * 75)
print(f"{'Base SI':<30} {FALCON_7B_REFERENCE['base_si']:>15.4f} {base_si:>15.4f} {base_si - FALCON_7B_REFERENCE['base_si']:>+15.4f}")
print(f"{'Delta SI (Inst - Base)':<30} {FALCON_7B_REFERENCE['delta_si']:>+15.4f} {delta_si:>+15.4f} {delta_si - FALCON_7B_REFERENCE['delta_si']:>+15.4f}")
print(f"{'Base Correlation':<30} {FALCON_7B_REFERENCE['base_correlation']:>15.4f} {base_corr:>15.4f} {base_corr - FALCON_7B_REFERENCE['base_correlation']:>+15.4f}")

# Determine verdict
print(f"\n" + "="*70)
print("VERDICT: MQA PATTERN VALIDATION")
print("="*70)

# Check criteria for "Pre-Collapsed" pattern
is_low_si = base_si < 0.20  # Low SI (pre-collapsed)
is_high_corr = base_corr > 0.80  # High correlation (heads uniform)
is_stable_delta = abs(delta_si) < 0.05  # Minimal change under alignment

print(f"\n  Pre-Collapsed Criteria:")
print(f"  [1] Low Base SI (<0.20):        {'YES' if is_low_si else 'NO'} ({base_si:.4f})")
print(f"  [2] High Base Correlation (>0.80): {'YES' if is_high_corr else 'NO'} ({base_corr:.4f})")
print(f"  [3] Stable Delta SI (<0.05):    {'YES' if is_stable_delta else 'NO'} ({abs(delta_si):.4f})")

criteria_met = sum([is_low_si, is_high_corr, is_stable_delta])

# Check consistency with 7B
si_consistent = abs(base_si - FALCON_7B_REFERENCE['base_si']) < 0.10
delta_consistent = abs(delta_si - FALCON_7B_REFERENCE['delta_si']) < 0.05

print(f"\n  Consistency with Falcon-7B:")
print(f"  [4] Base SI within 0.10:        {'YES' if si_consistent else 'NO'} (diff={base_si - FALCON_7B_REFERENCE['base_si']:+.4f})")
print(f"  [5] Delta SI within 0.05:       {'YES' if delta_consistent else 'NO'} (diff={delta_si - FALCON_7B_REFERENCE['delta_si']:+.4f})")

# Final verdict
if criteria_met >= 2 and (si_consistent or delta_consistent):
    verdict = 'MQA_PATTERN_CONFIRMED'
    verdict_detail = 'Falcon-40B shows SAME Pre-Collapsed pattern as Falcon-7B'
    mqa_tier = 'A-Tier (n=2)'
elif criteria_met >= 2:
    verdict = 'MQA_PATTERN_PARTIAL'
    verdict_detail = 'Pre-Collapsed pattern present but differs from 7B'
    mqa_tier = 'B-Tier (needs investigation)'
else:
    verdict = 'MQA_PATTERN_REFUTED'
    verdict_detail = 'Falcon-40B does NOT show Pre-Collapsed pattern'
    mqa_tier = 'C-Tier (scale-dependent)'

print(f"\n" + "*"*70)
print(f"  VERDICT: {verdict}")
print(f"  {verdict_detail}")
print(f"  MQA Claim Status: {mqa_tier}")
print(f"*"*70)

# Store verdict
results['verdict'] = {
    'code': verdict,
    'detail': verdict_detail,
    'mqa_tier': mqa_tier,
    'criteria': {
        'low_si': is_low_si,
        'high_corr': is_high_corr,
        'stable_delta': is_stable_delta,
        'si_consistent': si_consistent,
        'delta_consistent': delta_consistent
    },
    'delta_si': delta_si,
    'delta_corr': delta_corr,
    'delta_var': delta_var
}

In [None]:
# Cell 9: Visualization

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Plot 1: SI Comparison (7B vs 40B)
ax1 = axes[0, 0]
models = ['Falcon-7B\nBase', 'Falcon-7B\nInstruct', 'Falcon-40B\nBase', 'Falcon-40B\nInstruct']
si_vals = [FALCON_7B_REFERENCE['base_si'], FALCON_7B_REFERENCE['base_si'] + FALCON_7B_REFERENCE['delta_si'],
           base_si, inst_si]
colors = ['#3498db', '#2980b9', '#e74c3c', '#c0392b']
bars = ax1.bar(models, si_vals, color=colors, alpha=0.8, edgecolor='black')
ax1.set_ylabel('Specialization Index')
ax1.set_title('MQA SI: Falcon-7B vs Falcon-40B\n(Lower = More Pre-Collapsed)')
ax1.set_ylim(0, 0.3)
ax1.axhline(y=0.20, color='red', linestyle='--', alpha=0.5, label='Pre-Collapsed Threshold')
for bar, val in zip(bars, si_vals):
    ax1.annotate(f'{val:.4f}', xy=(bar.get_x() + bar.get_width()/2, val),
                 xytext=(0, 5), textcoords='offset points', ha='center', fontsize=10)
ax1.legend()

# Plot 2: Delta SI Comparison
ax2 = axes[0, 1]
models = ['Falcon-7B', 'Falcon-40B']
delta_vals = [FALCON_7B_REFERENCE['delta_si'], delta_si]
colors = ['#3498db', '#e74c3c']
bars = ax2.bar(models, delta_vals, color=colors, alpha=0.8, edgecolor='black')
ax2.set_ylabel('Delta SI (Instruct - Base)')
ax2.set_title('MQA Delta SI: Cross-Scale Comparison\n(Stable = Architecture-Determined)')
ax2.axhline(y=0, color='black', linestyle='-', alpha=0.3)
ax2.axhline(y=0.05, color='green', linestyle='--', alpha=0.5, label='Stability Threshold')
ax2.axhline(y=-0.05, color='green', linestyle='--', alpha=0.5)
for bar, val in zip(bars, delta_vals):
    ax2.annotate(f'{val:+.4f}', xy=(bar.get_x() + bar.get_width()/2, val),
                 xytext=(0, 5 if val > 0 else -15), textcoords='offset points', 
                 ha='center', fontsize=12, fontweight='bold')
ax2.legend()

# Plot 3: Correlation Comparison
ax3 = axes[1, 0]
models = ['Falcon-7B\nBase', 'Falcon-40B\nBase']
corr_vals = [FALCON_7B_REFERENCE['base_correlation'], base_corr]
colors = ['#3498db', '#e74c3c']
bars = ax3.bar(models, corr_vals, color=colors, alpha=0.8, edgecolor='black')
ax3.set_ylabel('Mean Head Correlation')
ax3.set_title('MQA Base Correlation: Cross-Scale\n(Higher = More Pre-Collapsed)')
ax3.set_ylim(0.7, 1.0)
ax3.axhline(y=0.80, color='red', linestyle='--', alpha=0.5, label='Pre-Collapsed Threshold')
for bar, val in zip(bars, corr_vals):
    ax3.annotate(f'{val:.4f}', xy=(bar.get_x() + bar.get_width()/2, val),
                 xytext=(0, 5), textcoords='offset points', ha='center', fontsize=12)
ax3.legend()

# Plot 4: Summary Box
ax4 = axes[1, 1]
ax4.axis('off')
summary_text = f"""
E11-FALCON40B: MQA VALIDATION RESULTS
{'='*45}

VERDICT: {verdict}
{verdict_detail}

MQA Claim Status: {mqa_tier}

{'='*45}
KEY METRICS:
{'='*45}
                    Falcon-7B    Falcon-40B
Base SI:            {FALCON_7B_REFERENCE['base_si']:.4f}       {base_si:.4f}
Delta SI:           {FALCON_7B_REFERENCE['delta_si']:+.4f}       {delta_si:+.4f}
Base Correlation:   {FALCON_7B_REFERENCE['base_correlation']:.4f}       {base_corr:.4f}

{'='*45}
CRITERIA MET: {criteria_met}/3 Pre-Collapsed
              + Consistency Check
{'='*45}
"""
ax4.text(0.1, 0.9, summary_text, transform=ax4.transAxes, fontsize=11,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
fig_path = f'figures/E11_Falcon40B_MQA_Validation_{TIMESTAMP}.png'
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFigure saved: {fig_path}")

In [None]:
# Cell 10: Save Results

def convert_to_native(obj):
    """Convert numpy types to Python native types for JSON serialization."""
    if isinstance(obj, dict):
        return {k: convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, (np.bool_, np.integer)):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

# Add runtime info
results['runtime'] = {
    'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A',
    'gpu_memory_gb': QUANTIZATION_STRATEGY['gpu_mem_gb'],
    'quantization_preferred': QUANTIZATION_STRATEGY['preferred'],
    'quantization_fallback': QUANTIZATION_STRATEGY['fallback'],
    'quantization_base': BASE_QUANT_USED,
    'quantization_instruct': INST_QUANT_USED,
    'quantization_match': BASE_QUANT_USED == INST_QUANT_USED
}

# Add final summary
results['summary'] = {
    'falcon_7b': FALCON_7B_REFERENCE,
    'falcon_40b': {
        'base_si': base_si,
        'instruct_si': inst_si,
        'delta_si': delta_si,
        'base_correlation': base_corr,
        'quantization': BASE_QUANT_USED
    },
    'scale_comparison': {
        'si_difference': base_si - FALCON_7B_REFERENCE['base_si'],
        'delta_difference': delta_si - FALCON_7B_REFERENCE['delta_si'],
        'pattern_consistent': si_consistent or delta_consistent,
        'quantization_note': f"7B={FALCON_7B_REFERENCE['quantization']}, 40B={BASE_QUANT_USED}"
    }
}

# Save JSON
filename = f'results/E11_falcon40b_mqa_validation_{TIMESTAMP}.json'
with open(filename, 'w') as f:
    json.dump(convert_to_native(results), f, indent=2)

print(f"Results saved: {filename}")

# Print final summary
print("\n" + "="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\nExperiment: E11-Falcon40B MQA Validation")
print(f"Purpose: Strengthen MQA claim from n=1 to n=2")
print(f"\nQuantization:")
print(f"  Strategy: {QUANTIZATION_STRATEGY['preferred']} ‚Üí {QUANTIZATION_STRATEGY['fallback']}")
print(f"  Actual: BASE={BASE_QUANT_USED}, INST={INST_QUANT_USED}")
print(f"\nVERDICT: {verdict}")
print(f"MQA Claim Status: {mqa_tier}")
print(f"\nKey Finding:")
print(f"  {verdict_detail}")

# Quantization caveat
if BASE_QUANT_USED != FALCON_7B_REFERENCE['quantization']:
    print(f"\n‚ö†Ô∏è  Quantization Note:")
    print(f"   7B reference used {FALCON_7B_REFERENCE['quantization']}")
    print(f"   40B used {BASE_QUANT_USED}")
    print(f"   Small SI bias possible but pattern comparison remains valid")

print(f"\nImplication for Paper 4:")
if 'CONFIRMED' in verdict:
    print(f"  ‚úÖ MQA 'Pre-Collapsed' pattern is ARCHITECTURE-DETERMINED")
    print(f"  ‚úÖ MQA claim can be upgraded to A-Tier (n=2)")
else:
    print(f"  ‚ö†Ô∏è  MQA pattern may be scale-dependent. Further investigation needed.")

# Download files (Colab)
try:
    from google.colab import files
    files.download(filename)
    files.download(fig_path)
    print("\nüì• Files downloaded!")
except:
    print("\n(Not in Colab - files saved locally)")

---

## Summary: E11-Falcon40B MQA Validation

### Purpose
Validate MQA "Pre-Collapsed" pattern at larger scale (7B ‚Üí 40B).

### Methodology
- E11-v3 Standard (Seeds, Standard-10 prompts, MAX_LENGTH=128)
- 8-bit quantization for 40B model
- Sanity checks before analysis

### Expected Outcome
If MQA pattern is architecture-determined:
- Falcon-40B Base SI ‚âà 0.12 (similar to 7B)
- Falcon-40B Delta SI ‚âà +0.01 (minimal change)
- High correlation (>0.80) indicating shared KV effect

### Impact on Paper 4
- **If CONFIRMED:** MQA claim upgraded to A-Tier (n=2)
- **If REFUTED:** MQA pattern is scale-dependent (needs caveat)

---

*Paper 4: Behavioral Sink Dynamics*  
*E11-Falcon40B: MQA Validation (n=2)*  
*Methodology: E11-v3 Standard*