In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import cyipopt

from functools import partial
from typing import *
from dataclasses import dataclass, field
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

import time

In [2]:
class PandaHand:
    def __init__(self, hand_model, name="hand"):
        self.model = hand_model
        self.robot = Robot(world.vis, name, hand_model, color="white", alpha=0.5, visualized_mesh="visual")
        self.hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
        self.hand_pc = self.robot.get_surface_points_fn(jnp.array([0.04, 0.04]))
    
    def get_bounding_box(self, name):
        fks = self.model.fk_fn(jnp.array([0.04, 0.04]))
        hand_points = []
        for i, link in enumerate(self.hand_model.links.values()):
            pts = jax.vmap(SE3(fks[i]).apply)(link.collision_mesh.vertices)
            hand_points.append(pts)
        hand_points = np.vstack(hand_points, dtype=float)
        min_points = hand_points.min(axis=0)
        max_points = hand_points.max(axis=0)
        extents = max_points - min_points
        center = (max_points + min_points) / 2
        box = Box(world.vis, name, extents, alpha=0.5, visualize=False)
        box.set_translate(center)
        return box
    
    def set_pose(self, pose):
        self.robot.set_pose(pose @ self.hand_pose_wrt_ee)

In [3]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [4]:
# Load robot
hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)
hand1 = PandaHand(hand_model, "hand1")
hand2 = PandaHand(hand_model, "hand2")

In [None]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obstacle = Box(world.vis, "obstacle", [0.4, 0.2, 0.35], 'white', 0.5)

obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)
obj_ho = Mesh(world.vis, "obj_ho", 
                "./sdf_world/assets/object/mesh.obj",
                color="yellow", alpha=0.5)
obj = Mesh(world.vis, "obj", 
                "./sdf_world/assets/object/mesh.obj",
                color="white", alpha=0.8)

In [None]:
ydev = 0.4
# panda1.set_translate([0,-ydev,0])
# panda2.set_translate([0,ydev,0])

table_start.set_translate([0.45, -ydev, 0.2/2])
table_goal.set_translate([0.45, ydev, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.45, 0.4, obj_lengths[-2]/2+table_lengths[-1]])
obstacle.set_translate([0.45, 0., 0.35/2])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)
obj.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
obj_ho.set_translate([0.3, 0., 0.6])

In [None]:
hand1 = PandaHand(hand_model, "hand1")
hand_pc = hand1.hand_pc
safe_dist = 0.05
env = SDFContainer([table_start, table_goal, obstacle], safe_dist)

In [5]:
# models 
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]

#manip net: input: wxyzxyz
manip_net = ManipNet(64)
zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
def manip_fn(pose_like:Array, robot_pose:Array):
    grasp_pose_world = SE3(pose_like).normalize()
    grasp_pose_robot = SE3(robot_pose).inverse() @ grasp_pose_world
    grasp_pose_robot_flip = grasp_pose_robot @ zflip
    inputs = jnp.vstack([
        grasp_pose_robot.parameters(), 
        grasp_pose_robot_flip.parameters() ])
    return jax.vmap(manip_net.apply, (None, 0))(
        restored_manip["params"], inputs).flatten()

2023-08-09 14:24:07.369374: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-08-09 14:24:07.369415: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 515.105.1


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [8]:
#functions
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)
def grasp_embedding(grasp_point):
    grasp = grasp_point * restored_grasp["scale_to_norm"]
    return grasp

In [9]:
frame1 = Frame(world.vis, "frame1")
frame2 = Frame(world.vis, "frame2")

In [10]:
def show_grasps(pose_ho, grasp1, grasp2):
    hand1.set_pose(SE3(pose_ho) @ grasp_reconst(grasp1))
    hand2.set_pose(SE3(pose_ho) @ grasp_reconst(grasp2))    

In [11]:
grasp1 = np.random.normal(size=3)*0.5
grasp2 = np.random.normal(size=3)*0.5

#show_grasps(obj_pose)

In [12]:
pose_r1 = SE3.from_translation(jnp.array([0, -0.4, 0])).parameters()
pose_r2 = SE3.from_translation(jnp.array([0, 0.4, 0])).parameters()
frame1.set_pose(pose_r1)
frame2.set_pose(pose_r2)

In [13]:
pose_g1w = (obj_ho.pose @ grasp_reconst(grasp1)).parameters()
pose_g2w = (obj_ho.pose @ grasp_reconst(grasp2)).parameters()

In [163]:
def manip_constr(qtn_like, pos, g1, g2):
    pose_ho = SE3(jnp.hstack([qtn_like, pos])).normalize()
    pose_g1w = (pose_ho @ grasp_reconst(g1)).parameters()
    pose_g2w = (pose_ho @ grasp_reconst(g2)).parameters()
    return manip_fn(pose_g1w, pose_r1).max() + \
           manip_fn(pose_g2w, pose_r2).max()    
