# Step 2 Phase 3: Physics-Informed Learning

This notebook implements and tests physics-informed learning for ResNet-BK.

**Objectives:**
- Train with energy conservation constraints (Hamiltonian dynamics)
- Verify energy drift monitoring and control
- Test symplectic integration
- Compare perplexity to Phase 2 baseline

**Expected Results:**
- Energy drift < 0.1 (controlled)
- Hamiltonian structure preserved
- Final perplexity within 30% of baseline
- Stable training with physics constraints

**Optimizations Applied:**
- Conservative energy loss weight (max 0.05)
- Extended warmup period (4 epochs)
- Automatic Lagrange multiplier adjustment
- Energy drift monitoring

## 1. Setup and Installation

In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import time

# Import ResNet-BK components
from src.models.physics_informed_layer import PhysicsInformedLanguageModel
from src.training.physics_informed_trainer import PhysicsInformedTrainer
from src.utils.data_utils import get_wikitext2_dataloaders
from src.utils.metrics import TrainingMetrics

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Load Data

In [None]:
# Configuration
BATCH_SIZE = 32
N_SEQ = 128
D_MODEL = 64
N_LAYERS = 4
NUM_EXPERTS = 4

# Load WikiText-2 data
print("Loading WikiText-2 dataset...")
train_loader, val_loader, vocab_size = get_wikitext2_dataloaders(
    batch_size=BATCH_SIZE,
    seq_len=N_SEQ,
    num_workers=2
)

print(f"Vocabulary size: {vocab_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 3. Create Physics-Informed Model

In [None]:
# Create physics-informed language model
model = PhysicsInformedLanguageModel(
    vocab_size=vocab_size,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_seq=N_SEQ,
    num_experts=NUM_EXPERTS,
    top_k=1,
    dropout_p=0.1,
    use_energy_conservation=True
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Architecture:")
print(f"  d_model: {D_MODEL}")
print(f"  n_layers: {N_LAYERS}")
print(f"  n_seq: {N_SEQ}")
print(f"  num_experts: {NUM_EXPERTS}")
print(f"  energy_conservation: True")
print(f"\nParameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Size: {total_params * 4 / 1e6:.2f} MB (FP32)")

## 4. Initialize Trainer

In [None]:
# Training configuration
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10
PHYSICS_START_EPOCH = 4  # Start physics constraints after 4 epochs warmup
LAMBDA_ENERGY_INIT = 0.05  # Conservative energy loss weight
ENERGY_TARGET_DRIFT = 0.1  # Target energy drift

# Optimizer and criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# Create physics-informed trainer
trainer = PhysicsInformedTrainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    lambda_energy_init=LAMBDA_ENERGY_INIT,
    lambda_energy_lr=0.01,
    energy_target_drift=ENERGY_TARGET_DRIFT,
    physics_start_epoch=PHYSICS_START_EPOCH,
    device=device
)

print(f"\nTraining Configuration:")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Physics start epoch: {PHYSICS_START_EPOCH}")
print(f"  Lambda energy init: {LAMBDA_ENERGY_INIT}")
print(f"  Energy target drift: {ENERGY_TARGET_DRIFT}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"\nOptimizations:")
print(f"  ✓ Conservative energy weight (0.05)")
print(f"  ✓ Extended warmup (4 epochs)")
print(f"  ✓ Automatic Lagrange multiplier adjustment")
print(f"  ✓ Energy drift monitoring")

## 5. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_loss_lm': [],
    'train_loss_energy': [],
    'lambda_energy': [],
    'energy_drift': [],
    'val_loss': [],
    'val_perplexity': [],
    'epoch_time': [],
}

print("\nStarting training...\n")
print("=" * 80)

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Train for one epoch
    epoch_metrics = trainer.train_epoch(train_loader, epoch=epoch)
    
    # Evaluate on validation set
    val_loss, val_ppl = trainer.evaluate(val_loader)
    
    epoch_time = time.time() - epoch_start
    
    # Store metrics
    history['train_loss'].append(epoch_metrics['total_loss'])
    history['train_loss_lm'].append(epoch_metrics['loss_lm'])
    history['train_loss_energy'].append(epoch_metrics['loss_energy'])
    history['lambda_energy'].append(epoch_metrics['lambda_energy'])
    history['energy_drift'].append(epoch_metrics['energy_drift'])
    history['val_loss'].append(val_loss)
    history['val_perplexity'].append(val_ppl)
    history['epoch_time'].append(epoch_time)
    
    # Print progress
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {epoch_metrics['total_loss']:.4f} "
          f"(LM: {epoch_metrics['loss_lm']:.4f}, "
          f"Energy: {epoch_metrics['loss_energy']:.4f})")
    print(f"  Val Loss: {val_loss:.4f}, Val PPL: {val_ppl:.2f}")
    print(f"  Lambda Energy: {epoch_metrics['lambda_energy']:.4f}")
    print(f"  Energy Drift: {epoch_metrics['energy_drift']:.4f}")
    print(f"  Physics Enabled: {epoch_metrics['physics_enabled']}")
    print(f"  Time: {epoch_time:.2f}s")
    print("-" * 80)

