In [45]:
import jax
from jax import Array
import jax.numpy as jnp
import jax_dataclasses as jdc
from typing import *
import numpy as np

import osqp
from scipy import sparse

In [218]:
from jax.experimental.sparse import CSC
import scipy

In [785]:
class Problem:
    def __init__(
        self,
        dim,
        eq_dim,
        ineq_dim,
        obj_fn, 
        constr_fn, 
        
        # grad_fn=None, 
        # hess_fn=None,
        # ceq_jac_fn=None,
        # cineq_jac_fn=None,
    ):
        self.dim = dim
        self.eq_dim = eq_dim
        self.ineq_dim = ineq_dim
        self.f = obj_fn
        self.grad_f = jax.grad(obj_fn)
        self.hess_f = jax.hessian(obj_fn)
        self.g = constr_fn
        self.jac_g = jax.jacrev(constr_fn)
        # self.g = ceq_fn
        # self.h = cineq_fn
        # self.jac_g = jax.jacrev(ceq_fn)
        # self.jac_h = jax.jacrev(cineq_fn)
    
    def eval_merit(self, x, mu):
        val_constr = self.g(x)
        fval = self.f(x)
        eq_viol = jnp.abs(val_constr[:self.eq_dim])
        ineq_viol = jnp.maximum(0., -val_constr[-self.ineq_dim:])
        return fval + mu * jnp.sum(jnp.hstack([eq_viol, ineq_viol]))
    
    def eval_merit_pred(self, x, sol, mu):
        p = sol[:self.dim]
        viol = sol[self.dim:]
        hess = self.hess_f(x)
        grad = self.grad_f(x)
        return prob.f(x) + 1/2*p@hess@p + grad@p + jnp.sum(viol)*mu

obj_fn = lambda x: x[0]*x[3]*jnp.sum(x[:3]) + x[2]
ceq_fn = lambda x: jnp.sum(x**2) - 40
cineq_fn = lambda x: jnp.prod(x) - 25 # >= 0
constr_fn = lambda x: jnp.hstack([ceq_fn(x), cineq_fn(x)])

prob = Problem(4, 1, 1, obj_fn, constr_fn)

In [415]:
from jax.experimental.sparse import BCOO
from scipy.sparse import csc_array

def bcoo_to_scipy(bcoo:BCOO):
    rows, cols = bcoo.indices.T
    return scipy.sparse.csc_matrix(
        (bcoo.data, (rows, cols)), shape=bcoo.shape)

In [527]:
dim = 4
eq_dim = 1
ineq_dim = 1
constr_dim = eq_dim + ineq_dim
aug_dim = dim + eq_dim*2 + ineq_dim

In [787]:
#predefined
P_shape = (aug_dim, aug_dim)

A_shape = (constr_dim+aug_dim, aug_dim)
A_fix = np.zeros(A_shape)
A_fix[:eq_dim, dim:dim+eq_dim] = np.eye(eq_dim)
A_fix[:eq_dim, dim+eq_dim:dim+2*eq_dim] = -np.eye(eq_dim) #r
A_fix[eq_dim:eq_dim+ineq_dim, -ineq_dim:] = np.eye(ineq_dim) #s
A_fix[-aug_dim:, -aug_dim:] = np.eye(aug_dim)
A_fix_indices = np.nonzero(A_fix)
A_fix_val = A_fix[A_fix_indices]

hess_indices = np.indices((dim,dim)).reshape(2,-1)
const_jac_indices = np.indices((constr_dim,dim)).reshape(2,-1)

In [865]:
mu = 0.1
step_size = 1.
max_step_size = 100
x = np.array([1, 5, 5, 1.])

In [790]:
import osqp

