# 04 Workflow Presets

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/08_workflow_system/04_workflow_presets.ipynb)

Features demonstrated:
- All entries in the WORKFLOW_PRESETS dictionary
- Using presets for common fitting scenarios
- Inspecting preset configurations
- Comparing preset performance

Run this example:
    python examples/scripts/08_workflow_system/04_workflow_presets.py

In [1]:
# @title Install NLSQ (run once in Colab)
import sys

if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

Not running in Colab - assuming NLSQ is already installed


In [2]:
import time
from pathlib import Path
from pprint import pprint

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import HybridStreamingConfig, fit
from nlsq.core.minpack import WORKFLOW_PRESETS

FIG_DIR = Path.cwd() / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
def exponential_model(x, a, b, c):
    """Exponential decay: y = a * exp(-b * x) + c"""
    return a * jnp.exp(-b * x) + c


def main():
    print("=" * 70)
    print("WORKFLOW_PRESETS Guide")
    print("=" * 70)
    print()

    np.random.seed(42)

    # =========================================================================
    # 1. Available Presets
    # =========================================================================
    print("1. Available WORKFLOW_PRESETS:")
    print("-" * 60)

    for preset_name in WORKFLOW_PRESETS:
        description = WORKFLOW_PRESETS[preset_name].get("description", "No description")
        print(f"  {preset_name:<20} - {description}")

    # =========================================================================
    # 2. Inspecting Presets
    # =========================================================================
    print()
    print("2. Inspecting Presets:")
    print("-" * 60)

    presets_to_show = ["standard", "quality", "fast", "streaming"]

    for preset_name in presets_to_show:
        if preset_name in WORKFLOW_PRESETS:
            print(f"\n  '{preset_name}' preset:")
            for key, value in WORKFLOW_PRESETS[preset_name].items():
                print(f"    {key}: {value}")

    # =========================================================================
    # 3. Testing Presets
    # =========================================================================
    print()
    print("3. Testing Presets on Exponential Decay:")
    print("-" * 70)

    n_samples = 500
    x_data = np.linspace(0, 5, n_samples)

    true_a, true_b, true_c = 3.0, 1.2, 0.5
    y_true = true_a * np.exp(-true_b * x_data) + true_c
    noise = 0.15 * np.random.randn(n_samples)
    y_data = y_true + noise

    p0 = [1.0, 0.5, 0.0]
    bounds = ([0.1, 0.1, -1.0], [10.0, 5.0, 2.0])

    print(f"  True parameters: a={true_a}, b={true_b}, c={true_c}")

    presets_to_test = ["fast", "standard", "quality"]
    results = {}

    for preset_name in presets_to_test:
        start_time = time.time()

        popt, pcov = fit(
            exponential_model,
            x_data,
            y_data,
            p0=p0,
            bounds=bounds,
            workflow=preset_name,
        )

        elapsed = time.time() - start_time

        y_pred = exponential_model(x_data, *popt)
        ssr = float(jnp.sum((y_data - y_pred) ** 2))

        results[preset_name] = {
            "popt": popt,
            "ssr": ssr,
            "time": elapsed,
        }

        print(f"\n  {preset_name.upper()}:")
        print(f"    Time:       {elapsed:.4f}s")
        print(f"    SSR:        {ssr:.6f}")
        print(f"    Parameters: a={popt[0]:.4f}, b={popt[1]:.4f}, c={popt[2]:.4f}")

    # =========================================================================
    # 4. Using fit() with Workflow Presets
    # =========================================================================
    print()
    print("4. Using fit() with Workflow Presets:")
    print("-" * 50)
    print()
    print("  Basic usage:")
    print("    popt, pcov = fit(model, x, y, workflow='quality')")
    print()
    print("  With additional parameters:")
    print("    popt, pcov = fit(")
    print("        model, x, y,")
    print("        workflow='standard',")
    print("        multistart=True,      # Override preset setting")
    print("        n_starts=20,          # Custom number of starts")
    print("        sampler='sobol',      # Different sampler")
    print("    )")

    # =========================================================================
    # 5. Defense Layer Presets (Streaming)
    # =========================================================================
    print()
    print("5. Defense Layer Presets for Streaming:")
    print("-" * 70)
    print()
    print("For streaming workflows, HybridStreamingConfig provides defense presets")
    print("that protect against L-BFGS warmup divergence:")
    print()

    defense_presets = {
        "defense_strict": {
            "method": "HybridStreamingConfig.defense_strict()",
            "use_case": "Warm-start refinement (checkpoint resume)",
        },
        "defense_relaxed": {
            "method": "HybridStreamingConfig.defense_relaxed()",
            "use_case": "Exploration (rough initial guesses)",
        },
        "scientific_default": {
            "method": "HybridStreamingConfig.scientific_default()",
            "use_case": "Production scientific computing",
        },
    }

    for name, info in defense_presets.items():
        print(f"  {name.upper()}:")
        print(f"    Method:   {info['method']}")
        print(f"    Use case: {info['use_case']}")
        print()

    # =========================================================================
    # 6. Visualization
    # =========================================================================
    print("6. Saving visualizations...")

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    preset_names = list(results.keys())
    colors = {"fast": "blue", "standard": "green", "quality": "red"}

    ax1 = axes[0]
    ssrs = [results[p]["ssr"] for p in preset_names]
    bars = ax1.bar(preset_names, ssrs, color=[colors[p] for p in preset_names])
    ax1.set_xlabel("Preset")
    ax1.set_ylabel("Sum of Squared Residuals")
    ax1.set_title("Fit Quality by Preset")
    for bar, ssr in zip(bars, ssrs, strict=False):
        ax1.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{ssr:.4f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    ax2 = axes[1]
    times = [results[p]["time"] for p in preset_names]
    bars = ax2.bar(preset_names, times, color=[colors[p] for p in preset_names])
    ax2.set_xlabel("Preset")
    ax2.set_ylabel("Time (seconds)")
    ax2.set_title("Computation Time by Preset")
    for bar, t in zip(bars, times, strict=False):
        ax2.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{t:.3f}s",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    ax3 = axes[2]
    tols = [WORKFLOW_PRESETS[p].get("gtol", 1e-8) for p in preset_names]
    bars = ax3.bar(preset_names, tols, color=[colors[p] for p in preset_names])
    ax3.set_xlabel("Preset")
    ax3.set_ylabel("gtol")
    ax3.set_title("Tolerance (gtol) by Preset")
    ax3.set_yscale("log")
    for bar, t in zip(bars, tols, strict=False):
        ax3.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{t:.0e}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    plt.tight_layout()
    plt.savefig(FIG_DIR / "04_preset_comparison.png", dpi=300, bbox_inches="tight")
    plt.close()
    print(f"  Saved: {FIG_DIR / '04_preset_comparison.png'}")

    # =========================================================================
    # Summary
    # =========================================================================
    print()
    print("=" * 70)
    print("Summary")
    print("=" * 70)
    print()
    print("Available presets:")
    for name in WORKFLOW_PRESETS:
        desc = WORKFLOW_PRESETS[name].get("description", "")
        print(f"  - {name}: {desc}")
    print()
    print("Quick usage:")
    print("  fit(model, x, y, workflow='quality')")
    print("  fit(model, x, y, workflow='standard', multistart=True, n_starts=20)")
    print()
    print("Defense presets for streaming:")
    print("  HybridStreamingConfig.defense_strict()     # Checkpoint resume")
    print("  HybridStreamingConfig.defense_relaxed()    # Exploration")
    print("  HybridStreamingConfig.scientific_default() # Production")

In [4]:
if __name__ == "__main__":
    main()

INFO:nlsq.curve_fit:Starting curve fit n_params=3 | n_data_points=500 | method=trf | solver=auto | batch_size=None | has_bounds=True | dynamic_sizing=False


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=3 | loss=linear | ftol=1.0000e-08 | xtol=1.0000e-08 | gtol=1.0000e-08


WORKFLOW_PRESETS Guide

1. Available WORKFLOW_PRESETS:
------------------------------------------------------------
  standard             - Standard curve_fit() with default tolerances
  quality              - Highest precision with multi-start and tighter tolerances
  fast                 - Speed-optimized with looser tolerances
  large_robust         - Chunked processing with multi-start for large datasets
  streaming            - AdaptiveHybridStreamingOptimizer for huge datasets
  hpc_distributed      - Multi-GPU/node configuration for HPC clusters

2. Inspecting Presets:
------------------------------------------------------------

  'standard' preset:
    description: Standard curve_fit() with default tolerances
    tier: STANDARD
    enable_multistart: False
    gtol: 1e-08
    ftol: 1e-08
    xtol: 1e-08

  'quality' preset:
    description: Highest precision with multi-start and tighter tolerances
    tier: STANDARD
    enable_multistart: True
    n_starts: 20
    gtol: 1e-10

PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=1.992605s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=6 | final_cost=5.3888 | elapsed=1.993s | final_gradient_norm=1.2824e-05


PERFORMANCE:nlsq.curve_fit:Timer: curve_fit elapsed=2.785487s




INFO:nlsq.curve_fit:Starting curve fit n_params=3 | n_data_points=500 | method=trf | solver=auto | batch_size=None | has_bounds=True | dynamic_sizing=False


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=3 | loss=linear | ftol=1.0000e-08 | xtol=1.0000e-08 | gtol=1.0000e-08



  FAST:
    Time:       2.9602s
    SSR:        10.777621
    Parameters: a=2.9620, b=1.1921, c=0.5042


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.202000s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=6 | final_cost=5.3888 | elapsed=0.202s | final_gradient_norm=1.2824e-05


PERFORMANCE:nlsq.curve_fit:Timer: curve_fit elapsed=0.460101s




INFO:nlsq.curve_fit:Starting curve fit n_params=3 | n_data_points=500 | method=trf | solver=auto | batch_size=None | has_bounds=True | dynamic_sizing=False


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=3 | loss=linear | ftol=1.0000e-08 | xtol=1.0000e-08 | gtol=1.0000e-08



  STANDARD:
    Time:       0.5319s
    SSR:        10.777621
    Parameters: a=2.9620, b=1.1921, c=0.5042


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.254399s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=6 | final_cost=5.3888 | elapsed=0.254s | final_gradient_norm=1.2824e-05


PERFORMANCE:nlsq.curve_fit:Timer: curve_fit elapsed=0.519984s





  QUALITY:
    Time:       0.5899s
    SSR:        10.777621
    Parameters: a=2.9620, b=1.1921, c=0.5042

4. Using fit() with Workflow Presets:
--------------------------------------------------

  Basic usage:
    popt, pcov = fit(model, x, y, workflow='quality')

  With additional parameters:
    popt, pcov = fit(
        model, x, y,
        workflow='standard',
        multistart=True,      # Override preset setting
        n_starts=20,          # Custom number of starts
        sampler='sobol',      # Different sampler
    )

5. Defense Layer Presets for Streaming:
----------------------------------------------------------------------

For streaming workflows, HybridStreamingConfig provides defense presets
that protect against L-BFGS warmup divergence:

  DEFENSE_STRICT:
    Method:   HybridStreamingConfig.defense_strict()
    Use case: Warm-start refinement (checkpoint resume)

  DEFENSE_RELAXED:
    Method:   HybridStreamingConfig.defense_relaxed()
    Use case: Exploration (r

  Saved: /home/wei/Documents/GitHub/NLSQ/examples/notebooks/08_workflow_system/figures/04_preset_comparison.png

Summary

Available presets:
  - standard: Standard curve_fit() with default tolerances
  - quality: Highest precision with multi-start and tighter tolerances
  - fast: Speed-optimized with looser tolerances
  - large_robust: Chunked processing with multi-start for large datasets
  - streaming: AdaptiveHybridStreamingOptimizer for huge datasets
  - hpc_distributed: Multi-GPU/node configuration for HPC clusters

Quick usage:
  fit(model, x, y, workflow='quality')
  fit(model, x, y, workflow='standard', multistart=True, n_starts=20)

Defense presets for streaming:
  HybridStreamingConfig.defense_strict()     # Checkpoint resume
  HybridStreamingConfig.defense_relaxed()    # Exploration
  HybridStreamingConfig.scientific_default() # Production
