# Multi-Start Optimization Basics

This tutorial demonstrates how multi-start optimization helps find global optima
in curve fitting problems with multiple local minima.

**Features demonstrated:**
- Local minima trap problem in nonlinear optimization
- `GlobalOptimizationConfig` configuration
- `curve_fit()` with `global_optimization` parameter
- Comparison of single-start vs multi-start results
- Visualization of loss landscape and starting point distribution

In [None]:
# Configure matplotlib for inline plotting (MUST come before imports)
%matplotlib inline

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from nlsq import curve_fit, GlobalOptimizationConfig

## 1. The Problem: Local Minima Traps

Many real-world curve fitting problems have multiple local minima. A standard
single-start optimizer may converge to a suboptimal local minimum instead of
the global optimum, depending on the initial parameter guess.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

In [None]:
# Define a multimodal sinusoidal model that has multiple local minima
def multimodal_model(x, a, b, c, d):
    """Multimodal model: y = a * sin(b * x + c) + d
    
    This model has multiple local minima due to the periodicity of sin().
    Different combinations of (b, c) can produce similar fits.
    """
    return a * jnp.sin(b * x + c) + d

In [None]:
# Generate synthetic data with known true parameters
n_samples = 200
x_data = np.linspace(0, 4 * np.pi, n_samples)

# True parameters
true_a, true_b, true_c, true_d = 2.0, 1.5, 0.5, 1.0

# Generate noisy observations
y_true = true_a * np.sin(true_b * x_data + true_c) + true_d
noise = 0.2 * np.random.randn(n_samples)
y_data = y_true + noise

print(f"True parameters: a={true_a}, b={true_b}, c={true_c}, d={true_d}")
print(f"Dataset: {n_samples} points")

In [None]:
# Visualize the data
fig, ax = plt.subplots(figsize=(10, 5))
ax.scatter(x_data, y_data, alpha=0.5, s=10, label="Noisy data")
ax.plot(x_data, y_true, "r-", linewidth=2, label="True function")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("Synthetic Data: Multimodal Sinusoidal Model")
ax.legend()
plt.tight_layout()
plt.savefig("figures/01_data_visualization.png", dpi=300, bbox_inches="tight")
plt.show()

## 2. Single-Start Optimization: The Local Minima Problem

Let's try fitting with a poor initial guess. Single-start optimization may
get trapped in a local minimum.

In [None]:
# Define bounds for parameters
# a: [0.5, 5], b: [0.5, 3], c: [-pi, pi], d: [-2, 5]
bounds = ([0.5, 0.5, -np.pi, -2.0], [5.0, 3.0, np.pi, 5.0])

# Try several different initial guesses with single-start optimization
initial_guesses = [
    [1.0, 0.8, 0.0, 0.5],   # Poor guess 1
    [3.0, 2.5, 2.0, 2.0],   # Poor guess 2
    [1.5, 1.2, -1.0, 0.0],  # Poor guess 3
]

single_start_results = []

print("Single-start optimization results:")
print("=" * 60)

for i, p0 in enumerate(initial_guesses):
    try:
        popt, pcov = curve_fit(
            multimodal_model,
            x_data,
            y_data,
            p0=p0,
            bounds=bounds,
        )
        # Calculate sum of squared residuals
        y_pred = multimodal_model(x_data, *popt)
        ssr = float(jnp.sum((y_data - y_pred) ** 2))
        single_start_results.append({"p0": p0, "popt": popt, "ssr": ssr})
        print(f"Guess {i+1}: p0={p0}")
        print(f"  Result: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}, d={popt[3]:.3f}")
        print(f"  SSR: {ssr:.4f}")
    except Exception as e:
        print(f"Guess {i+1}: Failed - {e}")
        single_start_results.append({"p0": p0, "popt": None, "ssr": float("inf")})

In [None]:
# Find best and worst results from single-start
valid_results = [r for r in single_start_results if r["popt"] is not None]
if valid_results:
    best_single = min(valid_results, key=lambda x: x["ssr"])
    worst_single = max(valid_results, key=lambda x: x["ssr"])
    print(f"\nBest single-start SSR: {best_single['ssr']:.4f}")
    print(f"Worst single-start SSR: {worst_single['ssr']:.4f}")
    print(f"\nVariability in results shows sensitivity to initial guess!")

## 3. Multi-Start Optimization: Finding the Global Optimum

Multi-start optimization explores the parameter space from multiple starting
points, significantly increasing the chance of finding the global optimum.

In [None]:
# Configure multi-start optimization
global_config = GlobalOptimizationConfig(
    n_starts=10,              # Number of starting points
    sampler="lhs",            # Latin Hypercube Sampling for even coverage
    center_on_p0=True,        # Center samples around initial guess
    scale_factor=1.0,         # Exploration scale
)

