# NLSQ Optimization Example

This notebook demonstrates how to use the **NLSQ** package for trust-region nonlinear least squares optimization in Homodyne v2.0+.

## What You'll Learn

1. Generate synthetic XPCS data with known ground truth parameters
2. Set up configuration and optimization parameters
3. Run NLSQ optimization with error recovery
4. Visualize optimization results and fit quality
5. Understand device configuration (CPU vs GPU)

## Requirements

```bash
pip install homodyne>=2.0 matplotlib numpy
```

In [None]:
# Imports
import matplotlib.pyplot as plt
import numpy as np

from tests.factories.synthetic_data import generate_static_isotropic_dataset

# Check if NLSQ is available
try:
    import nlsq

    print(f"✓ NLSQ version: {nlsq.__version__}")
except ImportError:
    print("❌ NLSQ not installed. Run: pip install nlsq")

# Check JAX device
try:
    import jax

    devices = jax.devices()
    print(f"✓ JAX devices: {devices}")
    print(f"  Default device: {devices[0].platform}")
except Exception as e:
    print(f"❌ JAX error: {e}")

## Step 1: Generate Synthetic XPCS Data

We'll create a synthetic dataset with known ground truth parameters:
- **D₀** = 1000.0 μm²/s (diffusion coefficient)
- **α** = 0.5 (anomalous diffusion exponent, α=1 is normal diffusion)
- **D_offset** = 10.0 μm²/s (baseline diffusion offset)
- **contrast** = 0.5 (scaling parameter)
- **offset** = 1.0 (baseline intensity)

In [None]:
# Generate synthetic data (static isotropic mode)
print("Generating synthetic XPCS data...")

# Ground truth parameters
TRUE_D0 = 1000.0
TRUE_ALPHA = 0.5
TRUE_D_OFFSET = 10.0
TRUE_CONTRAST = 0.5
TRUE_OFFSET = 1.0

# Generate dataset (10 phi angles, 25x25 time grid = 6,250 data points)
synthetic_data = generate_static_isotropic_dataset(
    D0=TRUE_D0,
    alpha=TRUE_ALPHA,
    D_offset=TRUE_D_OFFSET,
    contrast=TRUE_CONTRAST,
    offset=TRUE_OFFSET,
    noise_level=0.03,  # 3% noise
    n_phi=10,
    n_t1=25,
    n_t2=25,
    seed=42,  # For reproducibility
)

print(f"✓ Generated {10 * 25 * 25:,} data points")
print(f"  Phi angles: {synthetic_data['phi_values'].shape}")
print(f"  t1 values: {synthetic_data['t1_values'].shape}")
print(f"  t2 values: {synthetic_data['t2_values'].shape}")
print(f"  g2 values: {synthetic_data['g2'].shape}")
print(f"  g2 errors: {synthetic_data['g2_err'].shape}")

## Step 2: Visualize the Synthetic Data

Let's plot the g2 correlation function to see what we're optimizing:

In [None]:
# Plot g2 for a few phi angles
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

phi_indices = [0, 3, 6, 9]  # Plot 4 phi angles

for idx, phi_idx in enumerate(phi_indices):
    ax = axes[idx]

    # Get g2 for this phi angle (first t1 slice)
    g2_slice = synthetic_data["g2"][phi_idx, 0, :]
    g2_err_slice = synthetic_data["g2_err"][phi_idx, 0, :]
    t2_values = synthetic_data["t2_values"]

    # Plot with error bars
    ax.errorbar(
        t2_values,
        g2_slice,
        yerr=g2_err_slice,
        fmt="o",
        markersize=4,
        alpha=0.6,
        label="Data with errors",
    )

    ax.set_xlabel("t2 (delay time)", fontsize=10)
    ax.set_ylabel("g2(φ, t1=0, t2)", fontsize=10)
    ax.set_title(f'φ = {synthetic_data["phi_values"][phi_idx]:.2f} rad', fontsize=11)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle("Synthetic XPCS Data (g2 correlation)", fontsize=13, y=1.01)
plt.show()

print("Note: g2 decays from ~1.25 (contrast + offset) to ~1.0 (offset) as t2 increases")

## Step 3: Configure Optimization Parameters

Set up the configuration and initial parameter guesses:

