In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import cyipopt

from functools import partial
from typing import *
from dataclasses import dataclass, field
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle
import time

In [105]:
@dataclass
class Variable:
    name: str
    coord: np.ndarray
    lb: np.ndarray
    ub: np.ndarray
    
    @property
    def dim(self): return len(self.coord)

@dataclass
class Parameter:
    name: str
    coord: np.ndarray
    dim: int

    @property
    def lb(self): return np.full(self.dim, -np.inf)
    @property
    def ub(self): return np.full(self.dim, np.inf)

@dataclass
class Constraint:
    name: str
    coord: np.ndarray
    inputs: List[Variable]
    fn: "Function"
    lb: np.ndarray
    ub: np.ndarray
    no_deriv_names: List[str]
    jac_indices: np.ndarray

    @property
    def dim(self): return len(self.coord)

@dataclass
class Function:
    name: str
    in_dims: List[int]
    out_dim: int
    eval_fn: Callable
    jac_fn: Callable
    custom_jac_indices: Optional[List[np.ndarray]]
    constraints: List[Constraint] = field(default_factory=list)

class SparseIPOPT():
    def __init__(self):
        self.x_info: Dict[str,Variable] = {}
        self.c_info: Dict[str,Constraint] = {}
        self.fn_info: Dict[str,Function] = {}
        self.obj_info: Dict = {}
        self.param_info: Dict[str, Array] = {}

        self.x_idx, self.c_idx = 0, 0
        self.param_info: Dict[str, np.ndarray] = {}

    @property
    def xdim(self):
        return sum([x.dim for x in self.x_info.values()])
    @property
    def cdim(self):
        return sum([c.dim for c in self.c_info.values()])
    @property
    def input_info(self):
        return {**self.x_info, **self.param_info}
    
    def add_variable(self, name, dim, lb=-np.inf, ub=np.inf):
        assert name not in self.x_info
        assert isinstance(lb, float) or len(lb) == dim
        assert isinstance(ub, float) or len(ub) == dim

        if isinstance(lb, float): lb = np.full(dim, lb)
        if isinstance(ub, float): ub = np.full(dim, ub)

        coord = np.arange(self.x_idx, self.x_idx+dim)
        self.x_info[name] = Variable(
            name, coord, lb, ub)
        
        self.x_idx += dim
    
    def add_parameter(self, name, dim):
        assert name not in self.param_info
        #TODO: name also should not be in "variable"
        coord = np.arange(self.x_idx, self.x_idx+dim)
        self.x_info[name] = Parameter(name, coord, dim)
        self.x_idx += dim
    
    def set_objective(self, fn_name, input_x_names):
        self.obj_info["fn"] = self.fn_info[fn_name]
        self.obj_info["inputs"] = [self.x_info[name] for name in input_x_names]
    
    def set_debug_callback(self, debug_callback:Callable):
        self.obj_info["debug_cb"] = debug_callback

    def set_constr(self, name, cfn_name, input_x_names, lb, ub, no_deriv_names=[]):
        c_fn = self.fn_info[cfn_name]
        cdim = c_fn.out_dim
        assert name not in self.c_info
        assert isinstance(lb, float) or len(lb) == cdim
        assert isinstance(ub, float) or len(ub) == cdim
        if isinstance(lb, float): lb = np.full(cdim, lb)
        if isinstance(ub, float): ub = np.full(cdim, ub)

        vars = [self.x_info[name] for name in input_x_names]
        c_coord = np.arange(self.c_idx, self.c_idx+cdim)
        
        # if deriv_x_names is not None:
        #     deriv_vars = [self.x_info[name] for name in deriv_x_names]
        # else:
        #     deriv_vars = []
        #     for i, var in enumerate(vars):
        #         if isinstance(var, Parameter): continue
        #         deriv_vars.append(var)

        jac_indices = []
        for i, var in enumerate(vars):
            if var.name in no_deriv_names: continue
            if isinstance(var, Parameter): continue
            
            if c_fn.custom_jac_indices is not None:
                row, col = c_fn.custom_jac_indices[i]
            else:
                row, col = np.indices((cdim, var.dim)).reshape(2, -1)
            row_offset, col_offset = c_coord[0], var.coord[0] # offset
            jac_indices.append(np.vstack([row+row_offset, col+col_offset]))

        self.c_info[name] = Constraint(
            name, c_coord, vars, 
            c_fn, lb, ub,
            no_deriv_names, jac_indices)
        c_fn.constraints.append(self.c_info[name])
        self.c_idx += cdim

    def register_fn(self, name, in_dims, out_dim, eval_fn, jac_fn, custom_jac_indices=None):
        xdummies = [jnp.zeros(dim) for dim in in_dims]
        assert eval_fn(*xdummies).size == out_dim
        assert len(jac_fn(*xdummies)) == len(in_dims)
        if custom_jac_indices is not None:
            assert len(custom_jac_indices) == len(in_dims)
        self.fn_info[name] = Function(
            name, in_dims, out_dim, eval_fn, jac_fn, custom_jac_indices)
    
    def get_objective_fn(self, compile=True):
        no_obj = False
        if "fn" not in self.obj_info: 
            objective = lambda x: 0.
            no_obj = True
        else:
            def objective(x):        
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                fn_input = [xs[var.name] for var in self.obj_info["inputs"]]    
                val = self.obj_info["fn"].eval_fn(*fn_input)
                return val
        
        if "debug_cb" in self.obj_info:
            def objective_debug(x):
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                self.obj_info["debug_cb"](xs)    
                return objective(x)
            return objective_debug
        elif compile and not no_obj:
            return jax.jit(objective)
        return objective
    
    def get_gradient_fn(self, compile=True):
        no_obj = False
        if "fn" not in self.obj_info: 
            gradient = lambda x: np.zeros(self.xdim)
            no_obj = True
        else:
            grad_value_dict = {var.name: np.zeros(var.dim) for var in self.x_info.values()}
            def gradient(x):
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                fn_input = [xs[var.name] for var in self.obj_info["inputs"]]    
                grads = self.obj_info["fn"].jac_fn(*fn_input)
                for var, grad in zip(self.obj_info['inputs'], grads):
                    grad_value_dict[var.name] = grad
                return jnp.hstack(grad_value_dict.values())
        if compile and not no_obj:
            return jax.jit(gradient)
        return gradient      
    
    def get_constraint_fn(self, compile=True):
        def constraints(x):
            xs = {var.name:x[var.coord] for var in self.x_info.values()}
            result = []
            for constr in self.c_info.values():
                fn_input = [xs[var.name] for var in constr.inputs]    
                out = constr.fn.eval_fn(*fn_input)
                result.append(out)
            return jnp.hstack(result)
        if compile:
            return jax.jit(constraints)
        return constraints
    
    def get_jacobian_fn(self, compile=True):
        def jacobian(x):
            xs = {var.name:x[var.coord] for var in self.x_info.values()}
            result = []
            for constr in self.c_info.values():
                fn_input = [xs[var.name] for var in constr.inputs]    
                jacs = constr.fn.jac_fn(*fn_input)
                for i, var in enumerate(constr.inputs):
                    if var.name in constr.no_deriv_names: continue
                    elif isinstance(var, Parameter): continue
                    result.append(jacs[i].flatten())
            return jnp.hstack(result)
        if compile:
            return jax.jit(jacobian)
        return jacobian
                # indices_var = [i for i, var in enumerate(constr.inputs) 
                #                if isinstance(var, Variable)]
                # idx_jac = 0
                #     if constr.jac_indices[i] is not None:
                # for i, jac in enumerate(jacs):
                    # if constr.jac_indices[i] is not None:
                        # result.append(jac.flatten())
                    #if i in indices_var:
                        # result.append(jac.flatten())
    
    def get_jacobian_structure(self):
        rows, cols = [], []
        rows, cols = [], []
        for constr in self.c_info.values():
            for jac_idx in constr.jac_indices:
                if jac_idx is None: continue
                rows.append(jac_idx[0])
                cols.append(jac_idx[1])
        rows = np.hstack(rows)
        cols = np.hstack(cols)
        return rows, cols

    def print_sparsity(self):
        row, col = self.get_jacobian_structure()
        jac_struct = np.full((self.cdim, self.xdim), -1, dtype=int)
        jac_struct[row, col] = 1
        for row in jac_struct:
            row_str = ""
            for val in row:
                if val == -1: row_str += "-"
                else: row_str += f"o"
            print(row_str)
    
    def build(self, compile=True):
        lb = np.hstack([x.lb for x in self.x_info.values()])
        ub = np.hstack([x.ub for x in self.x_info.values()])
        cl = np.hstack([c.lb for c in self.c_info.values()])
        cu = np.hstack([c.ub for c in self.c_info.values()])
        row, col = self.get_jacobian_structure()
        jac_struct_fn = lambda : (row, col)

        fns = {
            "objective": self.get_objective_fn(compile),
            "gradient": self.get_gradient_fn(compile),
            "constraints": self.get_constraint_fn(compile),
            "jacobian": self.get_jacobian_fn(compile),
        }
        class Prob:
            pass
        prob = Prob()
        xdummy = jnp.zeros(self.xdim)
        for fn_name, fn in fns.items():
            print(f"compiling {fn_name} ...")
            fn(xdummy)
            setattr(prob, fn_name, fn)
        setattr(prob, "jacobianstructure", jac_struct_fn)

        ipopt = cyipopt.Problem(
            n=self.xdim, m=self.cdim,
            problem_obj=prob,
            lb=lb, ub=ub, cl=cl, cu=cu
        )
        # default option
        ipopt.add_option("acceptable_iter", 2)
        ipopt.add_option("acceptable_tol", np.inf) #release
        ipopt.add_option("acceptable_obj_change_tol", 0.1)
        ipopt.add_option("acceptable_constr_viol_tol", 1.)
        #ipopt.add_option("acceptable_dual_inf_tol", 1.) 
        ipopt.add_option('mu_strategy', 'adaptive')
        self.print_sparsity()
        return ipopt

