In [1]:
import torch
import numpy as np

In [2]:
grid_dims = (100, 100, 1000)
x_lim, y_lim, z_lim = (-5, 5), (-5, 5), (0, 40)

In [None]:
def get_sample_arb8s():
    arb8s = []

    arb8s.append(np.array([[-0.56630238, -1.85304077, 18.1],
              [-0., -2.41934315, 18.1],
              [-1.66126057, -2.41934315, 18.1],
              [-1.09495819, -1.85304077, 18.1],
              [-0.56630238, -1.85304077, 18.3],
              [-0., -2.41934315, 18.3],
              [-1.66126057, -2.41934315, 18.3],
              [-1.09495819, -1.85304077, 18.3]]))
    arb8s.append(np.array([[-0.56630238, -1.85304077, 18.1],
              [-0., -2.41934315, 18.1],
              [-1.66126057, -2.41934315, 18.1],
              [-1.09495819, -1.85304077, 18.1],
              [-0.56630238, -1.85304077, 18.3],
              [-0., -2.41934315, 18.3],
              [-1.66126057, -2.41934315, 18.3],
              [-1.09495819, -1.85304077, 18.3]]))
    arb8s.append(np.array([[-0.56630238, -1.85304077, 18.1],
              [-0., -2.41934315, 18.1],
              [-1.66126057, -2.41934315, 18.1],
              [-1.09495819, -1.85304077, 18.1],
              [-0.56630238, -1.85304077, 18.3],
              [-0., -2.41934315, 18.3],
              [-1.66126057, -2.41934315, 18.3],
              [-1.09495819, -1.85304077, 18.3]]))
    arb8s.append(np.array([[0., -1.799, 0.0],
              [0., 1.799, 0.0],
              [0.5, 1.799, 0.0],
              [0.5, -1.799, 0.0],
              [0., -1.799, 10.0],
              [0., 1.799, 10.0],
              [0.5, 1.799, 10.0],
              [0.5, -1.799, 10.0]]))
    arb8s.append(np.array([[0., -1.799, 0.0],
              [0., 1.799, 0.0],
              [0.5, 1.799, 0.0],
              [0.5, -1.799, 0.0],
              [0., -1.799, 10.0],
              [0., 1.799, 10.0],
              [0.5, 1.799, 10.0],
              [0.5, -1.799, 10.0]]))
    arb8s.append(np.array([[0.5, 1.3, -10.0640477],
              [0., 1.8, -10.0640477],
              [1.02, 1.8, -10.0640477],
              [0.52, 1.3, -10.0640477],
              [0.5, 1.3, -5.5440477],
              [0., 1.8, -5.5440477],
              [1.02, 1.8, -5.5440477],
              [0.52, 1.3, -5.5440477]]))

    arb8s.append(np.array([[-0.52, 1.3, -10.0640477],
              [-1.02, 1.8, -10.0640477],
              [-0., 1.8, -10.0640477],
              [-0.5, 1.3, -10.0640477],
              [-0.52, 1.3, -5.5440477],
              [-1.02, 1.8, -5.5440477],
              [-0., 1.8, -5.5440477],
              [-0.5, 1.3, -5.5440477]]))

    arb8s.append(np.array([[0., -1.95744757, -5.3440477],
              [0., 1.95744757, -5.3440477],
              [0.70216476, 1.25528281, -5.3440477],
              [0.70216476, -1.25528281, -5.3440477],
              [0., -2.09849295, -4.89307861],
              [0., 2.09849295, -4.89307861],
              [0.96872421, 1.12976874, -4.89307861],
              [0.96872421, -1.12976874, -4.89307861]]))

    arb8s.append(np.array([[-0., 1.95744757, -5.3440477],
              [-0., -1.95744757, -5.3440477],
              [-0.70216476, -1.25528281, -5.3440477],
              [-0.70216476, 1.25528281, -5.3440477],
              [-0., 2.09849295, -4.89307861],
              [-0., -2.09849295, -4.89307861],
              [-0.96872421, -1.12976874, -4.89307861],
              [-0.96872421, 1.12976874, -4.89307861]]))

    

    x = np.array(arb8s)
    x = torch.from_numpy(x).float()

    return x

geo = get_sample_arb8s()
geo.shape

torch.Size([8, 8, 3])

