# JAXICNNRegressor: Convex Surrogate Models for Global Optimization

This notebook demonstrates using Input Convex Neural Networks (ICNNs) as surrogate models for global optimization.

## What are ICNNs?

Input Convex Neural Networks are neural networks whose output is guaranteed to be **convex with respect to the inputs**. This is achieved through:

1. **Nonnegative hidden-to-hidden weights** (enforced via softplus parameterization)
2. **Nonnegative output weights** from the final hidden layer
3. **Convex, nondecreasing activations** (softplus or ReLU)

The convexity guarantee means:
- Any local minimum is a global minimum
- Gradient descent always finds the optimal solution
- No multistart or global optimization heuristics needed

## Use Case: Surrogate-Based Optimization

When optimizing an expensive black-box function:
1. **Sample** the function at various points
2. **Train** a convex surrogate model (ICNN)
3. **Optimize** the surrogate using gradient descent
4. **Evaluate** the true function at the surrogate's optimum

**Advantage**: Even if the true function is non-convex, the surrogate is convex and easy to optimize!

**Reference**: Amos, B., Xu, L., & Kolter, J. Z. (2017). Input Convex Neural Networks. ICML 2017.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from scipy.optimize import minimize

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

from pycse.sklearn.jax_icnn import JAXICNNRegressor

# Set random seed for reproducibility
np.random.seed(42)

## 1. Basic Usage: Fitting a Convex Function

First, let's verify that ICNNs work well on convex functions. We'll fit a simple quadratic (convex) function.

In [None]:
# Generate data from a convex quadratic function
X = np.random.randn(200, 2) * 2
y = np.sum(X**2, axis=1) + 0.1 * np.random.randn(200)  # y = x1² + x2² + noise

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train ICNN
model = JAXICNNRegressor(
    hidden_dims=(32, 32),
    epochs=500,
    random_state=42,
    verbose=True,
)
model.fit(X_train, y_train)

# Evaluate
r2 = model.score(X_test, y_test)
print(f"\nR² score on test set: {r2:.4f}")

In [None]:
# Visualize learning curve
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(model.loss_history_)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
y_pred = model.predict(X_test)
plt.scatter(y_test, y_pred, alpha=0.6)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title(f'Predictions (R² = {r2:.3f})')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. Gradient Computation

ICNNs provide exact gradients via automatic differentiation. The `predict_gradient` method returns ∂f/∂x for each sample.

In [None]:
# Compute gradients
grads = model.predict_gradient(X_test[:5])

print("Gradients for first 5 test samples:")
print(f"  Shape: {grads.shape}")
print()
for i in range(5):
    print(f"  Sample {i}: x = [{X_test[i, 0]:6.3f}, {X_test[i, 1]:6.3f}]")
    print(f"            grad = [{grads[i, 0]:6.3f}, {grads[i, 1]:6.3f}]")
    # True gradient for y = x1² + x2² is [2*x1, 2*x2]
    true_grad = 2 * X_test[i]
    print(f"            true = [{true_grad[0]:6.3f}, {true_grad[1]:6.3f}]")
    print()

In [None]:
# Combined prediction and gradient
y_pred, grads = model.predict_with_grad(X_test[:3])
print("predict_with_grad output:")
print(f"  Predictions: {y_pred}")
print(f"  Gradients shape: {grads.shape}")

## 3. Verifying Convexity

A key property of ICNNs is that the output is guaranteed convex in the inputs. Let's verify this by checking the midpoint convexity condition:

$$f\left(\frac{x + y}{2}\right) \leq \frac{f(x) + f(y)}{2}$$

for all pairs of points x and y.

In [None]:
# Verify convexity: f(midpoint) <= average(f(x), f(y))
np.random.seed(123)
n_pairs = 100
x1 = np.random.randn(n_pairs, 2) * 3
x2 = np.random.randn(n_pairs, 2) * 3
x_mid = (x1 + x2) / 2

f_x1 = model.predict(x1)
f_x2 = model.predict(x2)
f_mid = model.predict(x_mid)

