# Multi-Start Integration with curve_fit() Workflows

This tutorial demonstrates how to integrate multi-start optimization with
`curve_fit()` and `curve_fit_large()` workflows for practical applications.

**Features demonstrated:**
- Integration with `curve_fit()` workflows
- Bounds handling with multi-start
- Combining with `curve_fit_large()` for large datasets
- Practical workflow examples

**Level: Intermediate** | **Duration: 25 minutes**

In [None]:
# Configure matplotlib for inline plotting (MUST come before imports)
%matplotlib inline

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import curve_fit, curve_fit_large, GlobalOptimizationConfig

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

## 1. Basic Multi-Start with curve_fit()

The simplest way to enable multi-start is via the `multistart` parameter.

In [None]:
# Define a model with multiple local minima
def damped_oscillation(x, a, b, c, d):
    """Damped oscillation model.
    
    y = a * exp(-b * x) * cos(c * x + d)
    
    This model has many local minima due to the periodic cosine.
    """
    return a * jnp.exp(-b * x) * jnp.cos(c * x + d)

In [None]:
# Generate synthetic data
n_samples = 300
x_data = np.linspace(0, 10, n_samples)

# True parameters
true_a, true_b, true_c, true_d = 3.0, 0.3, 2.5, 0.5

y_true = true_a * np.exp(-true_b * x_data) * np.cos(true_c * x_data + true_d)
noise = 0.15 * np.random.randn(n_samples)
y_data = y_true + noise

print(f"True parameters: a={true_a}, b={true_b}, c={true_c}, d={true_d}")
print(f"Dataset: {n_samples} points")

In [None]:
# Define bounds for constrained optimization
bounds = (
    [0.5, 0.01, 0.5, -np.pi],  # Lower bounds
    [10.0, 2.0, 5.0, np.pi],   # Upper bounds
)

# Poor initial guess (far from true values)
p0 = [1.0, 0.1, 1.0, 0.0]

print(f"Initial guess: {p0}")
print(f"Bounds: lower={bounds[0]}, upper={bounds[1]}")

In [None]:
# Single-start fit (may get trapped in local minimum)
popt_single, pcov_single = curve_fit(
    damped_oscillation,
    x_data,
    y_data,
    p0=p0,
    bounds=bounds,
)

print("Single-start result:")
print(f"  a={popt_single[0]:.4f}, b={popt_single[1]:.4f}, c={popt_single[2]:.4f}, d={popt_single[3]:.4f}")

In [None]:
# Multi-start fit with basic parameters
popt_multi, pcov_multi = curve_fit(
    damped_oscillation,
    x_data,
    y_data,
    p0=p0,
    bounds=bounds,
    multistart=True,    # Enable multi-start
    n_starts=10,        # Number of starting points
    sampler="lhs",      # Latin Hypercube Sampling
)

print("Multi-start result:")
print(f"  a={popt_multi[0]:.4f}, b={popt_multi[1]:.4f}, c={popt_multi[2]:.4f}, d={popt_multi[3]:.4f}")
print(f"\nTrue values:")
print(f"  a={true_a}, b={true_b}, c={true_c}, d={true_d}")

## 2. Bounds Handling with Multi-Start

Multi-start optimization respects parameter bounds and samples starting points
within the bounded region.

In [None]:
# Define tight bounds around expected solution
tight_bounds = (
    [2.0, 0.1, 2.0, -0.5],  # Tighter lower bounds
    [5.0, 0.8, 3.5, 1.5],   # Tighter upper bounds
)

# Multi-start with tight bounds
popt_tight, _ = curve_fit(
    damped_oscillation,
    x_data,
    y_data,
    p0=[3.0, 0.4, 2.5, 0.5],  # Reasonable initial guess
    bounds=tight_bounds,
    multistart=True,
    n_starts=10,
    sampler="lhs",
)

print("Result with tight bounds:")
print(f"  a={popt_tight[0]:.4f}, b={popt_tight[1]:.4f}, c={popt_tight[2]:.4f}, d={popt_tight[3]:.4f}")