In [4]:
def get_blocks_hash(geo: torch.tensor, grid_dims=(100, 100, 1000), grid_lims=((-5, 5), (-5, 5), (0, 40))):
    """
    Calculates a memory-efficient voxel hash grid for a given geometry.

    This function avoids creating a meshgrid for the entire space. Instead, it
    calculates the voxel index range for each block's bounding box and generates
    indices only within that sub-volume.

    Returns:
        voxel_offsets (torch.Tensor): The CSR offset array.
        block_ids (torch.Tensor): The CSR content array (block IDs).
    """
    x_lim, y_lim, z_lim = grid_lims
    nx, ny, nz = grid_dims
    numel_grid = nx * ny * nz
    device = geo.device

    # Calculate the size of a single voxel cell
    grid_size_x = (x_lim[1] - x_lim[0]) / nx
    grid_size_y = (y_lim[1] - y_lim[0]) / ny
    grid_size_z = (z_lim[1] - z_lim[0]) / nz

    # --- Calculate block bounding boxes ---
    # We add a small padding to handle floating point precision at the edges.
    min_x = geo[:, :, 0].min(-1).values - grid_size_x
    max_x = geo[:, :, 0].max(-1).values + grid_size_x
    min_y = geo[:, :, 1].min(-1).values - grid_size_y
    max_y = geo[:, :, 1].max(-1).values + grid_size_y
    min_z = geo[:, :, 2].min(-1).values - grid_size_z
    max_z = geo[:, :, 2].max(-1).values + grid_size_z

    # --- Efficiently build the COO format without materializing the full grid ---
    voxel_indices = []
    block_ids = []

    for i in range(len(geo)):
        # Convert world-space AABB to integer index-space AABB
        ix_min = torch.clamp(torch.floor((min_x[i] - x_lim[0]) / grid_size_x), 0, nx - 1).long()
        ix_max = torch.clamp(torch.ceil((max_x[i] - x_lim[0]) / grid_size_x), 0, nx - 1).long()
        iy_min = torch.clamp(torch.floor((min_y[i] - y_lim[0]) / grid_size_y), 0, ny - 1).long()
        iy_max = torch.clamp(torch.ceil((max_y[i] - y_lim[0]) / grid_size_y), 0, ny - 1).long()
        iz_min = torch.clamp(torch.floor((min_z[i] - z_lim[0]) / grid_size_z), 0, nz - 1).long()
        iz_max = torch.clamp(torch.ceil((max_z[i] - z_lim[0]) / grid_size_z), 0, nz - 1).long()

        # If the index range is valid, create voxel indices for this block
        if ix_min <= ix_max and iy_min <= iy_max and iz_min <= iz_max:
            # Create a small, local meshgrid only for the sub-volume of this block
            ix_range = torch.arange(ix_min, ix_max + 1, device=device)
            iy_range = torch.arange(iy_min, iy_max + 1, device=device)
            iz_range = torch.arange(iz_min, iz_max + 1, device=device)
            
            grid_x, grid_y, grid_z = torch.meshgrid(ix_range, iy_range, iz_range, indexing='ij')

            # Convert 3D indices to 1D linear indices and flatten
            linear_indices = (grid_z.flatten() * (nx * ny) + 
                              grid_y.flatten() * nx + 
                              grid_x.flatten())
            
            voxel_indices.append(linear_indices)
            block_ids.append(torch.full_like(linear_indices, fill_value=i))

    # --- Convert from COO to CSR format (same as before) ---
    if len(voxel_indices) > 0:
        voxel_indices = torch.cat(voxel_indices)
        block_ids = torch.cat(block_ids)
    else:
        voxel_indices = torch.tensor([], dtype=torch.long, device=device)
        block_ids = torch.tensor([], dtype=torch.long, device=device)

    # Sort by voxel index to group blocks
    voxel_indices, sorted_indices = torch.sort(voxel_indices)
    block_ids = block_ids[sorted_indices]
    
    # Create the final offset array
    voxel_offsets = torch.zeros(numel_grid + 1, dtype=torch.int32, device=device)
    unique_voxel_indices, counts = torch.unique_consecutive(voxel_indices, return_counts=True)
    
    # Place the counts into the correct positions and compute the cumulative sum (prefix sum)
    voxel_offsets.scatter_add_(0, unique_voxel_indices + 1, counts.to(torch.int32))
    voxel_offsets = torch.cumsum(voxel_offsets, dim=0)
    
    return voxel_offsets, block_ids

In [5]:
%time
voxel_offsets, voxel_contents = get_blocks_hash(geo,grid_dims, (x_lim, y_lim, z_lim))

CPU times: user 3 μs, sys: 0 ns, total: 3 μs
Wall time: 6.44 μs


