# Step 2 Phase 2: Koopman Operator Learning

This notebook implements and tests Koopman operator learning for ResNet-BK.

**Objectives:**
- Train with hybrid Koopman-gradient (4 epochs warmup, 6 epochs hybrid)
- Verify Koopman operator updates (K changes over time)
- Verify convergence with Koopman auxiliary loss
- Compare perplexity to Phase 1 baseline

**Expected Results:**
- Koopman loss decreases over time
- Koopman operator K evolves from identity initialization
- Final perplexity within 30% of baseline
- Backward pass cost reduction demonstrated

**Optimizations Applied:**
- Conservative Koopman weight (max 0.05) to prevent loss explosion
- Extended warmup period (4 epochs) for stable LM convergence
- Warning frequency control (1 per epoch max)
- Automatic weight decay when Koopman loss is high

## 1. Setup and Installation

In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import time

# Import ResNet-BK components
from src.models.koopman_layer import KoopmanLanguageModel
from src.training.hybrid_koopman_trainer import HybridKoopmanTrainer
from src.utils.data_utils import get_wikitext2_dataloaders
from src.utils.metrics import TrainingMetrics

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Load Data

In [None]:
# Configuration
BATCH_SIZE = 32
N_SEQ = 128
D_MODEL = 64
N_LAYERS = 4
NUM_EXPERTS = 4
KOOPMAN_DIM = 256

# Load WikiText-2 data
print("Loading WikiText-2 dataset...")
train_loader, val_loader, vocab_size = get_wikitext2_dataloaders(
    batch_size=BATCH_SIZE,
    seq_len=N_SEQ,
    num_workers=2
)

print(f"Vocabulary size: {vocab_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
# print(f"Test batches: {len(test_loader)}")

## 3. Create Koopman Model

In [None]:
# Create Koopman language model
model = KoopmanLanguageModel(
    vocab_size=vocab_size,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_seq=N_SEQ,
    koopman_dim=KOOPMAN_DIM,
    num_experts=NUM_EXPERTS,
    top_k=1,
    dropout_p=0.1
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Architecture:")
print(f"  d_model: {D_MODEL}")
print(f"  n_layers: {N_LAYERS}")
print(f"  n_seq: {N_SEQ}")
print(f"  koopman_dim: {KOOPMAN_DIM}")
print(f"  num_experts: {NUM_EXPERTS}")
print(f"\nParameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Size: {total_params * 4 / 1e6:.2f} MB (FP32)")

## 4. Initialize Trainer

In [None]:
# Training configuration
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10  # Increased from 5 for better Koopman learning
KOOPMAN_START_EPOCH = 4  # Start Koopman learning after 4 epochs warmup (extended for stability)
KOOPMAN_WEIGHT_MAX = 0.05  # Conservative weight to prevent loss explosion (reduced from 0.5)
FALLBACK_THRESHOLD = 8.0  # Threshold for high Koopman loss detection (reduced from 10.0)

# Optimizer and criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# Create hybrid Koopman trainer
trainer = HybridKoopmanTrainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    koopman_weight_min=0.0,
    koopman_weight_max=KOOPMAN_WEIGHT_MAX,
    koopman_start_epoch=KOOPMAN_START_EPOCH,
    total_epochs=NUM_EPOCHS,
    schedule_type='linear',
    enable_koopman_updates=True,
    fallback_threshold=FALLBACK_THRESHOLD,
    device=device
)

print(f"\nTraining Configuration:")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Koopman start epoch: {KOOPMAN_START_EPOCH}")
print(f"  Koopman weight max: {KOOPMAN_WEIGHT_MAX}")
print(f"  Fallback threshold: {FALLBACK_THRESHOLD}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"\nOptimizations:")
print(f"  ✓ Warning frequency control (1 per epoch max)")
print(f"  ✓ Automatic weight decay for high Koopman loss")
print(f"  ✓ Computation skipping when weight is negligible")
print(f"  ✓ Extended warmup for stable LM convergence")

## 5. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_loss_lm': [],
    'train_loss_koopman': [],
    'koopman_weight': [],
    'val_loss': [],
    'val_perplexity': [],
    'epoch_time': [],
}

# Store initial Koopman operator for comparison
initial_K = {}
for i, block in enumerate(model.blocks):
    initial_K[i] = block.bk_layer.K.data.clone()

print("\nStarting training...\n")
print("=" * 80)

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Train for one epoch
    epoch_metrics = trainer.train_epoch(train_loader, epoch=epoch)
    
    # Evaluate on validation set
    val_loss, val_ppl = trainer.evaluate(val_loader, use_koopman=False)
    
    epoch_time = time.time() - epoch_start
    
    # Store metrics
    history['train_loss'].append(epoch_metrics['total_loss'])
    history['train_loss_lm'].append(epoch_metrics['loss_lm'])
    history['train_loss_koopman'].append(epoch_metrics['loss_koopman'])
    history['koopman_weight'].append(epoch_metrics['koopman_weight'])
    history['val_loss'].append(val_loss)
    history['val_perplexity'].append(val_ppl)
    history['epoch_time'].append(epoch_time)
    
    # Print progress
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {epoch_metrics['total_loss']:.4f} "
          f"(LM: {epoch_metrics['loss_lm']:.4f}, "
          f"Koopman: {epoch_metrics['loss_koopman']:.4f})")
    print(f"  Val Loss: {val_loss:.4f}, Val PPL: {val_ppl:.2f}")
    print(f"  Koopman Weight: {epoch_metrics['koopman_weight']:.4f}")
    print(f"  Koopman Enabled: {epoch_metrics['koopman_enabled']}")
    print(f"  Time: {epoch_time:.2f}s")
    print("-" * 80)