# Verify bounds are respected
print("\nBounds verification:")
for i, (name, val, lo, hi) in enumerate(zip(['a', 'b', 'c', 'd'], popt_tight, tight_bounds[0], tight_bounds[1])):
    in_bounds = lo <= val <= hi
    print(f"  {name}: {lo:.2f} <= {val:.4f} <= {hi:.2f} : {'OK' if in_bounds else 'VIOLATION'}")

## 3. Unbounded Optimization

For unbounded problems, multi-start samples around the initial guess using
`center_on_p0=True` and `scale_factor`.

In [None]:
# Simple exponential model for unbounded optimization
def exponential_model(x, a, b, c):
    """Exponential decay model."""
    return a * jnp.exp(-b * x) + c

# Generate data
x_exp = np.linspace(0, 5, 200)
y_exp_true = 2.5 * np.exp(-1.3 * x_exp) + 0.5
y_exp = y_exp_true + 0.1 * np.random.randn(len(x_exp))

# Unbounded multi-start (samples around p0)
popt_unbound, _ = curve_fit(
    exponential_model,
    x_exp,
    y_exp,
    p0=[2.0, 1.0, 0.0],  # Initial guess
    # No bounds specified - unbounded optimization
    multistart=True,
    n_starts=8,
    sampler="lhs",
)

print("Unbounded multi-start result:")
print(f"  a={popt_unbound[0]:.4f}, b={popt_unbound[1]:.4f}, c={popt_unbound[2]:.4f}")
print(f"\nTrue values: a=2.5, b=1.3, c=0.5")

## 4. Using GlobalOptimizationConfig

For more control, create a `GlobalOptimizationConfig` object with advanced settings.

In [None]:
# Create custom configuration
config = GlobalOptimizationConfig(
    n_starts=15,
    sampler="sobol",           # Sobol sequences for low-discrepancy
    center_on_p0=True,         # Center samples around initial guess
    scale_factor=0.8,          # Exploration scale
    elimination_rounds=2,      # For streaming scenarios
    elimination_fraction=0.5,
)

print("GlobalOptimizationConfig settings:")
print(f"  n_starts: {config.n_starts}")
print(f"  sampler: {config.sampler}")
print(f"  center_on_p0: {config.center_on_p0}")
print(f"  scale_factor: {config.scale_factor}")

In [None]:
# Use config with curve_fit via multistart parameters
popt_config, _ = curve_fit(
    damped_oscillation,
    x_data,
    y_data,
    p0=p0,
    bounds=bounds,
    multistart=True,
    n_starts=config.n_starts,
    sampler=config.sampler,
)

print(f"Result with config: a={popt_config[0]:.4f}, b={popt_config[1]:.4f}, c={popt_config[2]:.4f}, d={popt_config[3]:.4f}")

## 5. Integration with curve_fit_large()

For large datasets, `curve_fit_large()` provides automatic chunking and memory
management. Multi-start can be combined with large dataset processing.

In [None]:
# Generate a larger dataset
n_large = 50000  # 50K points (use more for real large dataset scenarios)
x_large = np.linspace(0, 10, n_large)

y_large_true = true_a * np.exp(-true_b * x_large) * np.cos(true_c * x_large + true_d)
y_large = y_large_true + 0.15 * np.random.randn(n_large)

print(f"Large dataset: {n_large:,} points")
print(f"Data memory: {y_large.nbytes / 1024**2:.2f} MB")

In [None]:
# curve_fit_large with multi-start
# For datasets > 1M points, this uses chunked processing with subsample exploration
popt_large, pcov_large = curve_fit_large(
    damped_oscillation,
    x_large,
    y_large,
    p0=p0,
    bounds=bounds,
    multistart=True,
    n_starts=10,
    sampler="lhs",
    memory_limit_gb=1.0,  # Memory limit for chunking
)

print("curve_fit_large with multi-start:")
print(f"  a={popt_large[0]:.4f}, b={popt_large[1]:.4f}, c={popt_large[2]:.4f}, d={popt_large[3]:.4f}")