In [6]:
print("\n--- Final Tensors for CUDA ---")
print(f"Voxel Contents Shape: {voxel_contents.shape}")
print(f"Voxel Offsets Shape:  {voxel_offsets.shape}")
print(f"Total overlaps stored: {len(voxel_contents)}")


--- Final Tensors for CUDA ---
Voxel Contents Shape: torch.Size([85446])
Voxel Offsets Shape:  torch.Size([10000001])
Total overlaps stored: 85446


In [7]:
%time
# --- EXAMPLE USAGE ---
# Let's see which blocks are in voxel 5050500
voxel_idx_to_check = 5050500
start_index = voxel_offsets[voxel_idx_to_check]
end_index = voxel_offsets[voxel_idx_to_check + 1]
candidate_blocks = voxel_contents[start_index:end_index]

print(f"\nExample lookup for voxel {voxel_idx_to_check}:")
print(f"  -> Start index from offset array: {start_index}")
print(f"  -> End index from offset array:   {end_index}")
print(f"  -> Candidate Block IDs from content array: {candidate_blocks}")

CPU times: user 5 μs, sys: 0 ns, total: 5 μs
Wall time: 11.7 μs

Example lookup for voxel 5050500:
  -> Start index from offset array: 85446
  -> End index from offset array:   85446
  -> Candidate Block IDs from content array: tensor([], dtype=torch.int64)


In [8]:
import torch

def get_candidates_from_hash(
    point_xyz: torch.Tensor,
    voxel_offsets: torch.Tensor,
    voxel_contents: torch.Tensor,
    grid_dims=(100, 100, 1000),
    grid_lims=((-5, 5), (-5, 5), (0, 40))
):
    """
    Finds candidate block IDs for a given point using the pre-computed voxel hash.

    Args:
        point_xyz (torch.Tensor): A tensor of shape [3] with the (x, y, z) coordinates.
        voxel_offsets (torch.Tensor): The CSR offset array from get_blocks_hash.
        voxel_contents (torch.Tensor): The CSR content array from get_blocks_hash.
        grid_dims (tuple): The dimensions of the grid (nx, ny, nz).
        grid_lims (tuple): The world-space limits of the grid ((x_min, x_max), ...).

    Returns:
        torch.Tensor: A 1D tensor containing the candidate block IDs.
    """
    x_lim, y_lim, z_lim = grid_lims
    nx, ny, nz = grid_dims
    device = voxel_offsets.device

    # --- 1. Calculate the 1D Voxel Index for the point ---

    # Check if the point is outside the grid bounds first
    if (point_xyz[0] < x_lim[0] or point_xyz[0] > x_lim[1] or
        point_xyz[1] < y_lim[0] or point_xyz[1] > y_lim[1] or
        point_xyz[2] < z_lim[0] or point_xyz[2] > z_lim[1]):
        return torch.tensor([], dtype=torch.long, device=device) # Outside the grid

    # Calculate the size of a single voxel cell
    grid_size_x = (x_lim[1] - x_lim[0]) / nx
    grid_size_y = (y_lim[1] - y_lim[0]) / ny
    grid_size_z = (z_lim[1] - z_lim[0]) / nz

    # Convert world coordinates to 3D integer indices
    ix = torch.clamp(torch.floor((point_xyz[0] - x_lim[0]) / grid_size_x), 0, nx - 1).long()
    iy = torch.clamp(torch.floor((point_xyz[1] - y_lim[0]) / grid_size_y), 0, ny - 1).long()
    iz = torch.clamp(torch.floor((point_xyz[2] - z_lim[0]) / grid_size_z), 0, nz - 1).long()

    # Convert 3D index to 1D linear index
    linear_voxel_index = iz * (nx * ny) + iy * nx + ix

    # --- 2. Find the Range using the offsets tensor ---
    start_index = voxel_offsets[linear_voxel_index]
    end_index = voxel_offsets[linear_voxel_index + 1]

    # --- 3. Get the Block IDs from the contents tensor ---
    candidate_ids = voxel_contents[start_index:end_index]

    return candidate_ids

In [9]:
point = torch.tensor([0.0, 0.0, 2.0], device=voxel_offsets.device)
%time
get_candidates_from_hash(point, voxel_offsets, voxel_contents, grid_dims, (x_lim, y_lim, z_lim))

CPU times: user 5 μs, sys: 0 ns, total: 5 μs
Wall time: 10.7 μs


tensor([3])