In [13]:
# bvh_jax.py

import jax
import jax.numpy as jnp
from jax import lax
from dataclasses import dataclass
from typing import Any, Dict, List

# Disable JIT for BVH building (for debugging or initialization)
jax.config.update("jax_disable_jit", True)

from primitives.aabb import AABB, union, enclose_centroids, get_largest_dim, get_surface_area, offset


# =============================================================================
# We define a small "BoundedBox" class to store an AABB and a primitive index,
# just like your Numba code. We'll convert your dictionary-of-arrays into a list
# of these BoundedBox objects (one entry per primitive).
# =============================================================================
@dataclass
class BoundedBox:
    bounds: AABB
    prim_num: int


def convert_dict_to_list(bboxes_dict: Dict[str, jnp.ndarray]) -> List[BoundedBox]:
    """
    Convert the dictionary-of-arrays (keys: "min_point", "max_point", "centroid", "prim_num")
    into a list of BoundedBox objects, each storing an AABB + prim_num.
    """
    # bboxes_dict["bounds"]["min_point"] is shape (N,3)
    # bboxes_dict["bounds"]["max_point"] is shape (N,3)
    # bboxes_dict["bounds"]["centroid"] is shape (N,3)
    # bboxes_dict["prim_num"] is shape (N,)

    min_point = bboxes_dict["bounds"]["min_point"]  # shape (N,3)
    max_point = bboxes_dict["bounds"]["max_point"]  # shape (N,3)
    centroid  = bboxes_dict["bounds"]["centroid"]   # shape (N,3)
    prim_nums = bboxes_dict["prim_num"]             # shape (N,)

    N = min_point.shape[0]
    boxes_list = []
    for i in range(N):
        box = AABB(min_point[i], max_point[i], centroid[i])
        pnum = int(prim_nums[i])
        boxes_list.append(BoundedBox(bounds=box, prim_num=pnum))
    return boxes_list


# =============================================================================
# BVHNode class (mimicking your working Numba code)
# =============================================================================
@dataclass
class BVHNode:
    bounds: AABB = None
    split_axis: int = -1
    first_prim_offset: int = -1
    n_primitives: int = 0
    child_0: Any = None  # left child
    child_1: Any = None  # right child

    def init_leaf(self, first_offset: int, n_prims: int, bounds: AABB):
        self.first_prim_offset = first_offset
        self.n_primitives = n_prims
        self.bounds = bounds
        self.child_0 = None
        self.child_1 = None


