In [5]:
import cyipopt
import jax
import jax.numpy as jnp
import numpy as np

In [33]:
class SparseIPOPTBuilder:
    def __init__(self):
        self.obj_info = {}
        self.var_info = {}
        self.constr_info = {}
        self.params = {}
        self.var_offset = 0
        self.constr_offset = 0
        self.objective_fn = None

    @property
    def var_names(self):
        return list(self.var_info.keys())
    @property
    def var_dims(self):
        return [self.var_info[varname]["dim"] for varname in self.var_names]
    @property
    def var_offsets(self):
        return [self.var_info[varname]["offset"] for varname in self.var_names]
    @property
    def constr_dims(self):
        return [constr["dim"] for constr in self.constr_info.values()]
    @property
    def xdim(self):
        return sum(self.var_dims)
    @property
    def cdim(self):
        return sum(self.constr_dims)
    @property
    def xdummy(self):
        return jnp.zeros(self.xdim)
    @property
    def lb(self):
        return np.hstack([var["lb"] for var in self.var_info.values()])
    @property
    def ub(self):
        return np.hstack([var["ub"] for var in self.var_info.values()])
    @property
    def cl(self):
        if self.cdim != 0:
            return np.hstack([constr["lb"] for constr in self.constr_info.values()])
        return None
    @property
    def cu(self):
        if self.cdim != 0:
            return np.hstack([constr["ub"] for constr in self.constr_info.values()])
        return None
    
    def as_vector(self, val, dim):
        if isinstance(val, float):
            return np.full(dim, val)
        assert len(val) == dim
        return val
    
    def add_variable(self, name, dim, lb=-np.inf, ub=np.inf):
        assert name not in self.var_info
        self.var_info[name] = dict(
            name=name, dim=dim, offset=self.var_offset,
            lb=self.as_vector(lb, dim),
            ub=self.as_vector(ub, dim)
        )
        self.var_offset += dim
    
    def add_parameter(self, name, value):
        assert name not in self.params
        self.params[name] = np.asarray(value)

    def add_constr(self, name, inputs, dim, lb, ub, constr_fn, jac_fn):
        assert name not in self.constr_info
        self.constr_info[name] = dict(
            name=name, inputs=inputs, dim=dim, 
            offset=self.constr_offset,
            constr_fn=constr_fn, jac_fn=jac_fn, 
            lb=self.as_vector(lb, dim), 
            ub=self.as_vector(ub, dim)
        )
        self.constr_offset += dim
    
    def set_objective(self, inputs, obj_fn, grad_fn):
        self.obj_info["inputs"] = inputs
        self.obj_info["obj_fn"] = obj_fn
        self.obj_info["grad_fn"] = grad_fn

    def build(self, compile=True):
        if "obj_fn" in self.obj_info:
            # indexing variables in objective, constraints
            self.obj_info["vars"] = {name:i 
                                    for i, name in enumerate(self.obj_info["inputs"]) 
                                    if name in self.var_names}

        if len(self.constr_dims) != 0:
            # indexing variables in constraints
            for constr in self.constr_info.values():
                constr["vars"] = {name:i 
                                for i, name in enumerate(constr["inputs"]) 
                                if name in self.var_names}
            
            # obtain sparsity of the jacobian
            jac_rows, jac_cols = [], []
            for constr in self.constr_info.values():
                vars = [name for name in constr["inputs"] if name in self.var_info]
                for var in vars:
                    var_dim = self.var_info[var]["dim"]
                    var_offset = self.var_info[var]["offset"]
                    constr_dim = constr["dim"]
                    constr_offset = constr["offset"]
                    row, col = np.indices((constr_dim, var_dim))
                    row, col = (row.flatten()+constr_offset, col.flatten()+var_offset)                    
                    jac_rows.append(row)
                    jac_cols.append(col)
            jac_rows, jac_cols = np.hstack(jac_rows), np.hstack(jac_cols)
            self.jac_rows = jac_rows
            self.jac_cols = jac_cols

        class Prob:
            pass
        prob = Prob()
        setattr(prob, "objective", self.get_objective_fn(compile))
        setattr(prob, "gradient", self.get_gradient_fn(compile))
            
        if len(self.constr_dims) != 0:
            setattr(prob, "constraints", self.get_constr_fn(compile))
            setattr(prob, "jacobian", self.get_sparse_jacobian_fn(compile))
            setattr(prob, "jacobianstructure", self.get_sparse_jac_indices_fn())
        ipopt = cyipopt.Problem(
            n=self.xdim, m=self.cdim,
            problem_obj=prob,
            lb=self.lb, ub=self.ub,
            cl=self.cl, cu=self.cu
        )
        self.print_summary()
        return ipopt
    
    def get_objective_fn(self, compile=True):
        def objective(x):
            xs = {name:x[offset:offset+dim] 
                  for name, dim, offset in zip(self.var_names, self.var_dims, self.var_offsets)}
            input_dict = {**xs, **self.params}
            inputs = [input_dict[var] for var in self.obj_info["inputs"]]
            return self.obj_info["obj_fn"](*inputs)
        
        if "obj_fn" not in self.obj_info:
            return lambda x: 0.
        elif compile:
            return jax.jit(objective).lower(self.xdummy).compile()
        else:
            return objective
    
    def get_gradient_fn(self, compile=True):
        out_dict = {name:np.zeros(dim) for name, dim in zip(self.var_names, self.var_dims)}
                
        def gradient(x):
            xs = {name:x[offset:offset+dim] 
                  for name, dim, offset in zip(self.var_names, self.var_dims, self.var_offsets)}
            input_dict = {**xs, **self.params}
            inputs = [input_dict[var] for var in self.obj_vars]
            grads = self.gradient_fn(*inputs)
            for name, idx in self.obj_info["vars"].items():
                out_dict[name] = grads[idx]
            return jnp.hstack([out_dict[name] for name in self.var_names])
        
        if "obj_fn" not in self.obj_info:
            return lambda x: np.zeros(self.xdim)
        elif compile:
            return jax.jit(gradient).lower(self.xdummy).compile()
        else:
            return gradient
    
    def get_constr_fn(self, compile=True):
        def constraints(x):
            xs = {name:x[offset:offset+dim] 
                  for name, dim, offset in zip(self.var_names, self.var_dims, self.var_offsets)}
            input_dict = {**xs, **self.params}
            val_list= []
            for constr in self.constr_info.values():
                constr_fn = constr["constr_fn"]
                inputs = [input_dict[var] for var in constr["inputs"]]
                val = constr_fn(*inputs)
                val_list.append(val)
            return jnp.hstack(val_list)
        if compile:
            return jax.jit(constraints).lower(self.xdummy).compile()
        return constraints
     
    def get_sparse_jacobian_fn(self, compile=True):
        names = list(self.var_info.keys())
        dims = [self.var_info[varname]["dim"] for varname in names]
        offsets = [self.var_info[varname]["offset"] for varname in names]

        def jacobian(x):
            xs = {name:x[offset:offset+dim] for name, dim, offset in zip(names, dims, offsets)}
            input_dict = {**xs, **self.params}
            val_list= []
            for constr in self.constr_info.values():
                jac_fn = constr["jac_fn"]
                inputs = [input_dict[var] for var in constr["inputs"]]
                vals = jac_fn(*inputs)
                for name, idx in constr["vars"].items():
                    val_list.append(vals[idx].flatten())
                # for val in vals:
                #     val_list.append(val.flatten())
            return jnp.hstack(val_list)
        if compile:
            return jax.jit(jacobian).lower(self.xdummy).compile()
        return jacobian
    
    def get_sparse_jac_indices_fn(self):
        def jacobian_structure():
            return (self.jac_rows, self.jac_cols)
        return jacobian_structure

    def print_summary(self):
        jac_structure = np.zeros((self.cdim, self.xdim))
        row, col = self.get_sparse_jac_indices_fn()()
        jac_structure[row, col] = 1
        print("-- Summary --")
        print("x variables (dim)") 
        print([f"{var['name']}({var['dim']})" for var in self.var_info.values()])
        print("constraints (dim)")
        print([f"{constr['name']}({constr['dim']})" for constr in self.constr_info.values()])
        print("\nSparsity pattern:")
        for row in jac_structure:
            print(" ".join(["o" if val==1. else "-" for val in row]))

