In [None]:
from nerf import utils
from jax import device_put
import jax.numpy as jnp
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from visualhull import *


FLAGS.config = "configs/blender"
FLAGS.data_dir = "data/nerf_synthetic/drums"
FLAGS.vh_test = True
if FLAGS.alpha_bkgd:
    FLAGS.num_rgb_channels = 4

if FLAGS.config is not None:
    utils.update_flags(FLAGS)

vsize = FLAGS.vsize

In [None]:
target = FLAGS.data_dir.split("/")[-1]
os.makedirs(os.path.join(FLAGS.voxel_dir, target), exist_ok=True)

# dataset = PureDataset("train", FLAGS)
# if FLAGS.dataset == "blender":
#     dataset.images = dataset.images.reshape(-1,800,800,4)
#     dataset.rays = dataset.rays._replace(origins = dataset.rays.origins.reshape(-1,800,800,3))
#     dataset.rays = dataset.rays._replace(directions = dataset.rays.directions.reshape(-1,800,800,3))
#     dataset.rays = dataset.rays._replace(viewdirs = dataset.rays.viewdirs.reshape(-1,800,800,3))
# elif FLAGS.dataset == "llff":
#     dataset.images = dataset.images.reshape(-1,756,1008,3)
#     dataset.rays = dataset.rays._replace(origins = dataset.rays.origins.reshape(-1,756,1008,3))
#     dataset.rays = dataset.rays._replace(directions = dataset.rays.directions.reshape(-1,756,1008,3))
#     dataset.rays = dataset.rays._replace(viewdirs = dataset.rays.viewdirs.reshape(-1,756,1008,3))

dataset = PureDataset("train", FLAGS)
dataset.images = dataset.images.reshape(-1,800,800,FLAGS.num_rgb_channels)
dataset.rays = dataset.rays._replace(
  origins=dataset.rays.origins.reshape(-1,800,800,3))
dataset.rays = dataset.rays._replace(
  directions=dataset.rays.directions.reshape(-1,800,800,3))
dataset.rays = dataset.rays._replace(
  viewdirs=dataset.rays.viewdirs.reshape(-1,800,800,3))

if FLAGS.vh_test:
    test_dataset = PureDataset("test", FLAGS)
    test_dataset.images = test_dataset.images.reshape(-1,800,800,FLAGS.num_rgb_channels)
    test_dataset.rays = test_dataset.rays._replace(
      origins=test_dataset.rays.origins.reshape(-1,800,800,3))
    test_dataset.rays = test_dataset.rays._replace(
      directions=test_dataset.rays.directions.reshape(-1,800,800,3))
    test_dataset.rays = test_dataset.rays._replace(
      viewdirs=test_dataset.rays.viewdirs.reshape(-1,800,800,3))

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.set_zlim(-5, 5)
ax.view_init(elev=90, azim=90)
# ax.view_init(elev=0, azim=0)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

for idx in range(0, len(dataset.rays.origins), 2):
    o = dataset.rays.origins[idx]
    d = dataset.rays.directions[idx]

    # render lower limit（blue）
    x_n, y_n, z_n = (o + d * t_n)[::16, ::16].reshape(-1, 3).T
    ax.scatter(x_n, y_n, z_n, c='blue', s=0.1)

    # render upper limit (green)
    x_f, y_f, z_f = (o + d * t_f)[::16, ::16].reshape(-1, 3).T
    ax.scatter(x_f, y_f, z_f, c='green', s=0.1)

