# Experiment 7: Effective Noise at LM Head

**Problem:** LayerNorm, Residual, and Attention modify noise as it propagates through layers.

**Solution:** 
1. Measure "effective noise" that reaches the LM Head
2. Inject noise directly at LM Head input to test T* without LayerNorm interference

**Formula:** $T^* = T_{base} \cdot \sqrt{1 + \alpha}$

**Model structure:**
```
Input → [Layers + LN + Attn + Residual] → Hidden → LM Head → Logits
                                            ↑
                                    Measure/inject noise here
```

In [1]:
# ============================================================
# IMPORTS & UTILITIES
# ============================================================

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from contextlib import contextmanager
from datasets import load_dataset

# ============================================================
# UTILITY FUNCTIONS
# ============================================================

def set_seed(seed):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def extract_answer(text):
    """Extract numerical answer from model response."""
    # Try boxed format first
    boxed = re.findall(r'\\boxed\{([^}]+)\}', text)
    if boxed:
        nums = re.findall(r'-?[\d,]+\.?\d*', boxed[-1].replace(',', ''))
        if nums:
            return nums[-1]
    
    # Look for patterns like "answer is X" or "= X"
    patterns = [
        r'answer is[:\s]*\$?(-?[\d,]+\.?\d*)',
        r'=\s*\$?(-?[\d,]+\.?\d*)\s*$',
        r'\$(-?[\d,]+\.?\d*)\s*$',
    ]
    for pattern in patterns:
        match = re.search(pattern, text.replace(',', ''), re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1)
    
    # Fall back to last number
    nums = re.findall(r'-?[\d,]+\.?\d*', text.replace(',', ''))
    return nums[-1] if nums else ''

def extract_ground_truth(answer_text):
    """Extract ground truth from GSM8K answer format."""
    match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text.replace(',', ''))
    return match.group(1) if match else ''

def check_answer(pred, truth):
    """Check if predicted answer matches ground truth."""
    if not pred or not truth:
        return False
    try:
        return abs(float(pred) - float(truth)) < 1e-5
    except:
        return pred.strip() == truth.strip()

print("Imports and utilities loaded!")

Imports and utilities loaded!


In [2]:
# ============================================================
# LOAD MODEL
# ============================================================

SEED = 42
set_seed(SEED)

MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    device_map="auto"
)
model.eval()
print(f"Model loaded on {model.device}")
print(f"LM Head shape: {model.lm_head.weight.shape}")

Loading deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).
`torch_dtype` is deprecated! Use `dtype` instead!


Model loaded on cuda:0
LM Head shape: torch.Size([151936, 1536])


In [3]:
# ============================================================
# LOAD GSM8K
# ============================================================

gsm8k = load_dataset('gsm8k', 'main', split='test')
print(f"GSM8K loaded: {len(gsm8k)} examples")

# Prepare prompts
num_examples = 10
test_examples = [gsm8k[i] for i in range(num_examples)]

gsm_prompts = []
for ex in test_examples:
    messages = [{"role": "user", "content": f"Solve this step by step:\n{ex['question']}"}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    gsm_prompts.append(prompt)

print(f"Prepared {len(gsm_prompts)} prompts")

GSM8K loaded: 1319 examples
Prepared 10 prompts


## Step 1: Find T_base (Optimal Temperature for Clean Model)

Before testing noise, find the optimal temperature for the clean model.

In [None]:
# ============================================================
# FIND T_BASE
# ============================================================

print("Finding T_base (optimal T for clean model)...")
print("="*60)

baseline_temps = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
baseline_results = {}

for temp in baseline_temps:
    set_seed(SEED)
    correct = 0
    
    for ex in test_examples:
        ground_truth = extract_ground_truth(ex['answer'])
        messages = [{"role": "user", "content": f"Solve this step by step:\n{ex['question']}"}]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        
        # Generate with temperature (sampling)
        input_ids = inputs.input_ids.clone()
        for _ in range(512):
            with torch.no_grad():
                outputs = model(input_ids)
                logits = outputs.logits[0, -1, :].float()
            
            probs = F.softmax(logits / temp, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(model.device)], dim=1)
            
            if next_token == tokenizer.eos_token_id:
                break
        
        response = tokenizer.decode(input_ids[0, inputs.input_ids.shape[1]:], skip_special_tokens=True)
        pred = extract_answer(response)
        
        if check_answer(pred, ground_truth):
            correct += 1
    
    acc = correct / len(test_examples)
    baseline_results[temp] = acc
    print(f"T={temp}: {acc:.0%}")

