# Train NURBS Transformer Surrogate Model

This notebook generates training data using Meep simulation and trains the Transformer surrogate model.

**Note**: Make sure to activate the parallel Meep Python environment before running.

In [None]:
# Import required libraries
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# Add current directory to path
sys.path.insert(0, os.getcwd())

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Reload modules to get latest changes
import importlib

# Import and reload nurbs_atoms_data
import nurbs_atoms_data
importlib.reload(nurbs_atoms_data)
from nurbs_atoms_data import Simulation

# Import and reload transformer_nurbs_model
import transformer_nurbs_model
importlib.reload(transformer_nurbs_model)
from transformer_nurbs_model import (
    NURBSTransformerModel,
    NURBSDataset,
    normalize_control_points,
    denormalize_control_points,
    normalize_targets,
    denormalize_targets
)

print("Modules imported and reloaded successfully!")

## 1. Generate Training Data

Choose one of two methods:
- **Method A**: Use Meep simulation (slow but realistic)
- **Method B**: Use synthetic data (fast, for testing model architecture)

In [None]:
# Configuration
USE_REAL_SIMULATION = False  # Set to True to use Meep simulation
N_SAMPLES = 500  # Number of training samples
WAVELENGTH = 550e-9  # 550 nm
PERTURBATION_RANGE = 0.04  # Control point perturbation range

# Base control points (default NURBS shape)
BASE_CONTROL_POINTS = np.array([
    (0.18, 0), (0.16, 0.16), (0, 0.18), (-0.16, 0.16),
    (-0.18, 0), (-0.16, -0.16), (0, -0.16), (0.16, -0.16)
])

In [None]:
def generate_synthetic_data(n_samples, base_points, perturbation_range=0.04):
    """
    Generate synthetic training data (for testing model architecture).
    Creates a physics-inspired relationship between control points and optical response.
    """
    print(f"Generating {n_samples} synthetic samples...")
    
    control_points_list = []
    targets_list = []
    
    for i in tqdm(range(n_samples)):
        # Generate perturbed control points
        perturbation = np.random.uniform(-perturbation_range, perturbation_range, (8, 2))
        cp = base_points + perturbation
        cp = np.clip(cp, -0.22, 0.22)
        
        # Calculate synthetic optical response (physics-inspired)
        # Area affects transmittance, asymmetry affects phase
        area = 0.5 * np.abs(np.sum(cp[:-1, 0] * np.roll(cp[:-1, 1], -1) - 
                                    np.roll(cp[:-1, 0], -1) * cp[:-1, 1]))
        centroid_x = np.mean(cp[:, 0])
        centroid_y = np.mean(cp[:, 1])
        asymmetry = np.sqrt(centroid_x**2 + centroid_y**2)
        
        # Map to phase and transmittance with some noise
        phase = np.arctan2(centroid_y, centroid_x) + np.random.normal(0, 0.1)
        phase = np.clip(phase, -np.pi, np.pi)
        
        transmittance = 0.3 + 0.6 * (area / 0.05) + np.random.normal(0, 0.05)
        transmittance = np.clip(transmittance, 0.1, 0.99)
        
        control_points_list.append(cp)
        targets_list.append([phase, transmittance])
    
    return np.array(control_points_list), np.array(targets_list)


def generate_simulation_data(n_samples, base_points, wavelength, perturbation_range=0.04):
    """
    Generate training data using Meep simulation.
    """
    print(f"Generating {n_samples} samples using Meep simulation...")
    print("This will take a while...")
    
    control_points_list = []
    targets_list = []
    failed_count = 0
    
    for i in tqdm(range(n_samples)):
        # Generate perturbed control points
        perturbation = np.random.uniform(-perturbation_range, perturbation_range, (8, 2))
        cp = base_points + perturbation
        cp = np.clip(cp, -0.22, 0.22)
        
        try:
            # Run simulation
            sim = Simulation(control_points=cp)
            transmittance, phase = sim.run_forward(
                wavelength_start=wavelength,
                wavelength_stop=wavelength
            )
            
            control_points_list.append(cp)
            targets_list.append([phase, transmittance])
            
        except Exception as e:
            failed_count += 1
            if failed_count <= 5:
                print(f"\nSimulation {i} failed: {e}")
            # Use synthetic fallback
            control_points_list.append(cp)
            phase = np.random.uniform(-np.pi, np.pi)
            trans = np.random.uniform(0.3, 0.9)
            targets_list.append([phase, trans])
    
    if failed_count > 0:
        print(f"\nTotal failed simulations: {failed_count}/{n_samples}")
    
    return np.array(control_points_list), np.array(targets_list)

