In [2]:
from jax import jacobian
import jax
jax.config.update('jax_platform_name', 'cpu')
import numpy as np
import jax.numpy as anp
import cyipopt

def unflatten(x0, order):
    unflattened = {}
    idx = 0
    for var in order:
        size = np.prod(var.shape)
        unflattened[var.name] = anp.reshape(x0[idx:idx + size], var.shape)
        idx += size
    return unflattened

def compute_residuals_generic(funcs_with_io, order):
    def inner(x0):
        variables = unflatten(x0, order)
        residuals = []

        for idx, (func, inputs, outputs) in enumerate(funcs_with_io):
            input_vars = tuple(variables[input_var.name] for input_var in inputs)
            output_vars = tuple(variables[output_var.name] if output_var else None for output_var in outputs)
            func_output = func(*input_vars)
            for output_var, func_val in zip(output_vars, func_output):
                residual = func_val
                if output_var is not None:
                    residual = output_var - func_val
                else:
                    residual = (func_val,)
                residuals.extend(residual)

        return anp.array(residuals)

    return inner

def compute_structure(funcs_with_io, order):
    structure = []
    for func, inputs, outputs in funcs_with_io:
        for output_var in outputs:
            out_shape = output_var.shape if output_var != None else 1
            row = tuple(np.ones((np.prod(out_shape), np.prod(var.shape))) if var in inputs else np.zeros((np.prod(out_shape), np.prod(var.shape))) for var in order)
            structure.append(np.hstack(row))
    structure = np.vstack(structure)
    return structure

class Var:
    def __init__(self, name, shape):
        self.name = name
        self.shape = shape

x1 = Var('x1', 2)
x2 = Var('x2', 2)
x3 = Var('x3', (2, 2))
x4 = Var('x4', 1)

f1 = lambda x2, x3, x1: (x3 @ x2, x1[0])
f2 = lambda x1: (x1,)
h1 = lambda x3: (x3[0, 0] - 1, x3[0, 1], x3[1, 0], x3[1, 1] - 1)

# Example usage
x0 = anp.array([1, 1, 1, 1, 1, 0, 0, 1, 0.5],  dtype="float32")  # flat vector
order = (x1, x2, x3, x4)  # order of variables
funcs_with_io = [
    (f1, (x2, x3, x1), (x1, x4)),
    (f2, (x1,), (x2,)),
    (h1, (x3,), (None, None, None, None))
]


x1 = Var('x1', 1)
x2 = Var('x2', 1)
x3 = Var('x3', 1)

f1 = lambda x3: (x3**2-3,)
f2 = lambda x1, x3: (x1*x3,)
f3 = lambda x1, x2: (x1*x2,)

# Example usage
x0 = anp.array([1, 1, 1],  dtype="float32")  # flat vector
order = (x1, x2, x3)  # order of variables
funcs_with_io = [
    (f1, (x3,), (x1,)),
    (f2, (x1,x3), (x2,)),
    (f3, (x1,x2), (x3,)),
]


residuals_func = compute_residuals_generic(funcs_with_io, order)
jacobian_res = jacobian(residuals_func)
residuals = residuals_func(x0)
J = jacobian_res(x0)

print(residuals)


[3. 0. 0.]


In [2]:
constraints = compute_residuals_generic(funcs_with_io, order)
constraints_jacobian = jacobian(constraints)
constraints_jacobian_structure = compute_structure(funcs_with_io, order)
m,n = constraints_jacobian_structure.shape

In [3]:
class Problem:
    def __init__(self, n, constraints, constraints_jacobian, constraints_jacobian_structure):
        self.n = n
        self.constraints = constraints
        self.jacobian = constraints_jacobian

    def objective(self, x):
        return 0
    def gradient(self, x):
        return np.zeros(self.n)

    def constraints(self, x):
        return self.constraints(x)
    
    def jacobian(self, x):
        return self.jacobian(x)

In [4]:
P = Problem(n, constraints, constraints_jacobian, constraints_jacobian_structure)

In [14]:
problem = cyipopt.Problem(
    n=n,
    m=m,
    problem_obj=P,
    lb=-2*np.ones(n),
    ub=2*np.ones(n),
    cl=np.zeros(m),
    cu=np.zeros(m)
)

In [15]:
out = problem.solve(x0)

In [16]:
xsol = np.round(out[0],3)

In [17]:
constraints(xsol)

DeviceArray([0., 0., 0.], dtype=float32)

In [18]:
out

(array([1., 2., 2.]),
 {'x': array([1., 2., 2.]),
  'g': array([-8.06466005e-13,  0.00000000e+00,  0.00000000e+00]),
  'obj_val': 0.0,
  'mult_g': array([-0., -0., -0.]),
  'mult_x_L': array([0., 0., 0.]),
  'mult_x_U': array([0., 0., 0.]),
  'status': 0,
  'status_msg': b'Algorithm terminated successfully at a locally optimal point, satisfying the convergence tolerances (can be specified by options).'})