### Load learned models

In [3]:
# models 
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net_posevec")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]

#manip net
manip_net = ManipNet(64)
manip_fn = lambda x: manip_net.apply(restored_manip["params"], x)[0]

### Load world

In [4]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7008/static/


In [5]:
world.show_in_jupyter()

In [6]:
# robot, hand
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)
hand = Robot(world.vis, "hand1", hand_model, color="white", alpha=0.5)



In [7]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obstacle = Box(world.vis, "obstacle", [0.4, 0.2, 0.35], 'white', 0.5)

obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)
obj = Mesh(world.vis, "obj", 
                "./sdf_world/assets/object/mesh.obj",
                color="white", alpha=0.8)

In [8]:
table_start.set_translate([0.4, -0.35, 0.2/2])
table_goal.set_translate([0.4, 0.35, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.4, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.4, 0.4, obj_lengths[-2]/2+table_lengths[-1]])
obstacle.set_translate([0.4, 0., 0.35/2])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)
obj.set_translate([0.4, -0.4, obj_lengths[-1]/2+table_lengths[-1]])

In [9]:
#visualization
pc_hands = PointCloud(world.vis, "pc_hands", np.zeros((10,3)), 0.01, "red")
dotted_line = DottedLine(world.vis, "traj", np.zeros((10,3)), 0.01, "blue")
pc_body = PointCloud(world.vis, "body_pc", np.zeros((10,3)), 0.01, "blue")
pc_obj = PointCloud(world.vis, "pc_obj", np.zeros((10,3)), 0.01, "green")

In [10]:
del pc_hands
del dotted_line
del pc_body
del pc_obj

