In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import cyipopt

from functools import partial
from typing import *
from dataclasses import dataclass, field
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

import time

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7005/static/


In [3]:
PREDEFINED_ROBOTS

['gen3+hand_e', 'panda+panda_hand']

In [4]:
gen3 = get_predefined_robot(world.vis, "gen3+hand_e")
panda = get_predefined_robot(world.vis, "panda+panda_hand")

In [5]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obstacle = Box(world.vis, "obstacle", [0.4, 0.2, 0.35], 'white', 0.5)

obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)
obj_ho = Mesh(world.vis, "obj_ho", 
                "./sdf_world/assets/object/mesh.obj",
                color="yellow", alpha=0.5)
obj = Mesh(world.vis, "obj", 
                "./sdf_world/assets/object/mesh.obj",
                color="white", alpha=0.8)

In [6]:
def SE3_trans(xyz):
    return SE3.from_translation(jnp.array(xyz))

ydev = 0.4
gen3_base_pose = SE3_trans([0, -ydev, 0])
panda_base_pose = SE3_trans([0, ydev, 0])
gen3.set_base_pose(gen3_base_pose)
panda.set_base_pose(panda_base_pose)

table_start.set_translate([0.45, -ydev, 0.2/2])
table_goal.set_translate([0.45, ydev, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.45, 0.4, obj_lengths[-2]/2+table_lengths[-1]])
obstacle.set_translate([0.45, 0., 0.35/2])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)
obj.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
obj_ho.set_translate([0.3, 0., 0.6])

In [7]:
hande = Gripper(world.vis, "hande", gen3.gripper.model, 
                gen3.gripper.tool_pose_offset, max_width=0.05, scale=0.001)
panda_hand = Gripper(world.vis, "panda_hand", panda.gripper.model, 
                    panda.gripper.tool_pose_offset, max_width=0.08, is_rev_type=False)

In [11]:
from sdf_world.network import *
# learned model
# models 
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip_gen3 = orbax_checkpointer.restore("model/manip_net_gen3")
restored_manip_panda = orbax_checkpointer.restore("model/manip_net")

#grasp net
grasp_net = GraspNet(32, 6)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]

#functions
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)
def grasp_embedding(grasp_point):
    grasp = grasp_point * restored_grasp["scale_to_norm"]
    return grasp

#manip net: input: wxyzxyz
manip_net = ManipNet(64)
zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
def manip_gen3(grasp_pose_world:SE3, robot_pose:SE3):
    grasp_pose_robot = robot_pose.inverse() @ grasp_pose_world
    grasp_pose_robot_flip = grasp_pose_robot @ zflip
    inputs = jnp.vstack([
        grasp_pose_robot.parameters(), 
        grasp_pose_robot_flip.parameters() ])
    return jax.vmap(manip_net.apply, (None, 0))(
        restored_manip_gen3["params"], inputs).flatten().max()

def manip_panda(grasp_pose_world:SE3, robot_pose:SE3):
    grasp_pose_robot = robot_pose.inverse() @ grasp_pose_world
    grasp_pose_robot_flip = grasp_pose_robot @ zflip
    inputs = jnp.vstack([
        grasp_pose_robot.parameters(), 
        grasp_pose_robot_flip.parameters() ])
    return jax.vmap(manip_net.apply, (None, 0))(
        restored_manip_panda["params"], inputs).flatten().max()

In [23]:
frame1 = Frame(world.vis, "frame1")
frame2 = Frame(world.vis, "frame2")

In [32]:
#prob
g1 = np.random.uniform(-1, 1, size=3)
g2 = np.random.uniform(-1, 1, size=3)

hande.set_tool_pose(obj_ho.pose@grasp_reconst(g1))
panda_hand.set_tool_pose(obj_ho.pose@grasp_reconst(g2))

In [33]:
gen3_neutral = jnp.array([0,0,0.,np.pi/3,0,np.pi/2,0.2])
panda_neutral = (panda.arm.model.ub + panda.arm.model.lb)/2

In [36]:

gen3.set_joint_angles(q1)
panda.set_joint_angles(q2)

In [38]:
fk_ee_gen3 = gen3.get_fk_ee_fn()
fk_ee_panda = panda.get_fk_ee_fn()

In [60]:
#q1, q2, p
q1, q2 = gen3_neutral, panda_neutral
p = jnp.array([0.3, 0. , 0.6])

