In [93]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch

from torch_obb.estimation import obb_estimate_pca, obb_estimate_dito
import trimesh



rng = np.random.RandomState(12)

def sample_spherical(npoints, ndim=3):
    vec = np.random.randn(npoints, ndim).astype(np.float32)
    vec /= np.linalg.norm(vec, axis=1)[:, None]
    return vec

rotation_matrix = trimesh.transformations.rotation_matrix(np.pi/4, [1, 0, 0])
# points1 = rng.random((500, 3)).astype(np.float32) * np.array([1, 1, 1], dtype=np.float32)
# points2 = rng.random((14, 3)).astype(np.float32) * np.array([1, 5, 2], dtype=np.float32)
points3 = sample_spherical(15, 3) * np.array([2, 2, 1] + np.array([1.0, 0.0, 0.0], dtype=np.float32), dtype=np.float32)
points3 = trimesh.transform_points(points3, rotation_matrix)
points3 = points3.astype(np.float32)
points = torch.nested.as_nested_tensor([torch.from_numpy(points3)], layout=torch.jagged)
vertices_pca, basis_pca = obb_estimate_pca(points)
vertices_dito, basis_dito = obb_estimate_dito(points)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [94]:
basis_pca

tensor([[[-0.9942, -0.0504,  0.0947],
         [ 0.0404,  0.6421,  0.7655],
         [-0.0994,  0.7649, -0.6364]]])

In [95]:
basis_dito

tensor([[[ 0.9427, -0.3158, -0.1075],
         [-0.0654, -0.4908,  0.8688],
         [-0.3271, -0.8120, -0.4834]]])

In [96]:
import pygfx as gfx
import trimesh
from rendercanvas.glfw import GlfwRenderCanvas

_TRIS = [[1, 3, 0],
         [4, 1, 0],
         [0, 3, 2],
         [2, 4, 0],
         [1, 7, 3],
         [5, 1, 4],
         [5, 7, 1],
         [3, 7, 2],
         [6, 4, 2],
         [2, 7, 6],
         [6, 5, 4],
         [7, 5, 6]]
i = 0
canvas = GlfwRenderCanvas(size=(1024, 768), max_fps=60, update_mode="continuous")   
scene = gfx.Scene()
background = gfx.Background(None, gfx.BackgroundMaterial('white'))
scene.add(background)

vertices_pca_np = vertices_pca[i].cpu().numpy()
print(vertices_pca_np.shape)
vertices_dito_np = vertices_dito[i].cpu().numpy()
points_np = points[i].cpu().numpy()


points_geometry = gfx.Geometry(positions=points_np.astype(np.float32))
pc = gfx.Points(points_geometry, material=gfx.PointsMaterial(size=10, color=(0, 0, 0, 1.0)))
scene.add(pc)

obb_geometry_pca = gfx.Geometry(positions=vertices_pca_np.squeeze().astype(np.float32), indices=np.array(_TRIS, dtype=np.int32))
obb_pca = gfx.Mesh(obb_geometry_pca, gfx.MeshBasicMaterial(wireframe=True, color='cyan', wireframe_thickness=2))
scene.add(obb_pca)

obb_geometry_dito = gfx.Geometry(positions=vertices_dito_np.squeeze().astype(np.float32), indices=np.array(_TRIS, dtype=np.int32))
obb_dito = gfx.Mesh(obb_geometry_dito, gfx.MeshBasicMaterial(wireframe=True, color='red', wireframe_thickness=2))
scene.add(obb_dito)


disp = gfx.Display(canvas=canvas)
disp.show(scene)

canvas.is_closed() is deprecated, use canvas.get_closed() instead.
Your scene does not contain any lights. Some objects may not be visible.
Unrecognized present mode 1000361000
Unrecognized present mode 1000361000


(8, 3)


Unrecognized present mode 1000361000


In [78]:
from torch_obb.util import ensure_warp_available
from torch_obb.estimation import prepare_vertices
from torch_obb.kernels.estimation_utils import compute_obb_vertices
from torch_obb.kernels.estimation_dito import NUM_SLAB_DIRS, SLAB_DIRS, best_obb_axes_from_base_triangle_kernel, BLOCK_SIZE, compute_obb_extents_jagged
from typing import Optional
import warp as wp