### Functions

In [10]:
# kinematics
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

@jax.custom_jvp
def fk(q):
    fks = panda_model.fk_fn(q)
    return fks[-1]

@fk.defjvp
def fk_jvp(primals, tangents):
    q, = primals
    q_dot, = tangents
    fks = panda_model.fk_fn(q)
    qtn, p_ee = fks[-1][:4], fks[-1][-3:]
    w, xyz = qtn[0], qtn[1:]
    geom_jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        geom_jac.append(jnp.hstack([lin_vel, rot_axis]))
    geom_jac = jnp.array(geom_jac).T  #geom_jacobian
    H = jnp.hstack([-xyz[:,None], skew(xyz)+jnp.eye(3)*w])
    rot_jac = 0.5*H.T@geom_jac[3:,:]
    jac = jnp.vstack([rot_jac, geom_jac[:3,:]])
    return fks[-1], jac@q_dot

### Constants, data

In [11]:
# constants, data
safe_dist = 0.05
ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
hand_pc = hand.get_surface_points_fn(jnp.array([0.04, 0.04]))

#links_points_mat
links_points_mat = []
for link in panda_model.links.values():
    if not link.has_mesh: continue
    if not panda_model.is_floating and link == panda_model.root_link: continue
    links_points_mat.append(link.surface_points)
links_points_mat = np.array(links_points_mat)

#object points
obj_points = farthest_point_sampling(obj_start.mesh.sample(200), 20)

# environment sdf
env = SDFContainer([table_start, table_goal, obstacle], safe_dist)

In [12]:
def point_jacobian(point, link_idx, joint_frames):
    def get_lin_vel(target_point, joint_frame):
        joint_to_target = target_point - joint_frame[-3:]
        rot_axis = SO3(joint_frame[:4]).as_matrix()[:3,2]
        return jnp.cross(rot_axis, joint_to_target)
    lin_jac = jax.vmap(get_lin_vel, in_axes=(None,0))(point, joint_frames).T
    masking = np.tile(np.arange(7),3).reshape(-1,7) + 1
    masking = jnp.where(masking > link_idx+1, 0, 1)
    return masking * lin_jac

@jax.custom_jvp
def min_collision_fn(q):
    fk_assign = lambda wxyzxyz, link_points: jax.vmap(SE3(wxyzxyz).apply)(link_points)
    fks = panda_model.fk_fn(jnp.hstack([q, 0.04, 0.04]))
    link_frames = fks[1:-1]
    assigned_points = jax.vmap(fk_assign)(link_frames, links_points_mat)
    distances = env.distances(jnp.vstack(assigned_points))
    min_distance = distances.min()
    return min_distance

@min_collision_fn.defjvp
def min_collision_jvp(primals, tangents):
    fk_assign = lambda wxyzxyz, link_points: jax.vmap(SE3(wxyzxyz).apply)(link_points)
    q, = primals
    q_dot, = tangents
    fks = panda_model.fk_fn(q)
    joint_frames = fks[1:8]
    assigned_points = jax.vmap(fk_assign)(fks[1:-1], links_points_mat)
    distances = env.distances(jnp.vstack(assigned_points)).reshape(10,-1)
    idx_link, idx_point = jnp.unravel_index(distances.argmin(), distances.shape)
    min_point = assigned_points[idx_link, idx_point, :]
    jac_point = point_jacobian(min_point, idx_link, joint_frames)
    repulsive_grad = jax.grad(env.distance)(min_point)
    return distances.min(), repulsive_grad @ jac_point @ q_dot

In [13]:
@jax.custom_jvp
def min_obj_collision_fn(q, q_pick, obj_pose_at_pick):
    obj_pose_at_pick = SE3(obj_pose_at_pick)
    grasp_pose = obj_pose_at_pick.inverse() @ SE3(fk(q_pick))
    obj_pose = SE3(fk(q)) @ grasp_pose.inverse()
    assigned_obj_points = jax.vmap(obj_pose.apply)(obj_points)
    distances = env.distances(jnp.vstack(assigned_obj_points))
    min_distance = distances.min()
    return min_distance

@min_obj_collision_fn.defjvp
def min_obj_collision_jvp(primals, tangents):
    q, q_pick, obj_pose_at_pick = primals
    q_dot, _, _ = tangents

    fks = panda_model.fk_fn(q)
    joint_frames = fks[1:8]

    obj_pose_at_pick = SE3(obj_pose_at_pick)
    grasp_pose = obj_pose_at_pick.inverse() @ SE3(fk(q_pick))
    obj_pose = SE3(fks[-1]) @ grasp_pose.inverse()
    assigned_obj_points = jax.vmap(obj_pose.apply)(obj_points)
    distances = env.distances(jnp.vstack(assigned_obj_points))

    idx_point = distances.argmin()
    min_point = assigned_obj_points[idx_point]
    jac_point = point_jacobian(min_point, 6, joint_frames)
    repulsive_grad = jax.grad(env.distance)(min_point)
    return distances.min(), repulsive_grad @ (jac_point @ q_dot)

In [14]:
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)
def grasp_embedding(grasp_point):
    grasp = grasp_point * restored_grasp["scale_to_norm"]
    return grasp

In [15]:
#utility functions
to_posevec = lambda x: jnp.hstack([x[4:], SO3(x[:4]).log()])
to_wxyzxyz = lambda x: jnp.hstack([SO3.exp(x[3:]).parameters(), x[:3]])

def get_hand_pc(grasp, wxyzxyz):
    grasp_pose = grasp_reconst(grasp)
    hand_base_pose_wrt_world = SE3(wxyzxyz) @ grasp_pose @ hand_pose_wrt_ee
    assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
    return assigned_hand_pc

