In [1]:
import cyipopt
import jax
import jax.numpy as jnp
import numpy as np

In [85]:
class SparseIPOPTBuilder:
    def __init__(self):
        self.vars = {}
        self.constrs = {}
        self.params = {}
        self.var_offset = 0
        self.constr_offset = 0
        self.objective_fn = None

    @property
    def var_names(self):
        return list(self.vars.keys())
    @property
    def var_dims(self):
        return [self.vars[varname]["dim"] for varname in self.var_names]
    @property
    def var_offsets(self):
        return [self.vars[varname]["offset"] for varname in self.var_names]
    @property
    def constr_dims(self):
        return [constr["dim"] for constr in self.constrs.values()]
    @property
    def xdim(self):
        return sum(self.var_dims)
    @property
    def cdim(self):
        return sum(self.constr_dims)
    @property
    def xdummy(self):
        return jnp.zeros(self.xdim)
    @property
    def lb(self):
        return np.hstack([var["lb"] for var in self.vars.values()])
    @property
    def ub(self):
        return np.hstack([var["ub"] for var in self.vars.values()])
    @property
    def cl(self):
        if self.cdim != 0:
            return np.hstack([constr["lb"] for constr in self.constrs.values()])
        return None
    @property
    def cu(self):
        if self.cdim != 0:
            return np.hstack([constr["ub"] for constr in self.constrs.values()])
        return None
    
    def as_vector(self, val, dim):
        if isinstance(val, float):
            return np.full(dim, val)
        assert len(val) == dim
        return val
    
    def add_variable(self, name, dim, lb=-np.inf, ub=np.inf):
        assert name not in self.vars
        self.vars[name] = dict(
            name=name, dim=dim, offset=self.var_offset,
            lb=self.as_vector(lb, dim),
            ub=self.as_vector(ub, dim)
        )
        self.var_offset += dim
    
    def add_parameter(self, name, dim, value):
        assert name not in self.params
        self.params[name] = dict(
            name=name, dim=dim, value=np.asarray(value)
        )

    def add_constr(self, name, vars, dim, lb, ub, constr_fn, jac_fn, params=[]):
        assert name not in self.constrs
        self.constrs[name] = dict(
            name=name, vars=vars, dim=dim, offset=self.constr_offset,
            constr_fn=constr_fn, jac_fn=jac_fn, 
            lb=self.as_vector(lb, dim), 
            ub=self.as_vector(ub, dim),
            params=params
        )
        self.constr_offset += dim
    
    def set_objective(self, vars, obj_fn, grad_fn, params=[]):
        self.obj_vars = vars
        self.obj_params = params
        self.objective_fn = obj_fn
        self.gradient_fn = grad_fn

    def build(self, compile=True):
        # sparse jacobian
        if len(self.constr_dims) != 0:
            jac_rows, jac_cols = [], []
            for constr in self.constrs.values():
                #indices = []
                for var in constr["vars"]:
                    var_dim = self.vars[var]["dim"]
                    var_offset = self.vars[var]["offset"]
                    constr_dim = constr["dim"]
                    constr_offset = constr["offset"]
                    row, col = np.indices((constr_dim, var_dim))
                    row, col = (row.flatten()+constr_offset, col.flatten()+var_offset)
                    #indices.append(rowcol)
                    jac_rows.append(row)
                    jac_cols.append(col)
                #constr["indices"] = indices
            jac_rows, jac_cols = np.hstack(jac_rows), np.hstack(jac_cols)
            self.jac_rows = jac_rows
            self.jac_cols = jac_cols

        class Prob:
            pass
        prob = Prob()
        setattr(prob, "objective", self.get_objective_fn(compile))
        setattr(prob, "gradient", self.get_gradient_fn(compile))
            
        if len(self.constr_dims) != 0:
            setattr(prob, "constraints", self.get_constr_fn(compile))
            setattr(prob, "jacobian", self.get_sparse_jacobian_fn(compile))
            setattr(prob, "jacobianstructure", self.get_sparse_jac_indices_fn())
        ipopt = cyipopt.Problem(
            n=self.xdim, m=self.cdim,
            problem_obj=prob,
            lb=self.lb, ub=self.ub,
            cl=self.cl, cu=self.cu
        )
        self.print_summary()
        return ipopt
    
    def get_objective_fn(self, compile=True):
        if self.objective_fn is None:
            return lambda x: 0.
        
        def objective(x):
            xs = {name:x[offset:offset+dim] 
                  for name, dim, offset in zip(self.var_names, self.var_dims, self.var_offsets)}
            inputs = [xs[var] for var in self.obj_vars] + [self.params[param] for param in self.obj_params]
            return self.objective_fn(*inputs)
        if compile:
            return jax.jit(objective).lower(self.xdummy).compile()
        return objective
    
    def get_gradient_fn(self, compile=True):
        if self.objective_fn is None:
            return lambda x: np.zeros(self.xdim)
        
        out = {name:np.zeros(dim) for name, dim in zip(self.var_names, self.var_dims)}
        def gradient(x):
            xs = {name:x[offset:offset+dim] 
                  for name, dim, offset in zip(self.var_names, self.var_dims, self.var_offsets)}
            grads = self.gradient_fn(*[xs[var] for var in self.obj_vars])
            for grad, varname in zip(grads, self.obj_vars):
                out[varname] = grad
            return jnp.hstack([out[name] for name in self.var_names])
        if compile:
            return jax.jit(gradient).lower(self.xdummy).compile()
        return gradient
    
    def get_constr_fn(self, compile=True):
        def constraints(x):
            xs = {name:x[offset:offset+dim] 
                  for name, dim, offset in zip(self.var_names, self.var_dims, self.var_offsets)}
            
            val_list= []
            for constr in self.constrs.values():
                constr_fn = constr["constr_fn"]
                inputs = [xs[var] for var in constr["vars"]] + \
                         [self.params[param] for param in constr["params"]]
                val = constr_fn(*inputs)
                val_list.append(val)
            return jnp.hstack(val_list)
        if compile:
            return jax.jit(constraints).lower(self.xdummy).compile()
        return constraints
     
    def get_sparse_jacobian_fn(self, compile=True):
        names = list(self.vars.keys())
        dims = [self.vars[varname]["dim"] for varname in names]
        offsets = [self.vars[varname]["offset"] for varname in names]
        def jacobian(x):
            xs = {name:x[offset:offset+dim] for name, dim, offset in zip(names, dims, offsets)}
            val_list= []
            for constr in self.constrs.values():
                jac_fn = constr["jac_fn"]
                inputs = [xs[var] for var in constr["vars"]] + \
                         [self.params[param] for param in constr["params"]]
                vals = jac_fn(*inputs)
                for val in vals:
                    val_list.append(val.flatten())
            return jnp.hstack(val_list)
        if compile:
            return jax.jit(jacobian).lower(self.xdummy).compile()
        return jacobian
    
    def get_sparse_jac_indices_fn(self):
        def jacobian_structure():
            return (self.jac_rows, self.jac_cols)
        return jacobian_structure

    def print_summary(self):
        jac_structure = np.zeros((self.cdim, self.xdim))
        row, col = self.get_sparse_jac_indices_fn()()
        jac_structure[row, col] = 1
        print("-- Summary --")
        print("x variables (dim)") 
        print([f"{var['name']}({var['dim']})" for var in self.vars.values()])
        print("constraints (dim)")
        print([f"{constr['name']}({constr['dim']})" for constr in self.constrs.values()])
        print("\nSparsity pattern:")
        for row in jac_structure:
            print(" ".join(["o" if val==1. else "-" for val in row]))