def obb_estimate_dito_dbg(vertices_t: torch.Tensor,
                          device: Optional[str] = None) -> torch.Tensor:
    ensure_warp_available()
    if device is None:
        device = vertices_t.device

    npoints_t = vertices_t.offsets().diff()
    if (npoints_t < NUM_SLAB_DIRS * 2).any():
        raise ValueError(f"Each batch must have at least {NUM_SLAB_DIRS * 2} vertices.")

    slab_dirs_t = SLAB_DIRS.to(dtype=vertices_t.dtype, device=device)#.unsqueeze(1)

    # operate directly on the non-jagged values and then convert back to jagged representations
    slab_projs_t = torch.nested.nested_tensor_from_jagged(torch.inner(vertices_t.values(), slab_dirs_t), 
                                                          offsets=vertices_t.offsets(), 
                                                          jagged_dim=1)

    min_proj_t, min_proj_arg_t = torch.min(slab_projs_t, dim=1)
    max_proj_t, max_proj_arg_t = torch.max(slab_projs_t, dim=1)

    # correct the indices for the offsets into the jagged array as torch nested does not support torch.gather directly
    min_proj_arg_t += slab_projs_t.offsets()[:-1].unsqueeze(1)
    max_proj_arg_t += slab_projs_t.offsets()[:-1].unsqueeze(1)

    min_vert_t = vertices_t.values()[min_proj_arg_t]
    max_vert_t = vertices_t.values()[max_proj_arg_t]

    many_vertices_mask_t = npoints_t > NUM_SLAB_DIRS * 2
    few_vertices_mask_t = ~many_vertices_mask_t
    aabb_min_t = min_proj_t[:, :3]
    aabb_max_t = max_proj_t[:, :3]

    # device_wp = wp.device_from_torch(device)
    # aabb_min_wp = wp.from_torch(aabb_min_t, dtype=wp.vec3f).to(device_wp)
    # aabb_max_wp = wp.from_torch(aabb_max_t, dtype=wp.vec3f).to(device_wp)
    # min_vert_wp = wp.from_torch(min_vert_t, dtype=wp.vec3f).to(device_wp)
    # max_vert_wp = wp.from_torch(max_vert_t, dtype=wp.vec3f).to(device_wp)
    # selected_vertices_wp = wp.from_torch(selected_vertices_t, dtype=wp.vec3f).to(device_wp)
    # basis_wp = wp.zeros((vertices_t.shape[0], 3), dtype=wp.vec3f, device=device_wp)

    # wp.launch_tiled(best_obb_axes_from_base_triangle_kernel, 
    #                 dim=[min_vert_t.shape[0]], 
    #                 inputs=[aabb_min_wp, aabb_max_wp, min_vert_wp, max_vert_wp, 
    #                         selected_vertices_wp, basis_wp], 
    #                 block_dim=BLOCK_SIZE,
    #                 device=device_wp)

    # basis_t = wp.to_torch(b    vertices_t, device = prepare_vertices(vertices, batch_offsets, device)asis_wp)

    # centroids_t = torch.mean(vertices_t, dim=1)
    # rotated_t_jagged = torch.bmm(vertices_t - centroids_t.unsqueeze(1), basis_t)
    # min_extents_t = rotated_t_jagged.min(dim=1).values
    # max_extents_t = rotated_t_jagged.max(dim=1).values

    # # extent_local_t = max_extent_t - min_extent_t

    # # min_extents_t, max_extents_t = compute_obb_extents_jagged(vertices_t, basis_t)
    # extents_t = max_extents_t - min_extents_t
    # center_local_t = (min_extents_t + max_extents_t) * 0.5
    # centroids_t = centroids_t + torch.bmm(center_local_t.unsqueeze(1), basis_t.mT).squeeze(1)

    # vertices_t = compute_obb_vertices(centroids_t, extents_t, basis_t)
    # return vertices_t, basis_t
    # use just the selected extreme points for large point clouds and fall back to 
    # all input vertices for small point clouds
    selected_vertices_t = torch.empty(vertices_t.shape[0], NUM_SLAB_DIRS * 2, 3, 
                                      device=device, dtype=vertices_t.dtype)
    selected_vertices_t[many_vertices_mask_t] = torch.cat((min_vert_t[many_vertices_mask_t], max_vert_t[many_vertices_mask_t]), dim=1)
    batch_mask = few_vertices_mask_t.repeat_interleave(vertices_t.offsets().diff())
    selected_vertices_t[few_vertices_mask_t] = vertices_t.values()[batch_mask].reshape(-1, NUM_SLAB_DIRS * 2, 3)

    aabb_min_t = min_proj_t[:, :3]
    aabb_max_t = max_proj_t[:, :3]

    device_wp = wp.device_from_torch(device)
    aabb_min_wp = wp.from_torch(aabb_min_t, dtype=wp.vec3f).to(device_wp)
    aabb_max_wp = wp.from_torch(aabb_max_t, dtype=wp.vec3f).to(device_wp)
    min_vert_wp = wp.from_torch(min_vert_t, dtype=wp.vec3f).to(device_wp)
    max_vert_wp = wp.from_torch(max_vert_t, dtype=wp.vec3f).to(device_wp)
    selected_vertices_wp = wp.from_torch(selected_vertices_t, dtype=wp.vec3f).to(device_wp)
    basis_wp = wp.zeros((vertices_t.shape[0], 3), dtype=wp.vec3f, device=device_wp)

    wp.launch_tiled(best_obb_axes_from_base_triangle_kernel, 
                    dim=[min_vert_t.shape[0]], 
                    inputs=[aabb_min_wp, aabb_max_wp, min_vert_wp, max_vert_wp, 
                            selected_vertices_wp, basis_wp], 
                    block_dim=BLOCK_SIZE,
                    device=device_wp)

    basis_t = wp.to_torch(basis_wp)
    wp.synchronize()
    print("basis_t", basis_t)

    min_extents_t, max_extents_t = compute_obb_extents_jagged(vertices_t, basis_t)
    extents_t = max_extents_t - min_extents_t

    print("min_extents_t", min_extents_t)
    print("max_extents_t", max_extents_t)
    print("extents_t", extents_t)

    # extent_local_t = max_extent_t - min_extent_t

    # min_extents_t, max_extents_t = compute_obb_extents_jagged(vertices_t, basis_t)
    center_local_t = (min_extents_t + max_extents_t) * 0.5
    print("center_local_t", center_local_t)
    center_t = torch.bmm(center_local_t.unsqueeze(1), basis_t).squeeze(1)
    #centroids_t = centroids_t + 
    print("center_t", center_t)
    vertices_t = compute_obb_vertices(center_t, extents_t, basis_t.mT)
    return vertices_t, basis_t