T_base = max(baseline_results.keys(), key=lambda t: baseline_results[t])
print(f"\n>>> T_base = {T_base} (accuracy: {baseline_results[T_base]:.0%})")

In [None]:
# Plot T_base results
fig, ax = plt.subplots(figsize=(8, 5))
temps = list(baseline_results.keys())
accs = [baseline_results[t] for t in temps]

ax.plot(temps, accs, 'bo-', linewidth=2, markersize=10)
ax.axvline(x=T_base, color='red', linestyle='--', label=f'T_base={T_base}')
ax.axvline(x=1.0, color='gray', linestyle=':', alpha=0.5, label='T=1.0')
ax.set_xlabel('Temperature')
ax.set_ylabel('GSM8K Accuracy')
ax.set_title('Clean Model: Finding T_base')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig('t_base_optimization.png', dpi=150)
plt.show()

## Step 2: LM Head Hooks and Noise Injection

In [4]:
# ============================================================
# HOOKS FOR CAPTURING HIDDEN STATES
# ============================================================

class HiddenStateCapture:
    """Capture hidden states before LM head."""
    
    def __init__(self):
        self.hidden = None
        
    def hook(self, module, input, output):
        self.hidden = input[0].detach().clone()
        
    def clear(self):
        self.hidden = None


def get_hidden_before_lm_head(model, tokenizer, prompt):
    """Get hidden states right before LM head."""
    capture = HiddenStateCapture()
    handle = model.lm_head.register_forward_hook(capture.hook)
    
    try:
        inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        with torch.no_grad():
            outputs = model(inputs.input_ids)
        return capture.hidden, outputs.logits
    finally:
        handle.remove()


def get_logits_with_head_noise(model, tokenizer, prompt, noise_scale):
    """
    Get logits with noise injected ONLY at LM head input.
    Bypasses LayerNorm/Residual effects.
    
    noise_scale is RELATIVE to hidden std (like activation noise).
    """
    hidden, _ = get_hidden_before_lm_head(model, tokenizer, prompt)
    
    # Add RELATIVE noise (same as activation_noise_context)
    noise = torch.randn_like(hidden) * noise_scale * hidden.std()
    noisy_hidden = hidden + noise
    
    # Apply LM head manually
    with torch.no_grad():
        noisy_logits = model.lm_head(noisy_hidden)
        clean_logits = model.lm_head(hidden)
    
    return clean_logits[0, -1, :], noisy_logits[0, -1, :], hidden[0, -1, :]


# Test
print("Testing hooks...")
clean_logits, noisy_logits, hidden = get_logits_with_head_noise(model, tokenizer, gsm_prompts[0], noise_scale=0.1)
print(f"Hidden shape: {hidden.shape}")
print(f"Hidden std: {hidden.std():.4f}")
print(f"Clean logits std: {clean_logits.std():.4f}")
print(f"Noisy logits std: {noisy_logits.std():.4f}")
print(f"Diff std: {(noisy_logits - clean_logits).std():.4f}")
print("Hooks working!")


Testing hooks...
Hidden shape: torch.Size([1536])
Hidden std: 3.1193
Clean logits std: 1.7382
Noisy logits std: 1.7523
Diff std: 0.2439
Hooks working!


## Step 3: Measure Effective Noise at LM Head

When noise is injected throughout the model, how much actually reaches the LM head?