def body_pose_error(posevec, posevec_d):
    pose = SE3(to_wxyzxyz(posevec))
    pose_d = SE3(to_wxyzxyz(posevec_d))
    R = pose.rotation().as_matrix()
    err_pos = R.T@(pose_d.translation() - pose.translation())
    err_rot = pose.rotation().inverse() @ pose_d.rotation()
    return jnp.hstack([err_pos, err_rot.log()])

def get_grasp_pose(grasp, wxyzxyz):
    obj_pose = SE3(wxyzxyz)
    grasp_pose = obj_pose @ grasp_reconst(grasp)
    return grasp_pose.normalize().parameters()

# constr fns
def grasp_constr(grasp):
    return grasp_logit_fn(grasp)
def robot_grasp_constr(q, obj_pose):
    tcp = fk(q)[-3:]
    grasp_point = SE3(obj_pose).inverse().apply(tcp)
    grasp = grasp_embedding(grasp_point)
    return grasp_logit_fn(grasp)
def manip_constr(grasp, wxyzxyz):
    grasp_pose = SE3(get_grasp_pose(grasp, wxyzxyz))
    zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
    grasp_pose_flip = grasp_pose @ zflip
    posevecs = [to_posevec(pose.parameters()) for pose in [grasp_pose, grasp_pose_flip]]
    return jax.vmap(manip_fn)(jnp.vstack(posevecs)).max()
def hand_col_constr(g1, wxyzxyz1, wxyzxyz2):
    obj_poses = jnp.vstack([wxyzxyz1, wxyzxyz2])
    pcs = jax.vmap(get_hand_pc, in_axes=(None,0))(g1, obj_poses)
    distances = env.distances(jnp.vstack(pcs)).reshape(2, -1)
    return distances.min(axis=-1)
def kin_constr(q, grasp, p_obj):
    target = get_grasp_pose(grasp, p_obj)
    curr = fk(q)
    pos_err = target[-3:] - curr[-3:]
    rot_err = (SO3(curr[:4]).inverse()@SO3(target[:4])).log()
    return jnp.hstack([pos_err, 0.5*rot_err])

def robot_col_constr(*qs):
    qs = jnp.vstack(qs)
    return jax.vmap(min_collision_fn)(qs)

def jac_robot_col_constr(*qs):
    qs = jnp.vstack(qs)
    return jax.vmap(jax.jacrev(min_collision_fn))(qs)

def obj_col_constr(q_pick, obj_pose_at_pick, *qs):
    qs = jnp.vstack(qs)
    return jax.vmap(min_obj_collision_fn, (0, None, None))(qs, q_pick, obj_pose_at_pick)
def jac_obj_col_constr(q_pick, obj_pose_at_pick, *qs):
    qs = jnp.vstack(qs)
    jacs_traj, jac_q_pick, jac_obj_pose = jax.vmap(
        jax.jacfwd(min_obj_collision_fn, argnums=[0,1,2]), in_axes=(0,None, None)
    )(qs, q_pick, obj_pose_at_pick)
    return [jac_q_pick, jac_obj_pose, *jacs_traj]

def travelled_distance(*qs):
    qs_mat = jnp.vstack([panda.neutral, *qs, panda.neutral])
    qdiff = qs_mat[1:] - qs_mat[:-1]
    return 0.5*jnp.sum(qdiff.flatten() ** 2)


In [41]:
q = panda.neutral
obj_pose = obj_start.pose.parameters()

In [160]:
def get_grasp_from_ik(q_pick, obj_pose_pick):
    ee_pose_pick = SE3(fk(q_pick))
    grasp_pose = SE3(obj_pose_pick).inverse() @ ee_pose_pick
    grasp = grasp_embedding(grasp_pose.translation())
    return grasp

def grasp_logit_q(q_pick, obj_pose_pick):
    ee_pose_pick = SE3(fk(q_pick))
    grasp_pose = SE3(obj_pose_pick).inverse() @ ee_pose_pick
    grasp = grasp_embedding(grasp_pose.translation())
    return grasp_logit_fn(grasp)

zflip = SO3.from_z_radians(np.pi)
def rot_diff(q, q_pick, obj_pose_pick):
    ee_pose = SE3(fk(q))
    grasp = get_grasp_from_ik(q_pick, obj_pose_pick)
    grasp_rot = SO3(grasp_fn(grasp)[2:]).normalize()
    diff1 = ee_pose.rotation().inverse() @ SO3(obj_pose[:4]) @ grasp_rot
    diff2 = ee_pose.rotation().inverse() @ SO3(obj_pose[:4]) @ grasp_rot @ zflip
    diff1, diff2 = diff1.log(), diff2.log()
    compare = jnp.sum(diff1**2) < jnp.sum(diff2**2)
    return jnp.where(compare, diff1, diff2)

In [167]:
bdr1 = SparseIPOPT()
bdr1.add_variable("q_pick", 7, panda.lb, panda.ub)
bdr1.add_variable("q_place", 7, panda.lb, panda.ub)
bdr1.add_parameter("p_start", 7)
bdr1.add_parameter("p_goal", 7)

def debug_callback(x_dict):
    q = x_dict["q_pick"]
    # p_start = x_dict["p_start"]
    # p_goal = x_dict["p_goal"]
    # posevecs = jnp.vstack([p_start, p_goal])
    # points = jax.vmap(get_hand_pc, in_axes=(None,0))(grasp, posevecs)
    # pc_hands.reload(points=np.vstack(points))
    panda.set_joint_angles(q)
    time.sleep(0.1)
bdr1.set_debug_callback(debug_callback)

bdr1.register_fn("grasp_fn", [7, 7], 1,
                grasp_logit_q, jax.jacfwd(grasp_logit_q, argnums=[0, 1]))
bdr1.register_fn("grasp_fk", [7, 7], 3,
                 rot_diff, jax.jacfwd(rot_diff, argnums=[0, 1]))

bdr1.set_constr("grasp_fn", "grasp_fn", ["q", "p_start"], 
                   1.2, np.inf)
