In [1]:
# bvh_jax.py
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from typing import Any, List

# =============================================================================
# Constants
# =============================================================================
INF = 1e10
EPSILON = 1e-6

# =============================================================================
# AABB and Utility Functions (from your aabb.py)
# =============================================================================
@dataclass(frozen=True)
class AABB:
    min_point: jnp.ndarray  #: shape (3,)
    max_point: jnp.ndarray  #: shape (3,)
    centroid: jnp.ndarray   #: shape (3,)

def update_centroid(aabb: AABB) -> AABB:
    new_centroid = (aabb.min_point + aabb.max_point) * 0.5
    return AABB(aabb.min_point, aabb.max_point, new_centroid)

def union(aabb1: AABB, aabb2: AABB) -> AABB:
    new_min = jnp.minimum(aabb1.min_point, aabb2.min_point)
    new_max = jnp.maximum(aabb1.max_point, aabb2.max_point)
    return AABB(new_min, new_max, (new_min + new_max) * 0.5)

def union_p(aabb: AABB, point: jnp.ndarray) -> AABB:
    new_min = jnp.minimum(aabb.min_point, point)
    new_max = jnp.maximum(aabb.max_point, point)
    return AABB(new_min, new_max, (new_min + new_max) * 0.5)

def get_surface_area(aabb: AABB) -> float:
    diag = aabb.max_point - aabb.min_point
    return 2.0 * (diag[0] * diag[1] + diag[0] * diag[2] + diag[1] * diag[2])

def get_largest_dim(aabb: AABB) -> int:
    extents = aabb.max_point - aabb.min_point
    return int(jnp.argmax(extents))

# A simple (Python) ray–AABB intersection using the “slab” method.
def aabb_intersect(aabb: AABB, ray_origin: jnp.ndarray, ray_direction: jnp.ndarray) -> bool:
    inv_dir = 1.0 / ray_direction
    t1 = (aabb.min_point - ray_origin) * inv_dir
    t2 = (aabb.max_point - ray_origin) * inv_dir
    # Compute the entry and exit times along each axis.
    tmin = jnp.maximum(jnp.minimum(t1, t2)[0],
                       jnp.maximum(jnp.minimum(t1, t2)[1],
                                jnp.minimum(t1, t2)[2]))
    tmax = jnp.minimum(jnp.maximum(t1, t2)[0],
                       jnp.minimum(jnp.maximum(t1, t2)[1],
                                jnp.maximum(t1, t2)[2]))
    return (tmax >= tmin) and (tmax > 0.0)

# =============================================================================
# BVH Data Structures
# =============================================================================
# This structure pairs a triangle’s AABB with its (integer) primitive index.
@dataclass(frozen=True)
class BVHPrimitive:
    bounds: AABB
    prim_index: int

# BVHNode is a binary tree node (either a leaf or interior node)
@dataclass
class BVHNode:
    bounds: AABB
    left: Any = None   # type: BVHNode or None
    right: Any = None  # type: BVHNode or None
    leaf: bool = False
    prim_indices: jnp.ndarray = None  # 1D int32 array of triangle indices
    split_axis: int = -1

# LinearBVHNode is used when “flattening” the tree (optional)
@dataclass
class LinearBVHNode:
    bounds: AABB
    primitives_offset: int  # offset into a separate array of primitive indices
    n_primitives: int
    second_child_offset: int
    axis: int

# =============================================================================
# BVH Building (recursive, median–split)
# =============================================================================
def build_bvh(primitives: List[BVHPrimitive], max_prims_in_node: int = 4) -> BVHNode:
    """
    Recursively builds a BVH from a list of BVHPrimitive.

    Uses a median split along the largest axis of the centroid bounds.
    """
    n = len(primitives)
    if n == 0:
        # Should not occur; return a dummy node.
        dummy = AABB(jnp.array([0.0, 0.0, 0.0]),
                     jnp.array([0.0, 0.0, 0.0]),
                     jnp.array([0.0, 0.0, 0.0]))
        return BVHNode(bounds=dummy, leaf=True, prim_indices=jnp.array([], dtype=jnp.int32))

    # Compute the overall bounding box.
    bounds = primitives[0].bounds
    for prim in primitives[1:]:
        bounds = union(bounds, prim.bounds)

    if n <= max_prims_in_node:
        # Create a leaf node: record the triangle indices.
        prim_indices = jnp.array([prim.prim_index for prim in primitives], dtype=jnp.int32)
        return BVHNode(bounds=bounds, leaf=True, prim_indices=prim_indices)
    else:
        # Compute the bounding box of all centroids.
        centroid_bounds = primitives[0].bounds
        for prim in primitives[1:]:
            centroid_bounds = union_p(centroid_bounds, prim.bounds.centroid)
        axis = get_largest_dim(centroid_bounds)

        # Sort primitives along the chosen axis (using Python’s sorted, which is fine offline)
        sorted_prims = sorted(primitives, key=lambda p: float(p.bounds.centroid[axis]))
        mid = n // 2
        left_prims = sorted_prims[:mid]
        right_prims = sorted_prims[mid:]

        left_node = build_bvh(left_prims, max_prims_in_node)
        right_node = build_bvh(right_prims, max_prims_in_node)
        node_bounds = union(left_node.bounds, right_node.bounds)
        return BVHNode(bounds=node_bounds, left=left_node, right=right_node,
                       leaf=False, split_axis=axis)