In [None]:
def visualhull(dataset, test_dataset=None, target="", dilation=5, thresh=100):
    os.makedirs(os.path.join(FLAGS.voxel_dir+"_dil{}".format(dilation), target), exist_ok=True)

    ### shape
    voxel_s = np.zeros([vsize, vsize, vsize]).astype(np.bool)  # add

    for idx in tqdm(range(dataset.size)):
        o = dataset.rays.origins[idx]
        d = dataset.rays.directions[idx]
        img = dataset.images[idx]
        # mask = np.sum(img != 1., axis=2) != 0
        mask = img[Ellipsis, 3] > 0
        img = img[Ellipsis, :3]
        # remove whiteout
        if not FLAGS.alpha_bkgd:
            mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3,3)))
        # dilation (It makes appearance worse, but recommended for voxel initialization)
        mask = cv2.dilate(mask.astype(np.uint8), np.ones((dilation,dilation)), iterations=1)
        voxel_s += carve_voxel(o, d, mask).block_until_ready()  # add

    voxel_s = (voxel_s > thresh).astype(jnp.uint8) * get_sphere()  # add

    if FLAGS.vh_save == "shape":
        print(voxel_s.dtype, voxel_s.shape)
        np.save(os.path.join(FLAGS.voxel_dir+"_dil{}".format(dilation), target, "voxel.npy"), voxel_s)
        print("done!")

    ### color
    if FLAGS.vh_save == "color" or FLAGS.vh_test:
        voxel_c = device_put(jnp.ones([vsize, vsize, vsize, 3]).astype(jnp.float32))
        voxel_t = device_put(jnp.zeros([vsize, vsize, vsize]).astype(jnp.int16) + (vsize+1))

        for idx in tqdm(range(dataset.size)):
            o = dataset.rays.origins[idx]
            d = dataset.rays.directions[idx]
            img = dataset.images[idx, Ellipsis, :3]
            output = paint_voxel(o, d, img, voxel_c, voxel_t, voxel_s)
            jax.tree_map(lambda x: x.block_until_ready(), output)
            voxel_t, voxel_c = output

        voxel_c = voxel_c * get_sphere(True)

        if FLAGS.vh_save == "color":
            print(voxel_c.dtype, voxel_c.shape)
            np.save(os.path.join(FLAGS.voxel_dir+"_dil{}".format(dilation), target, "voxel.npy"), voxel_c)
            print("done!")

        N=5
        plt.figure(figsize=(40,40))
        for i in range(N):
            o = test_dataset.rays.origins[i]
            d = test_dataset.rays.directions[i]
            frame = render_voxel(voxel_s, voxel_c, o, d)
            plt.subplot(6,N,i+1+N*0); plt.imshow(frame)

        for i in range(N):
            frame = test_dataset.images[i,Ellipsis,:3]
            plt.subplot(6,N,i+1+N*1); plt.imshow(frame)

        voxel_c_red = (voxel_c*0. + jnp.array([1.,0.,0.])) * get_sphere(True)
        pred_masks = []
        for i in range(N):
            o = test_dataset.rays.origins[i]
            d = test_dataset.rays.directions[i]
            frame = render_voxel(voxel_s, voxel_c_red, o, d)
            pred_masks.append(frame)
            plt.subplot(6,N,i+1+N*2); plt.imshow(frame)

        for i in range(N):
            frame = (test_dataset.images[i,Ellipsis,3:] > 0).astype(np.float32) * np.array([1.,0.,0.])
            plt.subplot(6,N,i+1+N*3); plt.imshow(frame)

        for i in range(N):
            frame = (test_dataset.images[i,Ellipsis,3:] > 0).astype(np.float32) * np.array([1.,0.,0.])
            plt.subplot(6,N,i+1+N*4); plt.imshow(np.clip(frame - pred_masks[i], 0, 1))

        for i in range(N):
            frame = (test_dataset.images[i,Ellipsis,3:] > 0).astype(np.float32) * np.array([1.,0.,0.])
            plt.subplot(6,N,i+1+N*5); plt.imshow(np.clip(pred_masks[i] - frame, 0, 1))
        plt.savefig(os.path.join(FLAGS.voxel_dir+"_dil{}".format(dilation), target, "voxel.png"))
        # plt.show()
        plt.close()
        # import moviepy.editor as mpy
        # frames = []
        # for i in tqdm(range(test_dataset.size)):
        #     o = test_dataset.rays.origins[i]
        #     d = test_dataset.rays.directions[i]
        #     frame = render_voxel(voxel_s, voxel_c, o, d).block_until_ready()
        #     frames.append(frame)
        # frames = [(np.array(frame) * 255.).astype(np.uint8) for frame in frames]
        # clip = mpy.ImageSequenceClip(frames, fps=10)
        # clip.write_gif(os.path.join(FLAGS.voxel_dir, target, "voxel.gif"))

        print("test done!")

In [None]:
voxel_s, voxel_c = visualhull(dataset, test_dataset, target, dilation=5, thresh=100)

In [None]:
voxel_s, voxel_c = visualhull(dataset, test_dataset, target, dilation=27, thresh=100)