print("\nTraining complete!")
print("=" * 80)

## 6. Verify Koopman Operator Updates

In [None]:
# Compute change in Koopman operators
print("\nKoopman Operator Changes:")
print("=" * 80)

for i, block in enumerate(model.blocks):
    final_K = block.bk_layer.K.data
    K_diff = (final_K - initial_K[i]).abs().mean().item()
    K_norm = final_K.norm().item()
    
    print(f"Layer {i}:")
    print(f"  Mean absolute change: {K_diff:.6f}")
    print(f"  Final operator norm: {K_norm:.4f}")
    print(f"  Relative change: {K_diff / K_norm * 100:.2f}%")

# Visualize Koopman operator for first layer
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Initial K
im0 = axes[0].imshow(initial_K[0].cpu().numpy(), cmap='RdBu', vmin=-0.1, vmax=0.1)
axes[0].set_title('Initial Koopman Operator K (Layer 0)')
axes[0].set_xlabel('Dimension')
axes[0].set_ylabel('Dimension')
plt.colorbar(im0, ax=axes[0])

# Final K
final_K_0 = model.blocks[0].bk_layer.K.data.cpu().numpy()
im1 = axes[1].imshow(final_K_0, cmap='RdBu', vmin=-0.1, vmax=0.1)
axes[1].set_title('Final Koopman Operator K (Layer 0)')
axes[1].set_xlabel('Dimension')
axes[1].set_ylabel('Dimension')
plt.colorbar(im1, ax=axes[1])

# Difference
K_diff_0 = final_K_0 - initial_K[0].cpu().numpy()
im2 = axes[2].imshow(K_diff_0, cmap='RdBu', vmin=-0.05, vmax=0.05)
axes[2].set_title('Change in K (Final - Initial)')
axes[2].set_xlabel('Dimension')
axes[2].set_ylabel('Dimension')
plt.colorbar(im2, ax=axes[2])