# Optionally, you can “flatten” the BVH tree into a linear array (used for iterative traversal)
def flatten_bvh(node: BVHNode, linear_nodes: List[LinearBVHNode] = None, offset: int = 0) -> (List[LinearBVHNode], int):
    """
    Flatten the recursive BVH tree into a list of LinearBVHNodes.

    (Note: The primitives_offset field is not fully implemented here because
    we assume you can store the leaf triangle indices separately.)
    """
    if linear_nodes is None:
        linear_nodes = []
    current_index = offset
    linear_node = LinearBVHNode(bounds=node.bounds,
                                primitives_offset=-1,
                                n_primitives=0,
                                second_child_offset=-1,
                                axis=node.split_axis if not node.leaf else -1)
    linear_nodes.append(linear_node)
    offset += 1
    if node.leaf:
        # In a full implementation you would record the offset into an ordered list.
        linear_nodes[current_index].primitives_offset = 0  # placeholder
        linear_nodes[current_index].n_primitives = int(node.prim_indices.shape[0])
    else:
        # Flatten left child first.
        linear_nodes, offset = flatten_bvh(node.left, linear_nodes, offset)
        # Record the start of the right child.
        second_child_index = offset
        linear_nodes, offset = flatten_bvh(node.right, linear_nodes, offset)
        linear_nodes[current_index].second_child_offset = second_child_index
    return linear_nodes, offset

# =============================================================================
# Ray, Intersection, and Ray–Triangle Intersection
# =============================================================================
@dataclass(frozen=True)
class Ray:
    origin: jnp.ndarray     #: shape (3,)
    direction: jnp.ndarray  #: shape (3,)

@dataclass
class Intersection:
    t: float = INF
    prim_index: int = -1

# A simple ray–triangle intersection routine.
# (This uses a Möller–Trumbore–style test; you can replace it with your watertight method.)
@jax.jit
def ray_triangle_intersect(ray_origin: jnp.ndarray, ray_direction: jnp.ndarray,
                           v0: jnp.ndarray, v1: jnp.ndarray, v2: jnp.ndarray,
                           t_max: float = INF, epsilon: float = EPSILON) -> (bool, float):
    edge1 = v1 - v0
    edge2 = v2 - v0
    h = jnp.cross(ray_direction, edge2)
    a = jnp.dot(edge1, h)

    def no_hit(_):
        return False, t_max

    def potential_hit(_):
        f = 1.0 / a
        s = ray_origin - v0
        u = f * jnp.dot(s, h)
        def reject_u(_):
            return False, t_max
        def accept_u(_):
            q = jnp.cross(s, edge1)
            v = f * jnp.dot(ray_direction, q)
            def reject_v(_):
                return False, t_max
            def accept_v(_):
                t_candidate = f * jnp.dot(edge2, q)
                hit_cond = (t_candidate > epsilon) & (u >= 0.0) & (v >= 0.0) & ((u + v) <= 1.0)
                def hit_true(_):
                    return True, t_candidate
                def hit_false(_):
                    return False, t_max
                return jax.lax.cond(hit_cond, hit_true, hit_false, operand=None)
            return jax.lax.cond((u < 0.0) | (u > 1.0), reject_u, accept_u, operand=None)
        return jax.lax.cond(jnp.abs(a) < epsilon, no_hit, potential_hit, operand=None)
    return potential_hit(None)

# =============================================================================
# BVH Traversal for Ray Intersection
# =============================================================================
def intersect_bvh(ray: Ray, bvh_root: BVHNode, triangles: dict) -> Intersection:
    """
    Traverse the BVH recursively (in Python) and find the closest triangle hit.

    `triangles` is a dictionary containing keys "vertex_1", "vertex_2", "vertex_3"
    (each a JAX array of shape (N,3)).
    """
    closest = Intersection(t=INF, prim_index=-1)

    def traverse(node: BVHNode, ray: Ray, closest: Intersection) -> Intersection:
        if not aabb_intersect(node.bounds, ray.origin, ray.direction):
            return closest
        if node.leaf:
            # For each triangle in the leaf, test intersection.
            for idx in node.prim_indices:
                v0 = triangles["vertex_1"][idx]
                v1 = triangles["vertex_2"][idx]
                v2 = triangles["vertex_3"][idx]
                hit, t_candidate = ray_triangle_intersect(ray.origin, ray.direction, v0, v1, v2)
                if hit and t_candidate < closest.t:
                    closest = Intersection(t=t_candidate, prim_index=int(idx))
            return closest
        else:
            closest = traverse(node.left, ray, closest)
            closest = traverse(node.right, ray, closest)
            return closest

    return traverse(bvh_root, ray, closest)

