In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import cyipopt

from functools import partial
from typing import *
from dataclasses import dataclass, field
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle
import time

In [2]:
@dataclass
class Variable:
    name: str
    coord: np.ndarray
    lb: np.ndarray
    ub: np.ndarray
    
    @property
    def dim(self): return len(self.coord)

@dataclass
class Parameter:
    name: str
    coord: np.ndarray
    dim: int

    @property
    def lb(self): return np.full(self.dim, -np.inf)
    @property
    def ub(self): return np.full(self.dim, np.inf)

@dataclass
class Constraint:
    name: str
    coord: np.ndarray
    inputs: List[Variable]
    fn: "Function"
    lb: np.ndarray
    ub: np.ndarray

    @property
    def dim(self): return len(self.coord)

@dataclass
class Function:
    name: str
    in_dims: List[int]
    out_dim: int
    eval_fn: Callable
    jac_fn: Callable
    constraints: List[Constraint] = field(default_factory=list)

class SparseIPOPT():
    def __init__(self):
        self.x_info: Dict[str,Variable] = {}
        self.c_info: Dict[str,Constraint] = {}
        self.fn_info: Dict[str,Function] = {}
        self.obj_info: Dict = {}
        self.param_info: Dict[str, Array] = {}

        self.x_idx, self.c_idx = 0, 0
        self.param_info: Dict[str, np.ndarray] = {}

    @property
    def xdim(self):
        return sum([x.dim for x in self.x_info.values()])
    @property
    def cdim(self):
        return sum([c.dim for c in self.c_info.values()])
    @property
    def input_info(self):
        return {**self.x_info, **self.param_info}
    
    def add_variable(self, name, dim, lb=-np.inf, ub=np.inf):
        assert name not in self.x_info
        assert isinstance(lb, float) or len(lb) == dim
        assert isinstance(ub, float) or len(ub) == dim

        if isinstance(lb, float): lb = np.full(dim, lb)
        if isinstance(ub, float): ub = np.full(dim, ub)

        coord = np.arange(self.x_idx, self.x_idx+dim)
        self.x_info[name] = Variable(
            name, coord, lb, ub)
        
        self.x_idx += dim
    
    def add_parameter(self, name, dim):
        assert name not in self.param_info
        #TODO: name also should not be in "variable"
        coord = np.arange(self.x_idx, self.x_idx+dim)
        self.x_info[name] = Parameter(name, coord, dim)
        self.x_idx += dim
    
    def set_objective(self, fn_name, input_x_names):
        self.obj_info["fn"] = self.fn_info[fn_name]
        self.obj_info["inputs"] = [self.x_info[name] for name in input_x_names]
    
    def set_debug_callback(self, debug_callback:Callable):
        self.obj_info["debug_cb"] = debug_callback

    def set_constr(self, name, cfn_name, input_x_names, lb, ub):
        c_fn = self.fn_info[cfn_name]
        dim = c_fn.out_dim
        assert name not in self.c_info
        assert isinstance(lb, float) or len(lb) == dim
        assert isinstance(ub, float) or len(ub) == dim
        if isinstance(lb, float): lb = np.full(dim, lb)
        if isinstance(ub, float): ub = np.full(dim, ub)

        vars = [self.x_info[name] for name in input_x_names]
        self.c_info[name] = Constraint(
            name, np.arange(self.c_idx, self.c_idx+dim), vars, 
            c_fn, lb, ub
        )
        c_fn.constraints.append(self.c_info[name])
        self.c_idx += dim

    def register_fn(self, name, in_dims, out_dim, eval_fn, jac_fn, compile=False):
        if compile:
            fn_inputs = [jnp.zeros(dim) for dim in in_dims]
            eval_fn = jax.jit(eval_fn).lower(*fn_inputs).compile()
            jac_fn = jax.jit(jac_fn).lower(*fn_inputs).compile()
        self.fn_info[name] = Function(
            name, in_dims, out_dim, eval_fn, jac_fn)
    
    def get_objective_fn(self, compile=True):
        no_obj = False
        if "fn" not in self.obj_info: 
            objective = lambda x: 0.
            no_obj = True
        else:
            def objective(x):        
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                fn_input = [xs[var.name] for var in self.obj_info["inputs"]]    
                val = self.obj_info["fn"].eval_fn(*fn_input)
                return val
        
        if "debug_cb" in self.obj_info:
            def objective_debug(x):
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                self.obj_info["debug_cb"](xs)    
                return objective(x)
            return objective_debug
        elif compile and not no_obj:
            return jax.jit(objective)
        return objective
    
    def get_gradient_fn(self, compile=True):
        no_obj = False
        if "fn" not in self.obj_info: 
            gradient = lambda x: np.zeros(self.xdim)
            no_obj = True
        else:
            grad_value_dict = {var.name: np.zeros(var.dim) for var in self.x_info.values()}
            def gradient(x):
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                fn_input = [xs[var.name] for var in self.obj_info["inputs"]]    
                grads = self.obj_info["fn"].jac_fn(*fn_input)
                for var, grad in zip(self.obj_info['inputs'], grads):
                    grad_value_dict[var.name] = grad
                return jnp.hstack(grad_value_dict.values())
        if compile and not no_obj:
            return jax.jit(gradient)
        return gradient      
    
    def get_constraint_fn(self, compile=True):
        def constraints(x):
            xs = {var.name:x[var.coord] for var in self.x_info.values()}
            result = []
            for constr in self.c_info.values():
                fn_input = [xs[var.name] for var in constr.inputs]    
                out = constr.fn.eval_fn(*fn_input)
                result.append(out)
            return jnp.hstack(result)
        if compile:
            return jax.jit(constraints)
        return constraints
    
    def get_jacobian_fn(self, compile=True):
        def jacobian(x):
            xs = {var.name:x[var.coord] for var in self.x_info.values()}
            result = []
            for constr in self.c_info.values():
                fn_input = [xs[var.name] for var in constr.inputs]    
                indices_var = [i for i, var in enumerate(constr.inputs) 
                               if isinstance(var, Variable)]
                jacs = constr.fn.jac_fn(*fn_input)
                for i, jac in enumerate(jacs):
                    if i in indices_var:
                        result.append(jac.flatten())
            return jnp.hstack(result)
        if compile:
            return jax.jit(jacobian)
        return jacobian
    
    def get_jacobian_structure(self):
        rows, cols = [], []
        for constr in self.c_info.values():
            for var in constr.inputs:
                if isinstance(var, Parameter): continue
                row, col = np.indices((constr.dim, var.dim)).reshape(2, -1)
                row += constr.coord[0]
                col += var.coord[0]
                rows.append(row)
                cols.append(col)
        rows = np.hstack(rows)
        cols = np.hstack(cols)
        return rows, cols

    def print_sparsity(self):
        row, col = self.get_jacobian_structure()
        jac_struct = np.full((self.cdim, self.xdim), -1, dtype=int)
        jac_struct[row, col] = 1
        for row in jac_struct:
            row_str = ""
            for val in row:
                if val == -1: row_str += "- "
                else: row_str += f"o "
            print(row_str)
    
    def build(self, compile=True):
        lb = np.hstack([x.lb for x in self.x_info.values()])
        ub = np.hstack([x.ub for x in self.x_info.values()])
        cl = np.hstack([c.lb for c in self.c_info.values()])
        cu = np.hstack([c.ub for c in self.c_info.values()])
        row, col = self.get_jacobian_structure()
        jac_struct_fn = lambda : (row, col)

        fns = {
            "objective": self.get_objective_fn(compile),
            "gradient": self.get_gradient_fn(compile),
            "constraints": self.get_constraint_fn(compile),
            "jacobian": self.get_jacobian_fn(compile),
        }
        class Prob:
            pass
        prob = Prob()
        xdummy = jnp.zeros(self.xdim)
        for fn_name, fn in fns.items():
            print(f"compiling {fn_name} ...")
            fn(xdummy)
            setattr(prob, fn_name, fn)
        setattr(prob, "jacobianstructure", jac_struct_fn)

        ipopt = cyipopt.Problem(
            n=self.xdim, m=self.cdim,
            problem_obj=prob,
            lb=lb, ub=ub, cl=cl, cu=cu
        )
        # default option
        ipopt.add_option("acceptable_iter", 2)
        ipopt.add_option("acceptable_tol", 0.1) #release
        ipopt.add_option("acceptable_obj_change_tol", 0.0001)
        ipopt.add_option("acceptable_dual_inf_tol", 1.) 
        ipopt.add_option('mu_strategy', 'adaptive')
        self.print_sparsity()
        return ipopt