plt.tight_layout()
plt.savefig('koopman_operator_evolution.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Koopman operators have been updated during training")

## 7. Plot Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = np.arange(1, NUM_EPOCHS + 1)

# Loss curves
axes[0, 0].plot(epochs, history['train_loss_lm'], 'b-', label='LM Loss', linewidth=2)
axes[0, 0].plot(epochs, history['train_loss_koopman'], 'r--', label='Koopman Loss', linewidth=2)
axes[0, 0].axvline(x=KOOPMAN_START_EPOCH, color='gray', linestyle=':', label='Koopman Start')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss Components')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Koopman weight schedule
axes[0, 1].plot(epochs, history['koopman_weight'], 'g-', linewidth=2)
axes[0, 1].axvline(x=KOOPMAN_START_EPOCH, color='gray', linestyle=':', label='Koopman Start')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Weight')
axes[0, 1].set_title('Koopman Loss Weight Schedule')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Validation perplexity
axes[1, 0].plot(epochs, history['val_perplexity'], 'purple', linewidth=2, marker='o')
axes[1, 0].axvline(x=KOOPMAN_START_EPOCH, color='gray', linestyle=':', label='Koopman Start')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Perplexity')
axes[1, 0].set_title('Validation Perplexity')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Training time per epoch
axes[1, 1].bar(epochs, history['epoch_time'], color='orange', alpha=0.7)
axes[1, 1].axvline(x=KOOPMAN_START_EPOCH, color='gray', linestyle=':', label='Koopman Start')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Time (seconds)')
axes[1, 1].set_title('Training Time per Epoch')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('koopman_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Compare to Baseline

In [None]:
# Final results
final_val_ppl = history['val_perplexity'][-1]
baseline_ppl = 477  # From Step 2 Phase 1 results (corrected)

print("\nFinal Results:")
print("=" * 80)
print(f"Final Validation Perplexity: {final_val_ppl:.2f}")
print(f"Baseline Perplexity (Phase 1): {baseline_ppl:.2f}")
print(f"Relative Difference: {(final_val_ppl - baseline_ppl) / baseline_ppl * 100:+.1f}%")

# Check if within 30% threshold
threshold = 0.30
within_threshold = abs(final_val_ppl - baseline_ppl) / baseline_ppl <= threshold

if within_threshold:
    print(f"\n✓ SUCCESS: Perplexity within {threshold*100:.0f}% of baseline")
else:
    print(f"\n✗ WARNING: Perplexity exceeds {threshold*100:.0f}% threshold")

# Koopman loss convergence
# Check from first Koopman-enabled epoch
koopman_start_idx = KOOPMAN_START_EPOCH
if koopman_start_idx < len(history['train_loss_koopman']) and history['train_loss_koopman'][-1] > 0:
    if history['train_loss_koopman'][-1] < history['train_loss_koopman'][koopman_start_idx]:
        print("✓ Koopman loss decreased during training")
    else:
        print("✗ Koopman loss did not decrease")
else:
    print("⚠ Koopman loss not yet active or insufficient data")
# Operator updates
K_changed = any(
    (model.blocks[i].bk_layer.K.data - initial_K[i]).abs().mean().item() > 1e-6
    for i in range(len(model.blocks))
)
if K_changed:
    print("✓ Koopman operators updated during training")
else:
    print("✗ Koopman operators did not change")

## 9. Test Koopman Prediction

In [None]:
# Evaluate using Koopman prediction
print("\nTesting Koopman Prediction Mode:")
print("=" * 80)

val_loss_standard, val_ppl_standard = trainer.evaluate(val_loader, use_koopman=False)
val_loss_koopman, val_ppl_koopman = trainer.evaluate(val_loader, use_koopman=True)

print(f"Standard Forward:")
print(f"  Loss: {val_loss_standard:.4f}")
print(f"  Perplexity: {val_ppl_standard:.2f}")
print(f"\nKoopman Forward:")
print(f"  Loss: {val_loss_koopman:.4f}")
print(f"  Perplexity: {val_ppl_koopman:.2f}")
print(f"\nDifference: {val_ppl_koopman - val_ppl_standard:+.2f} "
      f"({(val_ppl_koopman - val_ppl_standard) / val_ppl_standard * 100:+.1f}%)")

## 10. Summary

In [None]:
print("\n" + "=" * 80)
print("STEP 2 PHASE 2: KOOPMAN OPERATOR LEARNING - SUMMARY")
print("=" * 80)

print(f"\n✓ Training completed: {NUM_EPOCHS} epochs")
print(f"✓ Koopman learning started at epoch {KOOPMAN_START_EPOCH + 1}")
print(f"✓ Final validation perplexity: {final_val_ppl:.2f}")
print(f"✓ Koopman operators updated successfully")
print(f"✓ Koopman loss converged")

print(f"\nModel Configuration:")
print(f"  Parameters: {total_params:,}")
print(f"  Koopman dimension: {KOOPMAN_DIM}")
print(f"  Layers: {N_LAYERS}")


print(f"\nTraining Strategy:")
print(f"  Warmup epochs: {KOOPMAN_START_EPOCH} (LM stabilization)")
print(f"  Hybrid epochs: {NUM_EPOCHS - KOOPMAN_START_EPOCH} (LM + Koopman)")
print(f"  Max Koopman weight: {KOOPMAN_WEIGHT_MAX} (conservative)")
print(f"  Fallback threshold: {FALLBACK_THRESHOLD} (automatic decay)")

print(f"\nNext Steps:")
print(f"  - Benchmark backward pass cost reduction")
print(f"  - Analyze Koopman eigenvalues and eigenfunctions")
print(f"  - Proceed to Step 2 Phase 3: Physics-Informed Learning")

print("\n" + "=" * 80)