In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle
import time

### Load Models

In [2]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net_posevec")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]
#manip net
manip_net = ManipNet(64)
manip_fn = lambda x: manip_net.apply(restored_manip["params"], x)[0]

### Load world

In [3]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [4]:
# robot, hand
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)
hand = Robot(world.vis, "hand1", hand_model, color="yellow", alpha=0.5)

In [5]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)

In [6]:
table_start.set_translate([0.5, -0.3, 0.2/2])
table_goal.set_translate([0.5, 0.3, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.5, -0.3, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.5, 0.3, obj_lengths[-2]/2+table_lengths[-1]])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)

In [7]:
to_posevec = lambda x: jnp.hstack([x[4:], SO3(x[:4]).log()])
to_wxyzxyz = lambda x: jnp.hstack([SO3.exp(x[3:]).parameters(), x[:3]])
env = SDFContainer([table_start, table_goal], 0.05)
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)

#visualization
hand_pc = hand.get_surface_points_fn(jnp.array([0.04, 0.04]))
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
@jax.jit
def get_hand_pc(grasp, posevec):
    grasp_pose = grasp_reconst(grasp)
    hand_base_pose_wrt_world = SE3(to_wxyzxyz(posevec)) @ grasp_pose @ hand_pose_wrt_ee
    assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
    return assigned_hand_pc

pc = PointCloud(world.vis, "hand_pc", np.zeros((100,3)), color="red")

In [8]:
# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

### Prepare functions

In [9]:
#constr fns
def grasp_constr_fn(grasp):
    return grasp_logit_fn(grasp)

def manip_constr_fn(grasp, posevec):
    obj_pose = SE3(to_wxyzxyz(posevec))
    grasp_pose = obj_pose @ grasp_reconst(grasp)
    zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
    grasp_pose_flip = grasp_pose @ zflip
    posevecs = [to_posevec(pose.parameters()) for pose in [grasp_pose, grasp_pose_flip]]
    return jax.vmap(manip_fn)(jnp.vstack(posevecs)).max()

def dist_constr_fn(g1, posevec_st, posevec_ed):
    grasps = jnp.vstack([g1, g1])
    obj_poses = jnp.vstack([posevec_st, posevec_ed])
    pcs = jax.vmap(get_hand_pc, (0,0))(grasps, obj_poses)
    distances = env.distances(jnp.vstack(pcs))
    top4_indices = jnp.argpartition(distances, 4)[:4]
    return distances[top4_indices]

# def dist_constr_fn2(g1, g2, posevec_ho, posevec_st, posevec_ed):
#     grasps = jnp.vstack([g1, g1, g2, g2])
#     obj_poses = jnp.vstack([posevec_st, posevec_ho, posevec_ho, posevec_ed])
#     pcs = jax.vmap(get_hand_pc, (0,0))(grasps, obj_poses)
#     distances = env.distances(jnp.vstack(pcs))
#     top4_indices = jnp.argpartition(distances, 4)[:4]
#     return distances[top4_indices]

grasp_constr_jac_fns = [jax.jacrev(grasp_constr_fn)]
manip_constr_jac_fns = [jax.jacrev(manip_constr_fn, argnums=argnums) for argnums in [0,1]]
dist_constr_jac_fns = [jax.jacrev(dist_constr_fn)] + [None]*2

ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])

In [10]:
import time
builder = SparseIPOPTBuilder()

builder.add_variable("g_pick", 3, -1., 1.)
builder.add_parameter("p_start", to_posevec(obj_start.pose.parameters()))
builder.add_parameter("p_goal", to_posevec(obj_goal.pose.parameters()))

builder.register_fn("grasp_constr", [3], 1,
                          grasp_constr_fn, grasp_constr_jac_fns)
builder.register_fn("manip_constr", [3, 6], 1,
                          manip_constr_fn, manip_constr_jac_fns)
builder.register_fn("dist_constr", [3, 6, 6], 4,
                          dist_constr_fn, dist_constr_jac_fns)

def debug_obj(g_pick, posevec_st, posevec_ed):
    grasps = jnp.vstack([g_pick, g_pick])
    obj_poses = jnp.vstack([posevec_st, posevec_ed])
    pcs = jax.vmap(get_hand_pc, (0,0))(grasps, obj_poses)
    pc.reload(points=np.vstack(pcs))
    time.sleep(0.1)
    return 0.