## 6. Practical Workflow Example: Peak Fitting

A common use case: fitting spectroscopic peaks where multiple local minima exist.

In [None]:
# Multi-peak Gaussian model
def double_gaussian(x, a1, mu1, sigma1, a2, mu2, sigma2, baseline):
    """Two Gaussian peaks on a baseline."""
    peak1 = a1 * jnp.exp(-((x - mu1) ** 2) / (2 * sigma1 ** 2))
    peak2 = a2 * jnp.exp(-((x - mu2) ** 2) / (2 * sigma2 ** 2))
    return peak1 + peak2 + baseline

# Generate synthetic spectroscopy data
n_spec = 500
x_spec = np.linspace(0, 10, n_spec)

# True parameters: two overlapping peaks
true_params_spec = [3.0, 3.5, 0.5, 2.0, 5.0, 0.8, 0.5]

y_spec_true = double_gaussian(x_spec, *true_params_spec)
y_spec = y_spec_true + 0.1 * np.random.randn(n_spec)

print("True peak parameters:")
print(f"  Peak 1: amplitude={true_params_spec[0]}, center={true_params_spec[1]}, width={true_params_spec[2]}")
print(f"  Peak 2: amplitude={true_params_spec[3]}, center={true_params_spec[4]}, width={true_params_spec[5]}")
print(f"  Baseline: {true_params_spec[6]}")

In [None]:
# Define bounds for peak fitting
peak_bounds = (
    [0.1, 0.0, 0.1, 0.1, 0.0, 0.1, 0.0],   # Lower: amplitudes > 0, widths > 0
    [10.0, 10.0, 3.0, 10.0, 10.0, 3.0, 2.0],  # Upper bounds
)

# Poor initial guess (peaks swapped, wrong amplitudes)
p0_spec = [1.5, 5.0, 0.8, 3.5, 3.5, 0.4, 0.3]

# Single-start (may swap peaks or find local minimum)
popt_spec_single, _ = curve_fit(
    double_gaussian,
    x_spec,
    y_spec,
    p0=p0_spec,
    bounds=peak_bounds,
)

print("Single-start result:")
print(f"  Peak 1: a={popt_spec_single[0]:.3f}, mu={popt_spec_single[1]:.3f}, sigma={popt_spec_single[2]:.3f}")
print(f"  Peak 2: a={popt_spec_single[3]:.3f}, mu={popt_spec_single[4]:.3f}, sigma={popt_spec_single[5]:.3f}")

In [None]:
# Multi-start (explores parameter space more thoroughly)
popt_spec_multi, _ = curve_fit(
    double_gaussian,
    x_spec,
    y_spec,
    p0=p0_spec,
    bounds=peak_bounds,
    multistart=True,
    n_starts=20,  # More starts for 7-parameter problem
    sampler="lhs",
)

print("Multi-start result:")
print(f"  Peak 1: a={popt_spec_multi[0]:.3f}, mu={popt_spec_multi[1]:.3f}, sigma={popt_spec_multi[2]:.3f}")
print(f"  Peak 2: a={popt_spec_multi[3]:.3f}, mu={popt_spec_multi[4]:.3f}, sigma={popt_spec_multi[5]:.3f}")
print()
print("True values:")
print(f"  Peak 1: a={true_params_spec[0]:.3f}, mu={true_params_spec[1]:.3f}, sigma={true_params_spec[2]:.3f}")
print(f"  Peak 2: a={true_params_spec[3]:.3f}, mu={true_params_spec[4]:.3f}, sigma={true_params_spec[5]:.3f}")

In [None]:
# Visualize peak fitting results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Full spectrum
ax1 = axes[0]
ax1.scatter(x_spec, y_spec, alpha=0.3, s=10, label='Data')
ax1.plot(x_spec, y_spec_true, 'k--', linewidth=2, label='True')
ax1.plot(x_spec, double_gaussian(x_spec, *popt_spec_single), 'b-', linewidth=2, label='Single-start')
ax1.plot(x_spec, double_gaussian(x_spec, *popt_spec_multi), 'r-', linewidth=2, label='Multi-start')
ax1.set_xlabel('x')
ax1.set_ylabel('Intensity')
ax1.set_title('Double Gaussian Peak Fitting')
ax1.legend()

