# Constraints

How to handle constraints in nonlinear least squares problems using jaxls.
This guide covers equality and inequality constraints, with a portfolio optimization example.

Features used:
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` with `kind` parameter for constraints
- {class}`~jaxls.AugmentedLagrangianConfig` for solver tuning

In [None]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [None]:
import jax
import jax.numpy as jnp
import jaxls

## Types of constraints

jaxls supports three constraint types, specified via the `kind` parameter in `@jaxls.Cost.factory`:

| Constraint Type | `kind` Parameter | Mathematical Form |
|----------------|------------------|-------------------|
| Equality | `"constraint_eq_zero"` | $h(x) = 0$ |
| Inequality (upper bound) | `"constraint_leq_zero"` | $g(x) \leq 0$ |
| Inequality (lower bound) | `"constraint_geq_zero"` | $g(x) \geq 0$ |

The default `kind="l2_squared"` creates a standard least-squares cost term.

## Example: portfolio optimization

We'll optimize a portfolio of 4 assets to minimize variance (risk) subject to:
- Budget constraint: weights sum to 1 (equality)
- Return target: expected return >= minimum threshold (inequality)
- No short-selling: all weights >= 0 (inequality)

In [None]:
# Asset data: 4 assets with expected returns and covariance.
n_assets = 4
asset_names = ["Tech", "Healthcare", "Energy", "Bonds"]

# Expected annual returns.
expected_returns = jnp.array([0.12, 0.08, 0.10, 0.04])

# Covariance matrix (annual).
covariance = jnp.array(
    [
        [0.04, 0.006, 0.010, -0.002],
        [0.006, 0.025, 0.004, 0.001],
        [0.010, 0.004, 0.035, -0.001],
        [-0.002, 0.001, -0.001, 0.005],
    ]
)

print(
    "Expected returns:", {n: f"{r:.1%}" for n, r in zip(asset_names, expected_returns)}
)

In [None]:
# Define the portfolio weights variable.
class WeightsVar(
    jaxls.Var[jax.Array], default_factory=lambda: jnp.ones(n_assets) / n_assets
):
    """Portfolio weights (n_assets-dimensional vector)."""


weights_var = WeightsVar(id=0)

## Defining costs and constraints

### Objective: minimize variance

Portfolio variance is $w^T \Sigma w$. We use the Cholesky factor $L$ where $\Sigma = LL^T$,
so minimizing $\|L^T w\|^2$ is equivalent to minimizing variance.

In [None]:
# Cholesky decomposition for the variance cost.
cov_chol = jnp.linalg.cholesky(covariance)


@jaxls.Cost.factory  # Default kind="l2_squared".
def variance_cost(
    vals: jaxls.VarValues, var: WeightsVar, cov_chol: jax.Array
) -> jax.Array:
    """Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
    return cov_chol.T @ vals[var]

### Budget constraint (equality)

The weights must sum to 1 (fully invested portfolio): $\sum_i w_i = 1$

We write this as $h(w) = \sum_i w_i - 1 = 0$.

In [None]:
@jaxls.Cost.factory(kind="constraint_eq_zero")
def budget_constraint(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
    """Weights must sum to 1 (fully invested)."""
    weights = vals[var]
    return jnp.array([jnp.sum(weights) - 1.0])

### Return target (inequality >= 0)

Expected portfolio return must meet a minimum target: $w^T \mu \geq r_{\text{target}}$

We write this as $g(w) = w^T \mu - r_{\text{target}} \geq 0$.

In [None]:
@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
    vals: jaxls.VarValues, var: WeightsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
    """Expected return must meet target: E[r] >= target."""
    weights = vals[var]
    return jnp.array([jnp.dot(weights, exp_ret) - target])

### No short-selling (inequality >= 0)

All weights must be non-negative: $w_i \geq 0$ for all $i$

This returns the weights directly as the constraint output.

In [None]:
@jaxls.Cost.factory(kind="constraint_geq_zero")
def no_short_selling(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
    """No short-selling: weights >= 0."""
    return vals[var]

## Solving the problem

We'll solve for a target return of 8% (between the lowest-return Bonds at 4% and highest-return Tech at 12%).

In [None]:
target_return = 0.08

costs = [
    variance_cost(weights_var, cov_chol),
    budget_constraint(weights_var),
    return_constraint(weights_var, expected_returns, target_return),
    no_short_selling(weights_var),
]

problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
solution = problem.solve(
    linear_solver="dense_cholesky",
    termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)

In [None]:
# Extract results.
optimal_weights = solution[weights_var]

print("\nOptimal allocation:")
for name, w in zip(asset_names, optimal_weights):
    print(f"  {name}: {float(w):.1%}")

portfolio_return = float(jnp.dot(optimal_weights, expected_returns))
portfolio_std = float(jnp.sqrt(optimal_weights @ covariance @ optimal_weights))

print("\nPortfolio metrics:")
print(f"  Expected return: {portfolio_return:.2%}")
print(f"  Std deviation: {portfolio_std:.2%}")
print(f"  Weights sum: {float(jnp.sum(optimal_weights)):.6f}")

## Efficient frontier

By varying the target return, we can trace out the efficient frontier.
Using `jax.vmap`, we solve for all target returns in parallel.

In [None]:
# Range of target returns.
min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return + 0.005, max_return - 0.005, 15)


def solve_for_target(target: jax.Array) -> jax.Array:
    """Solve portfolio optimization for a given target return."""
    costs = [
        variance_cost(weights_var, cov_chol),
        budget_constraint(weights_var),
        return_constraint(weights_var, expected_returns, target),
        no_short_selling(weights_var),
    ]
    problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
    solution = problem.solve(
        verbose=False,
        linear_solver="dense_cholesky",
        termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
    )
    return solution[weights_var]


# Use vmap to solve for all target returns in parallel.
all_weights = jax.vmap(solve_for_target)(target_returns)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)
std_devs = jax.vmap(lambda w: jnp.sqrt(w @ covariance @ w))(all_weights)

print(f"Computed {len(target_returns)} points on the efficient frontier")

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

colors = ["#2196F3", "#4CAF50", "#FF9800", "#9C27B0"]

# Convert JAX arrays to Python floats for Plotly
std_devs_list = [float(s) * 100 for s in std_devs]
returns_list = [float(r) * 100 for r in returns_achieved]

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Efficient Frontier", "Asset Allocation"),
    column_widths=[0.4, 0.6],
)

# Left plot: Efficient frontier
fig.add_trace(
    go.Scatter(
        x=std_devs_list,
        y=returns_list,
        mode="lines+markers",
        marker=dict(size=6, color=returns_list, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Std Dev: %{x:.1f}%<br>Return: %{y:.1f}%<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Right plot: Asset allocation (use target_returns for x-axis to avoid stacking
# bars at duplicate achieved returns when return constraint isn't binding)
for i, (name, color) in enumerate(zip(asset_names, colors)):
    fig.add_trace(
        go.Bar(
            x=[float(t) * 100 for t in target_returns],
            y=[float(w) * 100 for w in all_weights[:, i]],
            name=name,
            marker_color=color,
            hovertemplate=f"{name}: %{{y:.1f}}%<extra></extra>",
        ),
        row=1,
        col=2,
    )

fig.update_xaxes(title_text="Std Deviation (%)", row=1, col=1)
fig.update_yaxes(title_text="Return (%)", row=1, col=1)
fig.update_xaxes(title_text="Target Return (%)", row=1, col=2)
fig.update_yaxes(title_text="Allocation (%)", range=[0, 105], row=1, col=2)

fig.update_layout(
    barmode="stack",
    height=400,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.75),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

## Why not standard least squares?

Before explaining the augmented Lagrangian method, let's see what happens if we try to solve the portfolio problem using a standard least squares penalty approach. We convert constraints to squared costs and add them with a fixed penalty weight:

In [None]:
# Standard least squares approach: convert constraints to costs with a penalty weight.
penalty_weight = 10.0


@jaxls.Cost.factory  # Regular cost, NOT a constraint.
def budget_cost_penalty(
    vals: jaxls.VarValues, var: WeightsVar, weight: float
) -> jax.Array:
    """Penalize deviation from budget constraint: weight * (sum(w) - 1)."""
    weights = vals[var]
    return jnp.array([weight * (jnp.sum(weights) - 1.0)])


@jaxls.Cost.factory
def return_cost_penalty(
    vals: jaxls.VarValues,
    var: WeightsVar,
    exp_ret: jax.Array,
    target: float,
    weight: float,
) -> jax.Array:
    """Penalize return shortfall: weight * max(0, target - expected_return)."""
    weights = vals[var]
    shortfall = jnp.maximum(0.0, target - jnp.dot(weights, exp_ret))
    return jnp.array([weight * shortfall])


@jaxls.Cost.factory
def no_short_penalty(
    vals: jaxls.VarValues, var: WeightsVar, weight: float
) -> jax.Array:
    """Penalize negative weights: weight * max(0, -w)."""
    weights = vals[var]
    return weight * jnp.maximum(0.0, -weights)


# Solve with penalty approach.
penalty_costs = [
    variance_cost(weights_var, cov_chol),
    budget_cost_penalty(weights_var, penalty_weight),
    return_cost_penalty(weights_var, expected_returns, target_return, penalty_weight),
    no_short_penalty(weights_var, penalty_weight),
]

penalty_problem = jaxls.LeastSquaresProblem(penalty_costs, [weights_var]).analyze()
penalty_solution = penalty_problem.solve(verbose=False)
penalty_weights = penalty_solution[weights_var]

# Check constraint violations.
budget_violation = float(jnp.sum(penalty_weights) - 1.0)
return_achieved = float(jnp.dot(penalty_weights, expected_returns))

print("Penalty method result:")
print(f"  Weights sum: {float(jnp.sum(penalty_weights)):.4f} (should be 1.0)")
print(f"  Return: {return_achieved:.2%} (target: {target_return:.2%})")
print(f"  Budget violation: {abs(budget_violation):.4f}")

In [None]:
# Compare penalty method vs augmented Lagrangian.
fig_compare = go.Figure()

# Penalty method allocation.
fig_compare.add_trace(
    go.Bar(
        name="Penalty method",
        x=asset_names,
        y=[float(w) * 100 for w in penalty_weights],
        marker_color="lightcoral",
        text=[f"{float(w) * 100:.1f}%" for w in penalty_weights],
        textposition="outside",
    )
)

# Augmented Lagrangian allocation.
fig_compare.add_trace(
    go.Bar(
        name="Augmented Lagrangian",
        x=asset_names,
        y=[float(w) * 100 for w in optimal_weights],
        marker_color="steelblue",
        text=[f"{float(w) * 100:.1f}%" for w in optimal_weights],
        textposition="outside",
    )
)

# Add constraint violation annotations.
penalty_sum = float(jnp.sum(penalty_weights))
al_sum = float(jnp.sum(optimal_weights))

fig_compare.update_layout(
    title=f"Penalty (weights sum = {penalty_sum:.3f}) vs Augmented Lagrangian (weights sum = {al_sum:.6f})",
    barmode="group",
    yaxis_title="Weight (%)",
    height=350,
    margin=dict(t=60, b=40),
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
)

HTML(fig_compare.to_html(full_html=False, include_plotlyjs="cdn"))

The penalty method fails to satisfy the budget constraint exactly. The weights don't sum to 1.0 because the solver must trade off between minimizing variance and satisfying the penalty term. Increasing the penalty weight helps but makes the problem ill-conditioned.

## How it works: Augmented Lagrangian

When constraints are present, jaxls uses an Augmented Lagrangian method.

### The augmented Lagrangian

The method adds both a linear term (Lagrange multiplier $\lambda$) and a quadratic penalty to the objective:

$$\mathcal{L}(x, \lambda, \rho) = f(x) + \lambda \cdot h(x) + \frac{\rho}{2} h(x)^2$$

The multiplier $\lambda$ handles steady-state enforcement while the penalty $\rho$ accelerates convergence.

### Multiplier updates

As the solver runs, it updates the multipliers based on constraint violations:

$$\lambda_{\text{new}} = \lambda + \rho \cdot h(x)$$

For inequality constraints $g(x) \leq 0$, this is projected to stay non-negative:
$\lambda_{\text{new}} = \max(0, \lambda + \rho \cdot g(x))$.

Updates occur when the cost stabilizes, indicating the current subproblem is solved.
The penalty $\rho$ increases if constraints aren't improving fast enough.

## Advanced: tuning the solver

For difficult problems, you can tune the Augmented Lagrangian solver via {class}`~jaxls.AugmentedLagrangianConfig`.
The parameters map to the concepts above. Max iterations can be controlled via {class}`~jaxls.TerminationConfig`:

In [None]:
# Custom configuration example.
al_config = jaxls.AugmentedLagrangianConfig(
    penalty_factor=10.0,  # Multiply rho by this when constraints stagnate.
    penalty_max=1e7,  # Cap on rho to prevent ill-conditioning.
    tolerance_absolute=1e-6,  # Constraint violation tolerance for convergence.
)

# Use with solve(). Max iterations is controlled via TerminationConfig.
solution = problem.solve(
    verbose=False,
    linear_solver="dense_cholesky",
    augmented_lagrangian=al_config,
    termination=jaxls.TerminationConfig(max_iterations=150),
)

print("Solution with custom config:")
for name, w in zip(asset_names, solution[weights_var]):
    print(f"  {name}: {float(w):.1%}")

## Summary

Key points for constrained optimization in jaxls:

- Use `kind="constraint_eq_zero"` for equality constraints $h(x) = 0$
- Use `kind="constraint_leq_zero"` for upper bounds $g(x) \leq 0$
- Use `kind="constraint_geq_zero"` for lower bounds $g(x) \geq 0$
- Constraints are handled automatically via Augmented Lagrangian
- Tune with {class}`~jaxls.AugmentedLagrangianConfig` if needed

For more examples, see the [mean-variance portfolio notebook](../../portfolio/mean_variance.ipynb).