In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7002/static/


In [5]:
p1 = jnp.array([0,0,0.])
capsule = Capsule(world.vis, "cap", p1, jnp.array([0, 0, 0.2]), 0.05, alpha=0.5)

In [4]:
del capsule

In [4]:
Cylinder(world.vis, "cyl", 0.3, 0.01)

<sdf_world.sdf_world.Cylinder at 0x7fb0fdec4280>

In [4]:
def point_to_T(p):
    return np.array(SE3.from_translation(p).as_matrix())

In [4]:
point_to_T(jnp.ones(3))

array([[1., 0., 0., 1.],
       [0., 1., 0., 1.],
       [0., 0., 1., 1.],
       [0., 0., 0., 1.]], dtype=float32)

In [5]:
capsule.handle

NameError: name 'capsule' is not defined

In [8]:
p1 = jnp.array([0,0,0.])
p2 = jnp.array([0, 0, 0.2])

In [None]:
cylinder = Cylinder(world.vis, "cyl", 0.5, 0.05)

In [18]:
normalize = lambda v: v/jnp.linalg.norm(v)
zaxis = normalize(p2 - p1)
xaxis = jnp.array([1, 0, 0.1])  #random
yaxis = normalize(jnp.cross(zaxis, xaxis))
xaxis = normalize(jnp.cross(yaxis, zaxis))
R = jnp.vstack([xaxis, yaxis, zaxis]).T
center = (p1 + p2) / 2
cylinder_pose = SE3.from_rotation_and_translation(SO3.from_matrix(R), center)


SE3(wxyz=[1. 0. 0. 0.], xyz=[0.         0.         0.09999999])

In [17]:
SE3.from_translation()

Array([0., 0., 1.], dtype=float32)

In [10]:
center

Array([0. , 0. , 0.1], dtype=float32)

In [6]:
del capsule

In [16]:
p_circle = fibonacci_sphere(100) * 0.05
p_center = jnp.zeros(3)

In [29]:
p = jnp.array([0, 1, 0.])

In [41]:
def signed_distance(p):
    distances = jax.vmap(safe_2norm)(p_circle - p)
    p_surface_idx = distances.argmin()
    p_surface = p_circle[p_surface_idx]
    sign = jnp.sign((p - p_surface) @ (p_surface - p_center))
    return sign*distances[p_surface_idx]
signed_distance_j = jax.jit(signed_distance)
signed_distance_batch = jax.jit(signed_distance)

In [33]:
signed_distance_vg = jax.jit(jax.value_and_grad(signed_distance))

In [46]:
%timeit jax.vmap(signed_distance_batch)(jnp.vstack([p_obj, p_obj, p_obj]))

1.25 ms ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Array([[1.        , 1.05      , 1.        ],
       [0.99262667, 1.0489899 , 1.0067545 ],
       [1.00123   , 1.0479798 , 0.9859848 ],
       [1.0104299 , 1.0469697 , 1.0136039 ],
       [0.9806105 , 1.0459596 , 0.9965702 ],
       [1.018477  , 1.0449495 , 0.9882465 ],
       [0.9938056 , 1.0439394 , 1.0230426 ],
       [0.98818547, 1.0429293 , 0.9772515 ],
       [1.0256002 , 1.0419192 , 1.0093493 ],
       [0.9734269 , 1.040909  , 1.0109688 ],
       [1.0127724 , 1.039899  , 0.9727065 ],
       [1.0094054 , 1.0388889 , 1.0299865 ],
       [0.97176194, 1.0378788 , 0.9836352 ],
       [1.0329865 , 1.0368687 , 0.99274826],
       [0.9799594 , 1.0358586 , 1.0285051 ],
       [0.9953926 , 1.0348485 , 0.96444225],
       [1.0281464 , 1.0338384 , 1.0237223 ],
       [0.9623187 , 1.0328283 , 1.0015578 ],
       [1.0273395 , 1.0318182 , 0.9727942 ],
       [0.9981805 , 1.0308081 , 1.039339  ],
       [0.9742754 , 1.029798  , 0.96917266],
       [1.0405159 , 1.0287879 , 1.0054519 ],
       [0.

In [39]:
p_obj = p_circle + 1.

In [None]:
sign @ 

In [12]:
p_circle

Array([0.95      , 0.9510626 , 0.9521241 , 0.9531844 , 0.95424354,
       0.9553016 , 0.95635825, 0.9574139 , 0.9584684 , 0.95952165,
       0.9605738 , 0.9616248 , 0.96267456, 0.96372324, 0.96477085,
       0.9658172 , 0.96686256, 0.9679068 , 0.9689498 , 0.9699917 ,
       0.97103244, 0.9720722 , 0.9731106 , 0.97414815, 0.97518456,
       0.9762198 , 0.9772539 , 0.97828704, 0.979319  , 0.9803499 ,
       0.9813797 , 0.9824084 , 0.98343605, 0.9844626 , 0.9854881 ,
       0.98651254, 0.98753595, 0.98855823, 0.9895795 , 0.99059975,
       0.9916189 , 0.99263704, 0.99365413, 0.99467015, 0.99568516,
       0.9966991 , 0.9977121 , 0.9987239 , 0.9997348 , 1.0007446 ,
       1.0017534 , 1.0027612 , 1.0037681 , 1.004774  , 1.0057787 ,
       1.0067825 , 1.0077854 , 1.008787  , 1.0097878 , 1.0107877 ,
       1.0117866 , 1.0127844 , 1.0137813 , 1.0147772 , 1.015772  ,
       1.016766  , 1.0177588 , 1.0187509 , 1.019742  , 1.020732  ,
       1.0217212 , 1.0227093 , 1.0236965 , 1.0246828 , 1.02566