<a href="https://colab.research.google.com/github/profteachkids/chetools/blob/main/ConstrainedRootSolver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
from scipy.optimize import root

In [124]:
def f(v):
    x1,x2,x3 = v
    return jnp.array([jnp.linalg.norm(v) - 1, x1/x2 - 1., x2/x3-0.5])

In [333]:
def wrap_constraint(f, guess, bounds):

    lbs = jnp.isfinite(bounds[:,0])
    ubs = jnp.isfinite(bounds[:,1])
    both = lbs & ubs
    ubs_only = ubs & ~lbs
    Nvs = bounds.shape[0]
    Nboth = jnp.sum(both)
    slacks = jnp.zeros(Nboth)
    constrained_guess = jnp.r_[guess, slacks]

    def constrained(vec):
        vec=jnp.asarray(vec)
        v=vec[:Nvs]
        slacks=vec[Nvs:]
        v=v.at[lbs].set(bounds[lbs,0]+jnp.abs(v[lbs]))
        v=v.at[ubs_only].set(bounds[ubs_only,1]-jnp.abs(v[ubs_only]))
        constraintf = slacks**2 - (bounds[both,1] - v[both])
        return jnp.r_[f(v[:Nvs]),constraintf]

    def v2x(vec):
        v=jnp.asarray(vec)[:Nvs]
        v=v.at[lbs].set(bounds[lbs,0]+jnp.abs(v[lbs]))
        v=v.at[ubs_only].set(bounds[ubs_only,1]-jnp.abs(v[ubs_only]))
        return v

    return constrained, constrained_guess, v2x

In [392]:
def wrap_reflect(f,bounds):

    lb = bounds[:,0]
    ub = bounds[:,1]
    lbs = jnp.isfinite(lb)
    ubs = jnp.isfinite(ub)
    lbs_only = lbs & ~ubs
    ubs_only = ubs & ~lbs
    both = lbs & ubs


    def reflect(v):
        v=jnp.asarray(v)
        v=v.at[lbs_only].set(jnp.maximum(v[lbs_only], 2*lb[lbs_only]- v[lbs_only]))
        v=v.at[ubs_only].set(jnp.minimum(v[ubs_only], 2*ub[ubs_only]- v[ubs_only]))
        d = ub[both]- lb[both]
        t = jnp.remainder(v[both]-lb[both], 2*d)
        v=v.at[both].set(lb[both]+jnp.minimum(t, 2*d-t))
        return v

    def wrapped(v):
        return f(reflect(v))

    return wrapped, reflect

In [393]:
bounds = jnp.asarray([[-1. ,None],
          [-1., 0.],
          [None, 0.]])
guess = [-3, -2.5, 20.]

In [394]:
wrapped, reflect = wrap_reflect(f, bounds)
res=root(wrapped, guess, jac=jax.jacobian(wrapped))
reflect(res.x)

Array([-0.40824829, -0.40824829, -0.81649658], dtype=float64)

In [391]:
res

 message: The solution converged.
 success: True
  status: 1
     fun: [ 8.018e-13  7.834e-13  3.281e-12]
       x: [-4.082e-01 -2.408e+00 -8.165e-01]
    nfev: 16
    njev: 2
    fjac: [[-1.565e-01 -9.873e-01  2.811e-02]
           [ 6.775e-01 -8.661e-02  7.304e-01]
           [ 7.187e-01 -1.334e-01 -6.824e-01]]
       r: [ 2.433e+00 -2.411e+00  2.488e-02 -1.056e+00  5.244e-01
           -1.404e+00]
     qtf: [-9.699e-11  6.802e-10 -4.386e-10]

In [339]:
constrained, constrained_guess, v2x=wrap_constraint(jax.jit(f), guess, bounds)

In [340]:
res=root(constrained, constrained_guess, jac=jax.jacobian(constrained))
v2x(res.x)

Array([-0.40824829, -0.40824829, -0.81649658], dtype=float64)

In [326]:
f(v2x(res.x))

Array([-4.50750548e-14, -3.54893892e-12,  4.21884749e-15], dtype=float64)

IndexError: ignored