In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from signed_distance import *
from nlp_builder import NLPBuilder
from nlp_solver import SQPSolver
from functools import partial

from pathlib import Path
from pybullet_suite import *
from panda_collision import *

pybullet build time: Dec 14 2022 00:46:04


In [2]:
# problem
obs_center = jnp.array([0.6, 0, 0.6])
obs_r = 0.2

safe_dist = 0.1
dim_robot = 7
num_steps = 8

In [3]:
robot_path = Path(PANDA_URDF).parent / "meshes/collision"
mesh_names = [
    'link0', 'link1', 'link2', 'link3', 'link4', 
    'link5', 'link6', 'link7', 'hand', "finger", "finger"
]
fk = get_fk_fn(PANDA_URDF)
robotpc = get_pointclouds(robot_path, mesh_names, 10)

Robot name: panda


concatenating texture: may result in visual artifacts


In [4]:
#util functions
def to_mat(x):
    return x.reshape(-1, dim_robot)
def to_vec(mat):
    return mat.flatten()
def at_timestep(i, x):
    return to_mat(x)[i]
def to_vel(x):
    return (to_mat(x)[1:] - to_mat(x)[:-1]).flatten()

In [5]:
# functions
obs = Circle(obs_center, obs_r)
env = EnvSDF((obs,), safe_dist)
assign_points = lambda q: robotpc.apply_transforms(fk(q))
assign_points_path = lambda x: jnp.vstack(jax.vmap(assign_points)(to_mat(x)))
    
state_init = partial(at_timestep, 0)
state_goal = partial(at_timestep,-1)
penetration = lambda x: env.penetrations(assign_points_path(x))

def min_dist_cost(x):
    v = to_vel(x)
    return v @ v + penetration(x)

get_ee_point = lambda q: fk(q)[-1, -3:]
get_ee_points = lambda x: jax.vmap(get_ee_point)(to_mat(x))

def min_ee_dist_cost(x):
    ee_pos_mat = get_ee_points(x)
    ee_vel_mat = ee_pos_mat[1:] - ee_pos_mat[:-1]
    ee_vel = ee_vel_mat.flatten()
    return ee_vel @ ee_vel

#panda = pandas[0]
# ql = jnp.array(panda.joint_lower_limit)
# qu = jnp.array(panda.joint_upper_limit)
q_init = jnp.array([-0.9   ,  0.    ,  0.    , -1.7708,  0.    ,  1.8675,  0.    ])
q_goal  = jnp.array([0.9   ,  0.    ,  0.    , -1.7708,  0.    ,  1.8675,  0.    ])
# xl = jnp.tile(ql, num_steps)
# xu = jnp.tile(qu, num_steps)
x0 = jnp.linspace(q_init, q_goal, 8).flatten()
#show_path(x0)

In [66]:
def get_jac_axis(wxyz_xyz, point):
    rot_axis = SE3(wxyz_xyz).as_matrix()[:3,2]
    return jnp.cross(rot_axis, point)
get_jacs = jax.vmap(get_jac_axis, in_axes=(0, None))

In [100]:
pc_link3 = robotpc.pcs[3]

In [107]:
point = jnp.array(pc_link3._points[0])

In [129]:
jax.grad(env.penetration)(point) @ jacobian

Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [130]:
def get_penetration_value_and_grad(env, point, link, q):
    frame_vecs = fk(q)[1:8]
    fk_point = SE3(frame_vecs[link]).apply(point)
    jac1 = get_jacs(frame_vecs, point).T
    boolidx = jnp.tile(jnp.arange(0, 7, 1), 3).reshape(3, -1) < link
    jac_point = jnp.where(boolidx, jac1, 0.)
    value, penet_grad = jax.value_and_grad(env.penetration)(fk_point)
    return value, penet_grad @ jac_point

In [133]:
jax.jit(get_penetration_value_and_grad)(env, point, 3, jnp.zeros(3))

(Array(0., dtype=float32, weak_type=True),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32))

In [119]:
link, q = 3, jnp.zeros(7)
frame_vecs = fk(q)[1:8]
fk_point = SE3(frame_vecs[link]).apply(point)
jac1 = get_jacs(frame_vecs, point).T
boolidx = jnp.tile(jnp.arange(0, 7, 1), 3).reshape(3, -1) < link
jac_point = jnp.where(boolidx, jac1, 0.)

In [128]:
_, jacobian = jax.jit(get_fk_and_jac)(point, 3, jnp.zeros(7))

In [87]:
jax.jit(get_jacobian)(jnp.ones(3), 1)