In [5]:
# ============================================================
# ACTIVATION NOISE CONTEXT
# ============================================================

@contextmanager
def activation_noise_context(model, noise_scale):
    """Add noise to activations throughout the model."""
    handles = []
    
    def make_hook(scale):
        def hook(module, args, kwargs):
            if len(args) > 0:
                x = args[0]
                noise = torch.randn_like(x) * scale * x.std()
                noisy_x = x + noise
                return (noisy_x,) + args[1:], kwargs
            return args, kwargs
        return hook
    
    try:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                handle = module.register_forward_pre_hook(make_hook(noise_scale), with_kwargs=True)
                handles.append(handle)
        yield model, handles
    finally:
        for handle in handles:
            handle.remove()


def get_hidden_with_noise(model, tokenizer, prompt, noise_scale):
    """Get hidden states when noise is injected throughout model."""
    capture = HiddenStateCapture()
    handle = model.lm_head.register_forward_hook(capture.hook)
    
    try:
        inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        with activation_noise_context(model, noise_scale) as (noisy_model, _):
            with torch.no_grad():
                outputs = noisy_model(inputs.input_ids)
        return capture.hidden, outputs.logits
    finally:
        handle.remove()

In [6]:
# ============================================================
# MEASURE EFFECTIVE NOISE
# ============================================================

noise_scales = [0.01, 0.02, 0.05, 0.1, 0.2]

print("Measuring effective noise at LM head input...")
print("="*70)

results_effective = []

for noise_scale in noise_scales:
    all_effective_stds = []
    all_effective_means = []
    all_preservation_ratios = []
    
    for prompt in gsm_prompts:
        # Get clean hidden
        clean_hidden, _ = get_hidden_before_lm_head(model, tokenizer, prompt)
        clean_h = clean_hidden[0, -1, :]
        clean_std = clean_h.std().item()
        clean_mean = clean_h.mean().item()
        
        # Get noisy samples
        for _ in range(3):
            noisy_hidden, _ = get_hidden_with_noise(model, tokenizer, prompt, noise_scale)
            noisy_h = noisy_hidden[0, -1, :]
            
            diff = noisy_h - clean_h
            effective_std = diff.std().item()
            effective_mean = diff.mean().item()
            
            all_effective_stds.append(effective_std)
            all_effective_means.append(effective_mean)
            
            if clean_std > 0:
                ratio = effective_std / (noise_scale * clean_std)
                all_preservation_ratios.append(ratio)
    
    mean_std = np.mean(all_effective_stds)
    std_std = np.std(all_effective_stds)
    mean_mean = np.mean(all_effective_means)
    std_mean = np.std(all_effective_means)
    mean_ratio = np.mean(all_preservation_ratios)
    
    results_effective.append({
        'noise_scale': noise_scale,
        'effective_std_mean': mean_std,
        'effective_std_std': std_std,
        'effective_mean_mean': mean_mean,
        'effective_mean_std': std_mean,
        'preservation_ratio': mean_ratio,
    })
    
    print(f"σ={noise_scale:.3f}:")
    print(f"  Effective std:  {mean_std:.4f} ± {std_std:.4f}")
    print(f"  Effective mean: {mean_mean:.6f} ± {std_mean:.6f}")
    print(f"  Preservation:   {mean_ratio:.1%}")

print("\n" + "="*70)
print("Note: If mean ≠ 0, there's a systematic bias in the noise.")
print("="*70)


Measuring effective noise at LM head input...
σ=0.010:
  Effective std:  0.3296 ± 0.1409
  Effective mean: 0.001380 ± 0.008933
  Preservation:   1077.5%
σ=0.020:
  Effective std:  0.5635 ± 0.1572
  Effective mean: 0.004007 ± 0.013124
  Preservation:   918.5%
σ=0.050:
  Effective std:  1.5490 ± 0.5410
  Effective mean: 0.008546 ± 0.022395
  Preservation:   1006.4%
σ=0.100:
  Effective std:  2.6211 ± 0.4483
  Effective mean: 0.024509 ± 0.043847
  Preservation:   851.6%
