In [1]:
import time
import jax
import jax.numpy as jnp
from utils.io import load_obj, create_triangle_arrays

In [2]:
#load object

square_path = "objects/square.obj"
sphere_path = "objects/sphere.obj"
cube_path = "objects/cube.obj"
cylinder_path = "objects/cylinder.obj"
rabbit = "objects/rabbit.obj"
carrot = "objects/carrot.obj"
plane = "objects/plane.obj"
squirrel = "objects/squirrel.obj"
tree = "objects/broad_deciduous_tree_green_leaves.obj"
ant = "objects/ant.obj"
fireball = "objects/fireball.obj"
dragon = "objects/dragon.obj"
crawler = "objects/crawler.obj"
cornellbox = "objects/cornellbox.obj"


file_path = crawler

In [3]:
vertices, faces = load_obj(file_path)
triangles = create_triangle_arrays(vertices, faces)
triangles

{'vertex_1': Array([[-4.91896 ,  5.792907,  2.260844],
        [-4.921619,  5.789971,  2.260844],
        [-4.930886,  5.779743,  2.2332  ],
        ...,
        [-4.726382,  7.067084,  1.800162],
        [-4.736981,  7.067084,  2.323453],
        [-4.726382,  7.067084,  1.800162]], dtype=float32),
 'vertex_2': Array([[-4.941784,  5.813587,  2.260844],
        [-4.944443,  5.810652,  2.260844],
        [-4.953711,  5.800424,  2.2332  ],
        ...,
        [-4.726382,  7.067084,  2.323453],
        [-4.726382,  7.067084,  2.323453],
        [-4.736981,  7.067084,  1.800162]], dtype=float32),
 'vertex_3': Array([[-4.944443,  5.810652,  2.260844],
        [-4.953711,  5.800424,  2.2332  ],
        [-4.953711,  5.800424,  2.096386],
        ...,
        [-4.736981,  7.067084,  2.323453],
        [-4.726382,  7.23843 ,  2.323453],
        [-4.736981,  7.23843 ,  1.800162]], dtype=float32),
 'centroid': Array([[-4.9350624,  5.8057156,  2.260844 ],
        [-4.9399242,  5.8003488,  2.251629

In [4]:
from accelerators.bvh import create_primitives, create_bvh_primitives, pack_primitives

# Create primitives (the geometry) and the BVH primitives (geometry with bounds and index).
primitives = create_primitives(triangles)

bvh_primitives = create_bvh_primitives(triangles)

In [5]:
from accelerators.hlbvh import build_hlbvh
from accelerators.bvh import build_bvh

# Choose a splitting method:
#  0 = Surface Area Heuristic (SAH)
#  1 = Middle split
#  2 = Equal counts (median)
split_method = 0

# Build the BVH. The build_bvh routine expects:
#   - primitives: list of primitives for intersection testing.
#   - bvh_primitives: list of helper objects (with .bounds and .prim_num).
#   - _start, _end: integer indices into the bvh_primitives list.
#   - ordered_prims: an initially empty list to be filled with primitives in BVH order.
#   - split_method: integer specifying the split method.
start_t = time.time()
total_nodes = [0]
ordered_prims = []  # Will be filled in-order.
bvh_root = build_hlbvh(primitives, bvh_primitives, ordered_prims, total_nodes)
packed_prims = pack_primitives(ordered_prims)
end_t = time.time()
print("Elapsed (with compilation) = %s" % (end_t - start_t))

Elapsed (with compilation) = 172.81642079353333


In [6]:
from accelerators.hlbvh import flatten_bvh, pack_linear_bvh

# Flatten the BVH tree into a linear array for traversal.
# The root node is at index 0.
start_t = time.time()
linear_bvh_list = flatten_bvh(bvh_root)
linear_bvh = pack_linear_bvh(linear_bvh_list)
end_t = time.time()
print("Elapsed (with compilation) = %s" % (end_t - start_t))

Elapsed (with compilation) = 0.43255114555358887


In [7]:

print("BVH build complete.")
print("Number of BVH tree nodes:", len(bvh_root))
print("Number of linear BVH nodes:", len(linear_bvh_list))
print("Number of ordered primitives:", len(ordered_prims))

BVH build complete.


TypeError: object of type 'BVHNode' has no len()

In [None]:
# from tests.test_bvh import print_bvh_tree
#
# print_bvh_tree(nodes, 0)

In [None]:
# from tests.test_bvh import print_linear_bvh
#
# print_linear_bvh(linear_bvh_list)

In [None]:
from base.renderer import create_default_camera, render

# Create a default camera that frames the entire object.
width = 300
height = 300
fov = 45.0  # vertical field-of-view in degrees
camera = create_default_camera(triangles, width, height, fov)

start_t = time.time()

# Render the image.
image = render(linear_bvh, packed_prims, camera, batch_size=1024*1)





# Display the image using matplotlib.
import matplotlib.pyplot as plt

plt.imshow(image)
plt.title("Simple Render")
plt.axis("off")
plt.show()

end_t = time.time()
print("Elapsed (with compilation) = %s" % (end_t - start_t))

In [38]:
from primitives.triangle import intersect_triangle
from primitives.ray import Ray

# Define a simple triangle in the XY plane.
v0 = jnp.array([0.0, 0.0, 0.0])
v1 = jnp.array([1.0, 0.0, 0.0])
v2 = jnp.array([0.0, 1.0, 0.0])

# Define a ray that starts at (0.25, 0.25, -1) and points in the +Z direction.
ray_origin = jnp.array([0.25, 0.25, -1.0])
ray_direction = jnp.array([0.0, 0.0, 1.0])
ray = Ray(origin=ray_origin, direction=ray_direction)

t_max = 1e10  # some very large distance
hit, t = intersect_triangle(ray_origin, ray_direction, v0, v1, v2, t_max)

# Print the results.
print("Ray origin:", ray_origin)
print("Ray direction:", ray_direction)
print("Triangle vertices:", v0, v1, v2)
print("Intersection hit:", hit)
print("Intersection t:", t)

Ray origin: [ 0.25  0.25 -1.  ]
Ray direction: [0. 0. 1.]
Triangle vertices: [0. 0. 0.] [1. 0. 0.] [0. 1. 0.]
Intersection hit: True
Intersection t: 1.0
