# Example 1: Basic Fault Tolerance with Streaming Optimizer

This example demonstrates the basic usage of the streaming optimizer with
fault tolerance enabled (default behavior).

Features demonstrated:
- Automatic best parameter tracking
- NaN/Inf detection at three validation points
- Adaptive retry strategies for failed batches
- Success rate validation
- Detailed diagnostics

Run this example:
    python examples/streaming/01_basic_fault_tolerance.py


In [1]:
import jax.numpy as jnp
import numpy as np

from nlsq import StreamingConfig, StreamingOptimizer


def exponential_decay(x, a, b):
    """Exponential decay model: y = a * exp(-b * x)"""
    return a * jnp.exp(-b * x)


def main():
    print("=" * 70)
    print("Streaming Optimizer: Basic Fault Tolerance Example")
    print("=" * 70)
    print()

    # Generate synthetic data
    np.random.seed(42)
    n_samples = 10000
    x_data = np.linspace(0, 10, n_samples)
    true_a, true_b = 2.5, 0.3
    y_true = exponential_decay(x_data, true_a, true_b)
    y_data = y_true + 0.1 * np.random.randn(n_samples)

    print(f"Dataset: {n_samples} samples")
    print(f"True parameters: a={true_a}, b={true_b}")
    print()

    # Configure optimizer with fault tolerance (default)
    config = StreamingConfig(
        batch_size=100,
        max_epochs=10,
        learning_rate=0.001,
        # Fault tolerance settings (all defaults)
        enable_fault_tolerance=True,  # Enable fault tolerance features
        validate_numerics=True,  # Check for NaN/Inf
        min_success_rate=0.5,  # Require 50% batch success
        max_retries_per_batch=2,  # Max 2 retry attempts
        # Checkpoint settings
        checkpoint_dir="checkpoints",
        checkpoint_frequency=100,  # Save every 100 iterations
        enable_checkpoints=True,
    )

    print("Configuration:")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Max epochs: {config.max_epochs}")
    print(f"  Learning rate: {config.learning_rate}")
    print(f"  Fault tolerance: {config.enable_fault_tolerance}")
    print(f"  Validate numerics: {config.validate_numerics}")
    print(f"  Min success rate: {config.min_success_rate:.0%}")
    print(f"  Max retries per batch: {config.max_retries_per_batch}")
    print()

    # Create optimizer
    optimizer = StreamingOptimizer(config)

    # Initial guess (deliberately poor to show convergence)
    p0 = np.array([1.0, 0.1])
    print(f"Initial guess: a={p0[0]}, b={p0[1]}")
    print()

    # Fit with automatic error handling
    print("Starting optimization...")
    print("-" * 70)
    result = optimizer.fit(
        (x_data, y_data),  # Data as tuple
        exponential_decay,  # Model function
        p0,  # Initial parameters
        verbose=1,  # Show progress
    )
    print("-" * 70)
    print()

    # Extract results
    best_params = result["x"]
    success = result["success"]
    message = result["message"]
    best_loss = result["best_loss"]
    diagnostics = result["streaming_diagnostics"]

    # Display results
    print("RESULTS")
    print("=" * 70)
    print(f"Success: {success}")
    print(f"Message: {message}")
    print()
    print("Best parameters found:")
    print(f"  a = {best_params[0]:.6f} (true: {true_a})")
    print(f"  b = {best_params[1]:.6f} (true: {true_b})")
    print(f"  Best loss = {best_loss:.6e}")
    print()

    # Display diagnostics
    print("DIAGNOSTICS")
    print("=" * 70)
    print(f"Batch success rate: {diagnostics['batch_success_rate']:.1%}")
    print(f"Total batches attempted: {diagnostics['total_batches_attempted']}")
    print(f"Total retries: {diagnostics['total_retries']}")
    print(f"Convergence achieved: {diagnostics['convergence_achieved']}")
    print(f"Final epoch: {diagnostics['final_epoch']}")
    print(f"Elapsed time: {diagnostics['elapsed_time']:.2f}s")
    print()

    # Failed batches (if any)
    if diagnostics["failed_batches"]:
        print(f"Failed batches ({len(diagnostics['failed_batches'])}):")
        print(f"  Indices: {diagnostics['failed_batches']}")
        print(f"  Error types: {diagnostics['error_types']}")
        print()

    # Aggregate statistics
    agg = diagnostics["aggregate_stats"]
    print("Aggregate Statistics (from batch buffer):")
    print(f"  Mean loss: {agg['mean_loss']:.6e}")
    print(f"  Std loss: {agg['std_loss']:.6e}")
    print(f"  Mean gradient norm: {agg['mean_grad_norm']:.6f}")
    print(f"  Mean batch time: {agg['mean_batch_time'] * 1000:.2f}ms")
    print()

    # Recent batch statistics
    recent_stats = diagnostics["recent_batch_stats"]
    if recent_stats:
        print(f"Recent batch statistics (last {len(recent_stats)} batches):")
        # Show last 5 batches
        for i, stats in enumerate(recent_stats[-5:], 1):
            status = "SUCCESS" if stats["success"] else "FAILED"
            retry_info = (
                f" (retries: {stats['retry_count']})"
                if stats["retry_count"] > 0
                else ""
            )
            print(
                f"  Batch {stats['batch_idx']}: {status}, loss={stats['loss']:.6e}{retry_info}"
            )
        print()

    # Checkpoint information
    if diagnostics["checkpoint_info"]:
        cp = diagnostics["checkpoint_info"]
        print("Checkpoint Information:")
        print(f"  Path: {cp['path']}")
        print(f"  Saved at: {cp['saved_at']}")
        print(f"  Batch index: {cp['batch_idx']}")
        print()

    print("=" * 70)
    print("Example complete!")
    print()
    print("Key takeaways:")
    print("  - Fault tolerance enabled by default (no configuration needed)")
    print("  - Best parameters always returned (never initial p0)")
    print("  - NaN/Inf detection at three validation points")
    print("  - Adaptive retry strategies for failed batches")
    print("  - Comprehensive diagnostics for analysis")
    print("  - Checkpoints saved automatically for recovery")