In [3]:
# models 
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net_posevec")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]
#manip net
manip_net = ManipNet(64)
manip_fn = lambda x: manip_net.apply(restored_manip["params"], x)[0]

In [4]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7002/static/


In [5]:
world.show_in_jupyter()

## Env setting

In [6]:
# robot, hand
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)
hand = Robot(world.vis, "hand1", hand_model, color="white", alpha=0.5)

In [7]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)

In [8]:
table_start.set_translate([0.5, -0.3, 0.2/2])
table_goal.set_translate([0.5, 0.3, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.5, -0.3, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.5, 0.3, obj_lengths[-2]/2+table_lengths[-1]])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)

In [9]:
#visualization
pc_hand = PointCloud(world.vis, "hand_pc", np.zeros((100,3)), color="red", size=0.01)

In [10]:
to_posevec = lambda x: jnp.hstack([x[4:], SO3(x[:4]).log()])
to_wxyzxyz = lambda x: jnp.hstack([SO3.exp(x[3:]).parameters(), x[:3]])

# constants
hand_pc = hand.get_surface_points_fn(jnp.array([0.04, 0.04]))
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])
obj_start_posevec = to_posevec(obj_start.pose.parameters())
obj_goal_posevec = to_posevec(obj_goal.pose.parameters())

# prepare functions
env = SDFContainer([table_start, table_goal], 0.05)
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)

