In [1]:
from pathlib import Path
import json
import os

import numpy as np
import open3d as o3d
import cv2
from einops import repeat, rearrange
import torch

from viz_utils import to_pcd, to_lines, to_mesh, viz_frustum, spherical_viz, create_octant_planes
from camera import fovx2intrinsic, get_spiral_poses, get_rays
from graphics import (
    generate_coarse_samples, generate_fine_samples, 
    volume_render, hierarchical_volume_render, 
    rays2image
)


__author__ = "__Girish_Hegde__"

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [11]:
bunny_file = './data/bunny_voxels.npy'

if Path(bunny_file).is_file():
    voxels = np.load(bunny_file)
else:
    print('This may take 2 to 5 minutes ...')
    Path(bunny_file).parent.mkdir(exist_ok=True, parents=True)
    bunny = o3d.data.BunnyMesh()
    mesh = o3d.io.read_triangle_mesh(bunny.path)
    mesh.compute_vertex_normals()
    mesh.scale(1 / np.max(mesh.get_max_bound() - mesh.get_min_bound()), center=mesh.get_center())
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.02)
    voxels = voxel_grid.get_voxels()
    voxels = np.array([vx.grid_index for vx in voxels]) - 1
    np.save(bunny_file, voxels)

voxels = voxels[:, [0, 2, 1]]
# voxels bounds
xmax, ymax, zmax = voxels.max(0) + 1
tmax = max(xmax, ymax, zmax)
centroid = np.mean(voxels, 0)

# get boolen grid from voxel positions
grid = np.zeros((tmax, tmax, tmax), dtype=bool)
grid[voxels[:, 0], voxels[:, 1], voxels[:, 2]] = True

# visualization
points = o3d.utility.Vector3dVector(voxels.astype(np.float64))
voxel_pcd = o3d.geometry.PointCloud(points)
voxel_pcd.paint_uniform_color([1., 1, 0])
octants = create_octant_planes(tmax)
o3d.visualization.draw_geometries([voxel_pcd, octants], 'voxels', mesh_show_back_face=True)

print(f'{voxels.shape = }, {grid.sum() = }, {(xmax, ymax, zmax) = }, {centroid = }')


voxels.shape = (8527, 3), grid.sum() = 8527, (xmax, ymax, zmax) = (50, 39, 50), centroid = array([21.40096165, 22.12560103, 19.39286971])


In [13]:
# ==========================================
# get camera intrinsics and spherical poses
# ==========================================
h, w = 100, 200
nviews, radius, vertical_offset = 20, 75, 10
up = torch.tensor([0, 1, 0.])

K = fovx2intrinsic(0.75, h, w)
view_mats = get_spiral_poses(
    torch.FloatTensor(centroid), nviews, radius, 
    (vertical_offset, vertical_offset), 1
)
sperical_scene = spherical_viz(view_mats, centroid, [octants, voxel_pcd]);

In [34]:
# ===============
# get grid rays
# ===============
cam_id = 3

origins, directions = get_rays(h, w, K, view_mats[cam_id])
grid_pts = origins + directions*100
frustum_viz = viz_frustum(grid_pts, origins, face_clr=(0, 1, 1), line_clr=(0, 0, 0))
spherical_viz(view_mats, centroid, sperical_scene + list(frustum_viz));

# flines = to_lines(
#     np.vstack([grid_pts, starts]),
#     rearrange([np.arange(n), n + np.arange(n)], 't n -> n t'),
#     [0, 0, 0],
# )

In [32]:
# =======================================================
# function get points on ray and voxel grid intersection
# =======================================================

def intersect_rays(origins, directions, samples, grid):
    densities = torch.zeros_like(samples)
    positions = origins[:, None, :] + directions[:, None, :]*samples[:, :, None]
    positions = torch.round(positions).long()
    xlim, ylim, zlim = grid.shape
    valid = (
        (
              (positions[..., 0] < xlim) 
            & (positions[..., 1] < ylim) 
            & (positions[..., 2] < zlim)
        ) 
        & (positions > 0).all(-1)
    )  # [b, n, 3]
    query = positions[valid]  # [k, 3]
    intersections = grid[query[:, 0], query[:, 1], query[:, 2]]
    xpts = query[intersections]
    # xpcd = to_pcd(xpts, [0, 1, 1], viz=True, name='intersections')

    temp = torch.zeros(len(intersections))
    temp[intersections] = 1
    densities[valid] = temp
    colors = torch.zeros((*densities.shape, 3))
    set_r = torch.zeros((len(intersections), 3))
    set_r[:, 0] = intersections
    colors[valid] = set_r

    return densities, colors, xpts