jac_manip_constr = jax.grad(manip_constr, argnums=[0])

In [164]:
def qtn_constr(qtn_like):
    return jnp.sum(qtn_like**2) - 1
jac_qtn_constr = jax.grad(qtn_constr, argnums=[0])

In [16]:
def split_qtn_pos(pose):
    wxyzxyz = pose.parameters()
    return wxyzxyz[:4], wxyzxyz[-3:]

In [17]:
qtn, pos = split_qtn_pos(obj_ho.pose)

In [165]:
def hash_array(x: Union[Array, np.ndarray]):
    return hash((x.shape, tuple(np.asarray(x).flat)))

class VGWrapper:
    def __init__(self, eval_fn:Callable):
        self.grads = {}
        self.eval_fn = eval_fn

    def reset(self):
        self.grads = {}

    def get_value(self, x):
        val, grad = self.eval_fn(x)
        self.grads[hash_array(x)] = grad
        return val
    
    def get_grad(self, x):
        key = hash_array(x)
        if key in self.grads:
            return self.grads[key]
        _, grad = self.eval_fn(x)
        return grad

In [166]:
class QtnOpt:
    def __init__(self):
        self.a = 1
        bdr_qtn = SparseIPOPT()
        bdr_qtn.add_variable("qtn_ho", 4, -1., 1.)
        bdr_qtn.add_parameter("pos", 3)
        bdr_qtn.add_parameter("grasp1", 3)
        bdr_qtn.add_parameter("grasp2", 3)

        bdr_qtn.register_fn("manip_qtn", [4, 3, 3, 3], 1,
            manip_constr, jac_manip_constr, jac_out_argnums=[0])
        bdr_qtn.register_fn("qtn_constr", [4], 1, qtn_constr, jac_qtn_constr)

        bdr_qtn.set_objective("manip_qtn", ["qtn_ho", "pos", "grasp1", "grasp2"])
        bdr_qtn.set_constr("qtn_constr", "qtn_constr", ["qtn_ho"], 0., 0.)
        self.ipopt_qtn = bdr_qtn.build()
        self.ipopt_qtn.add_option("acceptable_obj_change_tol", 0.01)

        def get_sensitivity(qtn_star, pos, grasp1, grasp2):
            grad_x_fn = jax.jit(jax.grad(manip_constr, argnums=[1,2,3]))
            grad_z_fn = jax.jit(jax.grad(manip_constr))
            grad_zz_fn = jax.jit(jax.hessian(manip_constr))
            grad_zx_fn = jax.jit(jax.jacfwd(grad_z_fn, argnums=[1,2,3]))

            grad_x = jnp.hstack(grad_x_fn(qtn_star, pos, grasp1, grasp2))
            grad_z = grad_z_fn(qtn_star, pos, grasp1, grasp2)
            grad_zz = grad_zz_fn(qtn_star, pos, grasp1, grasp2)
            grad_zx = jnp.hstack(grad_zx_fn(qtn_star, pos, grasp1, grasp2))
            dzdx = jnp.linalg.solve(grad_zz, -grad_zx)
            return grad_x, grad_z, dzdx
        self.get_sensitivity_fn = jax.jit(get_sensitivity)
    
    def __call__(self, qtn, pos, grasp1, grasp2):
        x0 = np.hstack([qtn, pos, grasp1, grasp2])
        sol, _ = self.ipopt_qtn.solve(x0)
        qtn_star = sol[:4]
        
        grad_x, grad_z, dzdx = self.get_sensitivity_fn(
            qtn_star, pos, grasp1, grasp2)
        manip = manip_constr(qtn_star, pos, grasp1, grasp2)
        return qtn_star, grad_x + grad_z @ dzdx

In [167]:
qtn_opt = QtnOpt()

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
oooo---------


In [177]:
pose_o0 = obj_start.pose.parameters()

In [175]:
obj_start_pose = obj_start.pose.parameters()


SE3(wxyz=[1. 0. 0. 0.], xyz=[ 0.45       -0.39999998  0.30668998])

In [187]:
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
def get_hand_pc(grasp, obj_pose):
    grasp_pose = grasp_reconst(grasp)
    hand_base_pose_wrt_world = SE3(obj_pose) @ grasp_pose @ hand_pose_wrt_ee
    assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
    return assigned_hand_pc

In [197]:
def total_grasp_constr(grasp, obj_pose, robot_pose):
    assigned_hand_pc = get_hand_pc(grasp, obj_pose)
    grasp_pose_w = SE3(obj_pose) @ grasp_reconst(grasp)
    logit = grasp_logit_fn(grasp)
    manip = manip_fn(grasp_pose_w.parameters(), robot_pose).max()
    min_dist = env.distances(assigned_hand_pc).min()
    return jnp.hstack([logit, manip, min_dist])
jac_total_grasp_constr = jax.jacfwd(total_grasp_constr)

In [198]:
jac_total_grasp_constr(grasp2, pose_o0, pose_r1)