bdr1.set_constr("grasp_rot_diff", "grasp_fk", ["q", "p_start"], 
                   0., 0.)
# bdr1.set_constr("manip_pick", "manip_fn", ["g_pick", "p_start"],
#                    0.2, np.inf)
# bdr1.set_constr("manip_place", "manip_fn", ["g_pick", "p_goal"],
#                    0.2, np.inf)
# bdr1.set_constr("col", "col_fn", ["g_pick", "p_start", "p_goal"], 
#                    0.05, np.inf)

In [168]:
ipopt_qdirect = bdr1.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
ooooooo--------------
ooooooo--------------
ooooooo--------------
ooooooo--------------


In [170]:
x0= np.hstack([panda.neutral, obj_start.pose.parameters(), obj_goal.pose.parameters()])
sol, info = ipopt_qdirect.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:       21
Number of nonzeros in inequality constraint Jacobian.:        7
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       21
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        7
                     variables with only upper bounds:        0
Total number of equality constraints.................:        3
Total number of inequality constraints...............:        1
        inequality constraints with only lower bounds:        1
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 1.02e+03 1.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [134]:
rot_diff(sol[:7], obj_goal.pose.parameters())

Array([ 0.00130882, -0.00376484, -0.00073964], dtype=float32)

In [135]:
frame = Frame(world.vis, "frame")

In [150]:
q = sol[:7]
ee_pose = SE3(fk(q))
grasp_pose = obj_goal.pose.inverse() @ ee_pose
grasp = grasp_embedding(grasp_pose.translation())

grasp_rot = SO3(grasp_fn(grasp)[2:]).normalize()
diff = ee_pose.rotation().inverse() @ obj_goal.pose.rotation() @ grasp_rot
frame.set_pose(obj_goal.pose @ grasp_reconst(grasp))

SO3(wxyz=[-0.70840997 -0.00944    -0.70409995  0.04812   ])

SO3(wxyz=[-0.17557 -0.75187 -0.61258  0.16916])

In [146]:
frame.set_pose(ee_pose)

In [133]:
panda.set_joint_angles(sol[:7])

In [58]:
q = panda.neutral
panda.set_joint_angles(q)

In [73]:
logit, grad = jax.value_and_grad(grasp_logit_q)(q, obj_pose)
grad2 = jax.jacfwd(rot_diff)(q, obj_pose)
d = grad + grad2
q = q + d * 0.0001
panda.set_joint_angles(q)

AssertionError: 

In [57]:
q

Array([-7.3740196e-01, -5.5426568e-01, -7.3740196e-01, -1.9088933e+00,
       -2.1305643e-01,  1.3429551e+00,  4.3669466e-08], dtype=float32)

In [44]:
jax.jacrev(rot_diff)(q, obj_pose)

Array([[ 0.69071823, -0.7266253 ,  0.69071823,  0.74117786, -0.6574468 ,
         0.6273676 , -1.0224965 ],
       [-0.870051  ,  0.34477785, -0.870051  , -0.15474825, -0.9259343 ,
        -0.1754411 ,  0.2891816 ],
       [ 0.52501655,  1.1678255 ,  0.52501655, -1.0234425 ,  0.07265568,
        -0.95543647, -0.59444916]], dtype=float32)

In [32]:
SO3(grasp_fn(jnp.zeros(3))[2:]).normalize()

SO3(wxyz=[-0.63684    -0.41827    -0.63528997  0.12605   ])

In [16]:
def get_jac_indices(cdim, xdim, row_offset=0):
    rowcol = np.indices((cdim, xdim)).reshape(2,-1)
    rowcol[0] += row_offset
    return rowcol

#indexing
num_mid_configs = 4
num_traj = num_mid_configs*3 + 2
idx_pick = num_mid_configs
idx_place = idx_pick + num_mid_configs + 1
idxs_hold = np.arange(idx_pick+1, idx_pick+1+num_mid_configs)
cdim, rdim = num_traj, 7

robot_col_rowcols = []
for i in range(num_traj):
    jac_struct = np.zeros((cdim, rdim), dtype=int)
    jac_struct[i,:] = 1
    rowcol = np.nonzero(jac_struct)
    robot_col_rowcols.append(np.vstack(rowcol))
jac_travelled_distance = jax.grad(travelled_distance, 
                                  argnums=np.arange(num_traj))


# q_pick_grasp_rowcols = np.indices((len(idxs_hold), 7)).reshape(2,-1)
# obj_place_grasp_rowcols = np.indices((len(idxs_hold), 7)).reshape(2,-1)
obj_col_rowcols = [get_jac_indices(len(idxs_hold), 7), get_jac_indices(len(idxs_hold), 7)] #placeholder: no jac for q_pick, obj_pose_pick
for i in range(len(idxs_hold)):
    row_offset = i
    rowcol = np.indices((1, 7)).reshape(2,-1)
    rowcol[0] += row_offset
    obj_col_rowcols.append(rowcol)



In [35]:
bdr1 = SparseIPOPT()
bdr1.add_variable("g_pick", 3, -1., 1.)
bdr1.add_parameter("p_start", 7)
bdr1.add_parameter("p_goal", 7)

def debug_callback(x_dict):
    grasp = x_dict["g_pick"]
    p_start = x_dict["p_start"]
    p_goal = x_dict["p_goal"]
    posevecs = jnp.vstack([p_start, p_goal])
    points = jax.vmap(get_hand_pc, in_axes=(None,0))(grasp, posevecs)
    pc_hands.reload(points=np.vstack(points))
    time.sleep(0.1)
bdr1.set_debug_callback(debug_callback)

bdr1.register_fn("grasp_fn", [3], 1,
                          grasp_constr, jax.jacrev(grasp_constr, argnums=[0]))
bdr1.register_fn("manip_fn", [3, 7], 1,
                          manip_constr, jax.jacrev(manip_constr, argnums=[0,1]))