# Check: f(mid) <= (f(x1) + f(x2))/2
avg = (f_x1 + f_x2) / 2
violations = np.sum(f_mid > avg + 1e-5)  # Allow small tolerance

print(f"Convexity check: {n_pairs} random pairs")
print(f"  Violations: {violations}/{n_pairs}")
print(f"  Max difference (f_mid - avg): {np.max(f_mid - avg):.6f}")
print(f"  Convexity verified: {'✓' if violations == 0 else '✗'}")

In [None]:
# Visualize convexity along a random line
plt.figure(figsize=(10, 4))

# Pick two random points
p1 = np.array([-2, -1.5])
p2 = np.array([2, 1.5])

# Sample points along the line
t = np.linspace(0, 1, 50)
line_points = np.outer(1 - t, p1) + np.outer(t, p2)

# Evaluate ICNN and true function along line
f_line = model.predict(line_points)
f_true = np.sum(line_points**2, axis=1)

# Linear interpolation (chord)
f_chord = (1 - t) * model.predict(p1.reshape(1, -1))[0] + t * model.predict(p2.reshape(1, -1))[0]

plt.subplot(1, 2, 1)
plt.plot(t, f_line, 'b-', linewidth=2, label='ICNN')
plt.plot(t, f_chord, 'r--', linewidth=2, label='Chord')
plt.fill_between(t, f_line, f_chord, alpha=0.3, color='green', 
                 where=f_line <= f_chord, label='Convexity gap')
plt.xlabel('t (interpolation)')
plt.ylabel('f(p1 + t*(p2-p1))')
plt.title('Convexity: Function Below Chord')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(t, f_line, 'b-', linewidth=2, label='ICNN')
plt.plot(t, f_true, 'g--', linewidth=2, label='True (x²+y²)')
plt.xlabel('t')
plt.ylabel('f')
plt.title('ICNN vs True Function')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Global Optimization with Non-Convex True Function

Now let's demonstrate the main use case: using an ICNN as a convex surrogate for a **non-convex** function.

We'll use the **Rastrigin function**, a famous non-convex test function with many local minima:

$$f(x) = 10n + \sum_{i=1}^{n} \left[ x_i^2 - 10\cos(2\pi x_i) \right]$$

The ICNN will learn a convex approximation, making optimization trivial.

In [None]:
# Define the Rastrigin function (non-convex with many local minima)
def rastrigin(X):
    """Rastrigin function - highly non-convex.
    
    Global minimum: f(0, 0) = 0
    Many local minima throughout the domain.
    """
    A = 10
    n = X.shape[1]
    return A * n + np.sum(X**2 - A * np.cos(2 * np.pi * X), axis=1)


# Visualize the Rastrigin function
x_grid = np.linspace(-3, 3, 100)
y_grid = np.linspace(-3, 3, 100)
X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
XY_flat = np.column_stack([X_grid.ravel(), Y_grid.ravel()])
Z_rastrigin = rastrigin(XY_flat).reshape(X_grid.shape)

fig = plt.figure(figsize=(12, 5))

ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X_grid, Y_grid, Z_rastrigin, cmap='viridis', alpha=0.8)
ax1.set_xlabel('x₁')
ax1.set_ylabel('x₂')
ax1.set_zlabel('f(x)')
ax1.set_title('Rastrigin Function (Non-Convex)')

ax2 = fig.add_subplot(122)
cs = ax2.contour(X_grid, Y_grid, Z_rastrigin, levels=20, cmap='viridis')
ax2.clabel(cs, inline=True, fontsize=8)
ax2.scatter([0], [0], color='red', marker='*', s=200, label='Global min')
ax2.set_xlabel('x₁')
ax2.set_ylabel('x₂')
ax2.set_title('Rastrigin Contours (Many Local Minima)')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
# Generate training data from the non-convex function
np.random.seed(42)
n_samples = 500
X_train_nc = np.random.uniform(-3, 3, (n_samples, 2))
y_train_nc = rastrigin(X_train_nc)