for j in range(5):
    # eval function
    grad = prob.grad_f(x)
    try:
        hess_true = prob.hess_f(x)
        np.linalg.cholesky(hess_true)
        hess = hess_true
    except:
        eigs, _ = np.linalg.eigh(hess_true)
        hess = hess_true - eigs[0] * np.eye(dim) * 1.1
    cval = prob.g(x)
    cjac = prob.jac_g(x)

    # convexify
    scaling = hess/hess[0,0]
    step_vector = scaling @ jnp.full(dim, step_size)
    P = csc_array((hess.flatten(), hess_indices), shape=P_shape)
    q = np.hstack([grad, np.full(eq_dim*2+ineq_dim, mu)])
    A = csc_array((
        np.hstack([cjac.flatten(), A_fix_val]), 
        np.hstack([const_jac_indices, A_fix_indices])), 
        shape=A_shape)
    l = np.hstack([-cval, np.full(dim,-step_vector), np.zeros(constr_dim+eq_dim)])
    u = np.hstack([-cval[:eq_dim], np.inf, np.full(dim,step_vector), np.full(constr_dim+eq_dim, np.inf)])

    merit_curr = prob.eval_merit(x, mu)
    for i in range(10):
        #solve qp
        qp = osqp.OSQP()
        qp.setup(P, q, A, l, u, verbose=False)
        res = qp.solve()
        print(res.info.status, res.x)
        p = res.x[:dim]

        # eval improve
        merit_next = prob.eval_merit(x+p, mu)
        merit_pred = prob.f(x) + 1/2*p@hess@p + grad@p + jnp.sum(res.x[4:])*mu
        true_reduction = merit_curr - merit_next
        pred_reduction = merit_curr - merit_pred
        ratio = true_reduction/pred_reduction
        print(f"ratio:{ratio}")

        # update
        if jnp.linalg.norm(p) < 0.01:
            print("small step")
            break
        elif ratio > 0.5:
            print("size up")
            step_size = np.minimum(2*step_size, max_step_size)
            x = x + p
            break # escape trust region loop
        elif ratio < 0.25:
            print("size down")
            step_size /= 4
    
    print()
print(f"step_size:{step_size}")

solved [-4.20159537e-01 -5.83835721e-01 -6.66164283e-01  6.70159552e-01
 -1.42288300e-08 -9.33079456e-09 -9.42992406e-09]
ratio:0.9265150427818298
size up

solved [ 3.16423270e-01  4.13109128e-03 -9.64888337e-02 -2.92753690e-01
 -2.41369211e-04 -1.17281038e-03 -8.96571162e-04]
ratio:0.7337083220481873
size up

solved [-0.0083267   0.00409842 -0.07333934  0.14720635 -0.00068086 -0.00077998
 -0.00074513]
ratio:1.0331720113754272
size up

solved [-7.74924855e-02  1.27618642e-02 -5.62686598e-02  1.52892702e-01
  2.50649107e-09  5.21556792e-09  4.90848455e-09]
ratio:-0.16906237602233887
size down
solved [-7.74933645e-02  1.27616386e-02 -5.62688412e-02  1.52894361e-01
  4.45795946e-09 -2.30728637e-10 -2.28819771e-10]
ratio:-0.16915658116340637
size down
solved [-7.74933645e-02  1.27616386e-02 -5.62688412e-02  1.52894361e-01
  4.45795946e-09 -2.30728637e-10 -2.28819771e-10]
