# 🔬 LLM Sensitivity & Super-Weight Prioritization Analysis

**Research Objective**: Identify and analyze critical weights in large language models through gradient-based sensitivity metrics.

## 📋 Analysis Pipeline:
1. **Model Loading**: Load target model with proper tokenization
2. **Baseline Evaluation**: Compute initial perplexity
3. **Sensitivity Computation**: Calculate gradient-based metrics
4. **Weight Ranking**: Identify top-K critical weights
5. **Perturbation Analysis**: Mask weights and evaluate impact
6. **Results Export**: Save findings for further analysis

## 🎯 Key Metrics:
- **Gradient × Weight**: `|grad × weight|`
- **Gradient²**: `grad²` (curvature-based sensitivity)
- **Top-K Analysis**: Most critical weights by layer
- **Impact Assessment**: Perplexity change after masking

In [None]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm

# Project imports
sys.path.append('../src')
from models.loader import load_model
from eval.perplexity import compute_perplexity
from sensitivity.metrics import compute_sensitivity, get_model_layers
from sensitivity.rank import rank_weights_topk

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("🔬 Critical Weight Analysis - Research Notebook")
print(f"📅 Session: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"🔧 PyTorch: {torch.__version__}")
print(f"🎯 CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🚀 GPU: {torch.cuda.get_device_name()}")
    print(f"💾 Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## ⚙️ Configuration

In [None]:
# Experiment configuration
CONFIG = {
    # Model settings
    'model_name': 'gpt2',  # Start with GPT-2, can change to EleutherAI/pythia-410m
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Analysis settings
    'sensitivity_metrics': ['grad_weight', 'grad_squared'],
    'topk_values': [10, 50, 100, 500],  # Top-K weights to analyze
    'perturbation_ratios': [0.1, 0.25, 0.5, 0.75, 1.0],  # Fraction of top-K to mask
    
    # Data settings
    'eval_texts_limit': 100,  # Number of evaluation texts
    'batch_size': 1,  # Individual processing for accuracy
    
    # Output settings
    'save_results': True,
    'output_dir': '../outputs',
    'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
}

print("📊 Experiment Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Create output directory
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(exist_ok=True)
session_dir = output_dir / f"sensitivity_analysis_{CONFIG['timestamp']}"
session_dir.mkdir(exist_ok=True)
print(f"\n💾 Results will be saved to: {session_dir}")

## 📚 Data Loading

In [None]:
# Load evaluation texts
data_file = Path('../src/data/dev_small.txt')
if data_file.exists():
    with open(data_file, 'r') as f:
        all_texts = [line.strip() for line in f if line.strip()]
else:
    # Fallback: create sample texts
    all_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Machine learning models require careful optimization.",
        "Natural language processing has advanced significantly.",
        "Deep neural networks learn complex patterns.",
        "Gradient-based methods are fundamental to training."
    ] * 20  # Repeat to get more samples

# Limit texts for efficiency
eval_texts = all_texts[:CONFIG['eval_texts_limit']]

print(f"📚 Loaded {len(eval_texts)} evaluation texts")
print(f"📝 Sample text: {eval_texts[0][:100]}...")
print(f"📏 Average length: {np.mean([len(text.split()) for text in eval_texts]):.1f} words")

## 🤖 Model Loading

In [None]:
print(f"🔄 Loading model: {CONFIG['model_name']}")
model, tokenizer = load_model(CONFIG['model_name'], device=CONFIG['device'])

print(f"✅ Model: {type(model).__name__}")
print(f"✅ Tokenizer: {type(tokenizer).__name__}")
print(f"✅ Device: {next(model.parameters()).device}")
print(f"📊 Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"🔍 Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Analyze model structure
layers = get_model_layers(model)
print(f"\n🏗️ Model Architecture:")
print(f"  Total layers: {len(layers)}")
for i, layer_name in enumerate(layers[:5]):  # Show first 5
    layer = dict(model.named_modules())[layer_name]
    if hasattr(layer, 'weight') and layer.weight is not None:
        print(f"  [{i}] {layer_name}: {tuple(layer.weight.shape)}")
if len(layers) > 5:
    print(f"  ... and {len(layers) - 5} more layers")

## 📏 Baseline Evaluation

In [None]:
print("📏 Computing baseline perplexity...")
baseline_ppl = compute_perplexity(model, tokenizer, eval_texts)

print(f"\n📊 Baseline Results:")
print(f"  Perplexity: {baseline_ppl:.2f}")
print(f"  Log Perplexity: {np.log(baseline_ppl):.4f}")
print(f"  Evaluation texts: {len(eval_texts)}")

# Store baseline for comparison
results = {
    'baseline_perplexity': baseline_ppl,
    'model_name': CONFIG['model_name'],
    'num_parameters': sum(p.numel() for p in model.parameters()),
    'eval_texts_count': len(eval_texts)
}

print(f"\n✅ Baseline established: PPL = {baseline_ppl:.2f}")

## 🧮 Sensitivity Analysis

In [None]:
print("🧮 Computing gradient-based sensitivity metrics...")
print(f"📊 Metrics: {', '.join(CONFIG['sensitivity_metrics'])}")

# Compute sensitivity for each metric
sensitivity_results = {}

for metric in CONFIG['sensitivity_metrics']:
    print(f"\n🔄 Computing {metric} sensitivity...")
    
    # Calculate sensitivity
    layer_sensitivities = compute_sensitivity(
        model, tokenizer, eval_texts, 
        method=metric,
        device=CONFIG['device']
    )
    
    sensitivity_results[metric] = layer_sensitivities
    
    # Summary statistics
    total_weights = sum(len(sens) for sens in layer_sensitivities.values())
    all_sensitivities = np.concatenate([sens.cpu().numpy().flatten() 
                                       for sens in layer_sensitivities.values()])
    
    print(f"  ✅ Computed for {len(layer_sensitivities)} layers")
    print(f"  📊 Total weights: {total_weights:,}")
    print(f"  📈 Mean sensitivity: {all_sensitivities.mean():.6f}")
    print(f"  📉 Std sensitivity: {all_sensitivities.std():.6f}")
    print(f"  🔝 Max sensitivity: {all_sensitivities.max():.6f}")

print(f"\n✅ Sensitivity analysis complete for {len(CONFIG['sensitivity_metrics'])} metrics")

## 🏆 Weight Ranking & Top-K Analysis

In [None]:
print("🏆 Ranking weights by sensitivity...")

# Store ranking results
ranking_results = {}

for metric in CONFIG['sensitivity_metrics']:
    print(f"\n🔄 Ranking for {metric}...")
    
    layer_sensitivities = sensitivity_results[metric]
    
    # Rank weights globally
    topk_results = {}
    
    for k in CONFIG['topk_values']:
        top_weights = rank_weights_topk(layer_sensitivities, k=k)
        topk_results[k] = top_weights
        
        # Analyze distribution across layers
        layer_counts = {}
        for weight_info in top_weights:
            layer = weight_info['layer']
            layer_counts[layer] = layer_counts.get(layer, 0) + 1
        
        print(f"  Top-{k}: {len(top_weights)} weights across {len(layer_counts)} layers")
        
        # Show top layers
        if layer_counts:
            top_layers = sorted(layer_counts.items(), key=lambda x: x[1], reverse=True)[:3]
            print(f"    Most critical layers: {', '.join([f'{layer}({count})' for layer, count in top_layers])}")
    
    ranking_results[metric] = topk_results

print(f"\n✅ Weight ranking complete for {len(CONFIG['sensitivity_metrics'])} metrics")

## 📊 Sensitivity Distribution Visualization

In [None]:
# Create visualization of sensitivity distributions
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('🔬 Sensitivity Analysis Results', fontsize=16, fontweight='bold')

for i, metric in enumerate(CONFIG['sensitivity_metrics']):
    layer_sensitivities = sensitivity_results[metric]
    
    # Plot 1: Distribution histogram
    ax = axes[i, 0]
    all_sens = np.concatenate([sens.cpu().numpy().flatten() 
                              for sens in layer_sensitivities.values()])
    
    # Use log scale for better visualization
    log_sens = np.log10(all_sens + 1e-12)  # Add small epsilon to avoid log(0)
    ax.hist(log_sens, bins=50, alpha=0.7, edgecolor='black')
    ax.set_title(f'{metric.replace("_", " ").title()} Distribution (Log Scale)')
    ax.set_xlabel('Log10(Sensitivity)')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Layer-wise sensitivity
    ax = axes[i, 1]
    layer_means = [layer_sensitivities[layer].mean().item() 
                   for layer in sorted(layer_sensitivities.keys())]
    layer_names = sorted(layer_sensitivities.keys())
    
    # Show only every nth layer for readability
    step = max(1, len(layer_names) // 10)
    x_pos = range(0, len(layer_names), step)
    x_labels = [layer_names[i].split('.')[-1] for i in x_pos]  # Short names
    y_values = [layer_means[i] for i in x_pos]
    
    ax.bar(x_pos, y_values, alpha=0.7)
    ax.set_title(f'{metric.replace("_", " ").title()} by Layer')
    ax.set_xlabel('Layer')
    ax.set_ylabel('Mean Sensitivity')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(x_labels, rotation=45, ha='right')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(session_dir / 'sensitivity_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"📊 Sensitivity distribution plots saved to {session_dir / 'sensitivity_distributions.png'}")

## 🎯 Perturbation Analysis

In [None]:
print("🎯 Starting perturbation analysis...")
print(f"📊 Testing top-K values: {CONFIG['topk_values']}")
print(f"📊 Masking ratios: {CONFIG['perturbation_ratios']}")

perturbation_results = []

# Test different metrics and top-K values
for metric in CONFIG['sensitivity_metrics']:
    print(f"\n🔄 Perturbation analysis for {metric}...")
    
    for k in CONFIG['topk_values']:
        top_weights = ranking_results[metric][k]
        print(f"\n  📊 Top-{k} weights ({len(top_weights)} total)")
        
        for ratio in CONFIG['perturbation_ratios']:
            # Calculate how many weights to mask
            num_to_mask = int(len(top_weights) * ratio)
            if num_to_mask == 0:
                continue
                
            weights_to_mask = top_weights[:num_to_mask]
            
            print(f"    🎯 Masking {num_to_mask} weights ({ratio*100:.0f}%)...", end=" ")
            
            # Store original values
            original_values = {}
            
            # Apply masking
            try:
                for weight_info in weights_to_mask:
                    layer_name = weight_info['layer']
                    indices = weight_info['indices']
                    
                    # Get the layer module
                    layer = dict(model.named_modules())[layer_name]
                    if hasattr(layer, 'weight') and layer.weight is not None:
                        # Store original value
                        key = (layer_name, tuple(indices))
                        original_values[key] = layer.weight.data[tuple(indices)].clone()
                        
                        # Mask the weight (set to zero)
                        layer.weight.data[tuple(indices)] = 0.0
                
                # Evaluate with masked weights
                perturbed_ppl = compute_perplexity(model, tokenizer, eval_texts[:20])  # Use subset for speed
                
                # Restore original values
                for weight_info in weights_to_mask:
                    layer_name = weight_info['layer']
                    indices = weight_info['indices']
                    
                    layer = dict(model.named_modules())[layer_name]
                    if hasattr(layer, 'weight') and layer.weight is not None:
                        key = (layer_name, tuple(indices))
                        if key in original_values:
                            layer.weight.data[tuple(indices)] = original_values[key]
                
                # Calculate impact
                ppl_increase = perturbed_ppl - baseline_ppl
                ppl_ratio = perturbed_ppl / baseline_ppl
                
                print(f"PPL: {perturbed_ppl:.2f} (+{ppl_increase:.2f}, {ppl_ratio:.2f}x)")
                
                # Store results
                perturbation_results.append({
                    'metric': metric,
                    'topk': k,
                    'mask_ratio': ratio,
                    'weights_masked': num_to_mask,
                    'baseline_ppl': baseline_ppl,
                    'perturbed_ppl': perturbed_ppl,
                    'ppl_increase': ppl_increase,
                    'ppl_ratio': ppl_ratio
                })
                
            except Exception as e:
                print(f"Error: {e}")
                # Restore weights even if there was an error
                for weight_info in weights_to_mask:
                    layer_name = weight_info['layer']
                    indices = weight_info['indices']
                    
                    layer = dict(model.named_modules())[layer_name]
                    if hasattr(layer, 'weight') and layer.weight is not None:
                        key = (layer_name, tuple(indices))
                        if key in original_values:
                            layer.weight.data[tuple(indices)] = original_values[key]

print(f"\n✅ Perturbation analysis complete: {len(perturbation_results)} experiments")

## 📈 Results Analysis & Visualization

In [None]:
# Convert results to DataFrame for analysis
df = pd.DataFrame(perturbation_results)

if len(df) > 0:
    print("📈 Perturbation Analysis Results:")
    print(f"  Total experiments: {len(df)}")
    print(f"  Metrics tested: {df['metric'].unique()}")
    print(f"  Max PPL increase: {df['ppl_increase'].max():.2f}")
    print(f"  Max PPL ratio: {df['ppl_ratio'].max():.2f}x")
    
    # Show top results
    print("\n🔝 Top 10 Most Impactful Perturbations:")
    top_results = df.nlargest(10, 'ppl_increase')
    for _, row in top_results.iterrows():
        print(f"  {row['metric']}, Top-{row['topk']}, {row['mask_ratio']*100:.0f}% masked: "
              f"PPL {row['baseline_ppl']:.1f} → {row['perturbed_ppl']:.1f} "
              f"(+{row['ppl_increase']:.1f})")
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('🎯 Perturbation Analysis Results', fontsize=16, fontweight='bold')
    
    # Plot 1: PPL increase by mask ratio
    ax = axes[0, 0]
    for metric in df['metric'].unique():
        metric_data = df[df['metric'] == metric]
        for k in sorted(metric_data['topk'].unique()):
            k_data = metric_data[metric_data['topk'] == k]
            ax.plot(k_data['mask_ratio'], k_data['ppl_increase'], 
                   marker='o', label=f'{metric}, Top-{k}')
    
    ax.set_xlabel('Masking Ratio')
    ax.set_ylabel('Perplexity Increase')
    ax.set_title('PPL Impact vs Masking Ratio')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3)
    
    # Plot 2: PPL ratio by number of weights masked
    ax = axes[0, 1]
    scatter = ax.scatter(df['weights_masked'], df['ppl_ratio'], 
                        c=df['topk'], alpha=0.7, s=60)
    ax.set_xlabel('Number of Weights Masked')
    ax.set_ylabel('Perplexity Ratio')
    ax.set_title('PPL Ratio vs Weights Masked')
    plt.colorbar(scatter, ax=ax, label='Top-K Value')
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Heatmap of PPL increases
    ax = axes[1, 0]
    pivot_data = df.pivot_table(values='ppl_increase', 
                               index='mask_ratio', 
                               columns='topk', 
                               aggfunc='mean')
    sns.heatmap(pivot_data, annot=True, fmt='.1f', ax=ax, cmap='Reds')
    ax.set_title('PPL Increase Heatmap\n(Mask Ratio vs Top-K)')
    
    # Plot 4: Metric comparison
    ax = axes[1, 1]
    if len(df['metric'].unique()) > 1:
        metric_comparison = df.groupby(['metric', 'mask_ratio'])['ppl_increase'].mean().unstack(level=0)
        metric_comparison.plot(kind='bar', ax=ax)
        ax.set_title('Metric Comparison by Mask Ratio')
        ax.set_xlabel('Mask Ratio')
        ax.set_ylabel('Mean PPL Increase')
        ax.legend(title='Metric')
    else:
        # Show distribution if only one metric
        ax.hist(df['ppl_increase'], bins=20, alpha=0.7, edgecolor='black')
        ax.set_title('PPL Increase Distribution')
        ax.set_xlabel('PPL Increase')
        ax.set_ylabel('Frequency')
    
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(session_dir / 'perturbation_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\n📊 Perturbation analysis plots saved to {session_dir / 'perturbation_analysis.png'}")
    
else:
    print("⚠️ No perturbation results to analyze")

## 💾 Export Results

In [None]:
if CONFIG['save_results']:
    print("💾 Exporting results...")
    
    # 1. Export perturbation results
    if len(perturbation_results) > 0:
        df.to_csv(session_dir / 'perturbation_results.csv', index=False)
        print(f"  ✅ Perturbation results: {session_dir / 'perturbation_results.csv'}")
    
    # 2. Export top-K weights for each metric
    for metric in CONFIG['sensitivity_metrics']:
        for k in CONFIG['topk_values']:
            if metric in ranking_results and k in ranking_results[metric]:
                top_weights = ranking_results[metric][k]
                
                # Convert to DataFrame
                weights_df = pd.DataFrame([
                    {
                        'rank': i + 1,
                        'layer': w['layer'],
                        'indices': str(w['indices']),
                        'sensitivity': w['sensitivity']
                    }
                    for i, w in enumerate(top_weights)
                ])
                
                filename = f'top_{k}_weights_{metric}.csv'
                weights_df.to_csv(session_dir / filename, index=False)
                print(f"  ✅ Top-{k} {metric} weights: {session_dir / filename}")
    
    # 3. Export experiment summary
    summary = {
        'experiment_id': CONFIG['timestamp'],
        'model_name': CONFIG['model_name'],
        'baseline_perplexity': baseline_ppl,
        'num_parameters': sum(p.numel() for p in model.parameters()),
        'eval_texts_count': len(eval_texts),
        'sensitivity_metrics': CONFIG['sensitivity_metrics'],
        'topk_values': CONFIG['topk_values'],
        'perturbation_ratios': CONFIG['perturbation_ratios'],
        'experiments_completed': len(perturbation_results),
        'max_ppl_increase': df['ppl_increase'].max() if len(df) > 0 else None,
        'max_ppl_ratio': df['ppl_ratio'].max() if len(df) > 0 else None
    }
    
    summary_df = pd.DataFrame([summary])
    summary_df.to_csv(session_dir / 'experiment_summary.csv', index=False)
    print(f"  ✅ Experiment summary: {session_dir / 'experiment_summary.csv'}")
    
    # 4. Export configuration
    config_df = pd.DataFrame(list(CONFIG.items()), columns=['parameter', 'value'])
    config_df.to_csv(session_dir / 'config.csv', index=False)
    print(f"  ✅ Configuration: {session_dir / 'config.csv'}")
    
    print(f"\n🎉 All results exported to: {session_dir}")
    print(f"📁 Files created:")
    for file in session_dir.glob('*'):
        print(f"  - {file.name}")
else:
    print("📄 Results export disabled in configuration")

## 📋 Analysis Summary

In [None]:
print("🔬 Critical Weight Analysis - Complete Summary")
print("=" * 60)
print(f"📅 Session: {CONFIG['timestamp']}")
print(f"🤖 Model: {CONFIG['model_name']}")
print(f"📊 Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"📚 Evaluation texts: {len(eval_texts)}")
print(f"📏 Baseline perplexity: {baseline_ppl:.2f}")

print(f"\n🧮 Sensitivity Analysis:")
for metric in CONFIG['sensitivity_metrics']:
    layer_sensitivities = sensitivity_results[metric]
    total_weights = sum(len(sens) for sens in layer_sensitivities.values())
    all_sensitivities = np.concatenate([sens.cpu().numpy().flatten() 
                                       for sens in layer_sensitivities.values()])
    print(f"  {metric}:")
    print(f"    Layers analyzed: {len(layer_sensitivities)}")
    print(f"    Total weights: {total_weights:,}")
    print(f"    Mean sensitivity: {all_sensitivities.mean():.2e}")
    print(f"    Max sensitivity: {all_sensitivities.max():.2e}")

if len(perturbation_results) > 0:
    print(f"\n🎯 Perturbation Analysis:")
    print(f"  Experiments completed: {len(perturbation_results)}")
    print(f"  Max PPL increase: {df['ppl_increase'].max():.2f}")
    print(f"  Max PPL ratio: {df['ppl_ratio'].max():.2f}x")
    
    # Find most effective perturbation
    best_result = df.loc[df['ppl_increase'].idxmax()]
    print(f"\n🏆 Most Impactful Perturbation:")
    print(f"  Metric: {best_result['metric']}")
    print(f"  Top-K: {best_result['topk']}")
    print(f"  Mask ratio: {best_result['mask_ratio']*100:.0f}%")
    print(f"  Weights masked: {best_result['weights_masked']}")
    print(f"  PPL change: {best_result['baseline_ppl']:.1f} → {best_result['perturbed_ppl']:.1f}")
    print(f"  Impact: +{best_result['ppl_increase']:.1f} ({best_result['ppl_ratio']:.2f}x)")

print(f"\n💾 Results saved to: {session_dir}")
print(f"🎉 Analysis complete!")

# Key insights
print(f"\n🔑 Key Insights:")
if len(perturbation_results) > 0:
    print(f"  • Masking top-{best_result['topk']} {best_result['metric']} weights causes {best_result['ppl_increase']:.1f} PPL increase")
    print(f"  • Most critical weights concentrate in specific layers")
    print(f"  • {best_result['metric'].replace('_', ' ').title()} shows strongest sensitivity correlation")
else:
    print(f"  • Sensitivity analysis completed successfully")
    print(f"  • Weight ranking identified critical parameters")
    print(f"  • Ready for perturbation experiments")

print(f"\n🚀 Next Steps:")
print(f"  • Analyze layer-specific sensitivity patterns")
print(f"  • Test different masking strategies")
print(f"  • Explore attention head importance")
print(f"  • Compare across different model sizes")