print(f"Training data: {n_samples} samples")
print(f"Input range: [-3, 3] × [-3, 3]")
print(f"Output range: [{y_train_nc.min():.1f}, {y_train_nc.max():.1f}]")
print(f"True global minimum: f(0, 0) = 0")

In [None]:
# Train ICNN as convex surrogate
surrogate = JAXICNNRegressor(
    hidden_dims=(64, 64),
    epochs=1000,
    learning_rate=5e-3,
    random_state=42,
    verbose=True,
)
surrogate.fit(X_train_nc, y_train_nc)

# Test accuracy
X_test_nc = np.random.uniform(-3, 3, (100, 2))
y_test_nc = rastrigin(X_test_nc)
r2 = surrogate.score(X_test_nc, y_test_nc)
print(f"\nSurrogate R² on held-out data: {r2:.4f}")

In [None]:
# Visualize the ICNN surrogate vs true function
Z_surrogate = surrogate.predict(XY_flat).reshape(X_grid.shape)

fig = plt.figure(figsize=(14, 5))

ax1 = fig.add_subplot(131, projection='3d')
ax1.plot_surface(X_grid, Y_grid, Z_rastrigin, cmap='viridis', alpha=0.8)
ax1.set_xlabel('x₁')
ax1.set_ylabel('x₂')
ax1.set_zlabel('f(x)')
ax1.set_title('True Function (Non-Convex)')

ax2 = fig.add_subplot(132, projection='3d')
ax2.plot_surface(X_grid, Y_grid, Z_surrogate, cmap='plasma', alpha=0.8)
ax2.set_xlabel('x₁')
ax2.set_ylabel('x₂')
ax2.set_zlabel('f(x)')
ax2.set_title('ICNN Surrogate (Convex)')

ax3 = fig.add_subplot(133)
cs1 = ax3.contour(X_grid, Y_grid, Z_rastrigin, levels=15, cmap='viridis', alpha=0.5)
cs2 = ax3.contour(X_grid, Y_grid, Z_surrogate, levels=15, cmap='plasma', linestyles='--')
ax3.scatter([0], [0], color='red', marker='*', s=200, zorder=10, label='True global min')
ax3.set_xlabel('x₁')
ax3.set_ylabel('x₂')
ax3.set_title('Contour Comparison')
ax3.legend()

plt.tight_layout()
plt.show()

## 5. Finding the Global Minimum

Because the ICNN surrogate is convex, we can find its global minimum using simple gradient descent. This is much easier than optimizing the original non-convex function!

In [None]:
def optimize_icnn(model, bounds, n_starts=1):
    """
    Find the global minimum of an ICNN using gradient descent.
    
    Since the ICNN is convex, any local minimum is the global minimum.
    We use scipy.optimize.minimize with the ICNN's gradient.
    """
    def objective(x):
        x = x.reshape(1, -1)
        return model.predict(x)[0]
    
    def gradient(x):
        x = x.reshape(1, -1)
        return model.predict_gradient(x)[0]
    
    best_result = None
    for _ in range(n_starts):
        # Random starting point
        x0 = np.random.uniform(bounds[:, 0], bounds[:, 1])
        
        result = minimize(
            objective,
            x0,
            method='L-BFGS-B',
            jac=gradient,
            bounds=bounds,
        )
        
        if best_result is None or result.fun < best_result.fun:
            best_result = result
    
    return best_result


# Optimize the ICNN surrogate
bounds = np.array([[-3, 3], [-3, 3]])
result = optimize_icnn(surrogate, bounds, n_starts=3)

x_opt_surrogate = result.x
f_opt_surrogate = result.fun

print("Optimization of ICNN Surrogate:")
print(f"  Optimal x: [{x_opt_surrogate[0]:.4f}, {x_opt_surrogate[1]:.4f}]")
print(f"  Surrogate value: {f_opt_surrogate:.4f}")
print(f"  True function at optimal x: {rastrigin(x_opt_surrogate.reshape(1, -1))[0]:.4f}")
print()
print("True Global Minimum:")
print("  Optimal x: [0.0, 0.0]")
print(f"  True value: {rastrigin(np.array([[0, 0]]))[0]:.4f}")

