# SC-SINDy Fair Evaluation (No Oracle)

This notebook demonstrates **fair evaluation** of Structure-Constrained SINDy using **learned** structure predictions instead of oracle (ground truth) access.

## Key Differences from Oracle Evaluation

| Aspect | Oracle (Unfair) | Learned (Fair) |
|--------|-----------------|----------------|
| Structure source | Ground truth | Trained network |
| Train/test split | None | System-level split |
| Results | Overly optimistic | Realistic |
| Publishable | No | Yes |

**Run this notebook in Google Colab to evaluate SC-SINDy properly.**

## 1. Setup

In [None]:
# Install SC-SINDy (uncomment for Colab)
# !pip install git+https://github.com/yourusername/structure-constrained-sindy.git[torch,viz] -q

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Type
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

print("Setup complete!")

In [None]:
# SC-SINDy imports
from sc_sindy import (
    sindy_stls,
    sindy_structure_constrained,
    build_library_2d,
    compute_derivatives_finite_diff,
    compute_structure_metrics,
    format_equation,
)

# Systems
from sc_sindy.systems import (
    VanDerPol, DuffingOscillator, DampedHarmonicOscillator,
    LotkaVolterra, SelkovGlycolysis, CoupledBrusselator,
    DynamicalSystem,
)

# Network components
from sc_sindy.network import (
    StructureNetwork,
    StructurePredictor,
    train_structure_network,
    train_structure_network_with_split,
    extract_trajectory_features,
)

# Evaluation framework
from sc_sindy.evaluation import (
    get_split,
    print_split_info,
    SCSINDyEvaluator,
    TRAIN_SYSTEMS_2D,
    TEST_SYSTEMS_2D,
)

print("Imports successful!")

## 2. Train/Test Split

The key to fair evaluation is splitting at the **system level**, not the trajectory level.

- **Train systems**: Used to train the structure network
- **Test systems**: Held out completely - never seen during training

In [None]:
# Show the train/test split
print_split_info()

# Get the 2D split
train_systems, test_systems = get_split(dimension=2)

print("\n" + "="*50)
print("We will:")
print(f"  1. Train network on: {[s.__name__ for s in train_systems]}")
print(f"  2. Evaluate on: {[s.__name__ for s in test_systems]}")
print("="*50)

## 3. Generate Training Data

Generate trajectories ONLY from training systems.

In [None]:
def generate_training_data(
    system_classes: List[Type[DynamicalSystem]],
    n_trajectories: int = 100,
    noise_levels: List[float] = None,
):
    """Generate training data from system classes."""
    if noise_levels is None:
        noise_levels = [0.0, 0.05, 0.10]
    
    # Get library term names
    dummy_x = np.random.randn(10, 2)
    _, term_names = build_library_2d(dummy_x)
    
    t = np.linspace(0, 50, 5000)
    dt = t[1] - t[0]
    
    all_data = []
    
    for SystemClass in system_classes:
        system = SystemClass()
        print(f"  Generating data for {system.name}...")
        
        # Get true structure for this system
        true_structure = system.get_true_structure(term_names)
        structure_flat = true_structure.flatten().astype(float)
        
        n_per_system = n_trajectories // len(system_classes)
        success_count = 0
        
        for _ in range(n_per_system * 2):  # Try extra to handle failures
            if success_count >= n_per_system:
                break
                
            x0 = np.random.randn(2) * 2
            noise = np.random.choice(noise_levels)
            
            try:
                x = system.generate_trajectory(x0, t, noise_level=noise)
                
                if np.any(np.isnan(x)) or np.any(np.isinf(x)):
                    continue
                
                # Trim edges
                x_trim = x[100:-100]
                
                # Extract features
                features = extract_trajectory_features(x_trim, dt)
                
                if np.any(np.isnan(features)) or np.any(np.isinf(features)):
                    continue
                
                all_data.append((features, structure_flat))
                success_count += 1
                
            except Exception:
                continue
        
        print(f"    Generated {success_count} trajectories")
    
    return all_data, term_names

# Generate training data from TRAIN systems only
print("Generating training data...")
print("(This uses ONLY training systems - test systems are held out)\n")

train_data, term_names = generate_training_data(
    TRAIN_SYSTEMS_2D,
    n_trajectories=200,
)