print("GlobalOptimizationConfig:")
print(f"  n_starts: {global_config.n_starts}")
print(f"  sampler: {global_config.sampler}")
print(f"  center_on_p0: {global_config.center_on_p0}")
print(f"  scale_factor: {global_config.scale_factor}")

In [None]:
# Use the first (poor) initial guess to show multi-start improvement
p0_poor = [1.0, 0.8, 0.0, 0.5]

# Fit with multi-start optimization
popt_multi, pcov_multi = curve_fit(
    multimodal_model,
    x_data,
    y_data,
    p0=p0_poor,
    bounds=bounds,
    multistart=True,
    n_starts=10,
    sampler="lhs",
)

# Calculate SSR for multi-start result
y_pred_multi = multimodal_model(x_data, *popt_multi)
ssr_multi = float(jnp.sum((y_data - y_pred_multi) ** 2))

print("\nMulti-start optimization result:")
print(f"  Result: a={popt_multi[0]:.3f}, b={popt_multi[1]:.3f}, c={popt_multi[2]:.3f}, d={popt_multi[3]:.3f}")
print(f"  SSR: {ssr_multi:.4f}")
print(f"\nTrue params: a={true_a}, b={true_b}, c={true_c}, d={true_d}")

In [None]:
# Compare single-start (from same initial guess) vs multi-start
popt_single, _ = curve_fit(
    multimodal_model,
    x_data,
    y_data,
    p0=p0_poor,
    bounds=bounds,
)
y_pred_single = multimodal_model(x_data, *popt_single)
ssr_single = float(jnp.sum((y_data - y_pred_single) ** 2))

print("\nComparison (same initial guess):")
print(f"  Single-start SSR: {ssr_single:.4f}")
print(f"  Multi-start SSR:  {ssr_multi:.4f}")
if ssr_multi < ssr_single:
    improvement = (1 - ssr_multi / ssr_single) * 100
    print(f"  Improvement: {improvement:.1f}% lower SSR")

## 4. Visualization: Comparing Results

In [None]:
# Create comparison visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left plot: Data with both fits
ax1 = axes[0]
ax1.scatter(x_data, y_data, alpha=0.4, s=15, label="Data", color="gray")
ax1.plot(x_data, y_true, "k--", linewidth=2, label="True function", alpha=0.7)
ax1.plot(x_data, y_pred_single, "b-", linewidth=2, label=f"Single-start (SSR={ssr_single:.2f})")
ax1.plot(x_data, y_pred_multi, "r-", linewidth=2, label=f"Multi-start (SSR={ssr_multi:.2f})")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
ax1.set_title("Single-Start vs Multi-Start Comparison")
ax1.legend()

# Right plot: Residuals comparison
ax2 = axes[1]
residuals_single = y_data - y_pred_single
residuals_multi = y_data - y_pred_multi
ax2.scatter(x_data, residuals_single, alpha=0.5, s=15, label="Single-start", color="blue")
ax2.scatter(x_data, residuals_multi, alpha=0.5, s=15, label="Multi-start", color="red")
ax2.axhline(y=0, color="k", linestyle="--", alpha=0.5)
ax2.set_xlabel("x")
ax2.set_ylabel("Residual")
ax2.set_title("Residuals Comparison")
ax2.legend()

plt.tight_layout()
plt.savefig("figures/01_comparison.png", dpi=300, bbox_inches="tight")
plt.show()

## 5. Loss Landscape Visualization

Let's visualize the loss landscape to understand why multi-start helps.
We'll fix two parameters and scan over the other two to see the multiple minima.

In [None]:
# Create loss landscape by scanning over (b, c) while fixing (a, d) at true values
b_range = np.linspace(0.5, 3.0, 50)
c_range = np.linspace(-np.pi, np.pi, 50)
B, C = np.meshgrid(b_range, c_range)

# Calculate SSR for each (b, c) combination
loss_landscape = np.zeros_like(B)
for i in range(len(c_range)):
    for j in range(len(b_range)):
        y_pred = true_a * np.sin(B[i, j] * x_data + C[i, j]) + true_d
        loss_landscape[i, j] = np.sum((y_data - y_pred) ** 2)

# Log-transform for better visualization
loss_log = np.log10(loss_landscape + 1)

In [None]:
# Visualize loss landscape
fig, ax = plt.subplots(figsize=(10, 8))

# Contour plot
contour = ax.contourf(B, C, loss_log, levels=30, cmap="viridis")
plt.colorbar(contour, ax=ax, label="log10(SSR + 1)")

# Mark true parameters
ax.scatter([true_b], [true_c], color="white", marker="*", s=200, 
           label="True parameters", edgecolors="black", linewidths=1)