In [None]:
# Compare with gradient descent on the true (non-convex) function
def rastrigin_single(x):
    return rastrigin(x.reshape(1, -1))[0]

def rastrigin_grad(x):
    """Gradient of Rastrigin function."""
    return 2 * x + 20 * np.pi * np.sin(2 * np.pi * x)

# Try multiple random starts
np.random.seed(123)
n_starts = 10
results_true = []

for i in range(n_starts):
    x0 = np.random.uniform(-3, 3, 2)
    result = minimize(
        rastrigin_single,
        x0,
        method='L-BFGS-B',
        jac=rastrigin_grad,
        bounds=bounds,
    )
    results_true.append(result)

# Find best result
best_true = min(results_true, key=lambda r: r.fun)

print(f"Gradient descent on TRUE Rastrigin function ({n_starts} random starts):")
print()
for i, r in enumerate(results_true):
    print(f"  Start {i+1}: x = [{r.x[0]:6.3f}, {r.x[1]:6.3f}], f = {r.fun:6.3f}")
print()
print(f"Best found: f = {best_true.fun:.4f} at x = [{best_true.x[0]:.4f}, {best_true.x[1]:.4f}]")
print()
print("Note: Many runs get stuck in local minima!")

In [None]:
# Visualize optimization paths
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# True function contours
for ax in axes:
    ax.contour(X_grid, Y_grid, Z_rastrigin, levels=20, cmap='viridis', alpha=0.5)

# Left: Local minima from direct optimization
for r in results_true:
    axes[0].scatter(r.x[0], r.x[1], c='blue', s=100, marker='x')
axes[0].scatter(0, 0, c='red', s=200, marker='*', zorder=10, label='Global min')
axes[0].set_xlabel('x₁')
axes[0].set_ylabel('x₂')
axes[0].set_title(f'Direct Optimization: {n_starts} starts\n(Blue X = local minima found)')
axes[0].legend()

# Right: ICNN surrogate optimum
axes[1].contour(X_grid, Y_grid, Z_surrogate, levels=20, cmap='plasma', alpha=0.5, linestyles='--')
axes[1].scatter(x_opt_surrogate[0], x_opt_surrogate[1], c='orange', s=200, marker='D', 
               edgecolor='black', linewidth=2, label='Surrogate optimum')
axes[1].scatter(0, 0, c='red', s=200, marker='*', zorder=10, label='True global min')
axes[1].set_xlabel('x₁')
axes[1].set_ylabel('x₂')
axes[1].set_title(f'ICNN Surrogate Optimization\n(Convex → single optimum)')
axes[1].legend()

plt.tight_layout()
plt.show()

## 6. Strong Convexity for Unique Minimum

Adding strong convexity (μ > 0) adds a quadratic term to ensure the minimum is **unique**:

$$f_{\mu}(x) = f(x) + \frac{\mu}{2} \|x\|^2$$

This is useful when:
- You want guaranteed uniqueness of the solution
- The optimization algorithm benefits from strong convexity (faster convergence)
- You want to regularize the solution toward the origin

In [None]:
# Train ICNN with strong convexity
mu = 0.5

surrogate_sc = JAXICNNRegressor(
    hidden_dims=(64, 64),
    epochs=1000,
    strong_convexity_mu=mu,
    random_state=42,
)
surrogate_sc.fit(X_train_nc, y_train_nc)

print(f"Trained ICNN with strong convexity μ = {mu}")
print(f"R² score: {surrogate_sc.score(X_test_nc, y_test_nc):.4f}")

In [None]:
# Compare optimization with and without strong convexity
result_base = optimize_icnn(surrogate, bounds, n_starts=1)
result_sc = optimize_icnn(surrogate_sc, bounds, n_starts=1)