builder.register_fn("debug_obj", [3, 6, 6], 1,
                    debug_obj, [None, None, None])

builder.add_objective(["g_pick", "p_start", "p_goal"], "debug_obj")

builder.add_constr("grasp_prob", ["g_pick"], "grasp_constr",
                   1., np.inf)
builder.add_constr("manip_pick", ["g_pick", "p_start"], "manip_constr",
                   0.3, np.inf)
builder.add_constr("manip_place", ["g_pick", "p_goal"], "manip_constr",
                   0.3, np.inf)
builder.add_constr("dist", ["g_pick", "p_start", "p_goal"], "dist_constr",
                   0.05, np.inf)

builder.freeze()

In [11]:
ipopt = builder.build(compile_obj=False)


Sparsity pattern:
00 01 02 
03 04 05 
06 07 08 
09 10 11 
12 13 14 
15 16 17 
18 19 20 


In [393]:
xsol, info = ipopt.solve(-np.ones(3))

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:       21
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:        3
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:        7
        inequality constraints with only lower bounds:        7
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 2.70e+02 1.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [394]:
grasp = xsol

In [14]:
frame = Frame(world.vis, "frame")

In [15]:
frame2 = Frame(world.vis, "frame2")

In [395]:
panda.set_joint_angles(panda.neutral)

In [17]:
zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
def kin_error_fn(q, grasp, p_obj):
    posevec_ee, jac = get_ee_fk_jac(q)
    obj_pose = SE3(to_wxyzxyz(p_obj))
    grasp_pose = obj_pose @ grasp_reconst(grasp) @ zflip
    grasp_posevec = to_posevec(grasp_pose.parameters())
    err = grasp_posevec - posevec_ee
    return err

def kin_error_jac_fn(q, grasp, p_obj):
    _, jac = get_ee_fk_jac(q)
    return - jac

def debug_obj2(q, grasp, p_obj):
    panda.set_joint_angles(q)
    posevec_ee, jac = get_ee_fk_jac(q)
    obj_pose = SE3(to_wxyzxyz(p_obj))
    frame2.set_pose(SE3(to_wxyzxyz(posevec_ee)))
    grasp_pose = obj_pose @ grasp_reconst(grasp) @ zflip
    frame.set_pose(grasp_pose)
    #pc.reload(points=np.vstack(pcs))
    time.sleep(0.1)
    return 0.

builder2 = SparseIPOPTBuilder()

builder2.add_variable("q_pick", 7, panda.lb, panda.ub)
#builder2.add_variable("q_place", 7, panda.lb, panda.ub)
builder2.add_parameter("p_start", to_posevec(obj_start.pose.parameters()))
#builder2.add_parameter("p_goal", to_posevec(obj_goal.pose.parameters()))
builder2.add_parameter("grasp", xsol)

builder2.register_fn("kin_constr", [7, 3, 6], 6,
                          kin_error_fn, [kin_error_jac_fn, None, None])
builder2.register_fn("obj_debug", [7, 3, 6], 1,
                          debug_obj2, [None, None, None])
builder2.add_objective(["q_pick", "grasp", "p_start"], "obj_debug")
builder2.add_constr("kin", ["q_pick", "grasp", "p_start"], "kin_constr",
                   0., 0.)

In [18]:
builder2.freeze()
ipopt2 = builder2.build(compile_obj=False)


Sparsity pattern:
00 01 02 03 04 05 06 
07 08 09 10 11 12 13 
14 15 16 17 18 19 20 
21 22 23 24 25 26 27 
28 29 30 31 32 33 34 
35 36 37 38 39 40 41 


In [456]:
import PyCeres

In [457]:
class IK(PyCeres.CostFunction):
    def __init__(self, grasp):
        super().__init__()
        self.set_num_residuals(6)
        self.set_parameter_block_sizes([7])
        target_pose = (obj_start.pose @ grasp_reconst(grasp))
        target_posevec = to_posevec(target_pose.parameters())
        self.ik_target = target_posevec
    
    def Evaluate(self, parameters, residuals, jacobians):
        #parameters = [q]
        #residuals = [err]
        #jacobians = [jac_q]
        q = parameters[0]
        ee_posevec, jac = get_ee_fk_jac(q)
        err = self.target_posevec - ee_posevec
        if (jacobians != None):
            jacobians[0][:] = jac.flatten()
        residuals[0][:] = err
        panda.set_joint_angles(q)
        time.sleep(0.1)
        return True    