vertices_t, device = prepare_vertices(points)
obb_vertices_t, basis_t = obb_estimate_dito(vertices_t, device)

p0 -2.989874, 0.037348, 0.121857
p1 2.993481, -0.096555, -0.026632
furthest_index 0
diff_norm_sq 35.840511
i 0, testing cases
i 0: best_val initial 47.688675
i 0: e0 -0.999442, 0.022367, 0.024803
i 0: n_vec -0.009402, 0.524201, -0.851543
i 0: m0 -0.032048, -0.851301, -0.523698
dlen span_e0 5.986695, span_n 1.699457, span_m0 3.695462
i 0: q0 38.578018
i 0: best_val 38.578018


In [79]:
from torch_obb._estimation_old import oriented_bounding_box_dito_14
points_numpy_dbg = points[0].cpu().numpy()
obb_numpy_dbg = oriented_bounding_box_dito_14(points_numpy_dbg)

selected_vertices [[-2.9898741   0.03734799  0.12185741]
 [ 0.03033754 -1.580308   -0.9092275 ]
 [ 0.09764951 -0.91442716 -1.5797487 ]
 [-2.2235081  -0.99000704 -0.90026516]
 [-2.5925367  -0.39307746  0.3175297 ]
 [-2.7081242   0.30593196 -0.3025362 ]
 [-2.0312033   1.061924    1.0176934 ]
 [ 2.9934807  -0.09655452 -0.02663192]
 [-0.21309438  1.5728798   1.0364574 ]
 [-0.26456386  0.96158034  1.5748404 ]
 [ 2.1768355   0.9276872   1.0113583 ]
 [ 2.827919    0.26411942 -0.20710038]
 [ 2.734769   -0.3736506   0.20131332]
 [ 2.1009917  -1.0343412  -0.9819188 ]]