# =============================================================================
# The iterative build function, mimicking your Numba approach
# =============================================================================
def build_bvh_tree(primitives: Any,
                   bounded_list: List[BoundedBox],
                   start: int, end: int,
                   ordered_prims: List[Any],
                   total_nodes: List[int],
                   split_method: int = 0) -> BVHNode:
    """
    iterative BVH build that uses a list of BoundedBox objects and a stack:
      (start, end, parent_idx, is_second_child)
    """
    max_prims_in_node = max(4, int(0.1 * len(bounded_list)))
    stack = [(start, end, -1, False)]
    nodes: List[BVHNode] = []

    while stack:
        start_, end_, parent_idx, is_second_child = stack.pop()
        node = BVHNode()
        current_node_idx = len(nodes)
        total_nodes[0] += 1

        # Link the newly created node to the parent
        if parent_idx != -1:
            parent = nodes[parent_idx]
            if is_second_child:
                parent.child_1 = node
            else:
                parent.child_0 = node

        # Union of bounds over [start_, end_)
        bounds = None
        for i in range(start_, end_):
            b = bounded_list[i].bounds
            bounds = union(bounds, b) if bounds is not None else b

        n_prims = end_ - start_

        if n_prims == 1:
            first_offset = len(ordered_prims)
            for i in range(start_, end_):
                prim_idx = bounded_list[i].prim_num
                ordered_prims.append(primitives[prim_idx])
            node.init_leaf(first_offset, n_prims, bounds)

        else:
            # Compute centroid bounds
            centroid_bounds = None
            for i in range(start_, end_):
                c = bounded_list[i].bounds.centroid
                centroid_bounds = enclose_centroids(centroid_bounds, c)
            dim = get_largest_dim(centroid_bounds)

            # For simplicity, do a midpoint partition along 'dim'.
            pmid = 0.5 * (centroid_bounds.min_point[dim] + centroid_bounds.max_point[dim])
            left_indices = []
            right_indices = []
            for i in range(start_, end_):
                cdim = bounded_list[i].bounds.centroid[dim]
                if cdim < pmid:
                    left_indices.append(i)
                else:
                    right_indices.append(i)
            if len(left_indices) == 0 or len(right_indices) == 0:
                mid = (start_ + end_) // 2
            else:
                # The partition boundary is between left_indices[-1] and right_indices[0].
                # We can place mid = start_ + len(left_indices)
                mid = start_ + len(left_indices)

            # If partition is degenerate, make a leaf.
            if mid == start_ or mid == end_:
                first_offset = len(ordered_prims)
                for i in range(start_, end_):
                    prim_idx = bounded_list[i].prim_num
                    ordered_prims.append(primitives[prim_idx])
                node.init_leaf(first_offset, n_prims, bounds)
            else:
                node.split_axis = int(dim)
                node.bounds = bounds
                node.n_primitives = 0

                # push right child, then left child
                stack.append((mid, end_, current_node_idx, True))
                stack.append((start_, mid, current_node_idx, False))

        nodes.append(node)

    return nodes[0]


# =============================================================================
# Flatten the resulting BVH
# =============================================================================
def flatten_bvh_tree(root: BVHNode) -> Dict[str, jnp.ndarray]:
    # We'll do a pre‐order traversal.
    bounds_min_list = []
    bounds_max_list = []
    first_offset_list = []
    n_prims_list = []
    split_axis_list = []
    second_child_list = []

    def traverse(node: BVHNode) -> int:
        idx = len(bounds_min_list)
        bounds_min_list.append(node.bounds.min_point)
        bounds_max_list.append(node.bounds.max_point)
        first_offset_list.append(node.first_prim_offset)
        n_prims_list.append(node.n_primitives)
        split_axis_list.append(node.split_axis)
        second_child_list.append(-1)
        if node.n_primitives == 0:
            left_idx = traverse(node.child_0)
            right_idx = traverse(node.child_1)
            second_child_list[idx] = right_idx
        return idx

    traverse(root)
    return {
        "bounds_min": jnp.array(bounds_min_list),
        "bounds_max": jnp.array(bounds_max_list),
        "first_offset": jnp.array(first_offset_list, dtype=jnp.int32),
        "n_primitives": jnp.array(n_prims_list, dtype=jnp.int32),
        "split_axis": jnp.array(split_axis_list, dtype=jnp.int32),
        "second_child_offset": jnp.array(second_child_list, dtype=jnp.int32),
        "ordered_prims": jnp.array([], dtype=jnp.int32),  # empty for demonstration
    }