bdr1.register_fn("col_fn", [3, 7, 7], 2,
                          hand_col_constr, jax.jacfwd(hand_col_constr, argnums=[0,1,2]))

bdr1.set_constr("grasp_prob_pick", "grasp_fn", ["g_pick"], 
                   1., np.inf)
bdr1.set_constr("manip_pick", "manip_fn", ["g_pick", "p_start"],
                   0.2, np.inf)
bdr1.set_constr("manip_place", "manip_fn", ["g_pick", "p_goal"],
                   0.2, np.inf)
bdr1.set_constr("col", "col_fn", ["g_pick", "p_start", "p_goal"], 
                   0.05, np.inf)

In [36]:
bdr2 = SparseIPOPT()
bdr2.add_variable("q_pick", 7, panda.lb, panda.ub)
bdr2.add_variable("q_place", 7, panda.lb, panda.ub)
bdr2.add_parameter("g_pick", 3)
bdr2.add_parameter("p_start", 7)
bdr2.add_parameter("p_goal", 7)

def debug_callback(x_dict):
    q_place = x_dict["q_place"]
    panda.set_joint_angles(q_place)
    time.sleep(0.1)
bdr2.set_debug_callback(debug_callback)

bdr2.register_fn("kin_constr", [7, 3, 7], 6, 
                 kin_constr, jax.jacfwd(kin_constr,argnums=[0,1,2]))

bdr2.set_constr("kin_pick", "kin_constr", ["q_pick", "g_pick", "p_start"], 
                   0., 0.)
bdr2.set_constr("kin_place", "kin_constr", ["q_place", "g_pick", "p_goal"], 
                   0., 0.)

In [69]:
bdr3 = SparseIPOPT()

q_names = []
for i in range(num_traj):
    q_name = f"q{i}"
    bdr3.add_variable(q_name, 7, panda.lb, panda.ub)
    q_names.append(q_name)
q_hold_names = [name for i, name in enumerate(q_names) if i in idxs_hold]

# bdr3.add_variable("g_pick", 3, -1., 1.)
bdr3.add_parameter("p_start", 7)
bdr3.add_parameter("p_goal", 7)

# def debug_callback(x_dict):
#     q_place = x_dict[f"q{idx_place}"]
#     grasp = x_dict["g_pick"]
#     p_start = x_dict["p_start"]
#     p_goal = x_dict["p_goal"]
#     qs = [x_dict[q_name] for q_name in q_names]
#     qs = jnp.vstack(qs)
#     p_ees = jax.vmap(fk)(qs)[:,-3:]
#     posevecs = jnp.vstack([p_start, p_goal])
#     points = jax.vmap(get_hand_pc, in_axes=(None,0))(grasp, posevecs)
#     pc_hands.reload(points=np.vstack(points))
#     panda.set_joint_angles(q_place)
#     dotted_line.reload(points=p_ees)
#     time.sleep(0.1)
# bdr3.set_debug_callback(debug_callback)

kin_tol = 1e-4
# bdr3.register_fn("grasp_fn", [3], 1,
#                  grasp_constr, jax.jacrev(grasp_constr, argnums=[0]))
# bdr3.register_fn("kin_constr", [7, 3, 6], 6, 
#                  kin_constr, jax.jacfwd(kin_constr,argnums=[0,1,2]))
bdr3.register_fn("robot_grasp_fn", [7, 7], 1,
                 robot_grasp_constr, jax.jacrev(robot_grasp_constr,argnums=[0, 1]),
                 [get_jac_indices(1,7, row_offset=i) for i in range(2)])
bdr3.register_fn("robot_col", [7]*num_traj, num_traj, 
                 robot_col_constr, jac_robot_col_constr, robot_col_rowcols)
bdr3.register_fn("obj_col", [7,7]+[7]*len(idxs_hold), len(idxs_hold), 
                 obj_col_constr, jac_obj_col_constr, obj_col_rowcols)

bdr3.register_fn("travelled_dist", [7]*num_traj, 1,
                 travelled_distance, jac_travelled_distance)

bdr3.set_objective("travelled_dist", q_names)

bdr3.set_constr("grasp_prob_pick", "robot_grasp_fn", [f"q{idx_pick}", "p_start"], 
                   0.5, np.inf)
bdr3.set_constr("grasp_prob_place", "robot_grasp_fn", [f"q{idx_place}", "p_goal"], 
                   0.5, np.inf)
bdr3.set_constr("robot_col", "robot_col", q_names, safe_dist, np.inf)
bdr3.set_constr("obj_col", "obj_col", [f"q{idx_pick}", "p_start"]+q_hold_names, safe_dist, np.inf, [f"q{idx_pick}"])

In [70]:
jac_fn = bdr3.get_jacobian_fn()

In [66]:
constr = bdr3.c_info['obj_col']

In [67]:
constr.fn.jac_fn

<function __main__.jac_obj_col_constr(q_pick, obj_pose_at_pick, *qs)>

In [72]:
jac_fn(jnp.zeros(bdr3.xdim)).shape

(140,)

In [120]:
xs = {var.name:x0[var.coord] for var in bdr3.x_info.values()}
constr = bdr3.c_info['obj_col']
fn_input = [xs[var.name] for var in constr.inputs]    
indices_var = [i for i, var in enumerate(constr.inputs) 
                if isinstance(var, Variable)]
jacs = constr.fn.jac_fn(*fn_input)

In [136]:
robot_col_rowcols