obj_frame = SE3.from_translation(p)
tcp_gen3 = gen3_base_pose.apply(fk_ee_gen3(q1)[-3:])
tcp_panda = panda_base_pose.apply(fk_ee_panda(q2)[-3:])
grasp_point1 = obj_frame.apply(grasp_reconst(g1).translation())
grasp_point2 = obj_frame.apply(grasp_reconst(g2).translation())
rot_points = jnp.vstack([grasp_point1, grasp_point2])
ref_points = jnp.vstack([tcp_gen3, tcp_panda])

In [77]:
def residual(ec, rot_points, ref_points):
    rot = SO3.exp(ec)
    d = jax.vmap(rot.apply)(rot_points) - ref_points
    return jnp.hstack(d)

vg_residual_fn = value_and_jacfwd(residual, argnums=0)

In [68]:
from typing import NamedTuple

class Carry(NamedTuple):
    rot_points: jnp.ndarray
    ref_points: jnp.ndarray
    ec: jnp.ndarray = jnp.zeros(3)
    d: jnp.ndarray = jnp.zeros(3)
    i: int = 0
    damping: float = 0.04
    threshold: float = 1e-4
    max_iter: int = 10

    def update(self, d):
        return Carry(
            self.rot_points, self.ref_points, 
            self.ec+d, d, self.i+1, self.damping)

def get_rot_body(carry:Carry):
    res, jac = vg_residual_fn(carry.ec, carry.rot_points, carry.ref_points)
    hess = jac.T@jac
    d = jnp.linalg.solve((hess+carry.damping*jnp.eye(3)), -jac.T@res)
    return carry.update(d)
    
def get_rot_cond(carry:Carry):
    return (jnp.linalg.norm(carry.d) > carry.threshold) | (carry.i < carry.max_iter)

In [69]:
carry = Carry(rot_points, ref_points)
result = jax.lax.while_loop(get_rot_cond, get_rot_body, carry)

In [74]:
pose_ho = SE3.from_rotation_and_translation(SO3.exp(result.ec), p)

In [75]:
obj_ho.set_pose(pose_ho)

In [76]:
hande.set_tool_pose(pose_ho@grasp_reconst(g1))
panda_hand.set_tool_pose(pose_ho@grasp_reconst(g2))

In [66]:
jax.grad(residual_sqrsum)(jnp.zeros(3), rot_points, ref_points)

Array([-0.19138832, -0.3921138 , -0.06205126], dtype=float32)

In [53]:
frame1.set_pose(obj_frame)

In [59]:
point.set_translate(grasp_point2)

In [49]:
point = Sphere(world.vis, "point", 0.02)
point.set_translate(tcp_panda)

In [133]:
def get_rot_ref_points(q1, q2, p_ho):
    obj_frame = SE3.from_translation(p_ho)
    tcp_gen3 = gen3_base_pose.apply(fk_ee_gen3(q1)[-3:])
    tcp_panda = panda_base_pose.apply(fk_ee_panda(q2)[-3:])
    grasp_point1 = obj_frame.apply(grasp_reconst(g1).translation())
    grasp_point2 = obj_frame.apply(grasp_reconst(g2).translation())
    rot_points = jnp.vstack([grasp_point1, grasp_point2])
    ref_points = jnp.vstack([tcp_gen3, tcp_panda])
    return rot_points, ref_points

def rot_objective(ec, param):
    q1, q2 = param[:7], param[7:14]
    p = param[-3:]
    rot_points, ref_points = get_rot_ref_points(q1, q2, p)
    res = residual(ec, rot_points, ref_points)
    return 0.5 * res.T@res

@jax.custom_jvp
def get_handover_rotation(q1, q2, p_ho):
    rot_points, ref_points = get_rot_ref_points(q1, q2, p_ho)    
    carry = Carry(rot_points, ref_points)
    result = jax.lax.while_loop(get_rot_cond, get_rot_body, carry)
    return result.ec

@get_handover_rotation.defjvp
def get_handover_rotation_jvp(primals, tangents):
    q1, q2, p_ho = primals
    #q1_dot, q2_dot, p_ho_dot = tangents
    param_dot = jnp.hstack(tangents)
    rot = get_handover_rotation(q1, q2, p_ho)
    param = jnp.hstack([q1, q2, p_ho])
    hess = jax.jacfwd(jax.jacfwd(rot_objective))(rot, param)
    jacxz = jax.jacfwd(jax.jacfwd(rot_objective, argnums=0), argnums=1)(rot, param)
    dzdx = jnp.linalg.solve(hess, -jacxz)
    return rot, dzdx @ param_dot
    

