In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import cyipopt
from pybullet_suite import *
from panda_collision import *
from signed_distance import *
from pathlib import Path

pybullet build time: May 20 2022 19:44:17


In [2]:
mesh_path = Path(PANDA_URDF).parent/"meshes/collision"
link_names = link_names = [
    'link1', 'link2', 'link3', 
    'link4', 'link5', 'link6', 'link7', 'hand', 
    'finger', 'finger'
]

In [114]:
fk = get_fk_fn(PANDA_URDF)
pc = get_pointclouds(mesh_path, link_names)

only got 89/100 samples!


Robot name: panda


In [115]:
bw = BulletWorld(gui=True)
sm = BulletSceneMaker(bw)

pandas = [bw.load_robot(Panda, f"panda{i}") for i in range(8)]

error: Only one local in-process GUI/GUI_SERVER connection allowed. Use DIRECT connection mode or start a separate GUI physics server (ExampleBrowser, App_SharedMemoryPhysics_GUI, App_SharedMemoryPhysics_VR) and connect over SHARED_MEMORY, UDP or TCP instead.

In [5]:
obs_center = jnp.array([0.6, 0, 0.6])
obs_r = 0.2

In [6]:
obs = sm.create_sphere("obstacle", obs_r, 0.1, Pose.identity(), [1,0,0,0.3])
obs.set_base_pose(Pose(trans=np.asarray(obs_center)))

In [7]:
# env
offset = SE3.from_translation(jnp.array([0., 0, -1]))
half_extents = jnp.array([10., 10, 1])
ground = box(offset, half_extents)
circ = circle(obs_center, obs_r)
env = EnvSDF((circ, ground))

In [40]:
def robot_penet(q):
    tfs = fk(q)
    points = pc.apply_transforms(tfs)
    return env.penetration_sum(points, pc.num_points, 0.02)
def path_penet(qs):
    q_batch = qs.reshape(-1, 7)    
    robot_penet_path = jax.vmap(robot_penet)(q_batch).sum()
    return robot_penet_path
#robot_penet_path = jax.vmap(robot_penet)
robot_penet_path = jax.jit(path_penet)

In [82]:
# util functions
dim_robot = 7
num_steps = 8
dim = dim_robot*num_steps
def to_mat(x):
    return x.reshape(-1, dim_robot)
def at_timestep(i, x):
    return to_mat(x)[i]
def to_vel(x):
    return (to_mat(x)[1:] - to_mat(x)[:-1]).flatten()

In [112]:
from functools import partial

# functions
def min_dist_cost(x):
    v = to_vel(x)
    return v @ v
state_init = partial(at_timestep, 0)
state_goal = partial(at_timestep,-1)

In [None]:
q_init = jnp.array([-0.9   ,  0.    ,  0.    , -1.7708,  0.    ,  1.8675,  0.    ])
q_goal = jnp.array([0.9   ,  0.    ,  0.    , -1.7708,  0.    ,  1.8675,  0.    ])
qs = jnp.linspace(q_init, q_goal, 8).flatten()


In [None]:
g, cl, cu = [], [], []
f = min_dist_cost
g.append(state_init)
cl.extend(q_init)
cu.extend(q_init)
g.append(state_goal)
cl.extend(q_goal)
cu.extend(q_goal)
g.append(ssdf_obs_path)
cl.extend(jnp.zeros(num_steps))
cu.extend(jnp.full(num_steps,jnp.inf))

In [37]:
def show_path(qs):
    qs = qs.reshape(-1, 7)
    for i, q in enumerate(qs):
        pandas[i].set_joint_angles(q)
show_path(qs)

In [110]:
def total_cost(qs):
    return min_dist_cost(qs) + 0.0001 * path_penet(qs)

In [111]:
for i in range(10):
    qs_grad = jax.grad(total_cost)(qs)
    if jnp.abs(qs_grad).sum() < 1e-3:
        break
    grad_desc_step = -0.2 * qs_grad
    qs += grad_desc_step
    print(total_cost(qs))
    show_path(qs)

0.4022482
0.35180604
0.31216303
0.27848944
0.24921426
0.22361854
0.20120019
0.18155658
0.16436628
0.14938079


In [71]:
%timeit jax.jit(path_penet)(qs)

273 µs ± 39.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