# Mark single-start result
ax.scatter([popt_single[1]], [popt_single[2]], color="blue", marker="o", s=100,
           label="Single-start result", edgecolors="white", linewidths=1)

# Mark multi-start result
ax.scatter([popt_multi[1]], [popt_multi[2]], color="red", marker="s", s=100,
           label="Multi-start result", edgecolors="white", linewidths=1)

ax.set_xlabel("b (frequency)")
ax.set_ylabel("c (phase)")
ax.set_title("Loss Landscape (a, d fixed at true values)\nMultiple local minima visible")
ax.legend(loc="upper right")

plt.tight_layout()
plt.savefig("figures/01_loss_landscape.png", dpi=300, bbox_inches="tight")
plt.show()

## 6. Starting Point Distribution

Latin Hypercube Sampling (LHS) provides better coverage of the parameter space
compared to random sampling.

In [None]:
# Generate samples using LHS to visualize the distribution
from nlsq.global_optimization import latin_hypercube_sample, scale_samples_to_bounds
import jax

# Generate 20 starting points for visualization
n_samples_viz = 20
n_params = 4

# Generate LHS samples in [0, 1]^d
key = jax.random.PRNGKey(42)
lhs_samples = latin_hypercube_sample(n_samples_viz, n_params, rng_key=key)

# Scale to bounds
lb = np.array([0.5, 0.5, -np.pi, -2.0])
ub = np.array([5.0, 3.0, np.pi, 5.0])
scaled_samples = scale_samples_to_bounds(lhs_samples, lb, ub)

print(f"Generated {n_samples_viz} LHS samples in 4D parameter space")
print(f"Sample shape: {scaled_samples.shape}")

In [None]:
# Visualize 2D projection of starting points (b vs c)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: LHS samples on loss landscape
ax1 = axes[0]
contour = ax1.contourf(B, C, loss_log, levels=30, cmap="viridis", alpha=0.7)
ax1.scatter(scaled_samples[:, 1], scaled_samples[:, 2], color="yellow", 
            marker="o", s=80, label="LHS starting points", edgecolors="black", linewidths=1)
ax1.scatter([true_b], [true_c], color="white", marker="*", s=200, 
            label="True parameters", edgecolors="black", linewidths=1)
ax1.set_xlabel("b (frequency)")
ax1.set_ylabel("c (phase)")
ax1.set_title("LHS Starting Points on Loss Landscape")
ax1.legend()

# Right: All 2D projections
ax2 = axes[1]
param_names = ["a", "b", "c", "d"]
colors = plt.cm.tab10(np.linspace(0, 1, 6))

plot_idx = 0
for i in range(n_params):
    for j in range(i + 1, n_params):
        ax2.scatter(scaled_samples[:, i], scaled_samples[:, j], 
                   alpha=0.6, s=30, color=colors[plot_idx],
                   label=f"{param_names[i]} vs {param_names[j]}")
        plot_idx += 1

ax2.set_xlabel("Parameter value (normalized)")
ax2.set_ylabel("Parameter value (normalized)")
ax2.set_title("LHS Coverage: All 2D Projections")
ax2.legend(loc="upper right", fontsize=8)

plt.tight_layout()
plt.savefig("figures/01_starting_points.png", dpi=300, bbox_inches="tight")
plt.show()

## 7. Key Takeaways

1. **Local minima are common** in nonlinear optimization, especially with periodic functions.

2. **Single-start optimization** is sensitive to initial guess and may converge to local minima.

3. **Multi-start optimization** explores multiple starting points, significantly increasing the
   chance of finding the global optimum.

4. **GlobalOptimizationConfig** controls:
   - `n_starts`: Number of starting points (more = better coverage, slower)
   - `sampler`: Sampling strategy ('lhs', 'sobol', 'halton')
   - `center_on_p0`: Whether to center around initial guess
   - `scale_factor`: Exploration range

5. **Latin Hypercube Sampling (LHS)** provides stratified coverage, ensuring the starting
   points are evenly distributed across the parameter space.

In [None]:
# Summary statistics
print("Summary")
print("=" * 50)
print(f"True parameters: a={true_a}, b={true_b}, c={true_c}, d={true_d}")
print(f"\nSingle-start result (poor initial guess):")
print(f"  Parameters: a={popt_single[0]:.3f}, b={popt_single[1]:.3f}, c={popt_single[2]:.3f}, d={popt_single[3]:.3f}")
print(f"  SSR: {ssr_single:.4f}")
print(f"\nMulti-start result (10 starts, LHS):")
print(f"  Parameters: a={popt_multi[0]:.3f}, b={popt_multi[1]:.3f}, c={popt_multi[2]:.3f}, d={popt_multi[3]:.3f}")
print(f"  SSR: {ssr_multi:.4f}")