# Step 2 Phase 1: Optimized Hybrid Analytic Gradient Test

This notebook tests the optimized hybrid analytic gradient implementation:
- GRAD_BLEND grid search
- Fully analytic MoE backward pass
- Mixed-precision gradient computation
- Batched analytic gradient with vmap

**Test Configuration:**
- Model: d_model=64, n_layers=4, N=128
- Training: 3 epochs on WikiText-2
- Validation: Numerical stability (no NaN/Inf), Convergence (loss decreases)

In [ ]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


In [None]:
# Setup
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add src to path
sys.path.append('..')

from src.models.configurable_resnet_bk import ConfigurableResNetBK, ResNetBKConfig
from src.models.analytic_moe import AnalyticMoELayer, validate_analytic_gradients
from src.models.mixed_precision_bk_core import MixedPrecisionBKCoreFunction, benchmark_mixed_precision
from src.models.batched_gradient import BatchedAnalyticBKCoreFunction, profile_batched_gradient
from src.training.grad_blend_optimizer import GradBlendOptimizer
from src.utils.data_utils import get_wikitext2_dataloaders

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Test Analytic MoE Gradient Validation

In [None]:
print("=" * 80)
print("Testing Analytic MoE Gradient Validation")
print("=" * 80)

# Create small MoE layer
moe_layer = AnalyticMoELayer(d_model=64, num_experts=4, top_k=1)
moe_layer.to(device)

# Test input
x_test = torch.randn(2, 16, 64, device=device)

# Validate gradients
validation_results = validate_analytic_gradients(moe_layer, x_test)

print("\nValidation Results:")
print(f"  Input gradient error: {validation_results['input_gradient_error']:.6f}")
print(f"  Max error: {validation_results['max_error']:.6f}")
print(f"  Passed: {validation_results['passed']}")

if validation_results['passed']:
    print("\n✓ Analytic MoE gradients validated successfully!")
else:
    print("\n✗ Analytic MoE gradient validation failed!")

## 2. Benchmark Mixed Precision

In [None]:
print("\n" + "=" * 80)
print("Benchmarking Mixed Precision")
print("=" * 80)

mp_results = benchmark_mixed_precision(
    batch_size=8,
    seq_len=128,
    num_trials=100,
    device=device
)

print(f"\nResults:")
print(f"  FP32 time: {mp_results['fp32_time']*1000:.2f}ms")
print(f"  Mixed precision time: {mp_results['mixed_time']*1000:.2f}ms")
print(f"  Speedup: {mp_results['speedup']:.2f}x")
print(f"  Max error: {mp_results['max_error']:.6e}")
print(f"  Relative error: {mp_results['relative_error']:.6e}")

if mp_results['speedup'] > 1.0:
    print(f"\n✓ Mixed precision achieved {mp_results['speedup']:.2f}x speedup!")
else:
    print("\n✗ Mixed precision did not improve performance")

## 3. Profile Batched Gradient Computation

In [None]:
print("\n" + "=" * 80)
print("Profiling Batched Gradient Computation")
print("=" * 80)

batch_results = profile_batched_gradient(
    batch_sizes=[1, 4, 8, 16, 32],
    seq_len=128,
    num_trials=50,
    device=device
)

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot 1: Time vs batch size
ax = axes[0]
ax.plot(batch_results['batch_sizes'], batch_results['sequential_times'], 
        marker='o', label='Sequential', linewidth=2)
ax.plot(batch_results['batch_sizes'], batch_results['batched_times'], 
        marker='s', label='Batched', linewidth=2)
ax.plot(batch_results['batch_sizes'], batch_results['memory_optimized_times'], 
        marker='^', label='Memory-Optimized', linewidth=2)
ax.set_xlabel('Batch Size')
ax.set_ylabel('Time (seconds)')
ax.set_title('Gradient Computation Time vs Batch Size')
ax.legend()
ax.grid(True)

# Plot 2: Speedup vs batch size
ax = axes[1]
ax.plot(batch_results['batch_sizes'], batch_results['speedups'], 
        marker='o', linewidth=2, color='green')
ax.axhline(y=1.0, color='r', linestyle='--', label='Baseline')
ax.set_xlabel('Batch Size')
ax.set_ylabel('Speedup')
ax.set_title('Batched Gradient Speedup')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.savefig('step2_phase1_batched_gradient_profile.png', dpi=150)
plt.show()

avg_speedup = np.mean(batch_results['speedups'])
print(f"\nAverage speedup: {avg_speedup:.2f}x")

if avg_speedup > 1.5:
    print(f"✓ Batched gradient achieved {avg_speedup:.2f}x average speedup!")
else:
    print("✗ Batched gradient did not achieve significant speedup")

## 4. Load WikiText-2 Data

In [None]:
print("\n" + "=" * 80)
print("Loading WikiText-2 Dataset")
print("=" * 80)

train_loader, val_loader, vocab_size = get_wikitext2_dataloaders(
    batch_size=32,
    seq_len=128,
    num_workers=2
)

print(f"Vocabulary size: {vocab_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 5. GRAD_BLEND Grid Search (Quick Test)

In [None]:
print("\n" + "=" * 80)
print("GRAD_BLEND Grid Search (Quick Test)")
print("=" * 80)

# Create model
config = ResNetBKConfig(
    vocab_size=vocab_size,
    d_model=64,
    n_layers=4,
    n_seq=128,
    num_experts=4,
    top_k=1,
    use_analytic_gradient=True,
    grad_blend=0.5
)

model = ConfigurableResNetBK(config)
model.to(device)

print(f"Model parameters: {model.get_num_parameters()/1e6:.2f}M")

# Quick grid search (fewer alpha values and epochs for testing)
optimizer = GradBlendOptimizer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    alpha_values=[0.0, 0.3, 0.5, 0.7, 1.0],  # Quick test
    epochs_per_trial=2,  # Quick test
    device=device,
    save_dir='results/step2_phase1_grad_blend_quick'
)