ratio:-0.16915658116340637
size down
solved [-7.74933645e-02  1.27616386e-02 -5.62688412e-02  1.52894361e-01
  4.4579594

In [866]:
def convexify(x, prob:Problem):
    # eval function
    grad = prob.grad_f(x)
    try:
        hess_true = prob.hess_f(x)
        np.linalg.cholesky(hess_true)
        hess = hess_true
    except:
        eigs, _ = np.linalg.eigh(hess_true)
        hess = hess_true - eigs[0] * np.eye(dim) * 1.1
    cval = prob.g(x)
    cjac = prob.jac_g(x)

    # convexify
    scaling = hess/hess[0,0]
    step_vector = scaling @ jnp.full(dim, step_size)
    P = csc_array((hess.flatten(), hess_indices), shape=P_shape)
    q = np.hstack([grad, np.full(eq_dim*2+ineq_dim, mu)])
    A = csc_array((
        np.hstack([cjac.flatten(), A_fix_val]), 
        np.hstack([const_jac_indices, A_fix_indices])), 
        shape=A_shape)
    l = np.hstack([-cval, np.full(dim,-step_vector), np.zeros(constr_dim+eq_dim)])
    u = np.hstack([-cval[:eq_dim], np.inf, np.full(dim,step_vector), np.full(constr_dim+eq_dim, np.inf)])
    model = (P, q, A, l, u)
    return model
    
def trust_region_loop(x, mu, step_size, xtol=0.01):
    P, q, A, l, u = convexify(x, prob)
    qp = osqp.OSQP()
    qp.setup(P, q, A, l, u, verbose=False)
    for i in range(10):
        #solve qp
        res = qp.solve()
        print(res.info.status, res.x)
        p = res.x[:dim]

        # eval improve
        merit_curr = prob.eval_merit(x, mu)
        merit_next = prob.eval_merit(x + p, mu)
        merit_pred = prob.eval_merit_pred(x, res.x, mu)
        ratio = (merit_curr - merit_next)/(merit_curr - merit_pred)
        print(f"ratio:{ratio}")

        # update
        if step_size < xtol:
            print("small step")
            return "converged", x, step_size * 4
        elif ratio > 0.4:
            print("size up")
            step_size = np.minimum(2*step_size, max_step_size)
            return "updated", x + p, step_size
        else: #  ratio < 0.25:
            print("size down")
            step_size /= 4
    print(f"step_size: {step_size}")
    print(f"p: {p}")

In [907]:
opt_state, x, step_size = trust_region_loop(x, mu, step_size)

solved [ 0.0212854  -0.00765709 -0.00414677  0.02295921 -0.00188036 -0.00188314
 -0.00029584]
ratio:-0.17543241381645203
size down
solved [ 2.12697068e-02 -7.65140608e-03 -4.14501793e-03  2.29485101e-02
  9.79335212e-09  1.11369923e-08  1.96430145e-09]
ratio:-2.6324422359466553
small step


In [908]:
mu *= 10

In [910]:
mu

100000000.0

: 

In [909]:
x

array([1.24832096, 4.41153495, 4.22183547, 1.07527271])

In [812]:
step_size

2.0

In [774]:
step_size

3.0517578125e-05

step_size:0.0078125


In [760]:
x

array([1.06879893, 4.82043342, 3.73254552, 1.29986486])

In [718]:
ratio > 0.75

Array(True, dtype=bool)

In [713]:
step_size*2

0.03125

In [691]:
step_size

0.015625

In [542]:
prob.eval_merit(x, mu)

Array([17.2], dtype=float32)

In [491]:
prob.f(x) + g

Array(16., dtype=float32)

In [480]:
p

array([-0.4209399 , -0.54875484, -0.6310834 ,  0.65108425,  0.0015923 ,
        0.66060736,  0.14706493])

In [485]:
p

array([-4.19622070e-01, -5.83757353e-01, -6.66085914e-01,  6.69649671e-01,
        5.12204052e-04, -7.07267509e-04, -5.69044454e-05])

In [348]:
A = csc_array((A_fix_val, A_fix_indices), shape=A_shape)

In [349]:
A.todense()

array([[ 0.,  0.,  0.,  0.,  1., -1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.]])

In [338]:
A_fixture_val

array([ 1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.])

In [324]:
jnp.rows

array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8])

In [305]:
eq_dim

1

In [None]:
mats = []
informations =[]
for mat

In [None]:
jnp.block([[jnp.ones(eq_dim, dim), jnp.ones(eq_dim, eq_dim), jnp.ones(eq_dim, eq_dim), ],
           [],
           []])

In [300]:
get_mat_indices(constr_dim, dim, 0,0)

Array([[0, 0],
       [0, 1],
       [0, 2],
       [0, 3],
       [1, 0],
       [1, 1],
       [1, 2],
       [1, 3]], dtype=int32)

In [289]:
ceq_jac_indices = jnp.indices((eq_dim+ineq_dim, dim)).reshape(2,-1).T

In [298]:
def get_mat_indices(dim_row, dim_cols, row_offset, col_offset):
    rows, cols = jnp.indices((dim_row, dim_cols)).reshape(2,-1)
    rows += row_offset
    cols += col_offset
    return jnp.vstack([rows, cols]).T

Array([[0, 0, 0, 0, 1, 1, 1, 1],
       [0, 1, 2, 3, 0, 1, 2, 3]], dtype=int32)

In [290]:
ceq_jac_indices

Array([[0, 0],
       [0, 1],
       [0, 2],
       [0, 3]], dtype=int32)

In [288]:
ceq_jac

Array([2., 2., 2., 2.], dtype=float32)

In [280]:
P.todense()