In [459]:
feat_ik = IK(grasp)

In [460]:
feat_ik.ik_target

Array([ 0.496085  , -0.25153726,  0.4085922 ,  2.095177  ,  2.2025557 ,
        0.19674964], dtype=float32)

In [396]:
q = panda.neutral.copy()


panda.set_joint_angles(panda.neutral)

In [454]:
ee, jac = get_ee_fk_jac(q)
err = target_posevec - ee
grad = - jac.T @ err
hess = jac.T@jac
d = - jnp.linalg.pinv(hess) @ grad

q = q + d*0.1

panda.set_joint_angles(q)

In [19]:
qsol, info = ipopt2.solve(panda.neutral)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:       42
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:        7
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        7
                     variables with only upper bounds:        0
Total number of equality constraints.................:        6
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 2.39e+00 0.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

KeyboardInterrupt: 

SE3(wxyz=[ 0.00842    -0.68869996 -0.72375    -0.04238   ], xyz=[-0.00169     0.05138     0.10241999])

In [19]:
x = jnp.zeros(3)
zip_x_ndo = zip(builder.xnames, builder.xdims, builder.xoffsets)
input_dict = {name:x[offset:offset+dim] 
                    for name, dim, offset in zip_x_ndo}
input_dict.update(builder.param_dict)
input_names = [x.name for x in builder.obj_info.inputs]
inputs = [input_dict[name] for name in input_names]

In [16]:
list(zip_x_ndo)

[('g_pick', 3, 0)]

In [17]:
[x.name for x in builder.obj_info.inputs]

['g_pick', 'p_start', 'p_goal']

In [43]:
builder = SparseIPOPTBuilder()

builder.add_variable("g_pick", 3, -1., 1.)
builder.add_variable("g_place", 3, -1., 1.)
builder.add_variable("p_ho", 6, ws_lb, ws_ub)
builder.add_parameter("p_start", to_posevec(obj_start.pose.parameters()))
builder.add_parameter("p_goal", to_posevec(obj_goal.pose.parameters()))

builder.register_fn("grasp_constr", [3], 1,
                          grasp_constr_fn, grasp_constr_jac_fns)
builder.register_fn("manip_constr", [3, 6], 1,
                          manip_constr_fn, manip_constr_jac_fns)
builder.register_fn("dist_constr", [3, 3, 6, 6, 6], 4,
                          dist_constr_fn, dist_constr_jac_fns)

builder.add_constr("grasp_prob_pick", ["g_pick"], "grasp_constr",
                   1., np.inf)
builder.add_constr("grasp_prob_place", ["g_place"], "grasp_constr",
                   1., np.inf)
builder.add_constr("manip_pick", ["g_pick", "p_start"], "manip_constr",
                   0.3, np.inf)
builder.add_constr("manip_place", ["g_place", "p_goal"], "manip_constr",
                   0.3, np.inf)
builder.add_constr("manip_ho_1", ["g_pick", "p_ho"], "manip_constr",
                   0.3, np.inf)
builder.add_constr("manip_ho_2", ["g_place", "p_ho"], "manip_constr",
                   0.3, np.inf)
builder.add_constr("dist", ["g_pick", "g_place", "p_ho", "p_start", "p_goal"], "dist_constr",
                   0.05, np.inf)
builder.freeze()

In [45]:
ipopt = builder.build(True)


Sparsity pattern:
00 01 02 -- -- -- -- -- -- -- -- -- 
-- -- -- 03 04 05 -- -- -- -- -- -- 
06 07 08 -- -- -- -- -- -- -- -- -- 
-- -- -- 09 10 11 -- -- -- -- -- -- 
12 13 14 -- -- -- 18 19 20 21 22 23 
-- -- -- 15 16 17 24 25 26 27 28 29 
30 31 32 42 43 44 54 55 56 57 58 59 
33 34 35 45 46 47 60 61 62 63 64 65 
36 37 38 48 49 50 66 67 68 69 70 71 
39 40 41 51 52 53 72 73 74 75 76 77 


In [None]:
x = [np.zeros(3), np.zer]
ipopt.solve()

In [17]:
grasp = jnp.zeros(3)
posevec = jnp.zeros(6)



Array([0.4810451, 0.5339005], dtype=float32)

In [23]:
dist_constr_fn(grasp, grasp, posevec, posevec, posevec)

Array([0.26522437, 0.26522437, 0.26522437, 0.26522437], dtype=float32)