# Example 2: Checkpoint Save and Resume

This example demonstrates checkpoint save/resume functionality for recovering
from interruptions during long-running optimizations.

Features demonstrated:
- Automatic checkpoint saving at intervals
- Auto-detection of latest checkpoint
- Resume from specific checkpoint path
- Full optimizer state preservation

Run this example:
    python examples/streaming/02_checkpoint_resume.py


In [1]:
import time
from pathlib import Path

import jax.numpy as jnp
import numpy as np

from nlsq import StreamingConfig, StreamingOptimizer


def gaussian_model(x, amp, center, width):
    """Gaussian model: y = amp * exp(-0.5 * ((x - center) / width)^2)"""
    return amp * jnp.exp(-0.5 * ((x - center) / width) ** 2)


def simulate_interruption(iteration, params, loss):
    """Callback to simulate interruption after 5 iterations"""
    if iteration == 5:
        print(f"\n  [SIMULATED INTERRUPTION at iteration {iteration}]")
        return False  # Stop optimization
    return True


def main():
    print("=" * 70)
    print("Streaming Optimizer: Checkpoint Save/Resume Example")
    print("=" * 70)
    print()

    # Generate synthetic data
    np.random.seed(42)
    n_samples = 5000
    x_data = np.linspace(-5, 5, n_samples)
    true_amp, true_center, true_width = 2.0, 0.5, 1.5
    y_true = gaussian_model(x_data, true_amp, true_center, true_width)
    y_data = y_true + 0.05 * np.random.randn(n_samples)

    print(f"Dataset: {n_samples} samples")
    print(f"True parameters: amp={true_amp}, center={true_center}, width={true_width}")
    print()

    # Clean up old checkpoints
    checkpoint_dir = Path("checkpoints_example")
    if checkpoint_dir.exists():
        for f in checkpoint_dir.glob("checkpoint_*.h5"):
            f.unlink()
        print(f"Cleaned up old checkpoints in {checkpoint_dir}")
        print()

    # Part 1: Initial training with interruption
    print("PART 1: Initial Training (will be interrupted)")
    print("=" * 70)

    config = StreamingConfig(
        batch_size=100,
        max_epochs=10,
        learning_rate=0.001,
        checkpoint_dir=str(checkpoint_dir),
        checkpoint_frequency=2,  # Save every 2 iterations (frequent for demo)
        enable_checkpoints=True,
        resume_from_checkpoint=None,  # Don't resume (start fresh)
    )

    print(f"Checkpoint directory: {config.checkpoint_dir}")
    print(f"Checkpoint frequency: every {config.checkpoint_frequency} iterations")
    print()

    optimizer = StreamingOptimizer(config)
    p0 = np.array([1.0, 0.0, 1.0])
    print(f"Initial guess: amp={p0[0]}, center={p0[1]}, width={p0[2]}")
    print()

    print("Starting training (will interrupt after 5 iterations)...")
    result1 = optimizer.fit(
        (x_data, y_data),
        gaussian_model,
        p0,
        callback=simulate_interruption,  # Simulate interruption
        verbose=1,
    )

    print()
    print("Training interrupted!")
    print(f"Iterations completed: {optimizer.iteration}")
    print(f"Best loss so far: {result1['best_loss']:.6e}")
    print(f"Best params so far: {result1['x']}")
    print()

    # Check saved checkpoints
    checkpoints = list(checkpoint_dir.glob("checkpoint_iter_*.h5"))
    print(f"Checkpoints saved: {len(checkpoints)}")
    for cp in sorted(checkpoints):
        print(f"  - {cp.name}")
    print()

    # Part 2: Resume from checkpoint (auto-detect latest)
    print("PART 2: Resume from Checkpoint (auto-detect)")
    print("=" * 70)

    config_resume = StreamingConfig(
        batch_size=100,
        max_epochs=10,
        learning_rate=0.001,
        checkpoint_dir=str(checkpoint_dir),
        checkpoint_frequency=2,
        enable_checkpoints=True,
        resume_from_checkpoint=True,  # Auto-detect latest checkpoint
    )

    print("Resuming with auto-detection of latest checkpoint...")
    print()

    optimizer2 = StreamingOptimizer(config_resume)
    result2 = optimizer2.fit(
        (x_data, y_data),
        gaussian_model,
        p0,  # Still provide p0 (used if checkpoint load fails)
        verbose=1,
    )

    print()
    print("Training resumed and completed!")
    print()

    # Part 3: Resume from specific checkpoint path
    print("PART 3: Resume from Specific Checkpoint")
    print("=" * 70)

    # Find a specific checkpoint (e.g., iteration 4)
    specific_checkpoint = checkpoint_dir / "checkpoint_iter_4.h5"
    if specific_checkpoint.exists():
        print(f"Resuming from specific checkpoint: {specific_checkpoint.name}")
        print()

        config_specific = StreamingConfig(
            batch_size=100,
            max_epochs=10,
            learning_rate=0.001,
            checkpoint_dir=str(checkpoint_dir),
            checkpoint_frequency=2,
            enable_checkpoints=True,
            resume_from_checkpoint=str(specific_checkpoint),  # Specific path
        )

        optimizer3 = StreamingOptimizer(config_specific)
        result3 = optimizer3.fit(
            (x_data, y_data),
            gaussian_model,
            p0,
            verbose=1,
        )

        print()
        print(
            f"Resumed from iteration 4, completed at iteration {optimizer3.iteration}"
        )
        print()

    # Display final results
    print("FINAL RESULTS")
    print("=" * 70)
    best_params = result2["x"]
    print("Best parameters:")
    print(f"  amp    = {best_params[0]:.6f} (true: {true_amp})")
    print(f"  center = {best_params[1]:.6f} (true: {true_center})")
    print(f"  width  = {best_params[2]:.6f} (true: {true_width})")
    print(f"  Best loss = {result2['best_loss']:.6e}")
    print()

    # Checkpoint diagnostics
    diag = result2["streaming_diagnostics"]
    if diag["checkpoint_info"]:
        cp_info = diag["checkpoint_info"]
        print("Final Checkpoint:")
        print(f"  Path: {cp_info['path']}")
        print(f"  Saved at: {cp_info['saved_at']}")
        print(f"  Batch index: {cp_info['batch_idx']}")
        print()

    print("=" * 70)
    print("Example complete!")
    print()
    print("Key takeaways:")
    print("  - Checkpoints save full optimizer state (params, momentum, etc.)")
    print("  - resume_from_checkpoint=True auto-detects latest checkpoint")
    print("  - resume_from_checkpoint='path' loads specific checkpoint")
    print("  - Seamless resume from any interruption point")
    print("  - No duplicate batch processing on resume")
    print(f"\nCheckpoints saved in: {checkpoint_dir.absolute()}")