In [None]:
# Create mock configuration
class MockConfig:
    def __init__(self):
        self.optimization = {"lsq": {"max_iterations": 200, "tolerance": 1e-6}}


config = MockConfig()

# Initial parameter guesses (intentionally perturbed from ground truth)
# Format: [contrast, offset, D0, alpha, D_offset]
initial_params = np.array(
    [
        0.45,  # contrast (true: 0.5)
        0.95,  # offset (true: 1.0)
        900.0,  # D0 (true: 1000.0)
        0.55,  # alpha (true: 0.5)
        12.0,  # D_offset (true: 10.0)
    ]
)

# Parameter bounds (lower, upper)
bounds = (
    np.array([0.0, 0.8, 100.0, 0.3, 1.0]),  # Lower bounds
    np.array([1.0, 1.2, 1e5, 1.5, 1000.0]),  # Upper bounds
)

print("Configuration:")
print(f"  Max iterations: {config.optimization['lsq']['max_iterations']}")
print(f"  Tolerance: {config.optimization['lsq']['tolerance']}")
print("\nInitial parameters:")
param_names = ["contrast", "offset", "D0", "alpha", "D_offset"]
for name, value, true_value in zip(
    param_names,
    initial_params,
    [TRUE_CONTRAST, TRUE_OFFSET, TRUE_D0, TRUE_ALPHA, TRUE_D_OFFSET],
    strict=False,
):
    error_pct = abs(value - true_value) / true_value * 100
    print(
        f"  {name:12s}: {value:10.4f} (true: {true_value:10.4f}, error: {error_pct:5.1f}%)"
    )

## Step 4: Run NLSQ Optimization

Now let's run the optimization and recover the ground truth parameters:

In [None]:
from homodyne.optimization.nlsq_wrapper import NLSQWrapper

# Create wrapper with error recovery enabled
wrapper = NLSQWrapper(
    enable_large_dataset=False,  # Not needed for 6K points
    enable_recovery=True,  # Enable automatic retry on failure
)

print("Running NLSQ optimization...")
print("(This may take 10-30 seconds depending on your hardware)\n")

# Run optimization
result = wrapper.fit(
    data=synthetic_data,
    config=config,
    initial_params=initial_params,
    bounds=bounds,
    analysis_mode="static_isotropic",
)

print("✓ Optimization complete!\n")

# Display results
print("=" * 70)
print("OPTIMIZATION RESULTS")
print("=" * 70)
print(f"Convergence: {result.convergence_status}")
print(f"Success: {result.success}")
print(f"Iterations: {result.n_iterations}")
print(f"Chi-squared: {result.chi_squared:.6f}")
print(f"Reduced χ²: {result.reduced_chi_squared:.6f}")
print(f"\nDevice: {result.device_info.get('device', 'unknown')}")
print(f"Platform: {result.device_info.get('platform', 'unknown')}")

if result.recovery_actions:
    print(f"\nRecovery actions taken: {len(result.recovery_actions)}")
    for action in result.recovery_actions:
        print(f"  - {action}")

print("\n" + "=" * 70)
print("PARAMETER RECOVERY")
print("=" * 70)
print(f"{'Parameter':<15} {'Initial':>12} {'Optimized':>12} {'True':>12} {'Error':>10}")
print("-" * 70)

true_values = [TRUE_CONTRAST, TRUE_OFFSET, TRUE_D0, TRUE_ALPHA, TRUE_D_OFFSET]
for i, name in enumerate(param_names):
    init_val = initial_params[i]
    opt_val = result.parameters[i]
    true_val = true_values[i]
    error_pct = abs(opt_val - true_val) / true_val * 100

    print(
        f"{name:<15} {init_val:12.6f} {opt_val:12.6f} {true_val:12.6f} {error_pct:9.2f}%"
    )

print("=" * 70)

## Step 5: Visualize Fit Quality

Let's compare the optimized fit to the data:

In [None]:
from homodyne.core.jax_backend import compute_g2_scaled

# Compute optimized g2 using recovered parameters
optimized_params_dict = {
    "contrast": result.parameters[0],
    "offset": result.parameters[1],
    "D0": result.parameters[2],
    "alpha": result.parameters[3],
    "D_offset": result.parameters[4],
}