In [86]:
#function
ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])

grasp_fn = lambda g: 0.
manip_fn = lambda g, pose: 0.
dist_fn = lambda g1, g2, pose, pose_st, pose_ed: jnp.zeros(4)

jac_grasp_fn = lambda g: [jnp.zeros(3)]
jac_manip_fn = lambda g, pose: [jnp.zeros(3), jnp.zeros(6)]
jac_dist_fn = lambda g1, g2, pose, pose_st, pose_ed: [jnp.zeros(3), jnp.zeros(3), jnp.zeros(6)]

In [87]:
builder = SparseIPOPTBuilder()

builder.add_variable("g_pick", 3, -1., 1.)
builder.add_variable("g_place", 3, -1., 1.)
builder.add_variable("p_ho", 6, ws_lb, ws_ub)
builder.add_parameter("p_start", 6, jnp.zeros(6))
builder.add_parameter("p_goal", 6, jnp.zeros(6))

#constr
builder.add_constr("grasp_prob_pick", ["g_pick"], 1, 
                   1., np.inf,
                   grasp_fn, jac_grasp_fn)
builder.add_constr("grasp_prob_place", ["g_place"], 1, 
                   1., np.inf,
                   grasp_fn, jac_grasp_fn)
builder.add_constr("manip_pick", ["g_pick"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn, params=["p_start"])
builder.add_constr("manip_place", ["g_place"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn, params=["p_goal"])
builder.add_constr("manip_ho_left", ["g_pick", "p_ho"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("manip_ho_right", ["g_place", "p_ho"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("dist", ["g_pick", "g_place", "p_ho"], 4, 
                   0.05, np.inf,
                   dist_fn, jac_dist_fn, params=["p_start", "p_goal"])
ipopt = builder.build()

-- Summary --
x variables (dim)
['g_pick(3)', 'g_place(3)', 'p_ho(6)']
constraints (dim)
['grasp_prob_pick(1)', 'grasp_prob_place(1)', 'manip_pick(1)', 'manip_place(1)', 'manip_ho_left(1)', 'manip_ho_right(1)', 'dist(4)']

Sparsity pattern:
o o o - - - - - - - - -
- - - o o o - - - - - -
o o o - - - - - - - - -
- - - o o o - - - - - -
o o o - - - o o o o o o
- - - o o o o o o o o o
o o o o o o o o o o o o
o o o o o o o o o o o o
o o o o o o o o o o o o
o o o o o o o o o o o o


In [89]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *

from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle
import time

In [100]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net_posevec")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)
def grasp_reconst(g:Array):
    rot = SO3(grasp_fn(g)[2:]).normalize()
    trans = g/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)
grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]
#manip net
manip_net = ManipNet(64)
manip_fn = lambda x: manip_net.apply(restored_manip["params"], x)[0]

In [91]:
world = SDFWorld()
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7010/static/




In [92]:
world.show_in_jupyter()

In [93]:
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue",
                 alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green",
                 alpha=0.5)

In [94]:
table_start.set_translate([0.5, -0.3, 0.2/2])
table_goal.set_translate([0.5, 0.3, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.5, -0.3, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.5, 0.3, obj_lengths[-2]/2+table_lengths[-1]])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal
)
obj_goal.set_pose(obj_goal_pose)

In [95]:
hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)
hand1 = Robot(world.vis, "hand1", hand_model, color="yellow", alpha=0.5)



In [96]:
hand_pc = hand1.get_surface_points_fn(jnp.array([0.04, 0.04]))
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
env = SDFContainer([table_start, table_goal], 0.05)
# x -> grasps



# @jax.jit
# def distance_constr(x):
#     hand_pcs = get_hand_points(x, obj_start.pose, obj_goal.pose)
#     distances = env.distances(hand_pcs)
#     top4_indices = jnp.argpartition(distances, 4)[:4]
#     return distances[top4_indices]

In [108]:
#functions
to_posevec = lambda x: jnp.hstack([x[4:], SO3(x[:4]).log()])
to_wxyzxyz = lambda x: jnp.hstack([SO3.exp(x[3:]).parameters(), x[:3]])
def get_hand_points(grasp, obj_posevec):
    grasp_pose = grasp_reconst(grasp)
    hand_base_pose_wrt_world = SE3(to_wxyzxyz(obj_posevec)) @ grasp_pose @ hand_pose_wrt_ee
    assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
    return assigned_hand_pc

# input: g(3), output: logit(1)
grasp_logit_fn 
# input: g(3),p(6) output: manip(1)
def manip_fn(grasp, obj_posevec):
    obj_pose = SE3(to_wxyzxyz(obj_posevec))
    posevec = to_posevec(grasp_reconst(grasp, obj_pose).parameters())
    zflip = (SO3.exp(posevec[3:]) @ SO3.from_z_radians(jnp.pi)).log()
    posevec_flip = jnp.hstack([posevec[:3], zflip])
    return jnp.maximum(manip_fn(posevec), manip_fn(posevec_flip))
# input: g1, g2(3), p_ho, p_st, p_ed(6) output: dist(4)
def dist_fn(g1, g2, obj_pose_ho, obj_pose_st, obj_pose_ed):
    grasps = jnp.vstack([g1, g1, g2, g2])
    obj_poses = jnp.vstack([obj_pose_st, obj_pose_ho, obj_pose_ho, obj_pose_ed])
    pcs = jax.vmap(get_hand_points, (0,0))(grasps, obj_poses)
    distances = env.distances(jnp.vstack(pcs))
    top4_indices = jnp.argpartition(distances, 4)[:4]
    return distances[top4_indices]

In [None]:
builder = SparseIPOPTBuilder()

builder.add_variable("g_pick", 3, -1., 1.)
builder.add_variable("g_place", 3, -1., 1.)
builder.add_variable("p_ho", 6, ws_lb, ws_ub)
builder.add_parameter("p_start", 6, jnp.zeros(6))
builder.add_parameter("p_goal", 6, jnp.zeros(6))

#constr
builder.add_constr("grasp_prob_pick", ["g_pick"], 1, 
                   1., np.inf,
                   grasp_fn, jac_grasp_fn)
builder.add_constr("grasp_prob_place", ["g_place"], 1, 
                   1., np.inf,
                   grasp_fn, jac_grasp_fn)
builder.add_constr("manip_pick", ["g_pick", "p_start"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("manip_place", ["g_place", "p_goal"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn, params=["p_goal"])
builder.add_constr("manip_ho_left", ["g_pick", "p_ho"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("manip_ho_right", ["g_place", "p_ho"], 1,
                   0.3, np.inf,
                   manip_fn, jac_manip_fn)
builder.add_constr("dist", ["g_pick", "g_place", "p_ho","p_start", "p_goal"], 4, 
                   0.05, np.inf,
                   dist_fn, jac_dist_fn)
ipopt = builder.build()

In [112]:
{**{1:1, 2:2}, **{3:3, 4:4}}

{1: 1, 2: 2, 3: 3, 4: 4}

In [111]:
builder.params

{'p_start': {'name': 'p_start',
  'dim': 6,
  'value': array([0., 0., 0., 0., 0., 0.], dtype=float32)},
 'p_goal': {'name': 'p_goal',
  'dim': 6,
  'value': array([0., 0., 0., 0., 0., 0.], dtype=float32)}}

In [11]:
x = jnp.array([0,0,-0.5])
hand_pcs = get_hand_points(x, obj_start.pose, obj_goal.pose)

In [13]:
def grasp_to_manip(x, obj_pose):
    posevec = to_posevec(grasp_reconst(x, obj_pose).parameters())
    posevec = to_posevec(grasp_reconst(x, obj_start.pose).parameters())
    zflip = (SO3.exp(posevec[3:]) @ SO3.from_z_radians(jnp.pi)).log()
    posevec_flip = jnp.hstack([posevec[:3], zflip])
    return jnp.maximum(manip_fn(posevec), manip_fn(posevec_flip))

In [14]:
x = jnp.zeros(3)

In [15]:
frame = Frame(world.vis, "frame")

In [16]:
pc = PointCloud(world.vis, "hand_pc", hand_pcs, color="red")

In [17]:
def objective(x):
    frame.set_pose(grasp_reconst(x, obj_start.pose))
    hand_pcs = get_hand_points(x, obj_start.pose, obj_goal.pose)
    pc.reload(points=hand_pcs)
    time.sleep(0.05)
    return 0.

gradient = lambda x: [jnp.zeros(3, dtype=float)]

In [18]:
builder = SparseIPOPTBuilder()

builder.add_variable("g", 3, -1., 1.)
builder.set_objective(["g"], objective, gradient)
builder.add_constr("logit", ["g"], 1, 1., np.inf,
                   grasp_logit_fn, jax.grad(grasp_logit_fn))
grasp_to_manip_init = partial(grasp_to_manip, obj_pose=obj_start.pose)
builder.add_constr("manip", ["g"], 1, 0.4, np.inf,
                   grasp_to_manip_init, jax.grad(grasp_to_manip_init))
builder.add_constr("dist", ["g"], 4, 0.05, np.inf,
                   distance_constr, jax.jacrev(distance_constr))



In [19]:
ipopt = builder.build(False)

In [20]:
ipopt.solve(jnp.zeros(3))


******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:       18
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:        3
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality c

(array([-0.01110441,  0.42939996,  0.75746162]),
 {'x': array([-0.01110441,  0.42939996,  0.75746162]),
  'g': array([2.06466198, 0.77494359, 0.11732963, 0.11914807, 0.12019696,
         0.12429338]),
  'obj_val': 0.0,
  'mult_g': array([-1.34700831e-12, -2.75378836e-11, -1.44729489e-10, -1.27496547e-10,
         -1.26756736e-10, -1.32878068e-10]),
  'mult_x_L': array([1.01616863e-11, 6.87109409e-12, 5.63088396e-12]),
  'mult_x_U': array([9.84278619e-12, 1.86238217e-11, 4.48245431e-11]),
  'status': 0,
  'status_msg': b'Algorithm terminated successfully at a locally optimal point, satisfying the convergence tolerances (can be specified by options).'})


Number of Iterations....: 18

                                   (scaled)                 (unscaled)
Objective...............:   0.0000000000000000e+00    0.0000000000000000e+00
Dual infeasibility......:   1.7944592282149640e-10    1.7944592282149640e-10
Constraint violation....:   0.0000000000000000e+00    0.0000000000000000e+00
Variable bound violation:   0.0000000000000000e+00    0.0000000000000000e+00
Complementarity.........:   1.0871672419843804e-11    1.0871672419843804e-11
Overall NLP error.......:   1.7944592282149640e-10    1.7944592282149640e-10


Number of objective function evaluations             = 19
Number of objective gradient evaluations             = 19
Number of equality constraint evaluations            = 0
Number of inequality constraint evaluations          = 19
Number of equality constraint Jacobian evaluations   = 0
Number of inequality constraint Jacobian evaluations = 19
Number of Lagrangian Hessian evaluations             = 0
Total seconds in IPOPT         

In [91]:
cu

array([inf, inf, inf, inf, inf, inf])

In [47]:
builder.get_gradient_fn(compile=False)(x)

Array([0., 0., 0.], dtype=float32)

In [65]:
jac = builder.get_sparse_jacobian_fn()

In [68]:
jac(x).shape

(18,)

In [70]:
builder.get_sparse_jac_indices_fn()()

(18,)

In [41]:
x = jnp.zeros(3)
xs = {name:x[offset:offset+dim] 
    for name, dim, offset in zip(builder.var_names, builder.var_dims, builder.var_offsets)}

In [43]:
builder.gradient_fn(*[xs[var] for var in builder.obj_vars])

Array([0., 0., 0.], dtype=float32)

In [38]:

    for grad, varname in zip(grads, self.obj_vars):
        out[varname] = grad
    return jnp.hstack([out[name] for name in self.varnames])
if compile:
    return jax.jit(gradient).lower(self.xdummy).compile()
return gradient

[0]

In [79]:
builder = SparseIPOPTBuilder()

builder.add_variable("g_pick", 3)
builder.add_variable("g_place", 3)
builder.add_variable("p_handover", 6)

builder.set_objective(["g_pick", "g_place"], 
                      lambda g1, g2: jnp.array(0., dtype=float),
                      lambda g1, g2: [jnp.arange(3), jnp.arange(3)*10])
#constr
builder.add_constr("grasp_prob_pick", ["g_pick"], 1, 
                   lambda g: 0., lambda g: [jnp.zeros(3)])
builder.add_constr("grasp_prob_place", ["g_place"], 1, 
                   lambda g: 0., lambda g: [jnp.zeros(3)])

builder.add_constr("manip_pick", ["g_pick"], 1,
                   lambda g: 0., lambda g: [jnp.zeros(3)])
builder.add_constr("manip_place", ["g_place"], 1,
                   lambda g: 0., lambda g: [jnp.zeros(3)])

builder.add_constr("manip_ho_left", ["g_pick", "p_handover"], 1,
                   lambda g, p: 0., lambda g, p: [jnp.zeros(3), jnp.zeros(6)])
builder.add_constr("manip_ho_right", ["g_place", "p_handover"], 1,
                   lambda g, p: 0., lambda g, p: [jnp.zeros(3), jnp.zeros(6)])

builder.add_constr("col", ["g_pick", "g_place", "p_handover"], 4, 
                   lambda g1, g2, p: jnp.zeros(4),
                   lambda g1, g2, p: [jnp.zeros((4,3)), jnp.zeros((4,3)), jnp.zeros((4,6))])

TypeError: add_constr() missing 2 required positional arguments: 'constr_fn' and 'jac_fn'

In [335]:
builder.build()

In [341]:
x = jnp.zeros(12)
obj_fn = builder.get_objective_fn()
grad_fn = builder.get_gradient_fn()
constr_fn = builder.get_constr_fn()
jac_fn = builder.get_sparse_jacobian_fn()
obj_fn(x)
grad_fn(x)
jac_fn(x)
constr_fn(x)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [251]:
x = jnp.arange(12, dtype=float)

In [252]:
jac_fn(x).shape

(78,)

In [253]:
jac_struct = builder.get_sparse_jac_indices_fn()

In [260]:
%timeit jac_fn(x)

4.72 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [185]:
jnp.hstack(row_list).shape

(78,)

In [182]:
builder.constrs["grasp_prob_place"]

{'name': 'grasp_prob_place',
 'vars': ['g_place'],
 'dim': 1,
 'offset': 1,
 'constr_fn': <function __main__.<lambda>(g)>,
 'jac_fn': <function __main__.<lambda>(g)>,
 'indices': [(array([1, 1, 1]), array([3, 4, 5]))]}

In [179]:
constr["indices"][0]

(array([6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]),
 array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]))

In [None]:
vals

In [156]:
dummy_fn = lambda g, p : [jnp.zeros(3),jnp.zeros(6)]

In [157]:
dummy_jit = jax.jit(dummy_fn)

In [158]:
dummy_jit(0., 1.)

[Array([0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0.], dtype=float32)]

In [137]:
constr["name"]

'manip_ho_left'

In [135]:
[xs[var] for var in constr["vars"]]

[Array([0., 1., 2.], dtype=float32),
 Array([ 6.,  7.,  8.,  9., 10., 11.], dtype=float32)]

In [None]:
vals

In [102]:
jac_fn = constr["jac_fn"]

In [103]:
jacs = jac_fn(*[xs[var] for var in constr["vars"]])

In [104]:
constr["indices"]

[(array([6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]),
  array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])),
 (array([6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]),
  array([3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5])),
 (array([6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9,
         9, 9]),
  array([ 6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10,
         11,  6,  7,  8,  9, 10, 11]))]

In [54]:
offsets

[0, 3, 6]

array([0, 0, 0])