@jax.jit
def get_hand_pc(grasp, posevec):
    grasp_pose = grasp_reconst(grasp)
    hand_base_pose_wrt_world = SE3(to_wxyzxyz(posevec)) @ grasp_pose @ hand_pose_wrt_ee
    assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
    return assigned_hand_pc

#constr fns
def grasp_cfn(grasp):
    return grasp_logit_fn(grasp)
jac_grasp_cfn = jax.grad(grasp_cfn, argnums=[0])

def manip_constr_fn(grasp, posevec):
    obj_pose = SE3(to_wxyzxyz(posevec))
    grasp_pose = obj_pose @ grasp_reconst(grasp)
    zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
    grasp_pose_flip = grasp_pose @ zflip
    posevecs = [to_posevec(pose.parameters()) for pose in [grasp_pose, grasp_pose_flip]]
    return jax.vmap(manip_fn)(jnp.vstack(posevecs)).max()
jac_manip_constr_fn = jax.grad(manip_constr_fn, argnums=[0,1])

# manip_cfn_start = partial(manip_constr_fn,
#                           posevec=obj_start_posevec)
# jac_manip_cfn_start = jax.grad(manip_cfn_start, argnums=[0])
# manip_cfn_goal = partial(manip_constr_fn, 
#                             posevec=obj_goal_posevec)
# jac_manip_cfn_goal = jax.grad(manip_cfn_goal, argnums=[0])

#TODO is it necessary to define 4dim distance constr?
def _dist_cfn(g1, posevec1, posevec2):
    obj_poses = jnp.vstack([posevec1, posevec2])
    pcs = jax.vmap(get_hand_pc, (None,0))(g1, obj_poses)
    distances = env.distances(jnp.vstack(pcs)).reshape(2, -1)
    # top4_indices = jnp.argpartition(distances, 1)[:1]
    return distances.min(axis=-1)
dist_cfn = partial(_dist_cfn, posevec1=obj_start_posevec, posevec2=obj_goal_posevec)
jac_dist_cfn = lambda grasp: [jax.jacrev(dist_cfn)(grasp)]


In [11]:
bdr = SparseIPOPT()
bdr.add_variable("g_pick", 3, -1., 1.)
bdr.add_parameter("p_start", 6)
bdr.add_parameter("p_goal", 6)

def debug_callback(x_dict):
    grasp = x_dict["g_pick"]
    p_start = x_dict["p_start"]
    p_goal = x_dict["p_goal"]
    posevecs = jnp.vstack([p_start, p_goal])
    points = jax.vmap(get_hand_pc, in_axes=(None,0))(grasp, posevecs)
    pc_hand.reload(points=np.vstack(points))
bdr.set_debug_callback(debug_callback)

bdr.register_fn("grasp_logit_fn", [3], 1,
                          grasp_cfn, jac_grasp_cfn)
bdr.register_fn("manip_fn", [3, 6], 1,
                          manip_constr_fn, jac_manip_constr_fn)
bdr.register_fn("dist_fn", [3], 2,
                          dist_cfn, jac_dist_cfn)

bdr.set_constr("grasp_prob_pick", "grasp_logit_fn", ["g_pick"], 
                   1., np.inf)
bdr.set_constr("manip_pick", "manip_fn", ["g_pick", "p_start"],
                   0.3, np.inf)
bdr.set_constr("manip_place", "manip_fn", ["g_pick", "p_goal"],
                   0.3, np.inf)
bdr.set_constr("dist", "dist_fn", ["g_pick"], 
                   0.05, np.inf)

