# Phase 0: Methodological Enhancements - Ablation Study

This notebook validates the Phase 0 protein-specific distillation enhancements:

1. **Uncertainty-Aware Position Weighting**: Weights distillation loss by position-specific entropy from teacher predictions
2. **Calibration-Aware Distillation**: Applies dynamic label smoothing based on teacher confidence

We compare four configurations:
- Baseline (no enhancements)
- +Uncertainty weighting only
- +Calibration smoothing only
- +Both enhancements

## 1. Setup and Imports

In [None]:
import sys
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add project root to path
project_root = Path(os.getcwd()).parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

from src.distillation import DistillationTrainer
import config

%matplotlib inline

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Teacher Model

In [None]:
# Load teacher model
print("Loading teacher model...")
teacher_model = GPT2LMHeadModel.from_pretrained(config.TEACHER_MODEL).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(config.TEACHER_MODEL)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

teacher_model.eval()
print(f"Teacher parameters: {sum(p.numel() for p in teacher_model.parameters()):,}")

## 3. Test Enhancement Methods

Before training, let's verify the enhancement methods work correctly on sample data.

In [None]:
# Test entropy computation
test_sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
inputs = tokenizer(test_sequence, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = teacher_model(**inputs)
    logits = outputs.logits

# Compute entropy using the method from DistillationTrainer
probs = F.softmax(logits, dim=-1)
log_probs = torch.log(probs + 1e-10)
entropy = -torch.sum(probs * log_probs, dim=-1)

print(f"Logits shape: {logits.shape}")
print(f"Entropy shape: {entropy.shape}")
print(f"Entropy range: [{entropy.min().item():.3f}, {entropy.max().item():.3f}]")
print(f"Entropy mean: {entropy.mean().item():.3f}")

In [None]:
# Test position weighting
entropy_squeezed = entropy.squeeze(0)  # Remove batch dimension

# Normalize entropy to [0, 1]
entropy_min = entropy_squeezed.min()
entropy_max = entropy_squeezed.max()
normalized = (entropy_squeezed - entropy_min) / (entropy_max - entropy_min + 1e-8)

# Scale to [0.5, 1.0]
weights = 0.5 + 0.5 * normalized

print(f"Weights shape: {weights.shape}")
print(f"Weights range: [{weights.min().item():.3f}, {weights.max().item():.3f}]")
print(f"Weights mean: {weights.mean().item():.3f}")

In [None]:
# Test calibration smoothing
smoothing_factor = 0.1

# Get max probability (confidence) at each position
max_prob = probs.max(dim=-1, keepdim=True)[0]

# Adaptive smoothing: more smoothing when less confident
adaptive_smoothing = smoothing_factor * (1.0 - max_prob)

# Apply smoothing
vocab_size = probs.size(-1)
uniform = torch.ones_like(probs) / vocab_size
smoothed_probs = (1.0 - adaptive_smoothing) * probs + adaptive_smoothing * uniform

# Verify still valid probability distribution
print(f"Original probs sum (should be 1): {probs.sum(dim=-1).mean().item():.6f}")
print(f"Smoothed probs sum (should be 1): {smoothed_probs.sum(dim=-1).mean().item():.6f}")
print(f"Max adaptive smoothing: {adaptive_smoothing.max().item():.4f}")
print(f"Min adaptive smoothing: {adaptive_smoothing.min().item():.4f}")

## 4. Visualize Position-wise Uncertainty

In [None]:
# Visualize entropy and weights along the sequence
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

positions = range(len(entropy_squeezed.cpu().numpy()))
entropy_np = entropy_squeezed.cpu().numpy()
weights_np = weights.cpu().numpy()
max_prob_np = max_prob.squeeze().cpu().numpy()

# Entropy
axes[0].bar(positions, entropy_np, alpha=0.7, color='steelblue')
axes[0].set_ylabel('Entropy')
axes[0].set_title('Teacher Prediction Entropy per Position')
axes[0].axhline(y=entropy_np.mean(), color='red', linestyle='--', label=f'Mean: {entropy_np.mean():.2f}')
axes[0].legend()

# Position weights
axes[1].bar(positions, weights_np, alpha=0.7, color='orange')
axes[1].set_ylabel('Weight')
axes[1].set_title('Position Weights for Uncertainty-Aware Distillation')
axes[1].axhline(y=0.5, color='gray', linestyle='--', label='Min weight (0.5)')
axes[1].axhline(y=1.0, color='gray', linestyle=':', label='Max weight (1.0)')
axes[1].legend()

# Max probability (confidence)
axes[2].bar(positions, max_prob_np, alpha=0.7, color='green')
axes[2].set_ylabel('Max Probability')
axes[2].set_xlabel('Sequence Position')
axes[2].set_title('Teacher Confidence (Max Probability) per Position')
axes[2].axhline(y=max_prob_np.mean(), color='red', linestyle='--', label=f'Mean: {max_prob_np.mean():.2f}')
axes[2].legend()

plt.tight_layout()
plt.show()

## 5. Test ECE Computation

In [None]:
# Add scripts to path and import ECE function
sys.path.insert(0, str(project_root / 'scripts'))
from evaluate import compute_ece, get_test_sequences

# Get test sequences
test_sequences = get_test_sequences(num_sequences=10)
print(f"Number of test sequences: {len(test_sequences)}")

# Compute ECE for teacher
teacher_ece = compute_ece(teacher_model, tokenizer, test_sequences[:5], device, n_bins=10)
print(f"\nTeacher ECE: {teacher_ece['ece']:.4f}")
print(f"Teacher MCE (max calibration error): {teacher_ece['mce']:.4f}")
print(f"Overall accuracy: {teacher_ece['overall_accuracy']:.4f}")
print(f"Overall confidence: {teacher_ece['overall_confidence']:.4f}")

In [None]:
# Reliability diagram
fig, ax = plt.subplots(figsize=(6, 6))

# Extract bin statistics
bin_stats = teacher_ece['bin_stats']
confidences = []
accuracies = []
counts = []

for stat in bin_stats:
    if stat['count'] > 0:
        confidences.append(stat['avg_confidence'])
        accuracies.append(stat['avg_accuracy'])
        counts.append(stat['count'])

# Scatter plot with size proportional to count
sizes = [c / max(counts) * 500 + 50 for c in counts]
ax.scatter(confidences, accuracies, s=sizes, alpha=0.7, label='Bins')

# Perfect calibration line
ax.plot([0, 1], [0, 1], 'r--', label='Perfect calibration')

ax.set_xlabel('Mean Confidence')
ax.set_ylabel('Mean Accuracy')
ax.set_title(f'Teacher Model Reliability Diagram (ECE={teacher_ece["ece"]:.4f})')
ax.legend()
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_aspect('equal')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Run Ablation Experiments

This section provides commands to run the ablation study. Due to training time, we provide the commands to run separately.

### Training Commands

```bash
# Baseline (no enhancements)
python scripts/train.py \
    --temperature 2.0 --alpha 0.5 \
    --n_layer 4 --n_head 4 --n_embd 512 \
    --train_size_prop 0.05 \
    --learning_rate 1e-3 \
    --num_train_epochs 1 \
    --output_dir ./models/ablation-baseline

# +Uncertainty weighting only
python scripts/train.py \
    --temperature 2.0 --alpha 0.5 \
    --n_layer 4 --n_head 4 --n_embd 512 \
    --train_size_prop 0.05 \
    --learning_rate 1e-3 \
    --num_train_epochs 1 \
    --use_uncertainty_weighting \
    --output_dir ./models/ablation-uncertainty

# +Calibration smoothing only
python scripts/train.py \
    --temperature 2.0 --alpha 0.5 \
    --n_layer 4 --n_head 4 --n_embd 512 \
    --train_size_prop 0.05 \
    --learning_rate 1e-3 \
    --num_train_epochs 1 \
    --use_calibration_smoothing \
    --smoothing_factor 0.1 \
    --output_dir ./models/ablation-calibration

# +Both enhancements
python scripts/train.py \
    --temperature 2.0 --alpha 0.5 \
    --n_layer 4 --n_head 4 --n_embd 512 \
    --train_size_prop 0.05 \
    --learning_rate 1e-3 \
    --num_train_epochs 1 \
    --use_uncertainty_weighting \
    --use_calibration_smoothing \
    --smoothing_factor 0.1 \
    --output_dir ./models/ablation-both
```

### Evaluation Commands

```bash
for model in baseline uncertainty calibration both; do
    python scripts/evaluate.py \
        --student_model ./models/ablation-${model} \
        --num_samples 100 \
        --compute_ece \
        --output results/ablation_${model}.json
done
```

## 7. Load and Compare Ablation Results

After running the experiments, load and visualize results.

In [None]:
# Load ablation results (if available)
results_dir = project_root / 'results'
ablation_files = {
    'Baseline': results_dir / 'ablation_baseline.json',
    '+Uncertainty': results_dir / 'ablation_uncertainty.json',
    '+Calibration': results_dir / 'ablation_calibration.json',
    '+Both': results_dir / 'ablation_both.json',
}

ablation_results = {}
for name, path in ablation_files.items():
    if path.exists():
        with open(path) as f:
            ablation_results[name] = json.load(f)
        print(f"Loaded {name}")
    else:
        print(f"Not found: {path}")

if ablation_results:
    print(f"\nLoaded {len(ablation_results)} ablation results")
else:
    print("\nNo ablation results found. Run the training commands above first.")

In [None]:
# Compare results if available
if ablation_results:
    print("Ablation Study Results")
    print("=" * 60)
    print(f"{'Configuration':<20} {'PPL Ratio':>12} {'KL Div':>12} {'ECE':>12}")
    print("-" * 60)
    
    for name, result in ablation_results.items():
        ppl_ratio = result.get('perplexity_ratio', 'N/A')
        kl_div = result.get('kl_divergence', 'N/A')
        ece = result.get('student_ece', {}).get('ece', 'N/A') if 'student_ece' in result else 'N/A'
        
        if isinstance(ppl_ratio, (int, float)):
            ppl_str = f"{ppl_ratio:.4f}"
        else:
            ppl_str = str(ppl_ratio)
            
        if isinstance(kl_div, (int, float)):
            kl_str = f"{kl_div:.4f}"
        else:
            kl_str = str(kl_div)
            
        if isinstance(ece, (int, float)):
            ece_str = f"{ece:.4f}"
        else:
            ece_str = str(ece)
        
        print(f"{name:<20} {ppl_str:>12} {kl_str:>12} {ece_str:>12}")

In [None]:
# Visualize comparison
if len(ablation_results) >= 2:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    names = list(ablation_results.keys())
    ppl_ratios = [r.get('perplexity_ratio', 0) for r in ablation_results.values()]
    kl_divs = [r.get('kl_divergence', 0) for r in ablation_results.values()]
    eces = [r.get('student_ece', {}).get('ece', 0) if 'student_ece' in r else 0 for r in ablation_results.values()]
    
    # Perplexity ratio
    bars1 = axes[0].bar(names, ppl_ratios, color=['steelblue', 'orange', 'green', 'red'][:len(names)])
    axes[0].set_ylabel('Perplexity Ratio')
    axes[0].set_title('Perplexity Ratio (lower is better)')
    axes[0].tick_params(axis='x', rotation=45)
    for bar, val in zip(bars1, ppl_ratios):
        if val > 0:
            axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                        f'{val:.3f}', ha='center', va='bottom')
    
    # KL divergence
    bars2 = axes[1].bar(names, kl_divs, color=['steelblue', 'orange', 'green', 'red'][:len(names)])
    axes[1].set_ylabel('KL Divergence')
    axes[1].set_title('KL Divergence (lower is better)')
    axes[1].tick_params(axis='x', rotation=45)
    for bar, val in zip(bars2, kl_divs):
        if val > 0:
            axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                        f'{val:.4f}', ha='center', va='bottom')
    
    # ECE
    bars3 = axes[2].bar(names, eces, color=['steelblue', 'orange', 'green', 'red'][:len(names)])
    axes[2].set_ylabel('Expected Calibration Error')
    axes[2].set_title('ECE (lower is better)')
    axes[2].tick_params(axis='x', rotation=45)
    for bar, val in zip(bars3, eces):
        if val > 0:
            axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                        f'{val:.4f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
else:
    print("Need at least 2 ablation results to create comparison plot.")

## 8. Summary

### Expected Results

Based on the literature and our implementation:

1. **Uncertainty-aware position weighting** should improve perplexity on difficult sequences (15-25% improvement expected)
2. **Calibration-aware distillation** should improve ECE scores (20-30% improvement expected)
3. **Combined enhancements** should provide benefits of both without significant interference

### Key Observations

- Entropy visualization shows variable uncertainty across protein sequence positions
- Position weights correctly range from [0.5, 1.0] as designed
- Calibration smoothing maintains valid probability distributions

### Next Steps

1. Run full ablation experiments with the commands in Section 6
2. Analyze results to validate improvements
3. Proceed to Phase 2 (Hyperparameter Sweeps) with optimal enhancement settings