In [None]:
import torch
import torch.nn.functional as F

import raymarching2


rays_o = (torch.zeros((10, 100, 3)) + 0.1).to("cuda")
rays_d = torch.randn((10, 100, 3)).to("cuda")
rays_d = F.normalize(rays_d, dim=-1)

density_bitfield = (torch.ones(
    (5, 128 ** 3 // 8), dtype=torch.uint8
) * 255).to("cuda")

aabb = torch.tensor([0., 0., 0., 1., 1., 1.]).to("cuda")

positions, dirs, deltas, nears, fars = raymarching2.generate_training_samples(
    rays_o, rays_d, aabb, density_bitfield
)
torch.cuda.synchronize()

In [None]:
from tava.utils.plotly import Trimesh, plot_scene, PointCloud

def aabb_to_mesh(aabb):
    vertices = torch.stack([
        aabb[[0, 1, 2]], 
        aabb[[3, 1, 2]], 
        aabb[[0, 4, 2]], 
        aabb[[0, 1, 5]], 
        aabb[[3, 4, 2]],
        aabb[[0, 4, 5]],
        aabb[[3, 1, 5]],
        aabb[[3, 4, 5]],
    ])
    faces = torch.tensor([
        [0, 1, 4], [0, 4, 2],
        [0, 3, 6], [0, 6, 1],
        [1, 6, 4], [4, 6, 7],
        [2, 4, 7], [2, 7, 5],
        [2, 5, 0], [0, 5, 3],
        [2, 4, 7], [2, 7, 5],
    ], dtype=torch.int32, device=aabb.device)
    return vertices, faces

vertices, faces = aabb_to_mesh(aabb)
plot_scene(
    {
        "bbox": {
            "struct": Trimesh(
                vertices.cpu().numpy(),
                faces.cpu().numpy(),
            ),
            "mesh_opacity": 0.7
        },
        "samples": {
            "struct": PointCloud(
                positions[positions.sum(dim=-1) > 0].cpu().numpy()
            )
        }
    }
)

In [None]:
(positions.sum(dim=-1) > 0).sum()