furthest_index 0
p0 [-2.9898741   0.03734799  0.12185741]
p1 [ 2.9934807  -0.09655452 -0.02663192]
diff_norm 35.84051
bestval initial 47.688675
e0 [-0.99944216  0.02236668  0.02480322]
n [-0.00940161  0.5242009  -0.8515428 ]
m0 [-0.03204806 -0.8513009  -0.5236982 ]
dlen0 [5.9866953 1.6994574 3.6954618]
quality0 38.57802
bestval 38.57802
basis [[-0.99944216  0.02236668  0.02480322]
 [-0.00940161  0.5242009  -0.8515428 ]
 [-0.0320480

In [80]:
basis_t

tensor([[[-0.9994,  0.0224,  0.0248],
         [-0.0094,  0.5242, -0.8515],
         [-0.0320, -0.8513, -0.5237]]])

In [81]:
obb_numpy_dbg

array([[ 3.0657372 ,  1.0358657 ,  1.8637326 ],
       [ 2.9404721 , -2.2915783 , -0.18322468],
       [ 3.0455232 ,  2.1629295 ,  0.03286362],
       [ 2.920258  , -1.1645144 , -2.0140936 ],
       [-2.9176183 ,  1.1697681 ,  2.0122218 ],
       [-3.0428834 , -2.157676  , -0.03473532],
       [-2.9378324 ,  2.296832  ,  0.18135297],
       [-3.0630975 , -1.030612  , -1.8656043 ]], dtype=float32)

In [82]:
obb_vertices_t

tensor([[[ 3.0657,  1.0359,  1.8637],
         [ 2.9405, -2.2916, -0.1832],
         [ 3.0455,  2.1629,  0.0329],
         [ 2.9203, -1.1645, -2.0141],
         [-2.9176,  1.1698,  2.0122],
         [-3.0429, -2.1577, -0.0347],
         [-2.9378,  2.2968,  0.1814],
         [-3.0631, -1.0306, -1.8656]]])

In [83]:
import pygfx as gfx
import trimesh
from rendercanvas.glfw import GlfwRenderCanvas

_TRIS = [[1, 3, 0],
         [4, 1, 0],
         [0, 3, 2],
         [2, 4, 0],
         [1, 7, 3],
         [5, 1, 4],
         [5, 7, 1],
         [3, 7, 2],
         [6, 4, 2],
         [2, 7, 6],
         [6, 5, 4],
         [7, 5, 6]]
i = 0
canvas = GlfwRenderCanvas(size=(1024, 768), max_fps=60, update_mode="continuous")   
scene = gfx.Scene()
background = gfx.Background(None, gfx.BackgroundMaterial('white'))
scene.add(background)

#selected_vertices_np = selected_vertices_t[i].cpu().numpy()
points_np = points[i].cpu().numpy()


points_geometry = gfx.Geometry(positions=points_np.astype(np.float32))
pc = gfx.Points(points_geometry, material=gfx.PointsMaterial(size=10, color=(0, 0, 0, 1.0)))
scene.add(pc)


# selected_vertices_geometry = gfx.Geometry(positions=selected_vertices_np.astype(np.float32))
# pc_selected = gfx.Points(selected_vertices_geometry, material=gfx.PointsMaterial(size=20, color=(1.0, 0, 0, 1.0)))
# scene.add(pc_selected)
obb_vertices_np = obb_vertices_t[i].cpu().numpy()

obb_geometry_dbg = gfx.Geometry(positions=obb_numpy_dbg.squeeze().astype(np.float32), indices=np.array(_TRIS, dtype=np.int32))
obb_dbg = gfx.Mesh(obb_geometry_dbg, gfx.MeshBasicMaterial(wireframe=True, color='cyan', wireframe_thickness=2))
scene.add(obb_dbg)

obb_geometry = gfx.Geometry(positions=obb_vertices_np.squeeze().astype(np.float32), indices=np.array(_TRIS, dtype=np.int32))
obb = gfx.Mesh(obb_geometry, gfx.MeshBasicMaterial(wireframe=True, color='red', wireframe_thickness=2))
scene.add(obb)

disp = gfx.Display(canvas=canvas)
disp.show(scene)

canvas.is_closed() is deprecated, use canvas.get_closed() instead.
Your scene does not contain any lights. Some objects may not be visible.
Unrecognized present mode 1000361000
Unrecognized present mode 1000361000


Unrecognized present mode 1000361000