In [159]:
p = jnp.array([0.4, -0.1, 0.5])
ec = get_handover_rotation(q1, q2, p)
pose_ho = SE3.from_rotation_and_translation(SO3.exp(ec), p)
obj_ho.set_pose(pose_ho)
hande.set_tool_pose(pose_ho@grasp_reconst(g1))
panda_hand.set_tool_pose(pose_ho@grasp_reconst(g2))

: 

In [144]:
jac = jax.jit(jax.jacfwd(get_handover_rotation, argnums=[0,1,2]))

In [145]:
jac(q1, q2, p)

2023-08-23 18:02:19.855065: W external/xla/xla/service/gpu/ir_emitter_triton.cc:761] Shared memory size limit exceeded.


(Array([[ 1.3208008e-01,  7.1240234e-01,  1.4392090e-01,  2.9077148e-01,
          1.1291504e-01, -1.2249756e-01, -3.4706318e-09],
        [-3.4228516e-01,  6.8481445e-02, -3.3911133e-01, -3.9459229e-02,
         -2.4243164e-01, -7.2448730e-02,  2.0445441e-08],
        [-3.1219482e-02,  2.5585938e-01, -2.6565552e-02,  8.9477539e-02,
         -1.5602112e-02, -5.7434082e-02,  3.3578544e-09]], dtype=float32),
 Array([[-3.66821289e-02, -4.46777344e-02, -3.66821289e-02,
          2.46582031e-02, -1.05972290e-02, -5.84602356e-04,
         -2.23008101e-09],
        [ 3.59863281e-01,  2.63427734e-01,  3.59863281e-01,
         -1.52954102e-01,  1.04003906e-01, -2.63023376e-03,
          2.18860805e-08],
        [ 2.37426758e-01, -7.49023438e-01,  2.37426758e-01,
          3.68408203e-01,  6.86035156e-02, -4.58679199e-02,
          1.44354999e-08]], dtype=float32),
 Array([[ 0.75439453,  0.37329102, -0.45581055],
        [-0.37451172,  1.6113281 ,  0.36791992],
        [-1.4179688 , -0.74853516,

In [125]:
ec = result.ec

In [128]:
A = jax.jacfwd(jax.jacfwd(rot_objective))(ec, param)
B = jax.jacfwd(jax.jacfwd(rot_objective, argnums=0), argnums=1)(ec, param)

In [129]:
dzdx = jnp.linalg.solve(A, -B)

In [132]:
dzdx[:,:7] @ q1_dot + 

(3, 7)

In [130]:
dzdx.shape

(3, 17)

In [98]:
rot = SO3.exp(result.ec).parameters()

In [99]:
g_fn = jax.jacfwd(rot_objective)

In [100]:
param = jnp.hstack([q1, q2, p])

In [102]:
g_fn(rot, param)

Array([ 5.0952658e-06,  2.0228326e-06,  8.6799264e-07, -4.1760504e-06],      dtype=float32)

In [None]:
g_fn()

In [None]:
jax.jacfwd(rot_objective,)

In [90]:
g_fn(rot, q1, q2, p)

(Array([ 1.05992913e-01, -2.12560333e-02,  1.05240606e-01,  5.68179926e-03,
         7.53667131e-02,  1.65768452e-02, -6.12352480e-09], dtype=float32),
 Array([ 2.0381449e-01, -2.0628087e-02,  2.0381449e-01,  6.9511801e-02,
         5.8887810e-02,  4.6412937e-02,  1.2393445e-08], dtype=float32),
 Array([-0.12258818,  0.08616162, -0.09589572], dtype=float32))

In [91]:
dgdq1, dgdq2, dgdp = (result.T for result in jax.jacfwd(g_fn)(rot, q1, q2, p))

In [92]:
jax.jacfwd(g_fn,)

Array([[-3.53284292e-02,  1.62174702e-01, -3.23006138e-02,
         5.37582636e-02, -2.09804997e-02, -3.90634611e-02,
         3.03836067e-09],
       [-1.71000063e-01, -6.41315341e-01, -1.81325123e-01,
        -2.71607786e-01, -1.38769731e-01,  1.01346970e-01,
         6.17364027e-09],
       [ 6.20234907e-01, -2.53408492e-01,  6.12506151e-01,
         2.33344324e-02,  4.36021417e-01,  1.57656103e-01,
        -3.78247300e-08],
       [ 1.51959017e-01,  3.92617822e-01,  1.58029303e-01,
         1.75260052e-01,  1.18691415e-01, -5.39622232e-02,
        -6.54683374e-09]], dtype=float32)

In [None]:
def get_handover_rotation(q1, q2):