# Compute g2 with optimized parameters
g2_fit = compute_g2_scaled(
    phi=synthetic_data["phi_values"],
    t1=synthetic_data["t1_values"],
    t2=synthetic_data["t2_values"],
    params=optimized_params_dict,
    mode="static_isotropic",
)

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, phi_idx in enumerate([0, 3, 6, 9]):
    ax = axes[idx]

    # Data
    g2_data = synthetic_data["g2"][phi_idx, 0, :]
    g2_err = synthetic_data["g2_err"][phi_idx, 0, :]
    t2 = synthetic_data["t2_values"]

    # Fit
    g2_fit_slice = g2_fit[phi_idx, 0, :]

    # Plot
    ax.errorbar(
        t2,
        g2_data,
        yerr=g2_err,
        fmt="o",
        markersize=4,
        alpha=0.5,
        label="Data",
        color="blue",
    )
    ax.plot(t2, g2_fit_slice, "-", linewidth=2, label="NLSQ Fit", color="red")

    # Compute residuals
    residuals = (g2_data - g2_fit_slice) / g2_err
    rms_residual = np.sqrt(np.mean(residuals**2))

    ax.set_xlabel("t2 (delay time)", fontsize=10)
    ax.set_ylabel("g2(φ, t1=0, t2)", fontsize=10)
    ax.set_title(
        f'φ = {synthetic_data["phi_values"][phi_idx]:.2f} rad\nRMS residual: {rms_residual:.3f}',
        fontsize=11,
    )
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle(f"NLSQ Fit Quality (χ² = {result.chi_squared:.2f})", fontsize=13, y=1.01)
plt.show()

print(f"Overall reduced χ² = {result.reduced_chi_squared:.4f}")
print("  (Values near 1.0 indicate excellent fit quality)")

## Step 6: Analyze Parameter Uncertainties

The covariance matrix provides uncertainty estimates:

In [None]:
# Extract parameter uncertainties from covariance matrix
param_uncertainties = np.sqrt(np.diag(result.covariance))

print("=" * 70)
print("PARAMETER UNCERTAINTIES")
print("=" * 70)
print(f"{'Parameter':<15} {'Value':>12} {'Uncertainty':>12} {'Relative':>10}")
print("-" * 70)

for i, name in enumerate(param_names):
    value = result.parameters[i]
    uncertainty = param_uncertainties[i]
    relative_uncertainty = (uncertainty / value) * 100 if value != 0 else np.inf

    print(f"{name:<15} {value:12.6f} {uncertainty:12.6f} {relative_uncertainty:9.2f}%")

print("=" * 70)
print("\nInterpretation:")
print("  - Small relative uncertainties (<5%) indicate well-constrained parameters")
print("  - Large uncertainties suggest the parameter may be poorly constrained by data")
print("  - For publication-quality uncertainty quantification, use MCMC sampling")

## Summary and Next Steps

### What We've Demonstrated

1. ✅ Generated synthetic XPCS data with known ground truth
2. ✅ Set up NLSQ optimization with parameter bounds
3. ✅ Recovered ground truth parameters within error tolerance
4. ✅ Visualized fit quality and residuals
5. ✅ Analyzed parameter uncertainties

### Key Features of NLSQ Optimization

- **Fast**: Converges in seconds for datasets with thousands of points
- **Robust**: Automatic error recovery with parameter perturbation
- **GPU-accelerated**: Transparent GPU usage via JAX (when available)
- **Backward compatible**: Works with existing homodyne configurations

### Next Steps

1. **Real data**: Replace synthetic data with actual XPCS measurements
2. **Advanced analysis**: Try laminar flow mode (7 parameters)
3. **Uncertainty quantification**: Use MCMC sampling for publication-quality error bars
4. **Large datasets**: Enable `enable_large_dataset=True` for >1M points

### Resources

- **Migration Guide**: [MIGRATION_OPTIMISTIX_TO_NLSQ.md](../docs/MIGRATION_OPTIMISTIX_TO_NLSQ.md)
- **Documentation**: [README.md](../README.md)
- **NLSQ Package**: [github.com/imewei/NLSQ](https://github.com/imewei/NLSQ)
- **Theory Paper**: [He et al. PNAS 2024](https://doi.org/10.1073/pnas.2401162121)