# NLSQ 2D Gaussian Demo (Fixed Version)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Dipolar-Quantum-Gases/nlsq/blob/main/examples/NLSQ_2D_Gaussian_Demo_Fixed.ipynb)

This notebook demonstrates 2D Gaussian fitting with improved GPU error handling.

## Installing and Importing

Make sure your runtime type is set to GPU if available (though this will work with CPU as well).

In [None]:
# Install NLSQ if not already installed
!pip install nlsq

## Configure Environment

Set up JAX to handle GPU memory properly and avoid cuSolver errors:

In [None]:
import os
import warnings

# Configure JAX for better GPU memory handling
os.environ['JAX_PREALLOCATE_GPU_MEMORY'] = 'false'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_GPU_MEMORY_FRACTION'] = '0.8'

# Optional: Force CPU if GPU issues persist
# os.environ['JAX_PLATFORMS'] = 'cpu'

print("Environment configured for optimal performance")

Import NLSQ before importing JAX since we need NLSQ to set JAX to use 64-bit precision:

In [None]:
import jax
import jax.numpy as jnp
from nlsq import CurveFit

# Check which device we're using
try:
    devices = jax.devices()
    print(f"Available JAX devices: {devices}")
    print(f"Using device: {devices[0]}")
except Exception as e:
    print(f"Device detection: {e}")
    print("Will use CPU fallback if needed")

## Define the 2D Gaussian Function

In [None]:
def rotate_coordinates2D(coords, theta):
    """Rotate 2D coordinates by angle theta."""
    R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], 
                   [jnp.sin(theta), jnp.cos(theta)]])
    
    shape = coords[0].shape
    coords = jnp.stack([coord.flatten() for coord in coords])
    rcoords = R @ coords
    return [jnp.reshape(coord, shape) for coord in rcoords]


def gaussian2d(coords, n0, x0, y0, sigma_x, sigma_y, theta, offset):
    """2D Gaussian function with rotation."""
    coords = [coords[0] - x0, coords[1] - y0]  # translate first
    X, Y = rotate_coordinates2D(coords, theta)
    density = n0 * jnp.exp(-0.5 * (X**2 / sigma_x**2 + Y**2 / sigma_y**2))
    return density + offset

## Generate Synthetic Data

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np


def get_coordinates(width, height):
    x = np.linspace(0, width - 1, width)
    y = np.linspace(0, height - 1, height)
    X, Y = np.meshgrid(x, y)
    return X, Y


def get_gaussian_parameters(length):
    n0 = 1
    x0 = length / 2
    y0 = length / 2
    sigx = length / 6
    sigy = length / 8
    theta = np.pi / 3
    offset = 0.1 * n0
    params = [n0, x0, y0, sigx, sigy, theta, offset]
    return params


# Start with a moderate size for testing
length = 200  # Reduced from 500 to avoid memory issues
XY_tuple = get_coordinates(length, length)

params = get_gaussian_parameters(length)
print(f"True parameters: {params}")

# Generate noisy data
zdata = gaussian2d(XY_tuple, *params)
zdata += np.random.normal(0, 0.1, size=(length, length))

# Visualize the data
plt.figure(figsize=(8, 6))
plt.imshow(zdata, cmap='viridis')
plt.colorbar(label='Intensity')
plt.title(f'2D Gaussian Data ({length}x{length})')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

## Perform Curve Fitting

We'll fit the data multiple times with different random seeds to test robustness:

In [None]:
from scipy.optimize import curve_fit


def get_random_float(low, high):
    delta = high - low
    return low + delta * np.random.random()


# Flatten data for fitting
flat_data = zdata.flatten()
flat_XY_tuple = [coord.flatten() for coord in XY_tuple]

# Initialize NLSQ CurveFit object
jcf = CurveFit()

# Perform multiple fits
n_fits = 10  # Reduced from 100 for faster testing
times = []
all_results = []

print(f"Performing {n_fits} fits...")