# Right: Residuals
ax2 = axes[1]
residuals_single = y_spec - double_gaussian(x_spec, *popt_spec_single)
residuals_multi = y_spec - double_gaussian(x_spec, *popt_spec_multi)
ax2.scatter(x_spec, residuals_single, alpha=0.5, s=10, label='Single-start')
ax2.scatter(x_spec, residuals_multi, alpha=0.5, s=10, label='Multi-start')
ax2.axhline(y=0, color='k', linestyle='--', alpha=0.5)
ax2.set_xlabel('x')
ax2.set_ylabel('Residual')
ax2.set_title('Fit Residuals')
ax2.legend()

plt.tight_layout()
plt.savefig('figures/05_peak_fitting.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Comparison Visualization

In [None]:
# Calculate SSR for each method
y_pred_single = damped_oscillation(x_data, *popt_single)
y_pred_multi = damped_oscillation(x_data, *popt_multi)

ssr_single = float(jnp.sum((y_data - y_pred_single) ** 2))
ssr_multi = float(jnp.sum((y_data - y_pred_multi) ** 2))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Data with fits
ax1 = axes[0]
ax1.scatter(x_data, y_data, alpha=0.4, s=10, label='Data', color='gray')
ax1.plot(x_data, y_true, 'k--', linewidth=2, label='True', alpha=0.7)
ax1.plot(x_data, y_pred_single, 'b-', linewidth=2, label=f'Single-start (SSR={ssr_single:.2f})')
ax1.plot(x_data, y_pred_multi, 'r-', linewidth=2, label=f'Multi-start (SSR={ssr_multi:.2f})')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_title('Damped Oscillation: Single vs Multi-Start')
ax1.legend()

# Right: Parameter comparison
ax2 = axes[1]
params_true = np.array([true_a, true_b, true_c, true_d])
params_single = np.array(popt_single)
params_multi = np.array(popt_multi)

x_pos = np.arange(4)
width = 0.25

ax2.bar(x_pos - width, params_true, width, label='True', color='green', alpha=0.7)
ax2.bar(x_pos, params_single, width, label='Single-start', color='blue', alpha=0.7)
ax2.bar(x_pos + width, params_multi, width, label='Multi-start', color='red', alpha=0.7)

ax2.set_xticks(x_pos)
ax2.set_xticklabels(['a', 'b', 'c', 'd'])
ax2.set_xlabel('Parameter')
ax2.set_ylabel('Value')
ax2.set_title('Parameter Comparison')
ax2.legend()

plt.tight_layout()
plt.savefig('figures/05_multistart_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Key Takeaways

1. **Basic multi-start:** Add `multistart=True`, `n_starts=N`, `sampler="lhs"` to curve_fit()

2. **Bounds handling:** Multi-start respects bounds; samples starting points within bounds

3. **Large datasets:** Use `curve_fit_large()` with multi-start for datasets > 1M points

4. **Practical tips:**
   - Start with `n_starts=5-10` for most problems
   - Increase `n_starts` for high-dimensional or highly multimodal problems
   - Use `sampler="lhs"` for general use, `"sobol"` for low dimensions

5. **Peak fitting:** Multi-start is especially helpful for overlapping peaks where
   peak assignments can be ambiguous

In [None]:
# Summary
print("Summary")
print("=" * 50)
print()
print("Multi-start integration patterns:")
print()
print("1. Basic: curve_fit(..., multistart=True, n_starts=10)")
print("2. With sampler: curve_fit(..., multistart=True, sampler='lhs')")
print("3. Large datasets: curve_fit_large(..., multistart=True)")
print()
print("Recommended settings by problem complexity:")
print("  - Simple (2-3 params): n_starts=5")
print("  - Medium (4-6 params): n_starts=10")
print("  - Complex (7+ params): n_starts=20+")