In [1]:
from plyfile import PlyData
import jax 
import jax.numpy as jnp 
from jax import grad, jacfwd
import polyscope as ps

from jaxlie import SO3, SE3

import numpy as np
import open3d as o3d 
import trimesh as tm

import meshcat
from meshcat import Visualizer
import meshcat.geometry as mc_geom
import meshcat.transformations as mc_trans

import time

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [10]:
norm = np.array([1., 1., 1.])
norm = norm/np.linalg.norm(norm)
root = np.array([1., 0., 0.])
v1 = np.array([0., 0., 1.])
def loss(params):
    w = params['w']
    v = params['v']
    _R = SO3.exp(w)
    return jnp.sum(jnp.dot(norm, _R@v1)) + np.sum((norm+root-v)**2)

In [3]:
vis = Visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [11]:
body1 = vis['v1']
body1.set_object(mc_geom.triad(scale=0.5))

body2 = vis['v2']
body2.set_object(mc_geom.triad(scale=0.5))


_angle = np.arccos(np.dot(np.array([0.,0.,1.]), norm))
_axis = np.cross(np.array([0.,0.,1.]), norm)
_R = SO3.exp(_axis*_angle)
_T = SE3.from_rotation_and_translation(_R, root).as_matrix()
body2.set_transform(np.array(_T, dtype=np.float64))



In [12]:
w = jnp.array([0.1,0.2,0.3])
v = jnp.array([.0,0.,.01])

params = {'w': w, 'v':v}
for _ in range(100):
    _grad = grad(loss)(params)
    params['w'] = params['w'] - 0.1*_grad['w']
    params['v'] = params['v'] - 0.1*_grad['v']
    _R = SO3.exp(params['w'])
    _T = SE3.from_rotation_and_translation(_R, params['v']).as_matrix()
    body1.set_transform(np.array(_T, dtype=np.float64))
    print(params, loss(params))
    time.sleep(0.1)

{'w': Array([0.15375288, 0.14687362, 0.29206684], dtype=float32), 'v': Array([0.31547007, 0.11547005, 0.12347005], dtype=float32)} 2.5973082
{'w': Array([0.2106922 , 0.09044492, 0.28350365], dtype=float32), 'v': Array([0.5678461, 0.2078461, 0.2142461], dtype=float32)} 1.8062948
{'w': Array([0.27064463, 0.03085877, 0.2743034 ], dtype=float32), 'v': Array([0.76974696, 0.28174692, 0.28686693], dtype=float32)} 1.2687414
{'w': Array([ 0.33334464, -0.03164876,  0.2644717 ], dtype=float32), 'v': Array([0.9312676 , 0.34086758, 0.3449636 ], dtype=float32)} 0.8911159
{'w': Array([ 0.39842916, -0.09674378,  0.25402907], dtype=float32), 'v': Array([1.0604842 , 0.3881641 , 0.39144093], dtype=float32)} 0.61402595
{'w': Array([ 0.4654383 , -0.16399376,  0.24301267], dtype=float32), 'v': Array([1.1638573 , 0.42600134, 0.42862278], dtype=float32)} 0.40007523
{'w': Array([ 0.5338238 , -0.2328741 ,  0.23147723], dtype=float32), 'v': Array([1.2465559 , 0.4562711 , 0.45836827], dtype=float32)} 0.22605816
{

In [None]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
vertices = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T

# _min_pnt = vertices.min(axis=0)
# _max_pnt = vertices.max(axis=0)
# _mid_pnt = vertices.mean(axis=0)
# vertices = (vertices - _mid_pnt)/(_max_pnt-_min_pnt)

faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=vertices, faces=faces)

ds_mesh = mesh.simplify_quadric_decimation(int(mesh.vertices.shape[0]/10))

In [None]:
# ps.init()

# ps_mesh = ps.register_surface_mesh("my mesh", ds_mesh.vertices, ds_mesh.faces)
# ps_mesh.add_vector_quantity("rand vecs", ds_mesh.vertex_normals, enabled=True)


# ps.show()