if __name__ == "__main__":
    main()

INFO:2025-11-17 16:53:23,458:jax._src.xla_bridge:808: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Starting streaming optimization with batch_size=100


Using Adam optimizer



Epoch 1/10


Streaming Optimizer: Basic Fault Tolerance Example

Dataset: 10000 samples
True parameters: a=2.5, b=0.3

Configuration:
  Batch size: 100
  Max epochs: 10
  Learning rate: 0.001
  Fault tolerance: True
  Validate numerics: True
  Min success rate: 50%
  Max retries per batch: 2

Initial guess: a=1.0, b=0.1

Starting optimization...
----------------------------------------------------------------------


Epoch 1 complete: avg_loss=0.002516, samples=10000



Epoch 2/10


Warmup complete. Batch padding enabled (max_shape=100)


Epoch 2 complete: avg_loss=0.002582, samples=10000



Epoch 3/10


Epoch 3 complete: avg_loss=0.002499, samples=10000



Epoch 4/10


Epoch 4 complete: avg_loss=0.002369, samples=10000



Epoch 5/10


Epoch 5 complete: avg_loss=0.002233, samples=10000



Epoch 6/10


Epoch 6 complete: avg_loss=0.002101, samples=10000



Epoch 7/10


Epoch 7 complete: avg_loss=0.001974, samples=10000



Epoch 8/10


Epoch 8 complete: avg_loss=0.001853, samples=10000



Epoch 9/10


Epoch 9 complete: avg_loss=0.001736, samples=10000



Epoch 10/10


Epoch 10 complete: avg_loss=0.001626, samples=10000


Optimization complete: 1000/1000 batches succeeded (100.0%)


----------------------------------------------------------------------

RESULTS
Success: True
Message: Optimization complete: 1000/1000 batches succeeded (100.0%)

Best parameters found:
  a = 1.189291 (true: 2.5)
  b = 0.134749 (true: 0.3)
  Best loss = 7.611412e-03

DIAGNOSTICS
Batch success rate: 100.0%
Total batches attempted: 1000
Total retries: 0
Convergence achieved: False
Final epoch: 9
Elapsed time: 0.56s

Aggregate Statistics (from batch buffer):
  Mean loss: 1.625896e-01
  Std loss: 3.034042e-01
  Mean gradient norm: 0.948855
  Mean batch time: 0.44ms

Recent batch statistics (last 100 batches):
  Batch 95: SUCCESS, loss=1.803398e-02
  Batch 96: SUCCESS, loss=1.725470e-02
  Batch 97: SUCCESS, loss=1.732153e-02
  Batch 98: SUCCESS, loss=1.525652e-02
  Batch 99: SUCCESS, loss=1.644096e-02

Checkpoint Information:
  Path: checkpoints/checkpoint_iter_1000.h5
  Saved at: 2025-11-17T16:53:24.335012
  Batch index: 99

Example complete!

Key takeaways:
  - Fault tolerance enabled by