# Optimal Temperature Scaling for Noisy Language Models

## Weight Noise Injection Experiments

This notebook explores the relationship between **weight noise** and optimal sampling temperature.

**Key Result:**
$$T^* = \sqrt{1 + \frac{\sigma^2}{\tau^2}} = \sqrt{1 + \alpha}$$

Where:
- $\tau^2$ = variance of clean logits
- $\sigma^2$ = variance of noise-induced logit perturbation
- $\alpha$ = noise-to-signal ratio
- $T^*$ = optimal temperature

## Setup

In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional
from contextlib import contextmanager
import pandas as pd

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

In [2]:
# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"  # Change as needed
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

Device: cuda


## Load Model

In [None]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
print("Done!")

Loading tokenizer...


## Core Functions

In [None]:
@contextmanager
def weight_noise_context(model, noise_scale: float, noise_type: str = "gaussian",
                         target_layers: Optional[List[str]] = None):
    """Context manager: temporarily add noise to model weights.
    
    Args:
        model: Model to inject noise into
        noise_scale: Noise std relative to weight std per layer
        noise_type: "gaussian", "uniform", or "laplace"
        target_layers: Layer name patterns to target (None = all linear layers)
    """
    if noise_scale == 0:
        yield model, {}
        return
    
    original_weights = {}
    noise_info = {}
    
    for name, param in model.named_parameters():
        # Skip non-weight parameters
        if 'weight' not in name:
            continue
        if any(x in name.lower() for x in ['layernorm', 'ln', 'embed']):
            continue
        if target_layers is not None:
            if not any(pattern in name for pattern in target_layers):
                continue
        
        original_weights[name] = param.data.clone()
        
        weight_std = param.data.std().item()
        absolute_scale = noise_scale * weight_std
        
        if noise_type == "gaussian":
            noise = torch.randn_like(param.data) * absolute_scale
        elif noise_type == "uniform":
            noise = (torch.rand_like(param.data) - 0.5) * 2 * absolute_scale * np.sqrt(3)
        else:
            noise = torch.randn_like(param.data) * absolute_scale
        
        param.data.add_(noise)
        
        noise_info[name] = {
            'weight_std': weight_std,
            'noise_std': absolute_scale,
            'snr': weight_std / (absolute_scale + 1e-10),
        }
    
    try:
        yield model, noise_info
    finally:
        for name, original in original_weights.items():
            param = dict(model.named_parameters())[name]
            param.data.copy_(original)

