In [1]:
import torch
import os
import sys
import importlib.util
import tqdm

from torch import Tensor
import nerfacc
import imageio
import numpy as np

from nerfacc.estimators.prop_net import (
    PropNetEstimator,
    get_proposal_requires_grad_fn,
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
exp_dir = "output/test9"
root_fp = "/media/ccl/Data/Datasets/NeRF/aizu-student-hall/output/processed"
test_chunk_size=8192

# Create a module spec
spec = importlib.util.spec_from_file_location('ngp_appearance', f'{exp_dir}/ngp_appearance.py')
ngp_appearance = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ngp_appearance)
NGPDensityField = ngp_appearance.NGPDensityField
NGPRadianceField = ngp_appearance.NGPRadianceField

# spec = importlib.util.spec_from_file_location('nerf_colmap', f'{exp_dir}/nerf_colmap.py')
# nerf_colmap = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(nerf_colmap)
# SubjectLoader = nerf_colmap.SubjectLoader
from datasets.nerf_colmap import SubjectLoader

device = "cuda:0"
# scene parameters
unbounded = True
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.2  # TODO: Try 0.02
far_plane = 1e3
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 2}
test_dataset_kwargs = {"factor": 4}
# model parameters
proposal_networks = [
    NGPDensityField(
        aabb=aabb,
        unbounded=unbounded,
        n_levels=5,
        max_resolution=128,
    ).to(device),
    NGPDensityField(
        aabb=aabb,
        unbounded=unbounded,
        n_levels=5,
        max_resolution=256,
    ).to(device),
]

estimator = PropNetEstimator().to(device)
# radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*2, n_levels=16, log2_hashmap_size=17).to(device)
# radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*4, n_levels=18, log2_hashmap_size=19).to(device)
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*8, n_levels=19, log2_hashmap_size=20).to(device)

In [3]:
radiance_field.load_state_dict(torch.load(os.path.join(exp_dir, 'radiance_field.pth')))
# estimator.load_state_dict(torch.load(os.path.join(exp_dir, 'estimator.pth')))

for i, net in enumerate(proposal_networks):
    state_dict = torch.load(os.path.join(exp_dir, f'proposal_network_{i}.pth'))
    print(state_dict.keys())
    net.load_state_dict(state_dict)


radiance_field.eval()
for p in proposal_networks:
    p.eval()
estimator.eval()


odict_keys(['aabb', 'mlp_base.params'])
odict_keys(['aabb', 'mlp_base.params'])


PropNetEstimator()

In [4]:

from datasets.utils import Rays
from utils import (
    render_image_with_propnet,
)
from datasets.nerf_colmap import _load_colmap, similarity_from_cameras

import torch.nn.functional as F


# render parameters
num_samples = 64
num_samples_per_prop = [256, 96]
sampling_type = "lindisp"
opaque_bkgd = True
factor = 4

images, camtoworlds, K, split_indices = _load_colmap(
    root_fp, 0, factor
)
# normalize the scene
T, sscale = similarity_from_cameras(
    camtoworlds, strict_scaling=True
)
camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, T)
camtoworlds[:, :3, 3] *= sscale

images = torch.from_numpy(images).to(torch.uint8).to(device)
camtoworlds = (
    torch.from_numpy(camtoworlds).to(torch.float32).to(device)
)
K = torch.tensor(K).to(torch.float32).to(device)

1
loading images


100%|██████████| 616/616 [00:01<00:00, 385.82it/s]


In [6]:


test_dataset = SubjectLoader(
    subject_id=0,
    root_fp=root_fp,
    split="test",
    num_rays=None,
    device=device,
    **test_dataset_kwargs,
)
with torch.no_grad():
    for i in tqdm.tqdm(range(len(test_dataset))):
        data = test_dataset[i]
        render_bkgd = data["color_bkgd"]
        rays = data["rays"]
        pixels = data["pixels"]
        img = data["img"]

        # rendering
        rgb, acc, depth, _, = render_image_with_propnet(
            radiance_field,
            proposal_networks,
            estimator,
            rays,
            # rendering options
            num_samples=num_samples,
            num_samples_per_prop=num_samples_per_prop,
            near_plane=near_plane,
            far_plane=far_plane,
            sampling_type=sampling_type,
            opaque_bkgd=opaque_bkgd,
            render_bkgd=render_bkgd,
            # test options
            test_chunk_size=test_chunk_size,
            img=img,
        )
        renders_dir = os.path.join(exp_dir, "renders")
        os.makedirs(renders_dir, exist_ok=True)

        imageio.imwrite(
            os.path.join(renders_dir, f"rgb_{i:08}_render.png"),
            (rgb.cpu().numpy() * 255).astype(np.uint8),
        )
        imageio.imwrite(
            os.path.join(renders_dir, f"rgb_{i:08}_ground_truth.png"),
            (pixels.cpu().numpy() * 255).astype(np.uint8),
        )
        imageio.imwrite(
            os.path.join(renders_dir, f"rgb_{i:08}_error.png"),
            (
                (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
            ).astype(np.uint8),
        )
        vis_depth = torch.log(depth)
        vis_depth -= torch.min(vis_depth)
        vis_depth /= torch.max(vis_depth)
        imageio.imwrite(
            os.path.join(renders_dir, f"rgb_{i:08}_depth.png"),
            (
                vis_depth.cpu().numpy() * 255
            ).astype(np.uint8),
        )

1
loading images


100%|██████████| 616/616 [00:01<00:00, 411.02it/s]
 75%|███████▌  | 58/77 [00:34<00:11,  1.67it/s]


KeyboardInterrupt: 

In [None]:
# for index in range(50):
#     image_id = [index]
#     height, width = images.shape[1:3]

#     
#     color_bkgd = torch.ones(3, device=images.device)

#     x, y = torch.meshgrid(
#         torch.arange(width, device=images.device),
#         torch.arange(height, device=images.device),
#         indexing="xy",
#     )
#     x = x.flatten()
#     y = y.flatten()

#     c2w = camtoworlds[image_id]  # (num_rays, 3, 4)
#     camera_dirs = F.pad(
#         torch.stack(
#             [
#                 (x - K[0, 2] + 0.5) / K[0, 0],
#                 (y - K[1, 2] + 0.5)
#                 / K[1, 1]
#                 * (1.0),
#             ],
#             dim=-1,
#         ),
#         (0, 1),
#         value=(1.0),
#     ) 
#     directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
#     origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
#     viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
#     origins = torch.reshape(origins, (height, width, 3))
#     viewdirs = torch.reshape(viewdirs, (height, width, 3))


#     rays = Rays(origins=origins, viewdirs=viewdirs)

#     img = torch.transpose(images[index], 0, 2)
#     img = torch.transpose(img, 1, 2)
#     img = (img / 255.0).unsqueeze(0).cuda()

#     with torch.no_grad():
#         rgb, acc, depth, _, = render_image_with_propnet(
#             radiance_field,
#             proposal_networks,
#             estimator,
#             rays,
#             # rendering options
#             num_samples=num_samples,
#             num_samples_per_prop=num_samples_per_prop,
#             near_plane=near_plane,
#             far_plane=far_plane,
#             sampling_type=sampling_type,
#             opaque_bkgd=opaque_bkgd,
#             render_bkgd=color_bkgd,
#             # test options
#             test_chunk_size=test_chunk_size,
#             img=img,
#         )
#         renders_dir = os.path.join(exp_dir, "renders")
#         os.makedirs(renders_dir, exist_ok=True)
#         imageio.imwrite(
#             os.path.join(renders_dir, f"rgb_{index:08}_render.png"),
#             (rgb.cpu().numpy() * 255).astype(np.uint8),
#         )