σ=0.200:
  Effective std:  3.7051 ± 0.3425
  Effective mean: 0.054292 ± 0.046791
  Preservation:   601.9%

Note: If mean ≠ 0, there's a systematic bias in the noise.


## Step 4: Test T* = T_base × √(1+α) with Direct Head Noise

In [7]:
# ============================================================
# T* VALIDATION FUNCTIONS
# ============================================================

def compute_alpha_at_head(clean_logits, noisy_logits):
    """Compute α from logits."""
    diff = noisy_logits - clean_logits
    tau_sq = clean_logits.var().item()
    sigma_sq = diff.var().item()
    alpha = sigma_sq / tau_sq if tau_sq > 0 else 0
    return alpha


def evaluate_temperature_head_noise(model, tokenizer, prompt, noise_std, temperatures, T_base):
    """
    Evaluate different temperatures with noise injected at LM head.
    Uses T_base for clean model comparison.
    """
    clean_logits, noisy_logits, hidden = get_logits_with_head_noise(
        model, tokenizer, prompt, noise_std
    )
    
    alpha = compute_alpha_at_head(clean_logits, noisy_logits)
    
    # Clean probs with T_base (not T=1!)
    clean_probs = F.softmax(clean_logits / T_base, dim=-1)
    correct_token = (clean_logits / T_base).argmax().item()
    
    results = {}
    for temp in temperatures:
        noisy_probs = F.softmax(noisy_logits / temp, dim=-1)
        kl_div = F.kl_div(noisy_probs.log(), clean_probs, reduction='sum').item()
        prob_correct = noisy_probs[correct_token].item()
        
        results[temp] = {
            'kl_div': kl_div,
            'prob_correct': prob_correct,
        }
    
    return alpha, results


In [10]:
# ============================================================
# TEST T* WITH T_BASE
# ============================================================
T_base = 0.8
print(f"Testing T* = T_base × √(1+α) with T_base = {T_base}")
print("="*70)

noise_stds = [0.05, 0.075, 0.1, 0.2, 0.5, 0.75]
# Temperature range around T_base
temperatures = [round(T_base * m, 2) for m in [0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.8, 2.0]]
temperatures = sorted(set(temperatures))  # Remove duplicates
print(f"Testing temperatures: {temperatures}")

results_7d = []

for noise_std in noise_stds:
    print(f"\nNoise std: {noise_std}")
    
    all_alphas = []
    all_best_t_kl = []
    all_best_t_prob = []
    
    for prompt in gsm_prompts:
        for _ in range(3):  # Multiple samples
            alpha, temp_results = evaluate_temperature_head_noise(
                model, tokenizer, prompt, noise_std, temperatures, T_base
            )
            
            best_t_kl = min(temperatures, key=lambda t: temp_results[t]['kl_div'])
            best_t_prob = max(temperatures, key=lambda t: temp_results[t]['prob_correct'])
            
            all_alphas.append(alpha)
            all_best_t_kl.append(best_t_kl)
            all_best_t_prob.append(best_t_prob)
    
    mean_alpha = np.mean(all_alphas)
    mean_best_t_kl = np.mean(all_best_t_kl)
    mean_best_t_prob = np.mean(all_best_t_prob)
    
    # T* = T_base × √(1+α)
    predicted_t_star = T_base * np.sqrt(1 + mean_alpha)
    
    results_7d.append({
        'noise_std': noise_std,
        'alpha': mean_alpha,
        'predicted_t_star': predicted_t_star,
        'best_t_kl': mean_best_t_kl,
        'best_t_prob': mean_best_t_prob,
    })
    
    print(f"  α = {mean_alpha:.4f}")
    print(f"  T* = {T_base} × √(1+{mean_alpha:.3f}) = {predicted_t_star:.3f}")
    print(f"  Best T (min KL) = {mean_best_t_kl:.2f}")
    print(f"  Best T (max prob) = {mean_best_t_prob:.2f}")