Array([[ 1.0604657e+02, -1.4969180e+02,  1.2857162e+01],
       [ 1.2652088e+00,  7.4033326e-01,  5.9606463e-02],
       [ 9.6904531e-02,  6.4486697e-02,  2.1396719e-01]], dtype=float32)

In [193]:
pc_hand = PointCloud(world.vis, "hand_pc", np.zeros((10,3)), 0.01, "red")

In [196]:
assigned_hand_pc = get_hand_pc(grasp2, pose_o0)
pc_hand.reload(points=assigned_hand_pc)

In [172]:
jax.grad(grasp_logit_fn)(grasp1)

Array([ 177.14746   ,    0.27306747, -255.30017   ], dtype=float32)

In [170]:
qtn_opt(qtn, pos, grasp1, grasp2)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        4
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       13
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        4
                     variables with only upper bounds:        0
Total number of equality constraints.................:        1
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  7.7611899e-01 1.99e-02 9.32e-01   0.0 0.00e+00    -  0.00e+00 0.00e+00 

(array([ 0.90431294,  0.08678928, -0.33545722,  0.25237443]),
 Array([ 1.7293658 ,  0.88006276,  1.3888981 ,  0.49285924, -0.24766336,
        -0.46177542,  0.4950693 ,  0.19820543,  0.3275724 ], dtype=float32))

In [None]:
hess_fn = jax.hessian(manip_constr)
def grad_tx_fn(qtn_star, pos, grasp1, grasp2):
    fn = jax.jacfwd(jax.jacfwd(manip_constr), argnums=[1,2,3])
    return jnp.hstack(fn(qtn_star, pos, grasp1, grasp2))

In [None]:
hess_zz = hess_fn(qtn_star, pos, grasp1, grasp2)
jac_zx = grad_tx_fn(qtn_star, pos, grasp1, grasp2)
dzdx = jnp.linalg.solve(hess_zz, -jac_zx)
x_total_deriv = xgrads + zgrad @ dzdx

In [None]:
xgrads = jax.grad(manip_constr, argnums=[1,2,3])(qtn_star, pos, grasp1, grasp2)
xgrads = jnp.hstack(xgrads)
zgrad = jax.grad(manip_constr, argnums=0)(qtn_star, pos, grasp1, grasp2)

In [130]:
qtn_opt = QtnOpt()

In [132]:
qtn_opt(3.)

1


In [18]:


qtn_sol, info = ipopt_qtn.solve(x0)

In [119]:
vg_grasp_logit_fn = jax.value_and_grad(grasp_logit_fn)
grasp_logit_fn_ = FnValueAndJac(vg_grasp_logit_fn)

In [124]:
x = np.zeros(3)

In [26]:
qtn = bdr_qtn.split_solution(qtn_sol)['qtn_ho']
pos = bdr_qtn.split_solution(qtn_sol)['pos']

In [37]:
show_grasps(jnp.hstack([qtn_sol[:4], pos]), grasp1, grasp2)
obj_ho.set_pose(SE3(jnp.hstack([qtn_sol[:4], pos])))

In [36]:
qtn = SO3.identity().parameters() # qtn_sol[:4]
pos = obj_ho.pose.parameters()[-3:] + np.array([0., -0.01, 0])
x0 = np.hstack([qtn, pos, grasp1, grasp2])


This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        4
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       13
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        4
                     variables with only upper bounds:        0
Total number of equality constraints.................:        1
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  7.7611899e-01 1.99e-02 9.32e-01   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [33]:
qtn_star = qtn_sol[:4]

In [38]:
xgrads = jax.grad(manip_constr, argnums=[1,2,3])(qtn_star, pos, grasp1, grasp2)
xgrads = jnp.hstack(xgrads)
zgrad = jax.grad(manip_constr, argnums=0)(qtn_star, pos, grasp1, grasp2)

In [48]:
hess_fn = jax.hessian(manip_constr)
def grad_tx_fn(qtn_star, pos, grasp1, grasp2):
    fn = jax.jacfwd(jax.jacfwd(manip_constr), argnums=[1,2,3])
    return jnp.hstack(fn(qtn_star, pos, grasp1, grasp2))

In [83]:
hess_zz = hess_fn(qtn_star, pos, grasp1, grasp2)
jac_zx = grad_tx_fn(qtn_star, pos, grasp1, grasp2)
dzdx = jnp.linalg.solve(hess_zz, -jac_zx)
x_total_deriv = xgrads + zgrad @ dzdx

In [86]:
x_total_deriv = xgrads + zgrad @ dzdx

Array([-3.4831464e-07, -4.0116720e-07, -1.6950071e-07, -1.4435500e-07,
       -1.2572855e-08, -3.8184226e-08,  1.8626451e-07, -9.5926225e-08,
       -2.0256266e-08], dtype=float32)

In [55]:
jac_zx.shape

(4, 9)