Array([[-1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

In [63]:
link = 1

Array([[-1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

In [35]:
rot_axis

Array([0., 0., 1.], dtype=float32)

In [33]:
point = jnp.ones(3)

In [None]:
jnp.cross(get_rot_axis)

In [31]:
get_rot_axes(joint_frames)

Array([[ 0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       [ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00, -2.5376105e-08,  1.0000000e+00],
       [ 0.0000000e+00, -1.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  6.8705230e-10,  1.0000000e+00],
       [ 0.0000000e+00, -1.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00, -6.8705230e-10, -1.0000002e+00]], dtype=float32)

In [18]:
jax.jit(fk)(jnp.zeros(7))

Array([[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         3.33000004e-01],
       [ 7.07106769e-01, -7.07106769e-01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         3.33000004e-01],
       [ 9.99999940e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  5.25585797e-09,
         6.48999989e-01],
       [ 7.07106709e-01,  7.07106709e-01,  0.00000000e+00,
         0.00000000e+00,  8.24999884e-02,  5.25585797e-09,
         6.48999989e-01],
       [ 9.99999881e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  7.12799242e-09,
         1.03299999e+00],
       [ 7.07106709e-01,  7.07106709e-01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  7.12799242e-09,
         1.0329999

In [11]:
jax.jit(jax.jacfwd(fk))(jnp.zeros(7))

Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00

In [12]:
fk_pos = lambda q:fk(q)[:,4:].flatten()

In [16]:
jax.jit(jax.jacrev(fk_pos))(jnp.zeros(7))

(36, 7)

In [16]:
jax.jit(jax.jacfwd(fk_pos))(jnp.zeros(7))

Array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000

In [18]:
jax.jit(jax.jacrev(jax.jacrev(fk_pos)))(jnp.zeros(7))

2023-05-20 16:41:35.085977: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit__lambda_] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


In [9]:
jax.jit(jax.jacfwd(jax.jacrev(fk_pos)))(jnp.zeros(7))

Array([[[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         ...,
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

        [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e

In [11]:
assign_points(jnp.zeros(7)).shape

(11, 10, 3)

In [61]:
fk_pos = lambda q: fk(q)[:,-3:]
jac_pos = jax.jit(jax.jacrev(fk_pos))

In [73]:
jax.jit(jax.grad(lambda q:fk(q).sum()))(jnp.zeros(7))

Array([ 5.4017887 ,  3.8775601 ,  4.901788  , -0.48645374,  3.6121817 ,
        3.2790458 , -1.9650754 ], dtype=float32)

In [62]:
jac_pos(jnp.zeros(7)).shape

(12, 3, 7)

In [67]:
jax.jit(jax.grad(penetration))(x0)

Array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.32375202, -0.04457903,  0.32375202,  0.06050486,
        0.02544595,  0.01748032, -0.00756898, -0.3881803 , -0.11117046,
       -0.38818032,  0.11441814, -0.03633801,  0.02956284,  0.01050684,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ], dtype=float32)

In [14]:
fk(jnp.zeros(7)).shape

(12, 7)

In [28]:
c_grad = jax.grad(env.penetration)

In [44]:
JT = jac_pos(jnp.zeros(7))[:-1].swapaxes(1,2)

In [51]:
gC = jax.vmap(c_grad)(jnp.vstack(points)).reshape(11, -1, 3).swapaxes(1,2)

In [56]:
jnp.einsum("ijk,ikl->ijl", JT, gC).swapaxes(1,2).shape

(11, 10, 7)

In [52]:
JT.shape, gC.shape

((11, 7, 3), (11, 3, 10))

In [25]:
points = assign_points(jnp.zeros(7))

In [32]:
points[0]

Array([[-4.1505657e-02,  2.8067309e-02,  1.4000000e-01],
       [-9.6079096e-02, -6.3804947e-02, -9.8495675e-06],
       [ 6.6316165e-02, -5.8328751e-03,  1.2496690e-03],
       [-1.1220833e-02, -5.6149375e-02,  1.3390321e-01],
       [-3.5568796e-02,  4.8435576e-02,  1.3071261e-01],
       [-2.0031877e-02,  8.6079203e-02,  2.2062847e-02],
       [ 4.2301998e-03,  2.6623788e-03,  7.7071461e-05],
       [-3.0249219e-02, -4.6237665e-03,  1.4000012e-01],
       [-4.1346554e-02, -8.6493276e-02,  2.0070581e-03],
       [-6.9459669e-02, -5.0699402e-02,  9.5153749e-02]], dtype=float32)

In [31]:
c_grad()

Array([[-0.,  0., -0.],
       [-0., -0., -0.],
       [-0., -0., -0.],
       [-0., -0., -0.],
       [-0.,  0., -0.],
       [-0.,  0., -0.],
       [-0.,  0., -0.],
       [-0., -0., -0.],
       [-0., -0., -0.],
       [-0., -0., -0.]], dtype=float32)

In [30]:
jax.vmap(c_grad)(points).shape

(11, 10, 3)

In [None]:
jac_pos(jnp.zer) @ c_grad(x)

In [119]:
lambda x: jax.jacrev(c_grad)(jnp.ones(3))

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [123]:
robotpc.apply_transforms(fk(jnp.zeros(7)))

Array([[-8.34387019e-02,  2.76514888e-02,  1.01273052e-01],
       [-2.83865705e-02,  3.80076133e-02, -1.13298056e-05],
       [-2.47727372e-02,  9.51951742e-03,  1.40000001e-01],
       [ 2.42441520e-02, -6.90249503e-02,  4.49906401e-02],
       [-5.73467277e-02,  6.88171461e-02,  5.63896000e-02],
       [-1.38195395e-01,  5.38827032e-02,  1.48725705e-02],
       [-8.21067691e-02, -6.06489927e-02,  6.86731189e-02],
       [ 4.53395545e-02,  2.80613713e-02,  8.55118968e-04],
       [-4.97161262e-02, -7.96922147e-02, -2.08754755e-05],
       [ 4.41156104e-02,  3.81585285e-02,  1.10976882e-01],
       [ 1.58967935e-02,  1.98786128e-02,  3.17514211e-01],
       [-5.39738238e-02, -4.15718257e-02,  3.43412369e-01],
       [ 3.21753621e-02,  4.45433147e-02,  1.73581332e-01],
       [-1.44155864e-02,  3.20091727e-03,  3.73275369e-01],
       [-2.64641605e-02, -9.98812616e-02,  2.65972853e-01],
       [ 5.21448143e-02, -7.15728551e-02,  3.44092339e-01],
       [-3.91985588e-02, -1.12060934e-01

In [None]:
jax.jacrev()

In [59]:
assign_points_path()

In [86]:
fk(jnp.zeros(7))[1:].shape

(11, 7)

In [106]:
def pos_jac(wxyz_xyz):
    qtn, pos = wxyz_xyz[:4], wxyz_xyz[4:]
    rot_axis = SO3(qtn).as_matrix()[:,2]
    return jnp.cross(rot_axis, pos)
jac_fn = lambda q: jax.vmap(pos_jac)(fk(q))

In [101]:
jac_pos(jnp.zeros(7))[8]

Array([[ 1.02421183e-10,  5.92999995e-01,  5.15343634e-09,
        -2.76999980e-01, -1.34298472e-09,  1.06999971e-01,
         0.00000000e+00],
       [ 8.79999921e-02, -1.39878353e-09,  8.79999846e-02,
         1.01636033e-10,  8.79999846e-02, -1.38241552e-09,
         0.00000000e+00],
       [ 1.39878353e-09, -8.79999921e-02,  1.00612844e-10,
         5.49999438e-03, -1.39878331e-09,  8.79999846e-02,
         0.00000000e+00]], dtype=float32)

In [109]:
jac_fn(jnp.zeros(7)).T.shape

(3, 12)

In [81]:
SO3_batch(fk(jnp.zeros(7))[:,:4])

SO3(wxyz=[[ 1.       0.       0.       0.     ]
 [ 1.       0.       0.       0.     ]
 [ 0.70711 -0.70711  0.       0.     ]
 [ 1.       0.       0.       0.     ]
 [ 0.70711  0.70711  0.       0.     ]
 [ 1.      -0.       0.       0.     ]
 [ 0.70711  0.70711  0.       0.     ]
 [ 0.       1.       0.       0.     ]
 [ 0.       0.92388  0.38268 -0.     ]
 [ 0.       0.92388  0.38268 -0.     ]
 [ 0.       0.38268 -0.92388  0.     ]
 [ 0.       0.92388  0.38268 -0.     ]])

In [50]:
jax.jit(jax.hessian(fn))(jnp.ones(7))

Array([[-5.9313650e+00,  8.1971443e-01,  9.9841571e-01,  1.5811131e+00,
        -1.6894269e+00,  5.8244032e-01,  1.9261215e+00],
       [ 8.1971431e-01, -8.7663193e+00, -2.3235323e+00,  4.0758805e+00,
        -6.3946247e-03,  1.6378136e+00, -3.6888069e-01],
       [ 9.9841523e-01, -2.3235321e+00, -1.9082432e+00, -4.2301780e-01,
        -5.5890155e-01, -3.0133522e-01,  7.1926659e-01],
       [ 1.5811131e+00,  4.0758810e+00, -4.2301774e-01, -4.0118780e+00,
         1.3909061e+00, -1.6301670e+00, -1.3381933e+00],
       [-1.6894267e+00, -6.3948333e-03, -5.5890143e-01,  1.3909063e+00,
        -1.0470909e+00, -9.1984290e-01,  1.4805307e+00],
       [ 5.8244038e-01,  1.6378136e+00, -3.0133477e-01, -1.6301670e+00,
        -9.1984314e-01, -6.8512458e-01, -8.0369651e-01],
       [ 1.9261215e+00, -3.6888057e-01,  7.1926647e-01, -1.3381932e+00,
         1.4805307e+00, -8.0369651e-01,  1.0402023e-01]], dtype=float32)

In [14]:
jax.jit(jax.hessian(min_dist_cost))(x0)