In [35]:
# ===============================================
# get coarse samples and coarse volume rendering
# ===============================================
Nc = 200  # No. of coarse samples
Nf = 32  # No. of fine samples
min_depth, max_depth = 0, 100

coarse_samples, coarse_distances, bin_starts, bin_size = generate_coarse_samples(n, Nc, min_depth, max_depth)
print(f'{coarse_samples.shape = }')


densities_c, colors_c, xpts = intersect_rays(origins, directions, coarse_samples, torch.BoolTensor(grid))
args = list(map(torch.FloatTensor, (coarse_samples, coarse_distances, densities_c, colors_c))) + [max_depth]
coarse_color, pdf = volume_render(*args)
print(f'{coarse_color.shape=}')


coarse_samples.shape = torch.Size([20000, 200])


: 

: 

In [None]:
# ===================================================
# get fine samples and hierarchical volume rendering
# ===================================================
fine_samples = generate_fine_samples(n, Nf, bin_size, bin_starts, pdf + 1e-6)
densities_f, colors_f, xptsf = intersect_rays(origins[cam_id], directions[cam_id], fine_samples, torch.BoolTensor(grid))

args = list(map(torch.FloatTensor, (
        coarse_samples, densities_c, colors_c,
        fine_samples, densities_f, colors_f
    ))
) + [max_depth]

ray_color, ray_pdf, (samples, distances, densities, colors) = hierarchical_volume_render(*args)
print(f'{ray_color.shape=}')

In [None]:
# ===============================
# volume rendering visualization
# ===============================
rays2image(coarse_color, h, w, stride, scale=4, bgr=True, show=True, filename=None);
rays2image(ray_color, h, w, stride, scale=4, bgr=True, show=True, filename=None);

In [None]:
# ===============================================
# intersections and fine sampling visualizations
# ===============================================
xpcd = to_pcd(xpts, [0, 1, 1], viz=False)
n = origins.shape[1]
line_clrs = np.zeros((n, 3))
# line_clrs[ray_id] = [1, 0, 0]
grid_pts, starts = origins[cam_id] + directions[cam_id]*100, origins[cam_id]
flines = to_lines(
    np.vstack([grid_pts, starts]),
    rearrange([np.arange(n), n + np.arange(n)], 't n -> n t'),
    line_clrs,
)

positions = origins[cam_id][:, None, :] + directions[cam_id][:, None, :]*fine_samples.numpy()[:, :, None]
finepcd = to_pcd(positions.reshape(-1, 3), [1, 0.3, 0], viz=True, name='fine samples')
cam_ray_vizs = [flines, ]
spherical_viz(centroid, eyes, k, j, i, sperical_scene + cam_ray_vizs + [xpcd, ]);

In [None]:
# ============================
# multi-view volume rendering 
# ============================

Nc = 100  # No. of coarse samples
Nf = 100  # No. of fine samples
min_depth, max_depth = 0, 100

outdir = Path(bunny_file).parent/'bunny_renderings/spherical_views'    
for cam_id in range(nviews):
    print(f'rendering {cam_id + 1}/{nviews} view ...')
    coarse_samples, coarse_distances, bin_starts, bin_size = generate_coarse_samples(n, Nc, min_depth, max_depth)

    densities_c, colors_c, xpts = intersect_rays(origins[cam_id], directions[cam_id], coarse_samples, torch.BoolTensor(grid))
    args = list(map(torch.FloatTensor, (coarse_samples, coarse_distances, densities_c, colors_c))) + [max_depth]
    coarse_color, pdf = volume_render(*args)

    fine_samples = generate_fine_samples(n, Nf, bin_size, bin_starts, pdf + 1e-6)
    densities_f, colors_f, xptsf = intersect_rays(origins[cam_id], directions[cam_id], fine_samples, torch.BoolTensor(grid))

    args = list(map(torch.FloatTensor, (
            coarse_samples, densities_c, colors_c,
            fine_samples, densities_f, colors_f
        ))
    ) + [max_depth]

    ray_color, ray_pdf, (samples, distances, densities, colors) = hierarchical_volume_render(*args)
    ray_color = ray_color.numpy()

    rays2image(coarse_color, h, w, stride, scale=4, bgr=True, show=False, filename=str(outdir/f'coarse_{cam_id}.png'))
    rays2image(ray_color, h, w, stride, scale=4, bgr=True, show=False, filename=str(outdir/f'{cam_id}.png'))