In [1]:
import meshcat
import meshcat.geometry as g
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
# from sdf_world.robots import *

In [2]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [3]:
box = Box(world.vis, "box", lengths=[0.5, 0.5, 0.5], color="green", alpha=0.5)
sphere = Sphere(world.vis, "sphere", r=0.2, color="blue", alpha=0.5)
sphere.set_translate([0.,0.6,0])

In [4]:
box.set_translate([0.,-0.6,0])

In [5]:
import trimesh
def farthest_point_sampling(points, num_samples):
    farthest_points = np.zeros((num_samples, 3))
    farthest_points[0] = points[np.random.randint(len(points))]
    distances = np.full(points.shape[0], np.inf)
    for i in range(1, num_samples):
        distances = np.minimum(distances, np.linalg.norm(points - farthest_points[i - 1], axis=1))
        farthest_points[i] = points[np.argmax(distances)]
    return farthest_points


In [6]:
mesh_sphere = trimesh.primitives.Sphere(radius=0.2)

In [7]:
world.vis["pointcloud"].delete()

In [8]:
data = mesh_sphere.sample(20*5)
sampled_points = farthest_point_sampling(data, 20)

points = np.asarray(sampled_points, dtype=np.float64).T
colors = np.tile(Colors.read("red",return_rgb=True), points.shape[1]).reshape(-1, 3).T
world.vis["pointcloud"].set_object(
    g.PointCloud(position=points, color=colors, size=0.03)
)

In [9]:
xyz = np.array([0, -0.2, 0])
T = np.block([[np.eye(3),   xyz[:,None]],
                [np.zeros(3), 1.         ]])
world.vis["pointcloud"].set_transform(T)

In [10]:
col_points = jax.vmap(SE3.from_matrix(T).apply)(sampled_points)

In [12]:
jax.vmap(box.penetration, in_axes=(0, None))(col_points, 0.001)

Array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.00656613, 0.        , 0.027682  ,
       0.0156482 , 0.        , 0.        , 0.        , 0.        ],      dtype=float32)

In [16]:
from functools import partial
distance = partial(box.distance, box_pose=box.pose, half_extents=np.array(box.lengths)/2)

In [18]:
jax.vmap(distance)(col_points)

Array([ 0.0420793 ,  0.22093514,  0.17782676,  0.3291595 , -0.00749285,
        0.18056977, -0.01105167,  0.252411  ,  0.27217913,  0.0277524 ,
        0.08217394,  0.29046512, -0.04222019,  0.09515226,  0.04370058,
        0.31922156,  0.1214962 ,  0.1893425 ,  0.2168242 ,  0.08601147],      dtype=float32)

In [None]:
box.distance()

In [78]:
verts = np.random.rand(3, 1000)

In [14]:
Colors.read("white")

16185078