print(f"\nTotal training samples: {len(train_data)}")
print(f"Feature dimension: {train_data[0][0].shape[0]}")
print(f"Output dimension: {train_data[0][1].shape[0]}")

## 4. Train Structure Network

Train the network to predict equation structure from trajectory features.

In [None]:
# Compute normalization statistics
all_features = np.array([f for f, _ in train_data])
feature_mean = np.mean(all_features, axis=0)
feature_std = np.std(all_features, axis=0)
feature_std = np.where(feature_std < 1e-10, 1.0, feature_std)

# Normalize training data
normalized_data = [
    ((features - feature_mean) / feature_std, labels)
    for features, labels in train_data
]

print("Training structure network...")
print("(This learns to predict structure WITHOUT seeing test systems)\n")

model, history = train_structure_network(
    normalized_data,
    epochs=100,
    batch_size=32,
    lr=0.001,
    verbose=True,
)

print(f"\nTraining complete!")
print(f"Final validation loss: {history['val_loss'][-1]:.4f}")

In [None]:
# Plot training history
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train', alpha=0.8)
plt.plot(history['val_loss'], label='Validation', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Binary Cross-Entropy Loss')
plt.legend()
plt.title('Training History')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history['val_loss'], 'orange')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss (zoom)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Create predictor with normalization
n_vars = 2
n_terms = len(term_names)

predictor = StructurePredictor(
    model=model,
    n_vars=n_vars,
    n_terms=n_terms,
    feature_mean=feature_mean,
    feature_std=feature_std,
)

print("StructurePredictor created!")
print(f"  n_vars: {n_vars}")
print(f"  n_terms: {n_terms}")
print(f"  term_names: {term_names}")

## 5. Evaluate on Test Systems

Now we evaluate SC-SINDy using the **learned** network predictions on systems that were **never seen during training**.

In [None]:
def evaluate_fair(
    test_system_classes: List[Type[DynamicalSystem]],
    predictor: StructurePredictor,
    term_names: List[str],
    n_trials: int = 20,
    noise_levels: List[float] = None,
):
    """Evaluate SC-SINDy with LEARNED predictions on test systems."""
    if noise_levels is None:
        noise_levels = [0.0, 0.05, 0.10]
    
    t = np.linspace(0, 50, 5000)
    dt = t[1] - t[0]
    
    results = []
    
    for SystemClass in test_system_classes:
        system = SystemClass()
        print(f"\nEvaluating on {system.name} (TEST SYSTEM - never seen in training)")
        
        true_xi = system.get_true_coefficients(term_names)
        true_structure = np.abs(true_xi) > 1e-6
        
        for noise in noise_levels:
            std_f1s, sc_f1s, net_f1s = [], [], []
            
            for _ in range(n_trials):
                x0 = np.random.randn(2) * 2
                
                try:
                    x = system.generate_trajectory(x0, t, noise_level=noise)
                    if np.any(np.isnan(x)) or np.any(np.isinf(x)):
                        continue
                    
                    x_trim = x[100:-100]
                    x_dot = compute_derivatives_finite_diff(x_trim, dt)
                    Theta, _ = build_library_2d(x_trim)
                    
                    # Standard SINDy
                    xi_std, _ = sindy_stls(Theta, x_dot, threshold=0.1)
                    
                    # SC-SINDy with LEARNED predictions (NOT oracle!)
                    network_probs = predictor.predict_from_trajectory(x_trim, dt)
                    xi_sc, _ = sindy_structure_constrained(
                        Theta, x_dot, network_probs, structure_threshold=0.3
                    )
                    
                    # Metrics
                    metrics_std = compute_structure_metrics(xi_std, true_xi)
                    metrics_sc = compute_structure_metrics(xi_sc, true_xi)
                    
                    # Network prediction quality
                    pred_structure = network_probs > 0.5
                    net_metrics = compute_structure_metrics(
                        pred_structure.astype(float), 
                        true_structure.astype(float)
                    )
                    
                    std_f1s.append(metrics_std['f1'])
                    sc_f1s.append(metrics_sc['f1'])
                    net_f1s.append(net_metrics['f1'])
                    
                except Exception:
                    continue
            
            if std_f1s:
                results.append({
                    'system': system.name,
                    'noise': noise,
                    'std_f1_mean': np.mean(std_f1s),
                    'std_f1_std': np.std(std_f1s),
                    'sc_f1_mean': np.mean(sc_f1s),
                    'sc_f1_std': np.std(sc_f1s),
                    'net_f1_mean': np.mean(net_f1s),
                    'improvement': np.mean(sc_f1s) - np.mean(std_f1s),
                    'n_trials': len(std_f1s),
                })
                
                print(f"  Noise {noise:.2f}: Std F1={np.mean(std_f1s):.3f}, "
                      f"SC F1={np.mean(sc_f1s):.3f}, "
                      f"Net F1={np.mean(net_f1s):.3f}, "
                      f"Improve={np.mean(sc_f1s)-np.mean(std_f1s):+.3f}")
    
    return results