# =============================================================================
# Utility: Create a BVH from Triangle Data (as loaded via your io.py)
# =============================================================================
def create_bvh_from_triangles(triangles: dict, max_prims_in_node: int = 4) -> BVHNode:
    """
    Given a dictionary of triangle arrays (each of shape (N,3)), build the BVH.
    """
    n = triangles["vertex_1"].shape[0]
    bvh_primitives = []
    for i in range(n):
        v0 = triangles["vertex_1"][i]
        v1 = triangles["vertex_2"][i]
        v2 = triangles["vertex_3"][i]
        # Compute the triangle’s AABB.
        min_point = jnp.minimum(jnp.minimum(v0, v1), v2)
        max_point = jnp.maximum(jnp.maximum(v0, v1), v2)
        centroid = (v0 + v1 + v2) / 3.0
        bounds = AABB(min_point, max_point, centroid)
        bvh_primitives.append(BVHPrimitive(bounds=bounds, prim_index=i))
    return build_bvh(bvh_primitives, max_prims_in_node)

# =============================================================================
# Example Usage
# =============================================================================

# For demonstration, we build a simple scene with two triangles.
# (In practice, you would load your triangle data from an OBJ file via your io.py.)
triangles = {
    "vertex_1": jnp.array([[0.0, 0.0, 0.0],
                             [1.0, 0.0, 0.0]]),
    "vertex_2": jnp.array([[0.0, 1.0, 0.0],
                             [1.0, 1.0, 0.0]]),
    "vertex_3": jnp.array([[0.0, 0.0, 1.0],
                             [1.0, 0.0, 1.0]])
}
# Build the BVH from the triangles.
bvh_root = create_bvh_from_triangles(triangles, max_prims_in_node=1)

# Create a ray that will hit one of the triangles.
ray = Ray(origin=jnp.array([0.5, 0.5, -1.0]),
          direction=jnp.array([0.0, 0.0, 1.0]))

# Traverse the BVH to find an intersection.
intersection = intersect_bvh(ray, bvh_root, triangles)
print("Intersection result:", intersection)


Intersection result: Intersection(t=10000000000.0, prim_index=-1)


In [2]:
import jax
import jax.numpy as jnp
from utils.io import load_obj, create_triangle_arrays

In [3]:
square_path = "objects/square.obj"
sphere_path = "objects/sphere.obj"
cube_path = "objects/cube.obj"
cylinder_path = "objects/cylinder.obj"
rabbit = "objects/rabbit.obj"
carrot = "objects/carrot.obj"
plane = "objects/plane.obj"
squirrel = "objects/squirrel.obj"
tree = "objects/broad_deciduous_tree_green_leaves.obj"
ant = "objects/ant.obj"
fireball = "objects/fireball.obj"
dragon = "objects/dragon.obj"
crawler = "objects/crawler.obj"


file_path = cube_path

In [4]:
vertices, faces = load_obj(file_path)
primitives = create_triangle_arrays(vertices, faces)
primitives

{'vertex_1': Array([[ 1.      , -1.      ,  1.      ],
        [-1.      ,  1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999],
        [ 0.999999,  1.      ,  1.000001],
        [-1.      , -1.      ,  1.      ],
        [ 1.      , -1.      , -1.      ],
        [ 1.      , -1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999],
        [ 1.      , -1.      , -1.      ],
        [ 1.      , -1.      ,  1.      ],
        [-1.      , -1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999]], dtype=float32),
 'vertex_2': Array([[-1.      , -1.      ,  1.      ],
        [-1.      ,  1.      ,  1.      ],
        [ 0.999999,  1.      ,  1.000001],
        [-1.      ,  1.      ,  1.      ],
        [-1.      ,  1.      ,  1.      ],
        [-1.      , -1.      , -1.      ],
        [ 1.      , -1.      ,  1.      ],
        [-1.      ,  1.      , -1.      ],
        [ 1.      ,  1.      , -0.999999],
        [ 0.999999,  1.      ,  1.000001],
        [-1. 

In [5]:
# Build the BVH from the triangles.
bvh_root = create_bvh_from_triangles(triangles, max_prims_in_node=1)

# Create a ray that will hit one of the triangles.
ray = Ray(origin=jnp.array([0.5, 0.5, -1.0]),
          direction=jnp.array([0.0, 0.0, 1.0]))

# Traverse the BVH to find an intersection.
intersection = intersect_bvh(ray, bvh_root, triangles)
print("Intersection result:", intersection)

Intersection result: Intersection(t=10000000000.0, prim_index=-1)


In [17]:
bboxes_dict = init_bounded_boxes_from_dict(primitives)

In [18]:
bounded_list = convert_dict_to_list(bboxes_dict)

In [19]:
total_nodes = [0]  # mutable integer
ordered_prims = []
root = build_bvh_tree(primitives, bounded_list, 0, len(bounded_list),
                      ordered_prims, total_nodes, split_method=0)


KeyError: 0