In [None]:
from pathlib import Path
import json
import os

import numpy as np
import open3d as o3d
import cv2
from einops import repeat, rearrange
import torch

from viz_utils import to_pcd, to_lines, to_mesh, spherical_viz, create_octant_planes
from camera import fovx2intrinsic, get_spherical_poses, generate_grid_rays
from graphics import generate_coarse_samples, generate_fine_samples, volume_render, hierarchical_volume_render


__author__ = "__Girish_Hegde__"


In [None]:
bunny_file = './data/bunny_voxels.npy'

if Path(bunny_file).is_file():
    voxels = np.load(bunny_file)
else:
    print('This may take 2 to 5 minutes ...')
    Path(bunny_file).parent.mkdir(exist_ok=True, parents=True)
    bunny = o3d.data.BunnyMesh()
    mesh = o3d.io.read_triangle_mesh(bunny.path)
    mesh.compute_vertex_normals()
    mesh.scale(1 / np.max(mesh.get_max_bound() - mesh.get_min_bound()), center=mesh.get_center())
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.02)
    voxels = voxel_grid.get_voxels()
    voxels = np.array([vx.grid_index for vx in voxels]) - 1
    np.save(bunny_file, voxels)

# voxels bounds
xmax, ymax, zmax = voxels.max(0) + 1
tmax = max(xmax, ymax, zmax)
centroid = np.mean(voxels, 0)

# get boolen grid from voxel positions
grid = np.zeros((tmax, tmax, tmax), dtype=bool)
grid[voxels[:, 0], voxels[:, 1], voxels[:, 2]] = True

# visualization
points = o3d.utility.Vector3dVector(voxels.astype(np.float64))
voxel_pcd = o3d.geometry.PointCloud(points)
voxel_pcd.paint_uniform_color([1., 1, 0])
octants = create_octant_planes(tmax)
o3d.visualization.draw_geometries([voxel_pcd], 'voxels')

print(f'{voxels.shape = }, {grid.sum() = }, {(xmax, ymax, zmax) = }, {centroid = }')


In [None]:
# ==========================================
# get camera intrinsics and spherical poses
# ==========================================
w, h = 100, 100
nviews, radius, vertical_offset = 20, 75, 10
up = np.array([0, 1, 0])

K = fovx2intrinsic(1, w, h)
K = fovx2intrinsic(0.75, w, h)
eyes, fronts, ups, rights, i, j, k = get_spherical_poses(centroid, nviews, radius, vertical_offset, up)
sperical_scene = spherical_viz(centroid, eyes, fronts, ups, rights, [octants, voxel_pcd])
spherical_viz(centroid, eyes, k, j, i, sperical_scene);

In [None]:
# ===============
# get grid rays
# ===============
stride = 2

width, height = w, h
origins, directions, grid_points = generate_grid_rays(K, eyes, i, j, k, width, height, stride)

cam_id = 0
n = grid_points.shape[1]
ends, starts = origins[cam_id] + directions[cam_id]*100, origins[cam_id]
flines = to_lines(
    np.vstack([ends, starts]),
    rearrange([np.arange(n), n + np.arange(n)], 't n -> n t'),
    [0, 0, 0],
)
cam_ray_vizs = [flines, ]
spherical_viz(centroid, eyes, k, j, i, sperical_scene + cam_ray_vizs);
print(f'{grid_points.shape = }')

# ray_vizs = []
# n = grid_points.shape[1]
# for ends, starts in zip(grid_points, origins):
#     flines = to_lines(
#         np.vstack([ends, starts]),
#         rearrange([np.arange(n), n + np.arange(n)], 't n -> n t'),
#         [0, 0, 0],
#     )
#     ray_vizs.append(flines)
# spherical_viz(centroid, eyes, k, j, i, sperical_scene + ray_vizs);

In [None]:
# =======================================================
# function get points on ray and voxel grid intersection
# =======================================================

def intersect_rays(origins, directions, samples, grid):
    densities = np.zeros_like(samples)
    colors = np.zeros((*densities.shape, 3))

    positions = origins[:, None, :] + directions[:, None, :]*samples[:, :, None]
    positions = np.round(positions).astype(int)
    xlim, ylim, zlim = grid.shape
    valid = (
        (
              (positions[..., 0] < xlim) 
            & (positions[..., 1] < ylim) 
            & (positions[..., 2] < zlim)
        ) 
        & (positions > 0).all(-1)
    )  # [b, n, 3]
    query = positions[valid]  # [k, 3]
    intersections = grid[query[:, 0], query[:, 1], query[:, 2]]
    xpts = query[intersections]
    # xpcd = to_pcd(xpts, [0, 1, 1], viz=True, name='intersections')

    temp = np.zeros(len(intersections))
    temp[intersections] = 1
    densities[valid] = temp
    colors[valid, 0] = intersections.astype(np.float32)

    return densities, colors, xpts

In [None]:
# ===============================================
# get coarse samples and coarse volume rendering
# ===============================================
Nc = 200  # No. of coarse samples
Nf = 32  # No. of fine samples
min_depth, max_depth = 0, 100
cam_id = 3

coarse_samples, coarse_distances, bin_starts, bin_size = generate_coarse_samples(n, Nc, min_depth, max_depth)
coarse_samples = coarse_samples.numpy()
print(f'{coarse_samples.shape = }')