for i in range(n_fits):
    if i % 5 == 0:
        print(f"  Fit {i+1}/{n_fits}")
    
    # Random seed near true values
    seed = [val * get_random_float(0.9, 1.2) for val in params]
    
    try:
        st = time.time()
        popt, pcov = jcf.curve_fit(gaussian2d, flat_XY_tuple, flat_data, p0=seed)
        fit_time = time.time() - st
        
        times.append(fit_time)
        all_results.append(popt)
        
    except Exception as e:
        print(f"  Warning: Fit {i+1} failed: {e}")
        continue

if times:
    print(f"\nCompleted {len(times)}/{n_fits} fits successfully")
    print(f"Average fit time: {np.mean(times[1:]):.3f} seconds (excluding JIT compilation)")
    print(f"First fit time (includes JIT): {times[0]:.3f} seconds")
else:
    print("No successful fits. Please check your environment.")

## Compare with SciPy

In [None]:
# Compare with a single SciPy fit
if all_results:
    print("Comparing NLSQ with SciPy...")
    
    # Use the last seed for comparison
    seed = [val * get_random_float(0.9, 1.2) for val in params]
    
    # Time SciPy
    st = time.time()
    popt_scipy, pcov_scipy = curve_fit(gaussian2d, flat_XY_tuple, flat_data, p0=seed)
    scipy_time = time.time() - st
    
    # Get last NLSQ result
    popt_nlsq = all_results[-1]
    
    print(f"\nFit times:")
    print(f"  NLSQ (after JIT): {np.mean(times[1:]) if len(times) > 1 else times[0]:.3f} seconds")
    print(f"  SciPy: {scipy_time:.3f} seconds")
    
    print(f"\nSpeedup: {scipy_time / np.mean(times[1:]) if len(times) > 1 else scipy_time / times[0]:.1f}x")
    
    print(f"\nParameter comparison:")
    print(f"  True params:  {params}")
    print(f"  NLSQ params:  {list(popt_nlsq)}")
    print(f"  SciPy params: {list(popt_scipy)}")
    
    # Calculate errors
    nlsq_error = np.max(np.abs((np.array(popt_nlsq) - np.array(params)) / np.array(params))[:-1])
    scipy_error = np.max(np.abs((np.array(popt_scipy) - np.array(params)) / np.array(params))[:-1])
    
    print(f"\nMax relative errors (excluding offset):")
    print(f"  NLSQ:  {nlsq_error:.4f}")
    print(f"  SciPy: {scipy_error:.4f}")

## Visualize Results

In [None]:
if all_results and len(times) > 1:
    # Plot fit times
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(times[1:], 'b-', label='NLSQ (after JIT)')
    plt.axhline(y=scipy_time, color='r', linestyle='--', label='SciPy')
    plt.xlabel('Fit Number')
    plt.ylabel('Fit Time (seconds)')
    plt.title('Fitting Speed Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot fitted vs true data
    plt.subplot(1, 2, 2)
    fitted_data = gaussian2d(XY_tuple, *popt_nlsq).reshape(length, length)
    residuals = zdata - fitted_data
    
    plt.imshow(residuals, cmap='RdBu', vmin=-0.3, vmax=0.3)
    plt.colorbar(label='Residuals')
    plt.title('Fit Residuals')
    plt.xlabel('X')
    plt.ylabel('Y')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nResiduals RMS: {np.sqrt(np.mean(residuals**2)):.4f}")

## Troubleshooting

If you encounter GPU errors:

1. **cuSolver Errors**: The notebook now includes automatic CPU fallback
2. **Out of Memory**: Reduce the `length` parameter or restart the runtime
3. **Force CPU**: Uncomment the `JAX_PLATFORMS='cpu'` line in the configuration cell
4. **Colab Specific**: Use Runtime → Restart runtime if GPU issues persist

The implementation now includes:
- Automatic GPU/CPU fallback for SVD operations
- Better memory management
- More robust error handling