# RLHF Safety Brake Test - Base vs Instruct

**Purpose:** Test Gemini's hypothesis that RLHF causes last-layer contraction

**Hypothesis:**
- LLaMA-3.1-8B **Base** (NO RLHF) → EXPANDS like Mistral (~1.37x)
- LLaMA-3.1-8B **Instruct** (WITH RLHF) → CONTRACTS (~0.48x)

**If confirmed:** RLHF is thermodynamically a "Safety Brake" (energy dampening)

**Evidence from Paper #1 Archive:**
```
META (Llama):
├── Normalisiert ALLES
├── Alle Statements ~0.58-0.68 (enger Range)
└── "Normalisierung" aller Positionen
```

**Connection across 3 Papers:**
- Paper #1: LLaMA shows highest uniformity (flattening)
- Paper #2: Chat templates cause uniformly negative correlations
- Paper #3: RLHF causes last-layer contraction (this test!)

In [None]:
# Cell 1: Setup and Imports
import torch
import numpy as np
import json
import gc
from datetime import datetime
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Colab setup
IN_COLAB = False
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running in Google Colab")
except:
    pass

# HuggingFace Login
try:
    from huggingface_hub import login
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
    if hf_token:
        login(token=hf_token)
        print("✅ HuggingFace login successful!")
except Exception as e:
    print(f"⚠️ HF login: {e}")

# GPU Check
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: RLHF Test Models Definition

RLHF_TEST_MODELS = {
    'llama3.1-8b-BASE': {
        'hf_name': 'meta-llama/Llama-3.1-8B',
        'layers': 32,
        'family': 'llama',
        'norm': 'RMSNorm',
        'mlp': 'SwiGLU',
        'rlhf': False,
        'expected': 'EXPANSION (~1.37x like Mistral)'
    },
    'llama3.1-8b-INSTRUCT': {
        'hf_name': 'meta-llama/Llama-3.1-8B-Instruct',
        'layers': 32,
        'family': 'llama',
        'norm': 'RMSNorm',
        'mlp': 'SwiGLU',
        'rlhf': True,
        'expected': 'CONTRACTION (~0.48x)'
    }
}

# Reference values
MISTRAL_REFERENCE = {'last_layer_gain': 1.37, 'initial_explosion': 43.86}
PREVIOUS_LLAMA_RESULT = {'last_layer_gain': 0.48, 'note': 'From 4-model validation'}

print("="*60)
print("RLHF SAFETY BRAKE TEST")
print("="*60)
print("\nModels to test:")
for name, info in RLHF_TEST_MODELS.items():
    rlhf = "WITH RLHF" if info['rlhf'] else "NO RLHF"
    print(f"  {name}:")
    print(f"    → {info['hf_name']}")
    print(f"    → {rlhf}")
    print(f"    → Expected: {info['expected']}")
    print()