In [12]:
ipopt = bdr.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
o o o - - - - - - - - - - - - 
o o o - - - - - - - - - - - - 
o o o - - - - - - - - - - - - 
o o o - - - - - - - - - - - - 
o o o - - - - - - - - - - - - 


In [30]:
obj_start_posevec2 = obj_start_posevec+ np.array([0., 0., 0.0, 0., 0., 0.])
obj_goal_posevec2 = obj_goal_posevec+ np.array([0.0, 0., 0.0, 0., 0., 0.0])

In [31]:
xinit = jnp.hstack([np.random.uniform(-1,1,size=3), obj_start_posevec2, obj_goal_posevec2])
xsol, info = ipopt.solve(xinit)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:       15
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       15
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:        5
        inequality constraints with only lower bounds:        5
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 2.88e+02 1.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [221]:
@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_pose and geometric jacobian
    fks = panda_model.fk_fn(q)
    ee = SE3(fks[-1])
    p_ee = fks[-1][-3:]
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T #geometric jacobian
    return ee, jac

# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    def identity(v, vmag):
        return np.eye(3)
    def angvel_to_ecvel(v, vmag):
        vskew = skew(v)
        term3 = vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))
        return jnp.eye(3) - 1/2*skew(v) + term3
    vmag = jnp.linalg.norm(v)
    return jax.lax.cond(vmag < 1e-3, identity, angvel_to_ecvel, v, vmag)

def pose_error(posevec, posevec_d):
    pose = SE3(to_wxyzxyz(posevec))
    pose_d = SE3(to_wxyzxyz(posevec_d))
    R = pose.rotation().as_matrix()
    err_pos = R.T@(pose_d.translation() - pose.translation())
    err_rot = pose.rotation().inverse() @ pose_d.rotation()
    return jnp.hstack([err_pos, err_rot.log()])

def get_grasp_pose(grasp, obj_posevec):
    obj_pose = SE3(to_wxyzxyz(obj_posevec))
    grasp_pose = obj_pose @ grasp_reconst(grasp)
    return to_posevec(grasp_pose.parameters())

def kin_error_fn(q, grasp, p_obj):
    obj_pose = SE3(to_wxyzxyz(p_obj))
    target_pose = obj_pose @ grasp_reconst(grasp)
    ee_pose, _ = get_ee_fk_jac(q)
    target_posevec = to_posevec(target_pose.parameters())
    ee_posevec = to_posevec(ee_pose.parameters())
    # represented in body
    err = pose_error(ee_posevec, target_posevec)
    # R_ee = ee_pose.rotation().as_matrix()
    # err_pos = R_ee.T@(target_pose.translation() - ee_pose.translation())
    # err_rot = ee_pose.rotation().inverse() @ target_pose.rotation()
    # err = jnp.hstack([err_pos, err_rot.log()])
    return err

def jac_kin_error_fn(q, grasp, p_obj):
    ee_pose, geom_jac = get_ee_fk_jac(q)
    ee_posevec = to_posevec(ee_pose.parameters())
    obj_pose = SE3(to_wxyzxyz(p_obj))
    target_pose = obj_pose @ grasp_reconst(grasp)
    target_posevec = get_grasp_pose(grasp, p_obj) # obj_pose @ grasp_reconst(grasp)

    # ee_jac, pose_jac = jax.jacrev(pose_error)(ee_posevec, target_posevec)
    err_rot = ee_pose.rotation().inverse() @ target_pose.rotation()
    R_ee = ee_pose.rotation().as_matrix()
    B = get_rotvec_angvel_map(err_rot.log())
    jac_pos = - R_ee.T @ geom_jac[:3]
    jac_rot = - B @ R_ee.T @ geom_jac[3:]
    jac_q = jnp.vstack([jac_pos, jac_rot])
    jac_poseerr_gpose = jax.jacrev(pose_error, argnums=1)(ee_posevec, target_posevec)
    jac_grasp, jac_objpose = jax.jacrev(get_grasp_pose, argnums=[0,1])(grasp, p_obj)
    return jac_q, jac_poseerr_gpose@jac_grasp, jac_poseerr_gpose@jac_objpose

def travelled_distance(*qs):
    qs_mat = jnp.vstack([panda.neutral, *qs, panda.neutral])
    qdiff = qs_mat[1:] - qs_mat[:-1]
    return 0.5*jnp.sum(qdiff.flatten() ** 2)

# def debug_callback(x_dict):
#     q = x_dict["q_pick"]
#     panda.set_joint_angles(q)
#     time.sleep(0.5)

In [222]:
num_robot_points = 200
num_link_points = 20
safe_dist = 0.05