In [None]:
# Generate data based on configuration
if USE_REAL_SIMULATION:
    control_points, targets = generate_simulation_data(
        N_SAMPLES, BASE_CONTROL_POINTS, WAVELENGTH, PERTURBATION_RANGE
    )
else:
    control_points, targets = generate_synthetic_data(
        N_SAMPLES, BASE_CONTROL_POINTS, PERTURBATION_RANGE
    )

print(f"\nData shape: control_points={control_points.shape}, targets={targets.shape}")
print(f"Phase range: [{targets[:, 0].min():.3f}, {targets[:, 0].max():.3f}]")
print(f"Transmittance range: [{targets[:, 1].min():.3f}, {targets[:, 1].max():.3f}]")

In [None]:
# Save raw data
np.save('training_control_points.npy', control_points)
np.save('training_targets.npy', targets)
print("Raw data saved to training_control_points.npy and training_targets.npy")

## 2. Prepare Data for Training

In [None]:
# Normalize data
normalized_control_points = normalize_control_points(control_points, min_val=-0.25, max_val=0.25)
normalized_targets = normalize_targets(targets)

print(f"Normalized control points range: [{normalized_control_points.min():.3f}, {normalized_control_points.max():.3f}]")
print(f"Normalized targets range: [{normalized_targets.min():.3f}, {normalized_targets.max():.3f}]")

In [None]:
# Split into train/validation/test sets
X_train, X_temp, y_train, y_temp = train_test_split(
    normalized_control_points, normalized_targets, test_size=0.3, random_state=42
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)

print(f"Training set: {len(X_train)} samples")
print(f"Validation set: {len(X_val)} samples")
print(f"Test set: {len(X_test)} samples")

In [None]:
# Create data loaders
BATCH_SIZE = 32

train_dataset = NURBSDataset(X_train, y_train)
val_dataset = NURBSDataset(X_val, y_val)
test_dataset = NURBSDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 3. Create and Train Model

In [None]:
# Create model
model = NURBSTransformerModel(
    n_control_points=8,
    d_model=128,
    nhead=8,
    num_layers=4,
    d_ff=256,
    dropout=0.1
)

print(f"Model created on device: {model.device}")
print(f"Total parameters: {sum(p.numel() for p in model.model.parameters()):,}")

In [None]:
# Test forward pass
test_input = torch.randn(2, 8, 2)  # batch_size=2, n_points=8, dim=2
test_input = test_input.to(model.device)

model.model.eval()
with torch.no_grad():
    test_output = model.model(test_input)
    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {test_output.shape}")
    print(f"Output values: {test_output}")
print("\nForward pass test PASSED!")

In [None]:
# Train model
EPOCHS = 100
MODEL_SAVE_PATH = "nurbs_transformer_model.pth"

print(f"Training for {EPOCHS} epochs...")
print("="*50)

model.train(train_loader, val_loader, epochs=EPOCHS, save_path=MODEL_SAVE_PATH)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(model.train_losses, label='Training Loss', linewidth=2)
axes[0].plot(model.val_losses, label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (MSE)')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')

# Loss curves (linear scale, last 50%)
start_idx = len(model.train_losses) // 2
axes[1].plot(range(start_idx, len(model.train_losses)), 
             model.train_losses[start_idx:], label='Training Loss', linewidth=2)
