In [1]:
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *

In [2]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net_posevec")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]
#manip net
manip_net = ManipNet(64)
manip_fn = lambda x: manip_net.apply(restored_manip["params"], x)[0]

In [3]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7010/static/


In [4]:
world.show_in_jupyter()

In [5]:
hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)



In [6]:
class PandaHand:
    def __init__(self, hand_model, name="hand"):
        self.model = hand_model
        self.robot = Robot(world.vis, name, hand_model, color="white", alpha=0.5, visualized_mesh="visual")
        self.hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
        self.hand_pc = self.robot.get_surface_points_fn(jnp.array([0.04, 0.04]))
    
    def get_bounding_box(self, name):
        fks = self.model.fk_fn(jnp.array([0.04, 0.04]))
        hand_points = []
        for i, link in enumerate(hand_model.links.values()):
            pts = jax.vmap(SE3(fks[i]).apply)(link.collision_mesh.vertices)
            hand_points.append(pts)
        hand_points = np.vstack(hand_points, dtype=float)
        min_points = hand_points.min(axis=0)
        max_points = hand_points.max(axis=0)
        extents = max_points - min_points
        center = (max_points + min_points) / 2
        box = Box(world.vis, name, extents, alpha=0.5, visualize=False)
        box.set_translate(center)
        return box
    
    def set_pose(self, pose):
        self.robot.set_pose(pose @ self.hand_pose_wrt_ee)

In [7]:
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)

In [22]:
pc_hands = PointCloud(world.vis, "pc_hands", np.zeros((10,3)), size=0.01, color="red")

In [8]:
hand1 = PandaHand(hand_model, "hand1")
hand2 = PandaHand(hand_model, "hand2")

In [37]:
del hand1, hand2

In [11]:
obj = Mesh(world.vis, "obj", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="white", alpha=0.5)

In [51]:
hand_pc = hand1.hand_pc
hand_bb = hand1.get_bounding_box("hand_bb")
def hands_collision(grasp1, grasp2):
    grasp_pose1 = grasp_reconst(grasp1) @ hand1.hand_pose_wrt_ee
    grasp_pose2 = grasp_reconst(grasp2) @ hand1.hand_pose_wrt_ee
    grasp1_wrt_2 = grasp_pose1.inverse() @ grasp_pose2
    points = jax.vmap(grasp1_wrt_2.apply)(hand_pc)
    return jax.vmap(hand_bb.sdf)(points).min()

In [264]:
grasp12 = np.random.normal(size=6)*0.3

print(grasp12)
print(hands_collision(grasp12[:3], grasp12[3:]))
hand1.set_pose(grasp_reconst(grasp12[:3]))
hand2.set_pose(grasp_reconst(grasp12[3:]))

[ 0.12983679 -0.15415074 -0.07238102  0.18586288 -0.34047881  0.11690921]
-0.017185275


In [64]:
jac_hands_col = jax.jacfwd(hands_collision, argnums=[0,1])

In [273]:
grads = jac_hands_col(grasp12[:3], grasp12[3:])
grasp12 += jnp.hstack(grads) * 0.1
print(hands_collision(grasp12[:3], grasp12[3:]))
hand1.set_pose(grasp_reconst(grasp12[:3]))
hand2.set_pose(grasp_reconst(grasp12[3:]))

0.0145794


In [11]:
hand_bb = hand.get_bounding_box("hand_bb")
hand_pc = hand.hand_pc

In [13]:
jax.vmap(hand_bb.sdf)(hand_pc)

Array([-7.99929909e-03, -3.69697646e-03, -4.14564321e-03, -1.71705130e-02,
       -6.71752403e-03, -2.89463834e-03, -1.11606643e-02, -8.74026679e-03,
       -3.25338007e-03, -7.06460746e-03, -2.14846060e-02, -1.51975619e-05,
       -2.13337895e-02, -2.60398295e-02, -2.11889800e-02, -2.12817453e-02,
       -9.90752690e-03, -2.17491519e-02, -2.11603846e-02, -2.41197553e-02,
       -6.41627776e-05, -2.27419790e-02, -2.28510369e-02, -2.17692498e-02,
       -2.12432910e-02, -2.22314205e-02, -1.85975116e-02, -2.12153438e-02,
       -1.75805669e-02, -1.26740765e-02], dtype=float32)

In [22]:
link = hand_model.links['panda_hand']

In [55]:
fks = hand_model.fk_fn(jnp.array([0.04, 0.04]))
hand_points = []
for i, link in enumerate(hand_model.links.values()):
    pts = jax.vmap(SE3(fks[i]).apply)(link.collision_mesh.vertices)
    hand_points.append(pts)

In [56]:
hand_points = np.vstack(hand_points, dtype=float)
min_points = hand_points.min(axis=0)
max_points = hand_points.max(axis=0)
extents = max_points - min_points
center = (max_points + min_points) / 2

In [52]:
del box

NameError: name 'box' is not defined

In [62]:
box = Box(world.vis, "palm", extents, alpha=0.5)
box.set_translate(center)