# Run fair evaluation
print("="*60)
print("FAIR EVALUATION (No Oracle Access)")
print("="*60)
print("\nTest systems have NEVER been seen during training!")

results = evaluate_fair(
    TEST_SYSTEMS_2D,
    predictor,
    term_names,
    n_trials=20,
    noise_levels=[0.0, 0.05, 0.10],
)

## 6. Results Summary

In [None]:
# Display results table
print("\n" + "="*70)
print("FAIR EVALUATION RESULTS SUMMARY")
print("="*70)
print(f"{'System':<25} {'Noise':<8} {'Std F1':<12} {'SC F1':<12} {'Net F1':<10} {'Improve':<10}")
print("-"*70)

for r in results:
    print(f"{r['system']:<25} {r['noise']:<8.2f} "
          f"{r['std_f1_mean']:.3f}+/-{r['std_f1_std']:.2f}  "
          f"{r['sc_f1_mean']:.3f}+/-{r['sc_f1_std']:.2f}  "
          f"{r['net_f1_mean']:.3f}      "
          f"{r['improvement']:+.3f}")

# Overall statistics
all_std = [r['std_f1_mean'] for r in results]
all_sc = [r['sc_f1_mean'] for r in results]
all_improve = [r['improvement'] for r in results]

print("-"*70)
print(f"{'OVERALL':<25} {'---':<8} "
      f"{np.mean(all_std):.3f}          "
      f"{np.mean(all_sc):.3f}          "
      f"{'---':<10} "
      f"{np.mean(all_improve):+.3f}")

print("\n" + "="*70)
print("KEY INSIGHT:")
if np.mean(all_improve) > 0:
    print(f"SC-SINDy improves F1 by {np.mean(all_improve):.3f} on average on UNSEEN systems!")
else:
    print(f"SC-SINDy shows {np.mean(all_improve):.3f} change - network may need more training data.")
print("="*70)

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart by system
systems = list(set(r['system'] for r in results))
x = np.arange(len(systems))
width = 0.35

std_means = [np.mean([r['std_f1_mean'] for r in results if r['system'] == s]) for s in systems]
sc_means = [np.mean([r['sc_f1_mean'] for r in results if r['system'] == s]) for s in systems]

axes[0].bar(x - width/2, std_means, width, label='Standard SINDy', color='steelblue', alpha=0.8)
axes[0].bar(x + width/2, sc_means, width, label='SC-SINDy (Learned)', color='darkorange', alpha=0.8)
axes[0].set_xlabel('Test System')
axes[0].set_ylabel('F1 Score')
axes[0].set_title('Performance on Held-Out Test Systems')
axes[0].set_xticks(x)
axes[0].set_xticklabels(systems, rotation=15)
axes[0].legend()
axes[0].set_ylim(0, 1.1)
axes[0].grid(True, alpha=0.3)

# Improvement by noise level
noise_levels = sorted(set(r['noise'] for r in results))
improvements = [np.mean([r['improvement'] for r in results if r['noise'] == n]) for n in noise_levels]

colors = ['green' if i > 0 else 'red' for i in improvements]
axes[1].bar(range(len(noise_levels)), improvements, color=colors, alpha=0.8)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=0.5)
axes[1].set_xlabel('Noise Level')
axes[1].set_ylabel('F1 Improvement (SC - Standard)')
axes[1].set_title('SC-SINDy Improvement by Noise Level')
axes[1].set_xticks(range(len(noise_levels)))
axes[1].set_xticklabels([f'{n:.0%}' for n in noise_levels])
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Network Prediction Analysis

Let's analyze what the network learned and how well it predicts structure on test systems.

