# 01 Fit Quickstart (v0.6.3)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/08_workflow_system/01_fit_quickstart.ipynb)

**The Three Workflows (v0.6.3):**
- `workflow="auto"` : Memory-aware local optimization (bounds optional)
- `workflow="auto_global"` : Memory-aware global optimization (bounds required)
- `workflow="hpc"` : auto_global + checkpointing for HPC (bounds required)

Features demonstrated:
- Using fit() with automatic memory-based strategy selection
- Using workflow='auto_global' for global optimization
- Adjusting tolerances directly (not via presets)
- Comparing fit(), curve_fit(), and curve_fit_large()

Run this example:
    python examples/scripts/08_workflow_system/01_fit_quickstart.py

In [None]:
# @title Install NLSQ (run once in Colab)
import sys

if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

In [None]:
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import curve_fit, curve_fit_large, fit

FIG_DIR = Path.cwd() / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
def exponential_decay(x, a, b, c):
    """Exponential decay: y = a * exp(-b * x) + c"""
    return a * jnp.exp(-b * x) + c


def main():
    print("=" * 70)
    print("Unified fit() Entry Point - Quickstart (v0.6.3)")
    print("=" * 70)
    print()

    np.random.seed(42)

    # =========================================================================
    # 1. Generate synthetic data
    # =========================================================================
    print("1. Generating synthetic data...")

    n_samples = 500
    x_data = np.linspace(0, 5, n_samples)

    true_a, true_b, true_c = 3.0, 1.2, 0.5

    y_true = true_a * np.exp(-true_b * x_data) + true_c
    noise = 0.15 * np.random.randn(n_samples)
    y_data = y_true + noise

    print(f"  True parameters: a={true_a}, b={true_b}, c={true_c}")
    print(f"  Dataset size: {n_samples} points")

    # =========================================================================
    # 2. workflow='auto' - Local optimization with automatic memory strategy
    # =========================================================================
    print()
    print("2. workflow='auto' - Local optimization (default, bounds optional)...")

    popt_auto, pcov_auto = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        workflow="auto",  # Default: automatic memory-based strategy selection
    )

    print(f"  Fitted: a={popt_auto[0]:.4f}, b={popt_auto[1]:.4f}, c={popt_auto[2]:.4f}")
    print(f"  True:   a={true_a:.4f}, b={true_b:.4f}, c={true_c:.4f}")

    # =========================================================================
    # 3. workflow='auto' with bounds
    # =========================================================================
    print()
    print("3. workflow='auto' with optional bounds...")

    bounds = ([0.1, 0.1, -1.0], [10.0, 5.0, 2.0])

    popt_bounded, _ = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=bounds,
        workflow="auto",  # Bounds are optional for 'auto'
    )
    print(f"  Bounded fit: a={popt_bounded[0]:.4f}, b={popt_bounded[1]:.4f}, c={popt_bounded[2]:.4f}")

    # =========================================================================
    # 4. workflow='auto_global' - Global optimization
    # =========================================================================
    print()
    print("4. workflow='auto_global' - Global optimization (bounds required)...")
    print("   Automatically selects CMA-ES or Multi-Start based on parameter scales")

    popt_global, _ = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=bounds,
        workflow="auto_global",  # Bounds required for global optimization
        n_starts=5,  # Number of multi-start runs
    )
    print(f"  Global fit: a={popt_global[0]:.4f}, b={popt_global[1]:.4f}, c={popt_global[2]:.4f}")

    # =========================================================================
    # 5. Adjusting tolerances directly
    # =========================================================================
    print()
    print("5. Adjusting tolerances directly (not via presets)...")

    # Looser tolerances for speed
    popt_fast, _ = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=bounds,
        workflow="auto",
        gtol=1e-6,
        ftol=1e-6,
        xtol=1e-6,
    )
    print(f"  Fast (gtol=1e-6): a={popt_fast[0]:.4f}, b={popt_fast[1]:.4f}, c={popt_fast[2]:.4f}")

    # Tighter tolerances for precision
    popt_precise, _ = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=bounds,
        workflow="auto",
        gtol=1e-10,
        ftol=1e-10,
        xtol=1e-10,
    )
    print(f"  Precise (gtol=1e-10): a={popt_precise[0]:.4f}, b={popt_precise[1]:.4f}, c={popt_precise[2]:.4f}")

    # =========================================================================
    # 6. Comparison with curve_fit() and curve_fit_large()
    # =========================================================================
    print()
    print("6. Comparison with other APIs...")

    popt_cf, _ = curve_fit(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=bounds,
    )
    print(f"  curve_fit():       a={popt_cf[0]:.4f}, b={popt_cf[1]:.4f}, c={popt_cf[2]:.4f}")

    popt_cfl, _ = curve_fit_large(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=bounds,
    )
    print(f"  curve_fit_large(): a={popt_cfl[0]:.4f}, b={popt_cfl[1]:.4f}, c={popt_cfl[2]:.4f}")

    # =========================================================================
    # 7. Visualization
    # =========================================================================
    print()
    print("7. Saving visualization...")

    y_pred = exponential_decay(x_data, *popt_auto)

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    ax1 = axes[0]
    ax1.scatter(x_data, y_data, alpha=0.4, s=10, label="Data")
    ax1.plot(x_data, y_true, "k--", linewidth=2, label="True function")
    ax1.plot(x_data, y_pred, "r-", linewidth=2, label="fit() result")
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_title("Exponential Decay Fit")
    ax1.legend()

    ax2 = axes[1]
    residuals = y_data - y_pred
    ax2.scatter(x_data, residuals, alpha=0.5, s=10)
    ax2.axhline(y=0, color="k", linestyle="--", alpha=0.5)
    ax2.set_xlabel("x")
    ax2.set_ylabel("Residual")
    ax2.set_title("Residuals")

    plt.tight_layout()
    plt.savefig(FIG_DIR / "01_fit_result.png", dpi=300, bbox_inches="tight")
    plt.show()
    print(f"  Saved: {FIG_DIR / '01_fit_result.png'}")

    # =========================================================================
    # Summary
    # =========================================================================
    print()
    print("=" * 70)
    print("Summary - The Three Workflows (v0.6.3)")
    print("=" * 70)
    print()
    print("Workflows:")
    print("  workflow='auto'        : Local optimization, bounds optional")
    print("                           Auto-selects: STANDARD / CHUNKED / STREAMING")
    print()
    print("  workflow='auto_global' : Global optimization, bounds required")
    print("                           Auto-selects: CMA-ES or Multi-Start")
    print()
    print("  workflow='hpc'         : auto_global + checkpointing for HPC")
    print()
    print("Tolerance control (set directly, not via presets):")
    print("  gtol, ftol, xtol=1e-6  : Fast fitting, looser tolerances")
    print("  gtol, ftol, xtol=1e-10 : High precision fitting")

In [None]:
if __name__ == "__main__":
    main()