In [43]:
builder.constr_info["dist"]

{'name': 'dist',
 'inputs': ['g_pick', 'g_place', 'p_ho', 'p_start', 'p_goal'],
 'dim': 4,
 'offset': 6,
 'constr_fn': <function __main__.<lambda>(g1, g2, pose, pose_st, pose_ed)>,
 'jac_fn': <function __main__.<lambda>(g1, g2, pose, pose_st, pose_ed)>,
 'lb': array([0.05, 0.05, 0.05, 0.05]),
 'ub': array([inf, inf, inf, inf]),
 'vars': {'g_pick': 0, 'g_place': 1, 'p_ho': 2}}

In [38]:
#function
ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])

grasp_fn = lambda g: 0.
manip_fn = lambda g, pose: 0.
dist_fn = lambda g1, g2, pose, pose_st, pose_ed: jnp.zeros(4)

jac_grasp_fn = lambda g: [jnp.zeros(3)]
jac_manip_fn = lambda g, pose: [jnp.zeros(3), jnp.zeros(6)]
jac_dist_fn = lambda g1, g2, pose, pose_st, pose_ed: [jnp.zeros(3), jnp.zeros(3), jnp.zeros(6), jnp.zeros(6), jnp.zeros(6)]

In [39]:
builder = SparseIPOPTBuilder()