#get links_points_mat
links_points_mat = []
for link in panda_model.links.values():
    if not link.has_mesh: continue
    if not panda_model.is_floating and link == panda_model.root_link: continue
    links_points_mat.append(link.surface_points)
links_points_mat = np.array(links_points_mat)
fk_assign = lambda wxyzxyz, link_points: jax.vmap(SE3(wxyzxyz).apply)(link_points)

In [223]:
def point_jacobian(point, link_idx, joint_frames):
    def get_lin_vel(target_point, joint_frame):
        joint_to_target = target_point - joint_frame[-3:]
        rot_axis = SO3(joint_frame[:4]).as_matrix()[:3,2]
        return jnp.cross(rot_axis, joint_to_target)
    lin_jac = jax.vmap(get_lin_vel, in_axes=(None,0))(point, joint_frames).T
    masking = np.tile(np.arange(7),3).reshape(-1,7) + 1
    masking = jnp.where(masking > link_idx, 0, 1)
    return masking * lin_jac

def distance_fn(q):
    fks = panda_model.fk_fn(q)
    link_frames = fks[1:-1]
    assigned_points = jax.vmap(fk_assign)(link_frames, links_points_mat)
    distances = env.distances(jnp.vstack(assigned_points))
    min_distance = distances.min() - safe_dist
    return jnp.where(min_distance < 0., min_distance, 0.)

def jac_distance_fn(q):
    fks = panda_model.fk_fn(q)
    joint_frames = fks[1:8]
    assigned_points = jax.vmap(fk_assign)(fks[1:-1], links_points_mat)
    distances = env.distances(jnp.vstack(assigned_points)).reshape(10,-1)
    idx_link, idx_point = jnp.unravel_index(distances.argmin(), distances.shape)
    min_point = assigned_points[idx_link, idx_point, :]
    jac_point = jax.lax.cond(distances.min() < safe_dist, 
                point_jacobian, lambda x,y,z :np.zeros((3,7)), min_point, idx_link, joint_frames)
    repulsive_grad = jax.grad(env.distance)(min_point)
    return repulsive_grad @ jac_point

In [224]:
num_mid_configs = 3
num_traj = 2 + num_mid_configs*3
idx_pick = num_mid_configs
idx_place = 2*num_mid_configs + 1
qs = [panda.get_random_config() for i in range(num_traj)]
qs = jnp.vstack(qs)
def jac_travelled_distance(*qs):
    return jax.grad(travelled_distance, argnums=np.arange(num_traj))(*qs)

In [225]:
# #functions
# kin_error_fn_start = partial(kin_error_fn, grasp=xsol, p_obj=obj_start_posevec)
# jac_kin_error_fn_start = partial(jac_kin_error_fn, grasp=xsol, p_obj=obj_start_posevec)
# kin_error_fn_goal = partial(kin_error_fn, grasp=xsol, p_obj=obj_goal_posevec)
# jac_kin_error_fn_goal = partial(jac_kin_error_fn, grasp=xsol, p_obj=obj_goal_posevec)

In [226]:
grasp = xsol[:3]

bdr = SparseIPOPT()
q_names = []
for i in range(num_traj):
    q_name = f"q{i}"
    bdr.add_variable(q_name, 7, panda.lb, panda.ub)
    q_names.append(q_name)
bdr.add_parameter("grasp", 3)
bdr.add_parameter("p_start", 6)
bdr.add_parameter("p_goal", 6)

bdr.register_fn("kin_err", [7], 6,
                kin_error_fn, jac_kin_error_fn)
# bdr.register_fn("kin_err_goal", [7], 6,
#                           kin_error_fn_goal, jac_kin_error_fn_goal)
bdr.register_fn("travelled_distance", [7]*num_traj, 1,
                travelled_distance, jac_travelled_distance)
# bdr.register_fn("distance_fn", [7]*num_traj, num_traj,
#                 jax.vmap(distance_fn), jax.vmap(jac_distance_fn))
#bdr.set_debug_callback(debug_callback)

bdr.set_objective("travelled_distance", q_names)
bdr.set_constr("kin_err_start", "kin_err", [f"q{idx_pick}", "grasp", "p_start"], 
                   0., 0.)
bdr.set_constr("kin_err_goal", "kin_err", [f"q{idx_place}", "grasp", "p_goal"], 
                   0., 0.)
# bdr.set_constr("manip_pick", "manip_fn_start", ["g_pick"],
#                    0.3, np.inf)
# bdr.set_constr("dist", "dist_fn", ["g_pick"], 
#                    0.05, np.inf)