[array([[0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[1, 1, 1, 1, 1, 1, 1],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[2, 2, 2, 2, 2, 2, 2],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[3, 3, 3, 3, 3, 3, 3],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[4, 4, 4, 4, 4, 4, 4],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[5, 5, 5, 5, 5, 5, 5],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[6, 6, 6, 6, 6, 6, 6],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[7, 7, 7, 7, 7, 7, 7],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[8, 8, 8, 8, 8, 8, 8],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[9, 9, 9, 9, 9, 9, 9],
        [0, 1, 2, 3, 4, 5, 6]]),
 array([[10, 10, 10, 10, 10, 10, 10],
        [ 0,  1,  2,  3,  4,  5,  6]]),
 array([[11, 11, 11, 11, 11, 11, 11],
        [ 0,  1,  2,  3,  4,  5,  6]]),
 array([[12, 12, 12, 12, 12, 12, 12],
        [ 0,  1,  2,  3,  4,  5,  6]]),
 array([[13, 13, 13, 13, 13, 13, 13],
        [ 0,  1,  2,  3,  4,  5,  6]])]

In [135]:
for i, var in enumerate(constr.inputs):
    

[Variable(name='q4', coord=array([28, 29, 30, 31, 32, 33, 34]), lb=array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]), ub=array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])),
 Parameter(name='p_start', coord=array([ 98,  99, 100, 101, 102, 103, 104]), dim=7),
 Variable(name='q5', coord=array([35, 36, 37, 38, 39, 40, 41]), lb=array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]), ub=array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])),
 Variable(name='q6', coord=array([42, 43, 44, 45, 46, 47, 48]), lb=array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]), ub=array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])),
 Variable(name='q7', coord=array([49, 50, 51, 52, 53, 54, 55]), lb=array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]), ub=array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])),
 Variable(name='q8', coord=array([56, 57, 58, 59, 60, 61, 62]), lb=array([-2.9671, 

In [112]:
print(jacs)
print(constr.jac_indices)

[[ 0.         -0.35800323 -0.00346175  0.27734634  0.12793577 -0.11073308
   0.02693846]
 [ 0.0835028  -0.04488284  0.0821012  -0.27521536 -0.09057803 -0.35462093
  -0.25156733]
 [ 0.09012468 -0.00871277  0.         -0.          0.         -0.
   0.        ]
 [ 0.08718382 -0.01029391  0.         -0.          0.         -0.
   0.        ]
 [ 0.08409116 -0.01180849  0.         -0.          0.         -0.
   0.        ]
 [-0.08048337  0.          0.          0.          0.         -0.
   0.        ]
 [-0.07493431 -0.30673409  0.05257253  0.45860153  0.00611077  0.0604114
   0.01441582]
 [ 0.         -0.269728    0.17018993  0.41250554  0.06623141  0.00349894
   0.04944601]
 [ 0.         -0.1913654   0.22639126  0.36471352  0.04119299 -0.03383972
   0.04983898]
 [ 0.3966969  -0.01367688  0.28387904 -0.10078335  0.01262022 -0.11999142
   0.        ]
 [ 0.00133374  0.          0.          0.          0.          0.
   0.        ]
 [-0.03660003  0.          0.          0.          0.         

In [130]:
jac_fn = bdr3.get_jacobian_fn()

In [132]:
jac_fn(x0).shape

(126,)

In [134]:
bdr3.get_jacobian_structure()[0].shape

(140,)

In [88]:
bdr3.c_info

{'grasp_prob_pick': Constraint(name='grasp_prob_pick', coord=array([0]), inputs=[Variable(name='q4', coord=array([28, 29, 30, 31, 32, 33, 34]), lb=array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]), ub=array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])), Parameter(name='p_start', coord=array([ 98,  99, 100, 101, 102, 103, 104]), dim=7)], fn=Function(name='robot_grasp_fn', in_dims=[7, 7], out_dim=1, eval_fn=<function robot_grasp_constr at 0x7fb457910af0>, jac_fn=<function robot_grasp_constr at 0x7fb43a887ca0>, custom_jac_indices=None, constraints=[..., Constraint(name='grasp_prob_place', coord=array([1]), inputs=[Variable(name='q9', coord=array([63, 64, 65, 66, 67, 68, 69]), lb=array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]), ub=array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])), Parameter(name='p_goal', coord=array([105, 106, 107, 108, 109, 110, 111]), dim=7)], fn=..., lb=array([0.5]), ub=array([inf]), jac_indices=[a

In [84]:
jac_obj_col_constr(jnp.zeros(7), jnp.zeros(7), *jnp.zeros((4,7)))

Array([[ 1.6489993e-09, -1.9850735e-01,  4.3310208e-10,  4.5141514e-02,
         0.0000000e+00,  3.6568429e-02,  0.0000000e+00],
       [ 1.6489993e-09, -1.9850735e-01,  4.3310208e-10,  4.5141514e-02,
         0.0000000e+00,  3.6568429e-02,  0.0000000e+00],
       [ 1.6489993e-09, -1.9850735e-01,  4.3310208e-10,  4.5141514e-02,
         0.0000000e+00,  3.6568429e-02,  0.0000000e+00],
       [ 1.6489993e-09, -1.9850735e-01,  4.3310208e-10,  4.5141514e-02,
         0.0000000e+00,  3.6568429e-02,  0.0000000e+00]], dtype=float32)

In [333]:
jac_fn = bdr3.get_jacobian_fn()

In [337]:
hess_fn = jax.hessian(travelled_distance)

In [64]:
ipopt_full = bdr3.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
----------------------------ooooooo-----------------------------------------------------------------------------
---------------------------------------------------------------ooooooo------------------------------------------
ooooooo---------------------------------------------------------------------------------------------------------
-------ooooooo--------------------------------------------------------------------------------------------------
--------------ooooooo-------------------------------------------------------------------------------------------
---------------------ooooooo------------------------------------------------------------------------------------
----------------------------ooooooo-----------------------------------------------------------------------------
-----------------------------------ooooooo----------------------------------------------------------------------


In [62]:
ipopt_grasp = bdr1.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
ooo--------------
ooo--------------
ooo--------------
ooo--------------
ooo--------------


In [63]:
ipopt_ik = bdr2.build()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
ooooooo------------------------
ooooooo------------------------
ooooooo------------------------
ooooooo------------------------
ooooooo------------------------
ooooooo------------------------
-------ooooooo-----------------
-------ooooooo-----------------
-------ooooooo-----------------
-------ooooooo-----------------
-------ooooooo-----------------
-------ooooooo-----------------


## Solve

In [66]:
grasp_init = np.random.uniform(-1, 1, size=3) #np.ones(3) #
p_obj_start = obj_start.pose.parameters()
p_obj_goal = obj_goal.pose.parameters()
x0 = np.hstack([grasp_init, p_obj_start, p_obj_goal])
grasp_sol, info = ipopt_grasp.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:       15
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       17
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:        5
        inequality constraints with only lower bounds:        5
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 2.80e+02 9.99e-01   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [67]:
grasp_init = grasp_sol[:3] #np.random.uniform(-1, 1, size=3)
qinit = panda.neutral
p_obj_start = obj_start.pose.parameters()
p_obj_goal = obj_goal.pose.parameters()
x0 = np.hstack([qinit, qinit, grasp_init, p_obj_start, p_obj_goal])
ik_sol, info = ipopt_ik.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:       84
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       31
                     variables with only lower bounds:        0
                variables with lower and upper bounds:       14
                     variables with only upper bounds:        0
Total number of equality constraints.................:       12
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 6.50e-01 0.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [68]:
#initialize
q0 = panda.neutral
q_pick, q_place  = ik_sol[:7], ik_sol[7:14]
waypoints = [q0, q_pick, q_place, q0]
traj_init = []
for i, q in enumerate(waypoints[:-1]):
    qs_ = np.linspace(q, waypoints[i+1], num_mid_configs+1, endpoint=False)
    traj_init.extend(qs_)
traj_init = traj_init[1:]

# grasp_init = grasp_sol[:3] #np.random.normal(size=3)
p_obj_start = obj_start.pose.parameters()
p_obj_goal = obj_goal.pose.parameters()
x0 = np.hstack([*traj_init, p_obj_start, p_obj_goal])

In [76]:
x0.shape

(112,)

In [82]:
ipopt_full.__jacobian(x0).shape

(133,)

In [81]:
ipopt_full.__jacobianstructure()[0].shape

(140,)

In [69]:
ipopt_full.add_option("acceptable_obj_change_tol", 0.01)
ipopt_full.add_option('mu_strategy', 'adaptive')
full_sol, info = ipopt_full.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:      140
Number of nonzeros in Lagrangian Hessian.............:        0



IndexError: Out of bounds on buffer access (axis 0)

Exception ignored in: 'ipopt_wrapper.jacobian_cb'
Traceback (most recent call last):
  File "/var/folders/_y/jwsg7ft16tq6wks_bwh9tkvh0000gn/T/ipykernel_28090/942793123.py", line 3, in <module>
IndexError: Out of bounds on buffer access (axis 0)


Error evaluating Jacobian of inequality constraints at user provided starting point.
  No scaling factors for inequality constraints computed!


IndexError: Out of bounds on buffer access (axis 0)

Exception ignored in: 'ipopt_wrapper.jacobian_cb'
Traceback (most recent call last):
  File "/var/folders/_y/jwsg7ft16tq6wks_bwh9tkvh0000gn/T/ipykernel_28090/942793123.py", line 3, in <module>
IndexError: Out of bounds on buffer access (axis 0)



Number of Iterations....: 0

Number of objective function evaluations             = 0
Number of objective gradient evaluations             = 0
Number of equality constraint evaluations            = 0
Number of inequality constraint evaluations          = 1
Number of equality constraint Jacobian evaluations   = 0
Number of inequality constraint Jacobian evaluations = 1
Number of Lagrangian Hessian evaluations             = 0
Total seconds in IPOPT                               = 0.022

EXIT: Invalid number in NLP function or derivative detected.


In [303]:
traj = full_sol[:7*num_traj].reshape(-1, 7)
traj = [panda.neutral, *traj, panda.neutral]
grasp_pose = grasp_reconst(full_sol[bdr3.x_info["g_pick"].coord])
obj_pose = obj_start.pose
obj.set_pose(obj_pose)
panda.set_joint_angles(panda.neutral)
i = 0

In [319]:
if i-1 in [4, *idxs_hold, 9]:
    q = traj[i]
    obj_pose = SE3(fk(q)) @ grasp_pose.inverse()
    obj.set_pose(obj_pose)
panda.set_joint_angles(traj[i])
i += 1

In [151]:
q = full_sol[10:17]

In [160]:
fks = panda_model.fk_fn(jnp.hstack([q, 0.04, 0.04]))
link_frames = fks[1:-1]
assigned_points = jax.vmap(fk_assign)(link_frames, links_points_mat)
assigned_points = jnp.vstack(assigned_points)

In [99]:
grasp_init = grasp_sol[:3] #np.random.uniform(-1, 1, size=3)
qinit = panda.neutral
p_obj_start = obj_start.pose.parameters()
p_obj_goal = obj_goal.pose.parameters()
x0 = np.hstack([qinit, qinit, grasp_init, p_obj_start, p_obj_goal])
ik_sol, info = ipopt_ik.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:       84
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       31
                     variables with only lower bounds:        0
                variables with lower and upper bounds:       14
                     variables with only upper bounds:        0
Total number of equality constraints.................:       12
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 6.85e-01 0.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [84]:
q_pick = ik[:7]
q_place = ik[7:14]

In [86]:
panda.set_joint_angles(q_place)

In [38]:
obj_poses = jnp.vstack([jnp.zeros(6), jnp.zeros(6)])
pcs = jax.vmap(get_hand_pc, in_axes=(None,0))(jnp.zeros(3), obj_poses)
distances = env.distances(jnp.vstack(pcs)).reshape(2, -1)
distances.min(axis=-1)

Array([0.2933615, 0.2933615], dtype=float32)