densities_c, colors_c, xpts = intersect_rays(origins[cam_id], directions[cam_id], coarse_samples, grid)
args = list(map(torch.FloatTensor, (coarse_samples, coarse_distances, densities_c, colors_c))) + [max_depth]
coarse_color, pdf = volume_render(*args)
coarse_color = coarse_color.numpy()
print(f'{coarse_color.shape=}')


In [None]:
# ===================================================
# get fine samples and hierarchical volume rendering
# ===================================================
fine_samples = generate_fine_samples(n, Nf, bin_size, bin_starts, pdf + 1e-6)
densities_f, colors_f, xptsf = intersect_rays(origins[cam_id], directions[cam_id], fine_samples.numpy(), grid)

args = list(map(torch.FloatTensor, (
        coarse_samples, densities_c, colors_c,
        fine_samples, densities_f, colors_f
    ))
) + [max_depth]

ray_color, ray_pdf, (samples, distances, densities, colors) = hierarchical_volume_render(*args)
ray_color = ray_color.numpy()
print(f'{ray_color.shape=}')

In [None]:
# ===============================
# volume rendering visualization
# ===============================
img = np.zeros((h, w, 3))
rendering = rearrange(coarse_color, '(w h) c -> h w c', w=width//stride)[::-1, :, ::-1]
img[::stride, ::stride] = rendering
img = (img*255).astype(np.uint8)
cv2.imshow('coarse volume rendering', cv2.resize(img, (h*4, w*4), interpolation=cv2.INTER_NEAREST))
cv2.waitKey(0)
cv2.destroyAllWindows()
outdir = Path(bunny_file).parent/'bunny_renderings'
outdir.mkdir(exist_ok=True, parents=True)
cv2.imwrite(str(outdir/'coarse.png'), img)

img = np.zeros((h, w, 3))
rendering = rearrange(ray_color, '(w h) c -> h w c', w=width//stride)[::-1, :, ::-1]
img[::stride, ::stride] = rendering
img = (img*255).astype(np.uint8)
cv2.imshow('img', cv2.resize(img, (h*4, w*4), interpolation=cv2.INTER_NEAREST))
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite(str(outdir/'hierarchical.png'), img)

In [None]:
# ===============================================
# intersections and fine sampling visualizations
# ===============================================
xpcd = to_pcd(xpts, [0, 1, 1], viz=False)
n = grid_points.shape[1]
line_clrs = np.zeros((n, 3))
# line_clrs[ray_id] = [1, 0, 0]
ends, starts = origins[cam_id] + directions[cam_id]*100, origins[cam_id]
flines = to_lines(
    np.vstack([ends, starts]),
    rearrange([np.arange(n), n + np.arange(n)], 't n -> n t'),
    line_clrs,
)

positions = origins[cam_id][:, None, :] + directions[cam_id][:, None, :]*fine_samples.numpy()[:, :, None]
finepcd = to_pcd(positions.reshape(-1, 3), [1, 0.3, 0], viz=True, name='fine samples')
cam_ray_vizs = [flines, ]
spherical_viz(centroid, eyes, k, j, i, sperical_scene + cam_ray_vizs + [xpcd, ]);

In [None]:
# ============================
# multi-view volume rendering 
# ============================

# ===============================================
# get coarse samples and coarse volume rendering
# ===============================================
Nc = 100  # No. of coarse samples
Nf = 100  # No. of fine samples
min_depth, max_depth = 0, 100

outdir = Path(bunny_file).parent/'bunny_renderings/spherical_views'
outdir.mkdir(exist_ok=True, parents=True)
    
for cam_id in range(nviews):
    print(f'rendering {cam_id + 1}/{nviews} view ...')
    coarse_samples, coarse_distances, bin_starts, bin_size = generate_coarse_samples(n, Nc, min_depth, max_depth)
    coarse_samples = coarse_samples.numpy()

    densities_c, colors_c, xpts = intersect_rays(origins[cam_id], directions[cam_id], coarse_samples, grid)
    args = list(map(torch.FloatTensor, (coarse_samples, coarse_distances, densities_c, colors_c))) + [max_depth]
    coarse_color, pdf = volume_render(*args)
    coarse_color = coarse_color.numpy()

    fine_samples = generate_fine_samples(n, Nf, bin_size, bin_starts, pdf + 1e-6)
    densities_f, colors_f, xptsf = intersect_rays(origins[cam_id], directions[cam_id], fine_samples.numpy(), grid)

    args = list(map(torch.FloatTensor, (
            coarse_samples, densities_c, colors_c,
            fine_samples, densities_f, colors_f
        ))
    ) + [max_depth]

    ray_color, ray_pdf, (samples, distances, densities, colors) = hierarchical_volume_render(*args)
    ray_color = ray_color.numpy()

    img = np.zeros((h, w, 3))
    rendering = rearrange(coarse_color, '(w h) c -> h w c', w=width//stride)[::-1, :, ::-1]
    img[::stride, ::stride] = rendering
    img = (img*255).astype(np.uint8)
    cv2.imwrite(str(outdir/f'coarse_{cam_id}.png'), img)

    img = np.zeros((h, w, 3))
    rendering = rearrange(ray_color, '(w h) c -> h w c', w=width//stride)[::-1, :, ::-1]
    img[::stride, ::stride] = rendering
    img = (img*255).astype(np.uint8)
    cv2.imwrite(str(outdir/f'{cam_id}.png'), img)