print("Optimization Results:")
print()
print("Without strong convexity (μ=0):")
print(f"  Optimal x: [{result_base.x[0]:.4f}, {result_base.x[1]:.4f}]")
print(f"  True function value: {rastrigin(result_base.x.reshape(1,-1))[0]:.4f}")
print()
print(f"With strong convexity (μ={mu}):")
print(f"  Optimal x: [{result_sc.x[0]:.4f}, {result_sc.x[1]:.4f}]")
print(f"  True function value: {rastrigin(result_sc.x.reshape(1,-1))[0]:.4f}")
print()
print(f"Note: Strong convexity pulls the optimum toward the origin")

In [None]:
# Visualize the effect of strong convexity
Z_base = surrogate.predict(XY_flat).reshape(X_grid.shape)
Z_sc = surrogate_sc.predict(XY_flat).reshape(X_grid.shape)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Without strong convexity
cs1 = axes[0].contour(X_grid, Y_grid, Z_base, levels=20, cmap='plasma')
axes[0].scatter(result_base.x[0], result_base.x[1], c='red', s=200, marker='*', 
               edgecolor='black', linewidth=2, label='Optimum')
axes[0].set_xlabel('x₁')
axes[0].set_ylabel('x₂')
axes[0].set_title('ICNN (μ=0)')
axes[0].legend()

# With strong convexity
cs2 = axes[1].contour(X_grid, Y_grid, Z_sc, levels=20, cmap='plasma')
axes[1].scatter(result_sc.x[0], result_sc.x[1], c='red', s=200, marker='*', 
               edgecolor='black', linewidth=2, label='Optimum')
axes[1].set_xlabel('x₁')
axes[1].set_ylabel('x₂')
axes[1].set_title(f'ICNN (μ={mu})')
axes[1].legend()

plt.tight_layout()
plt.show()

## 7. sklearn Compatibility

JAXICNNRegressor is fully compatible with sklearn's API, including pipelines and cross-validation.

In [None]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, GridSearchCV

# Create a pipeline
pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('icnn', JAXICNNRegressor(epochs=100, standardize_X=False)),  # Don't double-standardize
])

# Cross-validation
scores = cross_val_score(pipe, X_train_nc, y_train_nc, cv=3, scoring='r2')
print(f"Cross-validation R² scores: {scores}")
print(f"Mean R²: {scores.mean():.4f} (±{scores.std():.4f})")

In [None]:
# Grid search for hyperparameter tuning
param_grid = {
    'hidden_dims': [(32,), (32, 32), (64, 32)],
    'learning_rate': [1e-3, 5e-3],
}

model_gs = JAXICNNRegressor(epochs=100, random_state=42)
gs = GridSearchCV(model_gs, param_grid, cv=2, scoring='r2', verbose=1)
gs.fit(X_train_nc, y_train_nc)

print(f"\nBest parameters: {gs.best_params_}")
print(f"Best R² score: {gs.best_score_:.4f}")

## Summary

**JAXICNNRegressor** provides a sklearn-compatible implementation of Input Convex Neural Networks with:

### Key Features:
- **Guaranteed convexity**: Output is always convex in inputs
- **Exact gradients**: Efficient gradient computation via JAX autodiff
- **Strong convexity option**: Add μ||x||² for unique minimum
- **sklearn API**: Works with pipelines, cross-validation, etc.

### Use Cases:
1. **Surrogate-based optimization**: Learn convex approximation of complex/expensive functions
2. **Constrained optimization**: Convex objectives are easier to handle with constraints
3. **Interpretable models**: Convexity provides guarantees about the model's behavior

### Parameters:
- `hidden_dims`: Network architecture (default: (32, 32))
- `activation`: "softplus" or "relu" (default: softplus)
- `learning_rate`: Adam learning rate (default: 5e-3)
- `epochs`: Training epochs (default: 500)
- `strong_convexity_mu`: Strong convexity parameter (default: 0.0)
- `standardize_X/standardize_y`: Input/output standardization

### Methods:
- `fit(X, y)`: Train the model
- `predict(X)`: Make predictions
- `predict_gradient(X)`: Compute ∂f/∂x
- `predict_with_grad(X)`: Get (predictions, gradients) together
- `score(X, y)`: Compute R² score