Array([[ 0,  1,  2,  3,  0,  0,  0],
       [ 4,  5,  6,  7,  0,  0,  0],
       [ 8,  9, 10, 11,  0,  0,  0],
       [12, 13, 14, 15,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0]], dtype=int32)

In [254]:
row, col = np.indices((4,4))

In [217]:
P.as_scipy_coo_matrix().tocsc().todense()

matrix([[2., 1., 1., 4., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 0.],
        [4., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [205]:
len(hess.flatten())

16

In [None]:
hess_coord()

In [204]:
P.as_scipy_coo_matrix()

ValueError: row, column, and data array must all be the same length

In [199]:
hess

SparseCooMatrix(values=Array([[2., 1., 1., 4.],
       [1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [4., 1., 1., 0.]], dtype=float32), coords=SparseCooCoordinates(rows=array([0, 1, 2, 3]), cols=array([0, 1, 2, 3])), shape=(7, 7))

In [180]:
mat = MyCSC.fromdense(jnp.eye(10)[2:])

In [181]:
mat

CSC(float32[8, 10], nse=8)

In [173]:
mat.as_scipy()

AttributeError: 'CSC' object has no attribute 'as_scipy'

In [144]:
mat = CSC.fromdense(jnp.eye(10))

In [146]:
mat.todense()

Array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)

In [137]:
prob.hess_f(jnp.ones(4))

Array([[2., 1., 1., 4.],
       [1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [4., 1., 1., 0.]], dtype=float32)

In [107]:
@jdc.pytree_dataclass
class Param:
    q0: Array
    target: Array

@jdc.pytree_dataclass
class Residual:
    error_and_jac_fn: Callable[[Array, Param], Tuple[Array]]
    weights: Array

@jdc.pytree_dataclass
class Constr:
    value_and_jac_fn: Callable[[Array, Param], Tuple[Array]]
    lb: Array
    ub: Array
    
    @property
    def dim(self):
        return len(self.lb)
    
    def is_eq_constr(self):
        return jnp.array_equal(self.lb, self.ub)
    
    def value_and_jac_onesided(self, x, param):
        cval, cjac = self.value_and_jac_fn(x, param)
        cval_onesided = jnp.hstack([cval-self.lb, self.ub-cval])
        cjac_onesided = jnp.vstack([cjac, -cjac])
        return cval_onesided, cjac_onesided
    
    def get_bounds_onesided(self):
        lb = jnp.zeros(self.dim*2)
        if self.is_eq_constr():
            ub = jnp.zeros(self.dim*2)
        else:
            ub = jnp.full(self.dim*2, jnp.inf)
        return lb, ub


def get_residual_weights(res_fns:Tuple[Residual]):
    weights = []
    for res_fn in res_fns:
        weights.append(res_fn.weights)
    return jnp.hstack(weights)

# def get_constr_bounds(constr_fns:Tuple[ConstrFn]):
#     lbs, ubs = [], []
#     for constr_fn in constr_fns:
#         zeros = jnp.zeros(constr_fn.dim)
#         if constr_fn.is_eq_constr():
#             ubs.append(zeros)
#         else:
#             ubs.append(np.full(constr_fn.dim, np.inf))
#         lbs.append(zeros)
#         lbs.append(zeros)
#     return jnp.hstack(lbs), jnp.hstack(ubs)

def eval_residual(x:Array, param:Param, res_fns:Tuple[Residual]):
    errors, jacs = [], []
    for res_fn in res_fns:
        error, jac = res_fn.error_and_jac_fn(x, param)
        errors.append(error)
        jacs.append(jac)
    return jnp.hstack(errors), jnp.vstack(jacs)

def eval_constr(x, state, constr_fns:Tuple[Constr]):
    ceq, ceq_jacs = [], []
    cineq, cineq_jacs = [], []
    for constr_fn in constr_fns:
        val, jac = constr_fn.value_and_jac_onesided(x, state)
        if constr_fn.is_eq_constr():
            ceq.append(val)
            ceq_jacs.append(jac)
        else:
            cineq.append(val)
            cineq_jacs.append(jac)
    return jnp.hstack(ceq), jnp.vstack(ceq_jacs), jnp.hstack(cineq), jnp.vstack(cineq_jacs)

In [3]:
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

world = SDFWorld()
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [108]:
robot_dim = 7
horizon = 1
dim = robot_dim * horizon
dt = 0.1

to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()
def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])

# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [8]:
frame = Frame(world.vis, "frame")
def make_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.3,-0.5,0.3],[0.6, 0.5, 0.8])
    )

