<a href="https://colab.research.google.com/github/profteachkids/chetools/blob/main/ConstrainedRootSolver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
from scipy.optimize import root

In [124]:
def f(v):
    x1,x2,x3 = v
    return jnp.array([jnp.linalg.norm(v) - 1, x1/x2 - 1., x2/x3-0.5])

In [265]:
def wrap(f, guess, bounds):

    lbs = jnp.logical_not(jnp.isnan(bounds[:,0]))
    ubs = jnp.logical_not(jnp.isnan(bounds[:,1]))
    both = jnp.logical_and(lbs,ubs)
    ubs_only = jnp.logical_and(ubs,jnp.logical_not(lbs))
    Nvs = bounds.shape[0]
    Nboth = jnp.sum(both)
    slacks = jnp.zeros(Nboth)
    constrained_guess = jnp.r_[guess, slacks]

    def constrained(vec):
        vec=jnp.asarray(vec)
        v=vec[:Nvs]
        slacks=vec[Nvs:]
        v=v.at[lbs].set(bounds[lbs,0]+jnp.abs(v[lbs]))
        v=v.at[ubs_only].set(bounds[ubs_only,1]-jnp.abs(v[ubs_only]))
        constraintf = slacks**2 - (bounds[both,1] - v[both])
        return jnp.r_[f(v[:Nvs]),constraintf]

    def v2x(vec):
        v=jnp.asarray(vec)[:Nvs]
        v=v.at[lbs].set(bounds[lbs,0]+jnp.abs(v[lbs]))
        v=v.at[ubs_only].set(bounds[ubs_only,1]-jnp.abs(v[ubs_only]))
        return v

    return constrained, constrained_guess, v2x

In [289]:
bounds = jnp.asarray([[-1. ,None],
          [-1., 0.],
          [None, 0.]])
guess = [-0.5, 0.5, 0.5]

In [290]:
constrained, constrained_guess, v2x=wrap(jax.jit(f), guess, bounds)

In [291]:
res=root(constrained, constrained_guess, jac=jax.jacobian(constrained))
v2x(res.x)

Array([-0.40824829, -0.40824829, -0.81649658], dtype=float64)

In [292]:
f(v2x(res.x))

Array([-4.50750548e-14, -3.54893892e-12,  4.21884749e-15], dtype=float64)

In [293]:
res

 message: The solution converged.
 success: True
  status: 1
     fun: [-4.508e-14 -3.549e-12  4.219e-15  1.457e-11]
       x: [-5.918e-01  5.918e-01  8.165e-01  6.389e-01]
    nfev: 17
    njev: 2
    fjac: [[-1.580e-01 -9.752e-01 -5.396e-04  1.553e-01]
           [ 3.183e-01 -1.752e-01  5.177e-01 -7.746e-01]
           [ 8.311e-01 -1.326e-01 -5.399e-01  1.071e-02]
           [-4.277e-01 -2.797e-02 -6.637e-01 -6.130e-01]]
       r: [-2.571e+00 -2.366e+00 -1.231e-01  1.846e-01 -2.379e+00
           -1.272e-01 -9.683e-01  1.007e+00  1.239e-02 -7.654e-01]
     qtf: [ 6.153e-10 -1.043e-09  6.366e-11 -8.648e-10]