In [None]:
# Analyze network predictions on test systems
print("Network Prediction Analysis on Test Systems")
print("="*60)

t = np.linspace(0, 50, 5000)
dt = t[1] - t[0]

for SystemClass in TEST_SYSTEMS_2D:
    system = SystemClass()
    true_structure = system.get_true_structure(term_names)
    
    # Generate a test trajectory
    x0 = np.array([1.0, 0.0])
    x = system.generate_trajectory(x0, t)
    x_trim = x[100:-100]
    
    # Get network prediction
    probs = predictor.predict_from_trajectory(x_trim, dt)
    
    print(f"\n{system.name}:")
    print(f"  True structure (eq1): {[term_names[i] for i in range(len(term_names)) if true_structure[0, i]]}")
    print(f"  True structure (eq2): {[term_names[i] for i in range(len(term_names)) if true_structure[1, i]]}")
    print(f"  Predicted probs (eq1): {probs[0].round(3)}")
    print(f"  Predicted probs (eq2): {probs[1].round(3)}")
    
    # Highlight predictions for true terms
    print(f"  Probs for TRUE terms (eq1): {probs[0][true_structure[0]].round(3)}")
    print(f"  Probs for TRUE terms (eq2): {probs[1][true_structure[1]].round(3)}")

## 8. Comparison: Oracle vs Learned

Let's compare the results with oracle access (cheating) vs learned predictions (fair).

In [None]:
# Compare oracle vs learned on one test system
system = TEST_SYSTEMS_2D[0]()
print(f"Comparing Oracle vs Learned on {system.name}")
print("="*60)

true_xi = system.get_true_coefficients(term_names)
true_structure = np.abs(true_xi) > 1e-6

oracle_f1s, learned_f1s, std_f1s = [], [], []

for _ in range(30):
    x0 = np.random.randn(2) * 2
    x = system.generate_trajectory(x0, t, noise_level=0.05)
    
    if np.any(np.isnan(x)):
        continue
    
    x_trim = x[100:-100]
    x_dot = compute_derivatives_finite_diff(x_trim, dt)
    Theta, _ = build_library_2d(x_trim)
    
    # Standard SINDy
    xi_std, _ = sindy_stls(Theta, x_dot, threshold=0.1)
    
    # Oracle (CHEATING - uses ground truth)
    oracle_probs = true_structure.astype(float) * 0.9 + 0.05
    xi_oracle, _ = sindy_structure_constrained(Theta, x_dot, oracle_probs, structure_threshold=0.3)
    
    # Learned (FAIR - uses network prediction)
    learned_probs = predictor.predict_from_trajectory(x_trim, dt)
    xi_learned, _ = sindy_structure_constrained(Theta, x_dot, learned_probs, structure_threshold=0.3)
    
    std_f1s.append(compute_structure_metrics(xi_std, true_xi)['f1'])
    oracle_f1s.append(compute_structure_metrics(xi_oracle, true_xi)['f1'])
    learned_f1s.append(compute_structure_metrics(xi_learned, true_xi)['f1'])

print(f"\nResults (30 trials, 5% noise):")
print(f"  Standard SINDy:     F1 = {np.mean(std_f1s):.3f} +/- {np.std(std_f1s):.3f}")
print(f"  SC-SINDy (Oracle):  F1 = {np.mean(oracle_f1s):.3f} +/- {np.std(oracle_f1s):.3f}  <- CHEATING!")
print(f"  SC-SINDy (Learned): F1 = {np.mean(learned_f1s):.3f} +/- {np.std(learned_f1s):.3f}  <- FAIR")
print(f"\nOracle improvement: {np.mean(oracle_f1s) - np.mean(std_f1s):+.3f} (not publishable)")
print(f"Learned improvement: {np.mean(learned_f1s) - np.mean(std_f1s):+.3f} (publishable!)")

## Summary

This notebook demonstrated **fair evaluation** of SC-SINDy:

1. **Train/test split at system level** - Test systems never seen during training
2. **Learned structure predictions** - Not oracle/ground truth
3. **Realistic performance** - Results are publishable

### Key Takeaways

- Oracle testing (using ground truth) gives overly optimistic results
- Fair evaluation with learned predictions shows realistic improvement
- The network must generalize to unseen systems to be useful
- More diverse training data = better generalization