axes[1].plot(range(start_idx, len(model.val_losses)), 
             model.val_losses[start_idx:], label='Validation Loss', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss (MSE)')
axes[1].set_title('Loss (Last 50% of Training)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("Training curves saved to training_curves.png")

## 4. Evaluate Model

In [None]:
# Evaluate on test set
avg_loss, avg_phase_error, avg_trans_error = model.evaluate(test_loader)

print(f"\n=== Test Set Evaluation ===")
print(f"Average Loss (MSE): {avg_loss:.6f}")
print(f"Average Phase Error (normalized): {avg_phase_error:.6f}")
print(f"Average Transmittance Error (normalized): {avg_trans_error:.6f}")

In [None]:
# Make predictions on test set and compare
model.model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for batch_points, batch_targets in test_loader:
        batch_points = batch_points.to(model.device)
        outputs = model.model(batch_points)
        all_predictions.extend(outputs.cpu().numpy())
        all_targets.extend(batch_targets.numpy())

predictions = np.array(all_predictions)
actuals = np.array(all_targets)

# Denormalize
pred_denorm = denormalize_targets(predictions)
actual_denorm = denormalize_targets(actuals)

print(f"Predictions shape: {predictions.shape}")
print(f"Actuals shape: {actuals.shape}")

In [None]:
# Plot prediction vs actual
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Phase
axes[0].scatter(actual_denorm[:, 0], pred_denorm[:, 0], alpha=0.6, s=30)
axes[0].plot([-np.pi, np.pi], [-np.pi, np.pi], 'r--', linewidth=2, label='Perfect prediction')
axes[0].set_xlabel('Actual Phase (rad)')
axes[0].set_ylabel('Predicted Phase (rad)')
axes[0].set_title('Phase: Predicted vs Actual')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_aspect('equal')

# Transmittance
axes[1].scatter(actual_denorm[:, 1], pred_denorm[:, 1], alpha=0.6, s=30)
axes[1].plot([0, 1], [0, 1], 'r--', linewidth=2, label='Perfect prediction')
axes[1].set_xlabel('Actual Transmittance')
axes[1].set_ylabel('Predicted Transmittance')
axes[1].set_title('Transmittance: Predicted vs Actual')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_aspect('equal')

plt.tight_layout()
plt.savefig('prediction_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("Prediction comparison saved to prediction_comparison.png")

In [None]:
# Calculate R² scores
from sklearn.metrics import r2_score, mean_absolute_error

phase_r2 = r2_score(actual_denorm[:, 0], pred_denorm[:, 0])
trans_r2 = r2_score(actual_denorm[:, 1], pred_denorm[:, 1])

phase_mae = mean_absolute_error(actual_denorm[:, 0], pred_denorm[:, 0])
trans_mae = mean_absolute_error(actual_denorm[:, 1], pred_denorm[:, 1])

print(f"\n=== Final Model Performance ===")
print(f"Phase R² Score: {phase_r2:.4f}")
print(f"Phase MAE: {phase_mae:.4f} rad ({np.degrees(phase_mae):.2f}°)")
print(f"\nTransmittance R² Score: {trans_r2:.4f}")
print(f"Transmittance MAE: {trans_mae:.4f}")

## 5. Test Single Prediction

In [None]:
# Test prediction with a single sample
test_cp = BASE_CONTROL_POINTS + np.random.uniform(-0.02, 0.02, (8, 2))
test_cp = np.clip(test_cp, -0.22, 0.22)

# Normalize
test_cp_norm = normalize_control_points(test_cp.reshape(1, 8, 2), min_val=-0.25, max_val=0.25)

# Predict
prediction = model.predict(test_cp_norm)
pred_denorm = denormalize_targets(prediction)

print(f"Test control points:\n{test_cp}")
print(f"\nPredicted phase: {pred_denorm[0, 0]:.4f} rad ({np.degrees(pred_denorm[0, 0]):.2f}°)")
print(f"Predicted transmittance: {pred_denorm[0, 1]:.4f}")

In [None]:
print("\n" + "="*50)
print("Training Complete!")
print("="*50)
print(f"\nModel saved to: {MODEL_SAVE_PATH}")
print(f"Training data saved to: training_control_points.npy, training_targets.npy")
print(f"\nTo use the model for inference:")
print("  model = NURBSTransformerModel(...)")
print(f"  model.load_model('{MODEL_SAVE_PATH}')")
print("  prediction = model.predict(control_points)")