builder.add_variable("g_pick", 3, -1., 1.)
builder.add_variable("g_place", 3, -1., 1.)
builder.add_variable("p_ho", 6, ws_lb, ws_ub)
builder.add_parameter("p_start", jnp.zeros(6))
builder.add_parameter("p_goal", jnp.zeros(6))

#objec
#constr
builder.add_constr("grasp_prob_pick", ["g_pick"], 1, 
                   1., np.inf,
                   grasp_fn, jac_grasp_fn)
builder.add_constr("grasp_prob_place", ["g_place"], 1, 
                   1., np.inf,
                   grasp_fn, jac_grasp_fn)

builder.add_constr("manip_pick", ["g_pick", "p_start"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("manip_place", ["g_place", "p_goal"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("manip_ho_left", ["g_pick", "p_ho"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("manip_ho_right", ["g_place", "p_ho"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)

builder.add_constr("dist", ["g_pick", "g_place", "p_ho","p_start", "p_goal"], 4, 
                   0.05, np.inf,
                   dist_fn, jac_dist_fn)
ipopt = builder.build()

-- Summary --
x variables (dim)
['g_pick(3)', 'g_place(3)', 'p_ho(6)']
constraints (dim)
['grasp_prob_pick(1)', 'grasp_prob_place(1)', 'manip_pick(1)', 'manip_place(1)', 'manip_ho_left(1)', 'manip_ho_right(1)', 'dist(4)']

Sparsity pattern:
o o o - - - - - - - - -
- - - o o o - - - - - -
o o o - - - - - - - - -
- - - o o o - - - - - -
o o o - - - o o o o o o
- - - o o o o o o o o o
o o o o o o o o o o o o
o o o o o o o o o o o o
o o o o o o o o o o o o
o o o o o o o o o o o o


In [40]:
builder.get_sparse_jacobian_fn()(jnp.zeros(builder.xdim)).shape

(42,)

In [37]:
builder.get_sparse_jac_indices_fn()()[0].shape

(78,)