In [None]:
# Cell 3: Model Loading
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_model(model_key, model_info):
    """Load model with error handling."""
    hf_name = model_info['hf_name']
    print(f"\nLoading {model_key} ({hf_name})...")
    
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(hf_name, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            hf_name,
            torch_dtype=dtype,
            device_map='auto',
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        model.eval()
        
        n_layers = len(model.model.layers)
        print(f"  ✅ Loaded! {n_layers} layers, dtype={dtype}")
        return model, tokenizer, dtype
        
    except Exception as e:
        print(f"  ❌ FAILED: {e}")
        return None, None, None

def cleanup_model(model):
    if model is not None:
        del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("  Memory cleaned.")

In [None]:
# Cell 4: Residual Stream Analyzer

class ResidualStreamAnalyzer:
    def __init__(self, model, model_info):
        self.model = model
        self.model_info = model_info
        self.hooks = []
        self.residual_norms = []
        self.embedding_norm = None
        
    def _make_embedding_hook(self):
        def hook(module, args, output):
            with torch.no_grad():
                self.embedding_norm = output.float().norm().item()
        return hook
    
    def _make_layer_hook(self, layer_idx):
        def hook(module, args, output):
            hidden = output[0] if isinstance(output, tuple) else output
            with torch.no_grad():
                norm = hidden.float().norm().item()
                self.residual_norms.append((layer_idx, norm))
        return hook
    
    def register_hooks(self):
        # Embedding
        h = self.model.model.embed_tokens.register_forward_hook(self._make_embedding_hook())
        self.hooks.append(h)
        # Layers
        for i, layer in enumerate(self.model.model.layers):
            h = layer.register_forward_hook(self._make_layer_hook(i))
            self.hooks.append(h)
        print(f"  Registered {len(self.hooks)} hooks")
    
    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []
    
    def get_layer_gains(self):
        sorted_norms = sorted(self.residual_norms, key=lambda x: x[0])
        all_norms = [("emb", self.embedding_norm)] + sorted_norms
        
        gains = []
        for i in range(1, len(all_norms)):
            prev = all_norms[i-1][1]
            curr = all_norms[i][1]
            gains.append(curr / prev if prev > 1e-8 else 0.0)
        
        norms = [n for _, n in all_norms]
        return gains, norms

In [None]:
# Cell 5: RUN THE RLHF TEST

print("="*60)
print("RUNNING RLHF SAFETY BRAKE TEST")
print("="*60)

rlhf_results = {}

for model_key, model_info in RLHF_TEST_MODELS.items():
    print(f"\n{'='*60}")
    print(f"Testing: {model_key}")
    print(f"RLHF: {'YES' if model_info['rlhf'] else 'NO'}")
    print(f"{'='*60}")
    
    model, tokenizer, dtype = load_model(model_key, model_info)
    if model is None:
        continue
    
    try:
        analyzer = ResidualStreamAnalyzer(model, model_info)
        analyzer.register_hooks()
        
        # Test prompt
        prompt = "The capital of France is"
        inputs = tokenizer(prompt, return_tensors='pt')
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        print(f"  Running forward pass...")
        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', dtype=dtype):
                outputs = model(**inputs)
        
        gains, norms = analyzer.get_layer_gains()
        
        # Results
        last_gain = gains[-1]
        initial_gain = gains[0]
        n_layers = len(gains)
        
        print(f"\n  RESULTS:")
        print(f"  Initial Gain (Emb→L0): {initial_gain:.2f}x")
        print(f"  Last Layer Gain: {last_gain:.4f}x")
        print(f"  {'>>> EXPANDS <<<' if last_gain > 1.0 else '>>> CONTRACTS <<<'}")
        
        # Compute cumulative energy
        cumulative = np.exp(np.sum(np.log(np.array(gains) + 1e-10)))
        print(f"  Cumulative Energy: {cumulative:.2e}")
        
        rlhf_results[model_key] = {
            'rlhf': model_info['rlhf'],
            'hf_name': model_info['hf_name'],
            'initial_gain': float(initial_gain),
            'last_gain': float(last_gain),
            'last_expands': bool(last_gain > 1.0),
            'cumulative_energy': float(cumulative),
            'gains': [float(g) for g in gains],
            'norms': [float(n) for n in norms]
        }
        
        analyzer.remove_hooks()
        
    except Exception as e:
        print(f"  ERROR: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        cleanup_model(model)

print(f"\n\n{'='*60}")
print(f"RLHF TEST COMPLETE - {len(rlhf_results)} models tested")
print(f"{'='*60}")

In [None]:
# Cell 6: RLHF Hypothesis Evaluation

print("="*60)
print("RLHF SAFETY BRAKE HYPOTHESIS EVALUATION")
print("="*60)

if len(rlhf_results) == 2:
    base = rlhf_results.get('llama3.1-8b-BASE')
    instruct = rlhf_results.get('llama3.1-8b-INSTRUCT')
    
    if base and instruct:
        print("\n| Model | RLHF | Last Gain | Expands? |")
        print("|-------|------|-----------|----------|")
        print(f"| BASE | NO | {base['last_gain']:.4f}x | {'YES' if base['last_expands'] else 'NO'} |")
        print(f"| INSTRUCT | YES | {instruct['last_gain']:.4f}x | {'YES' if instruct['last_expands'] else 'NO'} |")
        print(f"| Mistral (ref) | NO | 1.37x | YES |")
        
        # Hypothesis test
        ratio = instruct['last_gain'] / base['last_gain'] if base['last_gain'] != 0 else 0
        
        print(f"\n" + "="*60)
        print("HYPOTHESIS TEST RESULTS")
        print("="*60)
        print(f"\nBase Last Gain:     {base['last_gain']:.4f}x")
        print(f"Instruct Last Gain: {instruct['last_gain']:.4f}x")
        print(f"Ratio (Inst/Base):  {ratio:.2f}")
        
        # Verdict
        if base['last_expands'] and not instruct['last_expands']:
            print(f"\n" + "*"*60)
            print("*** HYPOTHESIS CONFIRMED! ***")
            print("*"*60)
            print("\nRLHF is proven to be a 'SAFETY BRAKE':")
            print("  - Base model EXPANDS (natural behavior)")
            print("  - Instruct model CONTRACTS (RLHF dampening)")
            print("\nThis connects to:")
            print("  - Paper #1: LLaMA shows highest uniformity")
            print("  - Paper #2: Chat templates cause negative correlations")
            print("  - Paper #3: RLHF causes energy dissipation at output")
            verdict = "CONFIRMED"
            
        elif not base['last_expands'] and not instruct['last_expands']:
            print(f"\n" + "!"*60)
            print("*** HYPOTHESIS REJECTED ***")
            print("!"*60)
            print("\nBOTH models contract!")
            print("  → Contraction is architectural, NOT RLHF-induced")
            print("  → LLaMA 3.1 differs from Mistral at architecture level")
            verdict = "REJECTED - Architectural"
            
        elif base['last_expands'] and instruct['last_expands']:
            print(f"\n" + "?"*60)
            print("*** UNEXPECTED: BOTH EXPAND ***")
            print("?"*60)
            print("\nRLHF does NOT cause contraction in this case")
            verdict = "UNEXPECTED - Both expand"
            
        else:
            print(f"\n" + "?"*60)
            print("*** UNEXPECTED: Base contracts, Instruct expands ***")
            print("?"*60)
            verdict = "UNEXPECTED - Reversed"
        
        rlhf_results['_hypothesis_test'] = {
            'base_expands': base['last_expands'],
            'instruct_expands': instruct['last_expands'],
            'ratio': ratio,
            'verdict': verdict
        }
else:
    print("\n⚠️ Need both models to evaluate hypothesis!")

In [None]:
# Cell 7: Visualization
import os

# Create Results directory (works in Colab!)
RESULTS_DIR = './Results'
os.makedirs(RESULTS_DIR, exist_ok=True)

if len(rlhf_results) >= 2:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    colors = {'llama3.1-8b-BASE': 'blue', 'llama3.1-8b-INSTRUCT': 'red'}
    
    # Panel 1: Layer-wise Gains
    ax1 = axes[0]
    for name, res in rlhf_results.items():
        if name.startswith('_'):
            continue
        gains = res['gains']
        ax1.plot(range(len(gains)), gains, 'o-', label=name, 
                 color=colors.get(name, 'gray'), alpha=0.7)
    ax1.axhline(y=1.0, color='black', linestyle='--', alpha=0.5)
    ax1.set_xlabel('Layer')
    ax1.set_ylabel('Gain')
    ax1.set_title('Layer-wise Gains\n(Base vs Instruct)')
    ax1.legend()
    ax1.set_ylim(0.4, 1.6)
    
    # Panel 2: Last Layer Comparison
    ax2 = axes[1]
    names = [n for n in rlhf_results.keys() if not n.startswith('_')]
    last_gains = [rlhf_results[n]['last_gain'] for n in names]
    bar_colors = [colors.get(n, 'gray') for n in names]
    
    # Add Mistral reference
    names.append('Mistral (ref)')
    last_gains.append(1.37)
    bar_colors.append('purple')
    
    bars = ax2.bar(range(len(names)), last_gains, color=bar_colors, alpha=0.7)
    ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.5)
    ax2.set_xticks(range(len(names)))
    ax2.set_xticklabels(names, rotation=15, ha='right')
    ax2.set_ylabel('Last Layer Gain')
    ax2.set_title('Last Layer Gain Comparison\n(RLHF Effect)')
    
    for bar, val in zip(bars, last_gains):
        ax2.annotate(f'{val:.2f}x', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                     ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    # Panel 3: Cumulative Energy
    ax3 = axes[2]
    for name, res in rlhf_results.items():
        if name.startswith('_'):
            continue
        gains = np.array(res['gains'])
        cumulative = np.exp(np.cumsum(np.log(gains + 1e-10)))
        ax3.semilogy(range(len(cumulative)), cumulative, 'o-', 
                     label=name, color=colors.get(name, 'gray'), alpha=0.7)
    ax3.set_xlabel('Layer')
    ax3.set_ylabel('Cumulative Energy')
    ax3.set_title('Cumulative Energy\n(Product of Gains)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save
    output_path = f'{RESULTS_DIR}/RLHF_safety_brake_test.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_path}")

In [None]:
# Cell 8: Save Results
import os

# Use same directory as visualization
RESULTS_DIR = './Results'
os.makedirs(RESULTS_DIR, exist_ok=True)

output = {
    'experiment': 'RLHF Safety Brake Test',
    'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'hypothesis': 'RLHF causes last-layer contraction (Safety Brake)',
    'models': rlhf_results,
    'references': {
        'mistral_7b': MISTRAL_REFERENCE,
        'previous_llama': PREVIOUS_LLAMA_RESULT
    }
}

output_path = f'{RESULTS_DIR}/RLHF_safety_brake_test_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2)
print(f"Saved: {output_path}")

# Auto-download for Colab
try:
    from google.colab import files
    files.download(output_path)
    files.download(f'{RESULTS_DIR}/RLHF_safety_brake_test.png')
    print("\n✅ Files downloaded!")
except:
    pass

print("\n" + "="*60)
print("RLHF SAFETY BRAKE TEST COMPLETE")
print("="*60)
print(f"\nOutput files:")
print(f"  - {output_path}")
print(f"  - {RESULTS_DIR}/RLHF_safety_brake_test.png")