if __name__ == "__main__":
    main()

INFO:2025-11-17 16:53:28,622:jax._src.xla_bridge:808: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Starting streaming optimization with batch_size=100


Using Adam optimizer



Epoch 1/10


Streaming Optimizer: Checkpoint Save/Resume Example

Dataset: 5000 samples
True parameters: amp=2.0, center=0.5, width=1.5

PART 1: Initial Training (will be interrupted)
Checkpoint directory: checkpoints_example
Checkpoint frequency: every 100 iterations

Initial guess: amp=1.0, center=0.0, width=1.0

Starting training (will interrupt after 5 iterations)...


Optimization stopped by callback


Epoch 1 complete: avg_loss=0.000025, samples=500



Epoch 2/10


Epoch 2 complete: avg_loss=0.004370, samples=5000



Epoch 3/10


Warmup complete. Batch padding enabled (max_shape=100)


Epoch 3 complete: avg_loss=0.004230, samples=5000



Epoch 4/10


Epoch 4 complete: avg_loss=0.003978, samples=5000



Epoch 5/10


Epoch 5 complete: avg_loss=0.003712, samples=5000



Epoch 6/10


Epoch 6 complete: avg_loss=0.003451, samples=5000



Epoch 7/10


Epoch 7 complete: avg_loss=0.003195, samples=5000



Epoch 8/10


Epoch 8 complete: avg_loss=0.002947, samples=5000



Epoch 9/10



  [SIMULATED INTERRUPTION at iteration 5]


Epoch 9 complete: avg_loss=0.002706, samples=5000



Epoch 10/10


Epoch 10 complete: avg_loss=0.002475, samples=5000


Optimization complete: 455/455 batches succeeded (100.0%)


Loaded checkpoint from checkpoints_example/checkpoint_iter_455.h5 (iteration 455, epoch 9)


Starting streaming optimization with batch_size=100


Using Adam optimizer


Resuming from iteration 455



Epoch 10/10


Warmup complete. Batch padding enabled (max_shape=None)


Epoch 10 complete: avg_loss=0.000036, samples=100


Optimization complete: 1/1 batches succeeded (100.0%)



Training interrupted!
Iterations completed: 455
Best loss so far: 2.047768e-03
Best params so far: [9.99993671e-01 8.96047293e-06 9.99990227e-01]

Checkpoints saved: 5
  - checkpoint_iter_100.h5
  - checkpoint_iter_200.h5
  - checkpoint_iter_300.h5
  - checkpoint_iter_400.h5
  - checkpoint_iter_455.h5

PART 2: Resume from Checkpoint (auto-detect)
Resuming with auto-detection of latest checkpoint...


Training resumed and completed!

PART 3: Resume from Specific Checkpoint
FINAL RESULTS
Best parameters:
  amp    = 0.999994 (true: 2.0)
  center = 0.000009 (true: 0.5)
  width  = 0.999990 (true: 1.5)
  Best loss = 2.047768e-03

Final Checkpoint:
  Path: checkpoints_example/checkpoint_iter_456.h5
  Saved at: 2025-11-17T16:53:29.429060
  Batch index: 49

Example complete!

Key takeaways:
  - Checkpoints save full optimizer state (params, momentum, etc.)
  - resume_from_checkpoint=True auto-detects latest checkpoint
  - resume_from_checkpoint='path' loads specific checkpoint
  - Seamless res