In [227]:
ipopt = bdr.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
- - - - - - - - - - - - - - - - - - - - - o o o o o o o - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
- - - - - - - - - - - - - - - - - - - - - o o o o o o o - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
- - - - - - - - - - - - - - - - - - - - - o o o o o o o - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
- - - - - - - - - - - - - - - - - - - - - o o o o o o o - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
- - - - - - - - - - - - - - - - - - - - - o o o o o o o - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

In [231]:
jac_fn = bdr.get_jacobian_fn()

In [157]:
   
indices_var = [i for i, var in enumerate(constr.inputs) 
                if isinstance(var, Variable)]

In [163]:
constr.fn.jac_fn(*fn_input).shape

(6, 7)

In [158]:
indices_var

[0]

In [238]:
bdr.get_jacobian_fn()(x0)

Array([-4.33199584e-01, -2.31209863e-02, -4.33199584e-01, -1.73506126e-01,
       -1.25163734e-01, -1.49906665e-01, -5.92689986e-09,  4.33199495e-01,
       -2.31209658e-02,  4.33199495e-01, -1.73506051e-01,  1.25163704e-01,
       -1.49906650e-01,  4.60980987e-09, -4.42835635e-09, -6.50624454e-01,
       -4.42835635e-09,  4.79341596e-01, -8.73500050e-10,  8.79999846e-02,
       -7.86546060e-17, -1.26841128e-01, -9.51816201e-01, -1.26841128e-01,
        9.51816201e-01, -3.64949912e-01,  9.51816320e-01,  1.45999491e-02,
       -1.67321891e-01,  3.82988185e-01, -1.67321891e-01, -3.82988185e-01,
       -9.40804362e-01, -3.82988244e-01, -1.15048304e-01,  9.79172051e-01,
       -6.95779398e-02,  9.79172051e-01,  6.95779994e-02, -2.02340573e-01,
        6.95778728e-02, -9.95544910e-01, -4.33199584e-01, -2.31209863e-02,
       -4.33199584e-01, -1.73506126e-01, -1.25163734e-01, -1.49906665e-01,
       -5.92689986e-09,  4.33199495e-01, -2.31209658e-02,  4.33199495e-01,
       -1.73506051e-01,  

In [239]:
bdr.get_objective_fn()(x0)

Array(0., dtype=float32)

In [240]:
bdr.get_gradient_fn()(x0)

(Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32))

In [250]:
xs = {var.name:x[var.coord] for var in bdr.x_info.values()}
fn_input = [xs[var.name] for var in bdr.obj_info["inputs"]] 

Array([ 0.        ,  0.        ,  0.        , -1.5708    ,  0.        ,
        1.8675    ,  0.        ,  0.        ,  0.        ,  0.        ,
       -1.5708    ,  0.        ,  1.8675    ,  0.        ,  0.        ,
        0.        ,  0.        , -1.5708    ,  0.        ,  1.8675    ,
        0.        ,  0.        ,  0.        ,  0.        , -1.5708    ,
        0.        ,  1.8675    ,  0.        ,  0.        ,  0.        ,
        0.        , -1.5708    ,  0.        ,  1.8675    ,  0.        ,
        0.        ,  0.        ,  0.        , -1.5708    ,  0.        ,
        1.8675    ,  0.        ,  0.        ,  0.        ,  0.        ,
       -1.5708    ,  0.        ,  1.8675    ,  0.        ,  0.        ,
        0.        ,  0.        , -1.5708    ,  0.        ,  1.8675    ,
        0.        ,  0.        ,  0.        ,  0.        , -1.5708    ,
        0.        ,  1.8675    ,  0.        ,  0.        ,  0.        ,
        0.        , -1.5708    ,  0.        ,  1.8675    ,  0.  

In [251]:
grads = bdr.obj_info["fn"].jac_fn(*fn_input)
for 

(Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32))

In [237]:
jac_kin_error_fn(panda.neutral, xsol[:3], obj_start_posevec)

(Array([[-4.33199584e-01, -2.31209844e-02, -4.33199584e-01,
         -1.73506126e-01, -1.25163734e-01, -1.49906665e-01,
         -5.92689986e-09],
        [ 4.33199495e-01, -2.31209695e-02,  4.33199495e-01,
         -1.73506051e-01,  1.25163704e-01, -1.49906650e-01,
          4.60980987e-09],
        [-4.42835635e-09, -6.50624454e-01, -4.42835635e-09,
          4.79341567e-01, -8.73500050e-10,  8.79999846e-02,
         -8.32667268e-17],
        [-1.26841098e-01, -9.51816320e-01, -1.26841098e-01,
          9.51816320e-01, -3.64949971e-01,  9.51816440e-01,
          1.45998970e-02],
        [-1.67321920e-01,  3.82988244e-01, -1.67321920e-01,
         -3.82988244e-01, -9.40804482e-01, -3.82988304e-01,
         -1.15048304e-01],
        [ 9.79172051e-01, -6.95779622e-02,  9.79172051e-01,
          6.95780218e-02, -2.02340573e-01,  6.95778951e-02,
         -9.95544910e-01]], dtype=float32),
 Array([[ 9.0177409e-02,  9.4297603e-02,  2.7569398e-02],
        [ 9.0177417e-02, -9.4297610e-02,  2

In [236]:
qinit = np.tile(panda.neutral, num_traj)
x0 = jnp.hstack([qinit, xsol[:3], obj_start_posevec, obj_goal_posevec])
ipopt.add_option("nlp_scaling_method", "none")
traj, info = ipopt.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:       84
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0



IndexError: Out of bounds on buffer access (axis 0)

Exception ignored in: 'ipopt_wrapper.gradient_cb'
Traceback (most recent call last):
  File "/tmp/ipykernel_340429/2521261114.py", line 4, in <module>
IndexError: Out of bounds on buffer access (axis 0)



Number of Iterations....: 0

Number of objective function evaluations             = 0
Number of objective gradient evaluations             = 1
Number of equality constraint evaluations            = 0
Number of inequality constraint evaluations          = 0
Number of equality constraint Jacobian evaluations   = 1
Number of inequality constraint Jacobian evaluations = 0
Number of Lagrangian Hessian evaluations             = 0
Total seconds in IPOPT                               = 0.004

EXIT: Invalid number in NLP function or derivative detected.


In [688]:
i = 0

In [702]:
panda.set_joint_angles(traj.reshape(-1, 7)[i])
i+= 1

In [291]:

ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])

grasp_fn = lambda g: 0.
manip_fn = lambda g, pose: 0.
_dist_fn = lambda g1, g2, pose, pose_st, pose_ed: jnp.zeros(4)
manip_fn_start = partial(manip_fn, pose=np.zeros(6))
manip_fn_goal = partial(manip_fn, pose=np.zeros(6))
dist_fn = partial(_dist_fn, pose_st=np.zeros(6), pose_ed=np.zeros(6))

jac_grasp_fn = lambda g: [jnp.zeros(3)]
jac_manip_fn = lambda g, pose: [jnp.zeros(3), jnp.zeros(6)]
jac_manip_fn_start = lambda g: [jnp.zeros(3)]
jac_manip_fn_goal = lambda g: [jnp.zeros(3)]
jac_dist_fn = lambda g1, g2, pose: [jnp.zeros((4,3)), jnp.zeros((4,3)), jnp.zeros((4,6))]

bdr = SparseIPOPT()
bdr.add_variable("g_pick", 3, -1., 1.)
bdr.add_variable("g_place", 3, -1., 1.)
bdr.add_variable("p_ho", 6, ws_lb, ws_ub)

bdr.register_fn("grasp_logit_fn", [3], 1,
                          grasp_fn, jac_grasp_fn)
bdr.register_fn("manip_fn", [3, 6], 1,
                          manip_fn, jac_manip_fn)
bdr.register_fn("manip_fn_start", [3], 1,
                          manip_fn_start, jac_manip_fn_start)
bdr.register_fn("manip_fn_goal", [3], 1,
                          manip_fn_goal, jac_manip_fn_goal)
bdr.register_fn("dist_fn", [3, 3, 6], 4,
                          dist_fn, jac_dist_fn)

bdr.set_constr("grasp_prob_pick", "grasp_logit_fn", ["g_pick"], 
                   1., np.inf)
bdr.set_constr("grasp_prob_place", "grasp_logit_fn", ["g_place"], 
                   1., np.inf)
bdr.set_constr("manip_pick", "manip_fn_start", ["g_pick"],
                   0.3, np.inf)
bdr.set_constr("manip_place", "manip_fn_goal",["g_place"], 
                   0.3, np.inf)
bdr.set_constr("manip_ho_1", "manip_fn", ["g_pick", "p_ho"], 
                   0.3, np.inf)
bdr.set_constr("manip_ho_2", "manip_fn", ["g_place", "p_ho"], 
                   0.3, np.inf)
bdr.set_constr("dist", "dist_fn", ["g_pick", "g_place", "p_ho"], 
                   0.05, np.inf)

In [292]:
bdr.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
o o o - - - - - - - - - 
- - - o o o - - - - - - 
o o o - - - - - - - - - 
- - - o o o - - - - - - 
o o o - - - o o o o o o 
- - - o o o o o o o o o 
o o o o o o o o o o o o 
o o o o o o o o o o o o 
o o o o o o o o o o o o 
o o o o o o o o o o o o 


<ipopt_wrapper.Problem at 0x7f6ecc6ee1c0>

In [283]:
obj = bdr.get_objective_fn()
grad = bdr.get_gradient_fn()
constr = bdr.get_constraint_fn()
jac = bdr.get_jacobian_fn()

In [284]:
%timeit constr(x)

17.1 µs ± 606 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [255]:
obj(x)

Array(0., dtype=float32)

In [257]:
%timeit obj(x)

208 µs ± 1.82 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [217]:
jac = bdr.get_jacobian_fn()

In [218]:
%timeit jac(x)

6.09 ms ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [220]:
jac = jax.jit(jac)

In [222]:
%timeit jac(x)

16.9 µs ± 59.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [185]:
x = np.arange(bdr.xdim)

In [120]:
def constraints(x):
    xs = {var.name:x[var.coord] for var in bdr.x_info.values()}
    result = []
    for constr in bdr.c_info.values():
        fn_input = [xs[var.name] for var in constr.inputs]    
        out = constr.fn.eval_fn(*fn_input)
        result.append(out)
    return np.hstack(result)

In [153]:
def jacobian(x):
    xs = {var.name:x[var.coord] for var in bdr.x_info.values()}
    result = []
    for constr in bdr.c_info.values():
        fn_input = [xs[var.name] for var in constr.inputs]    
        jacs = constr.fn.jac_fn(*fn_input)
        for jac in jacs:
            result.append(jac.flatten())
    return np.hstack(result)

In [164]:
def get_jacobian_structure():
    rows, cols = [], []
    for constr in bdr.c_info.values():
        for var in constr.inputs:
            row, col = np.indices((constr.dim, var.dim)).reshape(2, -1)
            row += constr.coord[0]
            col += var.coord[0]
            rows.append(row)
            cols.append(col)
    rows = np.hstack(rows)
    cols = np.hstack(cols)
    return rows, cols

In [181]:
def objective(x):
    if "fn" in bdr.obj_info: 
        val = 0. #no objective
    else:
        xs = {var.name:x[var.coord] for var in bdr.x_info.values()}
        fn_input = [xs[var.name] for var in bdr.obj_info["inputs"]]    
        val = bdr.obj_info["fn"].eval_fn(*fn_input)
    if "debug_cb" in bdr.obj_info:
            bdr.obj_info["debug_cb"]
    return val

In [180]:
def gradient(x):
    if "fn" in bdr.obj_info: 
        grad = np.zeros(bdr.xdim) #no objective
    else:
        xs = {var.name:x[var.coord] for var in bdr.x_info.values()}
        fn_input = [xs[var.name] for var in bdr.obj_info["inputs"]]    
        grad = bdr.obj_info["fn"].jac_fn(*fn_input)
    return grad

In [179]:
gradient(x)

[]

In [138]:
xdim = constr.inputs[0].dim
cdim = constr.dim

In [143]:
row, col = np.indices((cdim, xdim)).reshape(2, -1)

In [147]:
row = row + constr.coord[0]
col = col + constr.inputs[0].coord[0]

In [148]:
row, col

(array([12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15]),
 array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]))

[Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float32),
 Array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=float32),
 Array([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], dtype=float32)]

In [47]:
# jacobian sparsity
x_idx = {}
x_idx["x1"] = np.arange(0, 3)
x_idx["x2"] = np.arange(3, 6)
x_idx["x3"] = np.arange(6, 9)

constr_idx = {}
constr_idx["c1"] = np.arange(0, 3)
constr_idx["c2"] = np.arange(3, 6)

3

In [45]:
rowcol = np.indices((len(x_idx["x1"]), len(constr_idx["c1"]))) 

In [46]:
rowcol.reshape(2, -1)

array([[0, 0, 0, 1, 1, 1, 2, 2, 2],
       [0, 1, 2, 0, 1, 2, 0, 1, 2]])

In [44]:
rowcol

array([[[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2]],

       [[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2]]])

In [26]:
%timeit x[idx["x3"]]

139 ns ± 0.334 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [16]:
x

array([ 0,  4,  8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64,
       68, 72, 76, 80])

In [8]:
x[x_idx["x1"][0]:]

1

In [13]:
x = np.arange(21) * 4