Testing T* = T_base × √(1+α) with T_base = 0.8
Testing temperatures: [0.56, 0.64, 0.72, 0.8, 0.88, 0.96, 1.04, 1.12, 1.2, 1.28, 1.44, 1.6]

Noise std: 0.05
  α = 0.0058
  T* = 0.8 × √(1+0.006) = 0.802
  Best T (min KL) = 0.78
  Best T (max prob) = 0.56

Noise std: 0.075
  α = 0.0129
  T* = 0.8 × √(1+0.013) = 0.805
  Best T (min KL) = 0.78
  Best T (max prob) = 0.56

Noise std: 0.1
  α = 0.0229
  T* = 0.8 × √(1+0.023) = 0.809
  Best T (min KL) = 0.78
  Best T (max prob) = 0.56

Noise std: 0.2
  α = 0.0926
  T* = 0.8 × √(1+0.093) = 0.836
  Best T (min KL) = 0.81
  Best T (max prob) = 0.58

Noise std: 0.5
  α = 0.5757
  T* = 0.8 × √(1+0.576) = 1.004
  Best T (min KL) = 0.96
  Best T (max prob) = 0.67

Noise std: 0.75
  α = 1.3186
  T* = 0.8 × √(1+1.319) = 1.218
  Best T (min KL) = 1.13
  Best T (max prob) = 0.93


In [None]:
# ============================================================
# SUMMARY TABLE
# ============================================================

print("\n" + "="*70)
print(f"SUMMARY: T* = T_base × √(1+α) with T_base = {T_base}")
print("="*70)
print(f"{'Noise σ':<10} {'α':<10} {'T* pred':<10} {'Best T(KL)':<12} {'Best T(prob)':<12} {'Match?'}")
print("-"*70)

for r in results_7d:
    match = "✓" if abs(r['predicted_t_star'] - r['best_t_kl']) < 0.15 else "✗"
    print(f"{r['noise_std']:<10.3f} {r['alpha']:<10.4f} {r['predicted_t_star']:<10.3f} "
          f"{r['best_t_kl']:<12.2f} {r['best_t_prob']:<12.2f} {match}")

In [None]:
# ============================================================
# PLOT RESULTS
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

noise_stds_plot = [r['noise_std'] for r in results_7d]
predicted = [r['predicted_t_star'] for r in results_7d]
best_kl = [r['best_t_kl'] for r in results_7d]
best_prob = [r['best_t_prob'] for r in results_7d]
alphas = [r['alpha'] for r in results_7d]

# Left: T* comparison
ax1 = axes[0]
ax1.plot(noise_stds_plot, predicted, 'b-o', linewidth=2, markersize=10, label=f'T* = {T_base}×√(1+α)')
ax1.plot(noise_stds_plot, best_kl, 'r--s', linewidth=2, markersize=8, label='Best T (min KL)')
ax1.plot(noise_stds_plot, best_prob, 'g--^', linewidth=2, markersize=8, label='Best T (max prob)')
ax1.axhline(y=T_base, color='gray', linestyle=':', alpha=0.5, label=f'T_base={T_base}')
ax1.set_xlabel('Noise Std at LM Head')
ax1.set_ylabel('Temperature')
ax1.set_title('Predicted vs Actual Optimal Temperature')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: α vs noise
ax2 = axes[1]
ax2.plot(noise_stds_plot, alphas, 'bo-', linewidth=2, markersize=10)
ax2.set_xlabel('Noise Std at LM Head')
ax2.set_ylabel('α (noise-to-signal ratio)')
ax2.set_title('α vs Injected Noise')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('t_star_validation_with_tbase.png', dpi=150)
plt.show()

## Summary

**Key Findings:**

1. **T_base found:** Optimal temperature for clean model
2. **Effective noise measured:** How much noise survives LayerNorm/Residual
3. **T* formula tested:** $T^* = T_{base} \times \sqrt{1 + \alpha}$ with direct head noise