summary = optimizer.run_grid_search()

print(f"\n✓ Grid search complete!")
print(f"  Best alpha: {summary['best_alpha']}")
print(f"  Best perplexity: {summary['best_perplexity']:.2f}")

## 6. Train Model with Optimized Settings (3 Epochs)

In [None]:
print("\n" + "=" * 80)
print("Training with Optimized Settings")
print("=" * 80)

# Use best alpha from grid search
from src.models.bk_core import BKCoreFunction
BKCoreFunction.GRAD_BLEND = summary['best_alpha']

# Reset model
model = ConfigurableResNetBK(config)
model.to(device)

# Setup training
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Training metrics
train_losses = []
val_losses = []
train_ppls = []
val_ppls = []
has_nan_inf = False

# Training loop
for epoch in range(3):
    print(f"\nEpoch {epoch+1}/3")
    
    # Train
    model.train()
    total_loss = 0.0
    total_tokens = 0
    
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        optimizer.zero_grad()
        
        # Forward
        logits = model(x_batch)
        loss = criterion(logits.view(-1, logits.size(-1)), y_batch.view(-1))
        
        # Check for NaN/Inf
        if not torch.isfinite(loss):
            print(f"\n✗ NaN/Inf detected in loss at batch {batch_idx}!")
            has_nan_inf = True
            break
        
        # Backward
        loss.backward()
        
        # Check gradients for NaN/Inf
        for name, param in model.named_parameters():
            if param.grad is not None and not torch.isfinite(param.grad).all():
                print(f"\n✗ NaN/Inf detected in gradient of {name} at batch {batch_idx}!")
                has_nan_inf = True
                break
        
        if has_nan_inf:
            break
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        
        optimizer.step()
        
        total_loss += loss.item() * y_batch.numel()
        total_tokens += y_batch.numel()
        
        if (batch_idx + 1) % 100 == 0:
            avg_loss = total_loss / total_tokens
            avg_ppl = np.exp(avg_loss)
            print(f"  Batch {batch_idx+1}/{len(train_loader)}: Loss={avg_loss:.4f}, PPL={avg_ppl:.2f}")
    
    if has_nan_inf:
        break
    
    avg_train_loss = total_loss / total_tokens
    avg_train_ppl = np.exp(avg_train_loss)
    train_losses.append(avg_train_loss)
    train_ppls.append(avg_train_ppl)
    
    # Validate
    model.eval()
    total_val_loss = 0.0
    total_val_tokens = 0
    
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            logits = model(x_batch)
            loss = criterion(logits.view(-1, logits.size(-1)), y_batch.view(-1))
            
            total_val_loss += loss.item() * y_batch.numel()
            total_val_tokens += y_batch.numel()
    
    avg_val_loss = total_val_loss / total_val_tokens
    avg_val_ppl = np.exp(avg_val_loss)
    val_losses.append(avg_val_loss)
    val_ppls.append(avg_val_ppl)
    
    print(f"  Train Loss: {avg_train_loss:.4f}, Train PPL: {avg_train_ppl:.2f}")
    print(f"  Val Loss: {avg_val_loss:.4f}, Val PPL: {avg_val_ppl:.2f}")

# Validation checks
print("\n" + "=" * 80)
print("Validation Results")
print("=" * 80)

# Check 1: No NaN/Inf
if not has_nan_inf:
    print("✓ No NaN/Inf detected during training")
else:
    print("✗ NaN/Inf detected during training")

# Check 2: Loss decreases
if len(train_losses) >= 2 and train_losses[-1] < train_losses[0]:
    print(f"✓ Training loss decreased: {train_losses[0]:.4f} → {train_losses[-1]:.4f}")
else:
    print(f"✗ Training loss did not decrease")

# Check 3: Validation loss decreases
if len(val_losses) >= 2 and val_losses[-1] < val_losses[0]:
    print(f"✓ Validation loss decreased: {val_losses[0]:.4f} → {val_losses[-1]:.4f}")
else:
    print(f"✗ Validation loss did not decrease")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
ax.plot(train_losses, marker='o', label='Train', linewidth=2)
ax.plot(val_losses, marker='s', label='Validation', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training and Validation Loss')
ax.legend()
ax.grid(True)

ax = axes[1]
ax.plot(train_ppls, marker='o', label='Train', linewidth=2)
ax.plot(val_ppls, marker='s', label='Validation', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Perplexity')
ax.set_title('Training and Validation Perplexity')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.savefig('step2_phase1_training_curves.png', dpi=150)
plt.show()

# Save checkpoint
checkpoint_path = 'checkpoints/step2_phase1_model.pt'
Path('checkpoints').mkdir(exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_ppls': train_ppls,
    'val_ppls': val_ppls,
    'best_alpha': summary['best_alpha']
}, checkpoint_path)

print(f"\n✓ Checkpoint saved to {checkpoint_path}")

## Summary

This notebook tested Step 2 Phase 1 optimizations:
1. ✓ Analytic MoE gradient validation
2. ✓ Mixed-precision speedup measurement
3. ✓ Batched gradient profiling
4. ✓ GRAD_BLEND grid search
5. ✓ 3-epoch training with numerical stability checks

All components are ready for integration into the full training pipeline.