# =============================================================================
# Example: init_bounded_boxes from your primitives dictionary
# (which is produced by create_triangle_arrays)
# =============================================================================
def init_bounded_boxes_from_dict(primitives: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
    """
    If you specifically want a dictionary of arrays with keys:
      "bounds": { "min_point", "max_point", "centroid" }
      "prim_num"
    then you can do this:
    """
    v1 = primitives["vertex_1"]
    v2 = primitives["vertex_2"]
    v3 = primitives["vertex_3"]
    centroids = primitives["centroid"]
    bounds_min = jnp.minimum(jnp.minimum(v1, v2), v3)
    bounds_max = jnp.maximum(jnp.maximum(v1, v2), v3)
    n = int(centroids.shape[0])
    return {
        "bounds": {
            "min_point": bounds_min,   # shape (n,3)
            "max_point": bounds_max,   # shape (n,3)
            "centroid": centroids      # shape (n,3)
        },
        "prim_num": jnp.arange(n, dtype=jnp.int32)
    }


# =============================================================================
# Additional: print a flattened BVH
# =============================================================================
def print_bvh_tree(bvh: Dict[str, jnp.ndarray], node_idx: int = 0, level: int = 0):
    indent = "  " * level
    bmin = bvh["bounds_min"][node_idx]
    bmax = bvh["bounds_max"][node_idx]
    centroid = 0.5 * (bmin + bmax)
    nprims = int(bvh["n_primitives"][node_idx])
    if nprims > 0:
        foffset = int(bvh["first_offset"][node_idx])
        print(f"{indent}Leaf Node {node_idx}:")
        print(f"{indent}  n_primitives = {nprims}, first_offset = {foffset}")
        print(f"{indent}  bounds_min = {bmin}, bounds_max = {bmax}")
        print(f"{indent}  centroid = {centroid}")
    else:
        axis = int(bvh["split_axis"][node_idx])
        second_child = int(bvh["second_child_offset"][node_idx])
        print(f"{indent}Interior Node {node_idx}:")
        print(f"{indent}  split_axis = {axis}")
        print(f"{indent}  bounds_min = {bmin}, bounds_max = {bmax}")
        print(f"{indent}  centroid = {centroid}")
        print(f"{indent}  Left Child Index: {node_idx + 1}")
        print(f"{indent}  Right Child Index: {second_child}")
        print_bvh_tree(bvh, node_idx + 1, level + 1)
        print_bvh_tree(bvh, second_child, level + 1)


In [14]:
import jax
import jax.numpy as jnp
from utils.io import load_obj, create_triangle_arrays

In [15]:
square_path = "objects/square.obj"
sphere_path = "objects/sphere.obj"
cube_path = "objects/cube.obj"
cylinder_path = "objects/cylinder.obj"
rabbit = "objects/rabbit.obj"
carrot = "objects/carrot.obj"
plane = "objects/plane.obj"
squirrel = "objects/squirrel.obj"
tree = "objects/broad_deciduous_tree_green_leaves.obj"
ant = "objects/ant.obj"
fireball = "objects/fireball.obj"
dragon = "objects/dragon.obj"
crawler = "objects/crawler.obj"


file_path = cube_path

In [16]:
vertices, faces = load_obj(file_path)
primitives = create_triangle_arrays(vertices, faces)
primitives

{'vertex_1': Array([[ 1.      , -1.      ,  1.      ],
        [-1.      ,  1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999],
        [ 0.999999,  1.      ,  1.000001],
        [-1.      , -1.      ,  1.      ],
        [ 1.      , -1.      , -1.      ],
        [ 1.      , -1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999],
        [ 1.      , -1.      , -1.      ],
        [ 1.      , -1.      ,  1.      ],
        [-1.      , -1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999]], dtype=float32),
 'vertex_2': Array([[-1.      , -1.      ,  1.      ],
        [-1.      ,  1.      ,  1.      ],
        [ 0.999999,  1.      ,  1.000001],
        [-1.      ,  1.      ,  1.      ],
        [-1.      ,  1.      ,  1.      ],
        [-1.      , -1.      , -1.      ],
        [ 1.      , -1.      ,  1.      ],
        [-1.      ,  1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999],
        [ 0.999999,  1.      ,  1.000001],
        [-1. 

In [17]:
bboxes_dict = init_bounded_boxes_from_dict(primitives)

In [18]:
bounded_list = convert_dict_to_list(bboxes_dict)

In [19]:
total_nodes = [0]  # mutable integer
ordered_prims = []
root = build_bvh_tree(primitives, bounded_list, 0, len(bounded_list),
                      ordered_prims, total_nodes, split_method=0)


KeyError: 0