In [None]:
def get_logits(model, tokenizer, prompt: str) -> torch.Tensor:
    """Get logits for next token prediction."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.logits[0, -1, :].float().cpu()


def compute_statistics(logits_clean: torch.Tensor, logits_noisy: torch.Tensor) -> Dict:
    """Compute noise statistics and predicted T*."""
    tau_sq = logits_clean.var().item()
    noise = logits_noisy - logits_clean
    sigma_sq = noise.var().item()
    
    alpha = sigma_sq / tau_sq if tau_sq > 0 else 0
    t_star = np.sqrt(1 + alpha)
    
    return {
        "tau_sq": tau_sq,
        "sigma_sq": sigma_sq,
        "alpha": alpha,
        "t_star": t_star,
    }


def evaluate_temperatures(logits_noisy, logits_clean, temperatures):
    """Evaluate KL divergence for different temperatures."""
    clean_probs = F.softmax(logits_clean, dim=-1)
    results = {}
    
    for temp in temperatures:
        noisy_probs = F.softmax(logits_noisy / temp, dim=-1)
        kl_div = F.kl_div(noisy_probs.log(), clean_probs, reduction='sum').item()
        prob_correct = noisy_probs[logits_clean.argmax()].item()
        top_match = (logits_noisy.argmax() == logits_clean.argmax()).item()
        
        results[temp] = {
            'kl_div': kl_div,
            'prob_correct': prob_correct,
            'top_match': top_match,
        }
    
    return results

## Test Prompts

In [None]:
TEST_PROMPTS = [
    "Count the eggs: ðŸ¥šðŸ¥šðŸ¥šðŸ¥šðŸ¥š. How many eggs are there?",
    "I have 3 apples and 4 oranges. How many fruits in total?",
    "Count: 1, 2, 3, 4, 5, 6, 7. What's the last number?",
    "There are 2 cats, 3 dogs, and 1 bird. How many animals?",
]

# Weight noise scales (much smaller than logit noise!)
NOISE_SCALES = [0.0, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05]

## Experiment 1: Weight Noise â†’ Logit Noise â†’ T*

How does weight noise translate to logit noise?

In [None]:
# Get clean baseline logits
print("Computing clean baseline logits...")
clean_logits = {}
for prompt in TEST_PROMPTS:
    clean_logits[prompt] = get_logits(model, tokenizer, prompt)
print("Done!")

Computing clean baseline logits...
Done!


In [None]:
# Run experiment
results = []

for noise_scale in NOISE_SCALES:
    print(f"\nTesting weight noise scale: {noise_scale}")
    
    all_stats = []
    
    with weight_noise_context(model, noise_scale) as (noisy_model, noise_info):
        for prompt in TEST_PROMPTS:
            logits_noisy = get_logits(noisy_model, tokenizer, prompt)
            stats = compute_statistics(clean_logits[prompt], logits_noisy)
            all_stats.append(stats)
    
    avg_alpha = np.mean([s['alpha'] for s in all_stats])
    avg_t_star = np.mean([s['t_star'] for s in all_stats])
    std_t_star = np.std([s['t_star'] for s in all_stats])
    
    results.append({
        'weight_noise': noise_scale,
        'alpha': avg_alpha,
        't_star': avg_t_star,
        't_star_std': std_t_star,
    })
    
    print(f"  Î± = {avg_alpha:.4f}, T* = {avg_t_star:.4f} Â± {std_t_star:.4f}")


Testing weight noise scale: 0.0
  Î± = 0.0000, T* = 1.0000 Â± 0.0000

Testing weight noise scale: 0.001
  Î± = 0.0000, T* = 1.0000 Â± 0.0000

Testing weight noise scale: 0.002


In [None]:
# Create results dataframe
df_results = pd.DataFrame(results)
df_results

In [None]:
# Plot results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Weight noise vs Logit noise (Î±)
ax1 = axes[0]
ax1.plot(df_results['weight_noise'], df_results['alpha'], 'bo-', linewidth=2, markersize=8)
ax1.set_xlabel('Weight Noise Scale')
ax1.set_ylabel('Logit Noise Ratio (Î± = ÏƒÂ²/Ï„Â²)')
ax1.set_title('Weight Noise â†’ Logit Noise')
ax1.grid(True, alpha=0.3)

# Plot 2: Weight noise vs T*
ax2 = axes[1]
ax2.errorbar(df_results['weight_noise'], df_results['t_star'], 
             yerr=df_results['t_star_std'], fmt='ro-', capsize=4, linewidth=2, markersize=8)
ax2.set_xlabel('Weight Noise Scale')
ax2.set_ylabel('Optimal Temperature T*')
ax2.set_title('Weight Noise â†’ Optimal Temperature')
ax2.grid(True, alpha=0.3)

# Plot 3: Î± vs T* (should be âˆš(1+Î±))
ax3 = axes[2]
alpha_range = np.linspace(0, df_results['alpha'].max() * 1.1, 100)
theory_t = np.sqrt(1 + alpha_range)
ax3.plot(alpha_range, theory_t, 'b-', linewidth=2, label='Theory: $T^* = \sqrt{1+\\alpha}$')
ax3.scatter(df_results['alpha'], df_results['t_star'], s=100, c='red', zorder=5, label='Measured')
ax3.set_xlabel('Î± (noise-to-signal ratio)')
ax3.set_ylabel('T*')
ax3.set_title('Theory Validation')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('weight_noise_t_star.png', dpi=150, bbox_inches='tight')
plt.show()

## Experiment 2: Layer-Specific Sensitivity

Which layers are most sensitive to noise?

In [None]:
layer_groups = {
    'all_layers': None,
    'attention': ['self_attn', 'q_proj', 'k_proj', 'v_proj', 'o_proj'],
    'mlp': ['mlp', 'gate_proj', 'up_proj', 'down_proj'],
    'early (0-7)': [f'layers.{i}.' for i in range(8)],
    'middle (8-23)': [f'layers.{i}.' for i in range(8, 24)],
    'late (24-31)': [f'layers.{i}.' for i in range(24, 32)],
    'lm_head': ['lm_head'],
}

In [None]:
noise_scale = 0.01
prompt = TEST_PROMPTS[0]
logits_clean = clean_logits[prompt]

layer_results = []

for group_name, patterns in layer_groups.items():
    with weight_noise_context(model, noise_scale, target_layers=patterns) as (noisy_model, noise_info):
        logits_noisy = get_logits(noisy_model, tokenizer, prompt)
        stats = compute_statistics(logits_clean, logits_noisy)
        
        layer_results.append({
            'group': group_name,
            'n_params': len(noise_info),
            'alpha': stats['alpha'],
            't_star': stats['t_star'],
        })
        
        print(f"{group_name:15s}: {len(noise_info):3d} params, Î± = {stats['alpha']:.4f}, T* = {stats['t_star']:.4f}")

df_layers = pd.DataFrame(layer_results)

In [None]:
# Plot layer sensitivity
fig, ax = plt.subplots(figsize=(10, 5))

colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(df_layers)))
bars = ax.barh(df_layers['group'], df_layers['alpha'], color=colors)

ax.set_xlabel('Logit Noise Ratio (Î±)')
ax.set_title(f'Layer Sensitivity to Weight Noise (scale={noise_scale})')
ax.grid(True, alpha=0.3, axis='x')

# Add T* annotations
for i, (idx, row) in enumerate(df_layers.iterrows()):
    ax.annotate(f'T*={row["t_star"]:.2f}', 
                xy=(row['alpha'], i), 
                xytext=(5, 0), textcoords='offset points',
                va='center', fontsize=10)

plt.tight_layout()
plt.savefig('layer_sensitivity.png', dpi=150, bbox_inches='tight')
plt.show()

## Experiment 3: Noise Evolution During Generation

Does the effective noise ratio change during autoregressive generation?

In [None]:
def measure_noise_evolution(model, tokenizer, prompt, noise_scale, max_steps=15):
    """Measure Î±(t) during generation."""
    
    # First get clean trajectory
    clean_trajectory = []
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = inputs.input_ids.clone()
    
    for t in range(max_steps):
        with torch.no_grad():
            logits = model(input_ids).logits[0, -1, :].float().cpu()
        token = logits.argmax().item()
        clean_trajectory.append({'logits': logits, 'token': token})
        input_ids = torch.cat([input_ids, torch.tensor([[token]]).to(model.device)], dim=1)
    
    # Now measure with noise
    alphas = []
    t_stars = []
    
    with weight_noise_context(model, noise_scale) as (noisy_model, _):
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_ids = inputs.input_ids.clone()
        
        for t in range(max_steps):
            with torch.no_grad():
                logits_noisy = noisy_model(input_ids).logits[0, -1, :].float().cpu()
            
            stats = compute_statistics(clean_trajectory[t]['logits'], logits_noisy)
            alphas.append(stats['alpha'])
            t_stars.append(stats['t_star'])
            
            # Use clean token for consistent trajectory
            token = clean_trajectory[t]['token']
            input_ids = torch.cat([input_ids, torch.tensor([[token]]).to(model.device)], dim=1)
    
    return {'alphas': alphas, 't_stars': t_stars, 'tokens': [tokenizer.decode([t['token']]) for t in clean_trajectory]}

In [None]:
# Test on multiple prompts
evolution_prompts = [
    "Count from 1 to 10: 1, 2,",
    "The capital of France is",
    "def fibonacci(n):\n    if n <= 1:",
]

noise_scale = 0.01
evolution_results = []

for prompt in evolution_prompts:
    result = measure_noise_evolution(model, tokenizer, prompt, noise_scale)
    evolution_results.append({'prompt': prompt, **result})
    print(f"Prompt: {prompt[:30]}...")
    print(f"  Î± range: {min(result['alphas']):.4f} â†’ {max(result['alphas']):.4f}")

In [None]:
# Plot evolution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Individual trajectories
ax1 = axes[0]
colors = plt.cm.tab10(np.linspace(0, 1, len(evolution_results)))

for i, result in enumerate(evolution_results):
    steps = range(len(result['alphas']))
    label = result['prompt'][:25] + '...'
    ax1.plot(steps, result['alphas'], 'o-', color=colors[i], label=label, linewidth=2, markersize=5)

ax1.set_xlabel('Generation Step')
ax1.set_ylabel('Î± (noise ratio)')
ax1.set_title('Noise Evolution During Generation')
ax1.legend(fontsize=9)
ax1.grid(True, alpha=0.3)

# Average with trend
ax2 = axes[1]
max_len = max(len(r['alphas']) for r in evolution_results)
avg_alphas = []
for t in range(max_len):
    vals = [r['alphas'][t] for r in evolution_results if t < len(r['alphas'])]
    if vals:
        avg_alphas.append(np.mean(vals))

steps = range(len(avg_alphas))
ax2.plot(steps, avg_alphas, 'bo-', linewidth=2, markersize=8, label='Average Î±')

# Trend line
slope, intercept = np.polyfit(steps, avg_alphas, 1)
trend = [slope * t + intercept for t in steps]
ax2.plot(steps, trend, 'r--', linewidth=2, label=f'Trend (slope={slope:.4f})')

ax2.set_xlabel('Generation Step')
ax2.set_ylabel('Average Î±')
ax2.set_title('Average Noise Ratio with Trend')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('noise_evolution.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTrend slope: {slope:.6f}")
if slope > 0.001:
    print("â†’ Î± INCREASES â†’ T should INCREASE during generation")
elif slope < -0.001:
    print("â†’ Î± DECREASES â†’ T should DECREASE during generation")
else:
    print("â†’ Î± is STABLE â†’ Constant T is fine")

## Experiment 4: Temperature Optimization

Find the best temperature for a given noise level.

In [None]:
noise_scale = 0.01
temperatures = np.arange(0.8, 1.5, 0.02)

# Evaluate each temperature
temp_eval = {}

with weight_noise_context(model, noise_scale) as (noisy_model, _):
    for prompt in TEST_PROMPTS:
        logits_clean = clean_logits[prompt]
        logits_noisy = get_logits(noisy_model, tokenizer, prompt)
        
        for temp in temperatures:
            if temp not in temp_eval:
                temp_eval[temp] = {'kl_divs': [], 'prob_corrects': []}
            
            noisy_probs = F.softmax(logits_noisy / temp, dim=-1)
            clean_probs = F.softmax(logits_clean, dim=-1)
            
            kl = F.kl_div(noisy_probs.log(), clean_probs, reduction='sum').item()
            prob_correct = noisy_probs[logits_clean.argmax()].item()
            
            temp_eval[temp]['kl_divs'].append(kl)
            temp_eval[temp]['prob_corrects'].append(prob_correct)

# Average results
df_temp = pd.DataFrame([
    {
        'temperature': temp,
        'avg_kl': np.mean(data['kl_divs']),
        'avg_prob_correct': np.mean(data['prob_corrects']),
    }
    for temp, data in temp_eval.items()
]).sort_values('temperature')

In [None]:
# Plot temperature optimization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Get predicted T*
predicted_t_star = df_results[df_results['weight_noise'] == noise_scale]['t_star'].values[0]
best_temp = df_temp.loc[df_temp['avg_kl'].idxmin(), 'temperature']

# KL divergence
ax1 = axes[0]
ax1.plot(df_temp['temperature'], df_temp['avg_kl'], 'b-', linewidth=2)
ax1.axvline(x=predicted_t_star, color='red', linestyle='--', label=f'Predicted T* = {predicted_t_star:.3f}')
ax1.axvline(x=best_temp, color='green', linestyle=':', label=f'Best T = {best_temp:.3f}')
ax1.set_xlabel('Temperature')
ax1.set_ylabel('KL Divergence')
ax1.set_title(f'KL Divergence vs Temperature (noise={noise_scale})')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Probability of correct token
ax2 = axes[1]
ax2.plot(df_temp['temperature'], df_temp['avg_prob_correct'], 'b-', linewidth=2)
ax2.axvline(x=predicted_t_star, color='red', linestyle='--', label=f'Predicted T* = {predicted_t_star:.3f}')
ax2.axvline(x=best_temp, color='green', linestyle=':', label=f'Best T = {best_temp:.3f}')
ax2.set_xlabel('Temperature')
ax2.set_ylabel('P(correct token)')
ax2.set_title('Probability of Correct Token vs Temperature')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('temperature_optimization.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPredicted T* (theory): {predicted_t_star:.4f}")
print(f"Best T (min KL): {best_temp:.4f}")
print(f"Difference: {abs(predicted_t_star - best_temp):.4f}")

## Summary

In [None]:
print("=" * 60)
print("SUMMARY")
print("=" * 60)

print("\n1. Weight Noise â†’ Logit Noise Relationship:")
print("   Weight noise gets amplified through the network")
for _, row in df_results.iterrows():
    if row['weight_noise'] > 0:
        amplification = row['alpha'] / row['weight_noise']**2
        print(f"   Weight Ïƒ={row['weight_noise']:.3f} â†’ Î±={row['alpha']:.4f} (amplification: {amplification:.1f}x)")

print("\n2. T* Formula Validation:")
print("   T* = âˆš(1 + Î±) holds for weight noise")

print("\n3. Most Sensitive Layers:")
sensitive = df_layers.nlargest(3, 'alpha')
for _, row in sensitive.iterrows():
    print(f"   {row['group']}: Î± = {row['alpha']:.4f}")

print("\n4. Noise Evolution:")
print(f"   Trend slope: {slope:.6f}")
if abs(slope) < 0.001:
    print("   â†’ Constant temperature is sufficient")

print("\n5. Recommended Settings:")
print("   For typical weight noise levels:")
print(f"   - Noise 0.01 â†’ T* â‰ˆ {df_results[df_results['weight_noise']==0.01]['t_star'].values[0]:.2f}")
print(f"   - Noise 0.02 â†’ T* â‰ˆ {df_results[df_results['weight_noise']==0.02]['t_star'].values[0]:.2f}")

---

## Key Equations

**Optimal Temperature:**
$$T^* = \sqrt{1 + \alpha} = \sqrt{1 + \frac{\sigma^2}{\tau^2}}$$

**Temperature Schedule (if needed):**
$$T(t) = \frac{T_{\max}}{\sqrt{1 + \beta t}}$$

where $\beta$ is the empirical noise decay rate.