print("\nTraining complete!")
print("=" * 80)

## 6. Verify Energy Conservation

In [None]:
# Analyze energy drift
print("\nEnergy Conservation Analysis:")
print("=" * 80)

# Get energy drift after physics enabled
physics_start_idx = PHYSICS_START_EPOCH
if physics_start_idx < len(history['energy_drift']):
    energy_drifts_physics = history['energy_drift'][physics_start_idx:]
    avg_energy_drift = np.mean(energy_drifts_physics)
    max_energy_drift = np.max(energy_drifts_physics)
    
    print(f"Average energy drift (after physics enabled): {avg_energy_drift:.4f}")
    print(f"Max energy drift: {max_energy_drift:.4f}")
    print(f"Target energy drift: {ENERGY_TARGET_DRIFT}")
    
    if avg_energy_drift < ENERGY_TARGET_DRIFT:
        print(f"\n✓ Energy drift controlled (avg {avg_energy_drift:.4f} < target {ENERGY_TARGET_DRIFT})")
    else:
        print(f"\n⚠ Energy drift above target (avg {avg_energy_drift:.4f} > target {ENERGY_TARGET_DRIFT})")
else:
    print("⚠ Physics constraints not yet enabled")

# Visualize energy drift
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Energy drift over time
epochs = np.arange(1, NUM_EPOCHS + 1)
axes[0].plot(epochs, history['energy_drift'], 'b-', linewidth=2, marker='o')
axes[0].axhline(y=ENERGY_TARGET_DRIFT, color='r', linestyle='--', label=f'Target ({ENERGY_TARGET_DRIFT})')
axes[0].axvline(x=PHYSICS_START_EPOCH, color='gray', linestyle=':', label='Physics Start')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Energy Drift')
axes[0].set_title('Energy Drift Over Time')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Lambda energy adaptation
axes[1].plot(epochs, history['lambda_energy'], 'g-', linewidth=2, marker='s')
axes[1].axvline(x=PHYSICS_START_EPOCH, color='gray', linestyle=':', label='Physics Start')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Lambda Energy')
axes[1].set_title('Lagrange Multiplier Adaptation')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('physics_informed_energy_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Energy conservation analysis complete")

## 7. Plot Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = np.arange(1, NUM_EPOCHS + 1)

# Loss curves
axes[0, 0].plot(epochs, history['train_loss_lm'], 'b-', label='LM Loss', linewidth=2)
axes[0, 0].plot(epochs, history['train_loss_energy'], 'r--', label='Energy Loss', linewidth=2)
axes[0, 0].axvline(x=PHYSICS_START_EPOCH, color='gray', linestyle=':', label='Physics Start')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss Components')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Total loss
axes[0, 1].plot(epochs, history['train_loss'], 'purple', linewidth=2, marker='o')
axes[0, 1].axvline(x=PHYSICS_START_EPOCH, color='gray', linestyle=':', label='Physics Start')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Total Loss')
axes[0, 1].set_title('Total Training Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Validation perplexity
axes[1, 0].plot(epochs, history['val_perplexity'], 'green', linewidth=2, marker='s')
axes[1, 0].axvline(x=PHYSICS_START_EPOCH, color='gray', linestyle=':', label='Physics Start')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Perplexity')
axes[1, 0].set_title('Validation Perplexity')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Training time per epoch
axes[1, 1].bar(epochs, history['epoch_time'], color='orange', alpha=0.7)
axes[1, 1].axvline(x=PHYSICS_START_EPOCH, color='gray', linestyle=':', label='Physics Start')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Time (seconds)')
axes[1, 1].set_title('Training Time per Epoch')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('physics_informed_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Compare to Baseline

In [None]:
# Final results
final_val_ppl = history['val_perplexity'][-1]
baseline_ppl = 479  # From Step 2 Phase 2 results (corrected)

print("\nFinal Results:")
print("=" * 80)
print(f"Final Validation Perplexity: {final_val_ppl:.2f}")
print(f"Baseline Perplexity (Phase 2): {baseline_ppl:.2f}")
print(f"Relative Difference: {(final_val_ppl - baseline_ppl) / baseline_ppl * 100:+.1f}%")

# Check if within 30% threshold
threshold = 0.30
within_threshold = abs(final_val_ppl - baseline_ppl) / baseline_ppl <= threshold

if within_threshold:
    print(f"\n✓ SUCCESS: Perplexity within {threshold*100:.0f}% of baseline")
else:
    print(f"\n✗ WARNING: Perplexity exceeds {threshold*100:.0f}% threshold")

# Energy conservation check
if physics_start_idx < len(history['energy_drift']):
    avg_drift = np.mean(history['energy_drift'][physics_start_idx:])
    if avg_drift < ENERGY_TARGET_DRIFT:
        print(f"✓ Energy drift controlled: {avg_drift:.4f} < {ENERGY_TARGET_DRIFT}")
    else:
        print(f"⚠ Energy drift above target: {avg_drift:.4f} > {ENERGY_TARGET_DRIFT}")

# Loss convergence
if history['train_loss'][-1] < history['train_loss'][0]:
    print("✓ Training loss decreased")
else:
    print("✗ Training loss did not decrease")

## 9. Summary

In [None]:
print("\n" + "=" * 80)
print("STEP 2 PHASE 3: PHYSICS-INFORMED LEARNING - SUMMARY")
print("=" * 80)

print(f"\n✓ Training completed: {NUM_EPOCHS} epochs")
print(f"✓ Physics constraints started at epoch {PHYSICS_START_EPOCH + 1}")
print(f"✓ Final validation perplexity: {final_val_ppl:.2f}")
print(f"✓ Energy conservation implemented")

print(f"\nModel Configuration:")
print(f"  Parameters: {total_params:,}")
print(f"  Layers: {N_LAYERS}")
print(f"  Energy conservation: Enabled")

print(f"\nTraining Strategy:")
print(f"  Warmup epochs: {PHYSICS_START_EPOCH} (LM stabilization)")
print(f"  Physics epochs: {NUM_EPOCHS - PHYSICS_START_EPOCH} (LM + Energy)")
print(f"  Lambda energy init: {LAMBDA_ENERGY_INIT} (conservative)")
print(f"  Energy target drift: {ENERGY_TARGET_DRIFT}")

if physics_start_idx < len(history['energy_drift']):
    avg_drift = np.mean(history['energy_drift'][physics_start_idx:])
    print(f"\nEnergy Conservation:")
    print(f"  Average drift: {avg_drift:.4f}")
    print(f"  Target drift: {ENERGY_TARGET_DRIFT}")
    print(f"  Status: {'✓ Controlled' if avg_drift < ENERGY_TARGET_DRIFT else '⚠ Above target'}")

print(f"\nNext Steps:")
print(f"  - Analyze Hamiltonian structure preservation")
print(f"  - Test symplectic integration")
print(f"  - Benchmark computational efficiency")
print(f"  - Proceed to full-scale training")

print("\n" + "=" * 80)