In [109]:
@jax.jit
def vg_pose_error(x:Array, param:Param):
    ee, jac = get_ee_fk_jac(x)
    error = param.target - ee
    return error, -jac

@jax.jit
def vj_joint_limit(x:Array, param:Param):
    cval = x
    cjac = jnp.eye(robot_dim)
    return cval, cjac

In [13]:
pose_d = make_pose()
frame.set_pose(pose_d)
posevec_d = to_posevec(pose_d)

In [110]:
x = panda.neutral
param = Param(panda.neutral, posevec_d)

In [111]:
pose_error = Residual(vg_pose_error, jnp.array([1, 1, 1, 0.3, 0.3, 0.3]))
joint_limit = Constr(vj_joint_limit, panda.lb, panda.ub)
joint_limit_ = Constr(vj_joint_limit, panda.lb, panda.lb)
res_fns = [pose_error]
constr_fns = [joint_limit, joint_limit_]


weight = get_residual_weights(res_fns)
W = jnp.diag(weight)
dim_x = 7
dim_error = 6
dim_eq = 14
dim_ineq = 14

In [117]:
class Problem:
    def __init__(self, dim, res_fns:List[Residual], constr_fns:List[Constr]):
        self.res_fns = res_fns
        self.constr_fns = constr_fns
        self.dim_x = dim
        self.dim_error = 0
        for res_fn in res_fns:
            self.dim_error += len(res_fn.weights)
        self.dim_eq = 0
        self.dim_ineq = 0
        for constr_fn in constr_fns:
            if constr_fn.is_eq_constr():
                self.dim_eq += len(constr_fn.lb)
            else:
                self.dim_ineq += len(constr_fn.lb)

    
prob = Problem(7, res_fns, constr_fns)

In [116]:
prob.dim_error

6

In [118]:
error, jac = eval_residual(x, param, res_fns)
ceq, ceq_jac, cineq, cineq_jac = eval_constr(x, param, constr_fns)

grad = jac.T @ W @ error
hess = jac.T @ W @ jac

In [119]:
mu = 1.

In [126]:
dim_aug = prob.dim_x + prob.dim_eq *2 + prob.dim_ineq
P = jnp.zeros((dim_aug, dim_aug))
P = P.at[:dim_x, :dim_x].set(hess)
q = jnp.hstack([grad, jnp.full(prob.dim_eq*2+prob.dim_ineq, mu)])


In [None]:
A_fix = jnp.hstack(
    [ceq_jac, 
     -jnp.eye(prob.dim_eq), 
     jnp.eye(prob.dim_eq), 
     jnp.zeros((prob.dim_ineq, prob.dim_ineq)) ])

In [129]:
ceq_jac

Array([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [-1., -0., -0., -0., -0., -0., -0.],
       [-0., -1., -0., -0., -0., -0., -0.],
       [-0., -0., -1., -0., -0., -0., -0.],
       [-0., -0., -0., -1., -0., -0., -0.],
       [-0., -0., -0., -0., -1., -0., -0.],
       [-0., -0., -0., -0., -0., -1., -0.],
       [-0., -0., -0., -0., -0., -0., -1.]], dtype=float32)

In [124]:
dim_aug

28

In [None]:
jnp.vstack([ceq_jac, jnp.eye(dim_eq)])

In [105]:
dim_eq

14

In [104]:
ceq_jac

Array([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [-1., -0., -0., -0., -0., -0., -0.],
       [-0., -1., -0., -0., -0., -0., -0.],
       [-0., -0., -1., -0., -0., -0., -0.],
       [-0., -0., -0., -1., -0., -0., -0.],
       [-0., -0., -0., -0., -1., -0., -0.],
       [-0., -0., -0., -0., -0., -1., -0.],
       [-0., -0., -0., -0., -0., -0., -1.]], dtype=float32)

In [43]:
val = error@jnp.diag(weight)@error
grad = jac.T @ error
hess = jac.T @ jnp.diag(weight) @ jac

In [None]:
# merit fn
merit = 
for constr_fn in constr_fns:


In [None]:
#f

#g

#h