In [1]:
import torch
import os
import sys
import importlib.util
import tqdm

from torch import Tensor
import nerfacc
import imageio
import numpy as np
try:
    import pytorch3d
    from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_matrix
except ModuleNotFoundError:
    pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
    version_str="".join([
        f"py3{sys.version_info.minor}_cu",
        torch.version.cuda.replace(".",""),
        f"_pyt{pyt_version_str}"
    ])
    !pip install fvcore iopath
    !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    

from nerfacc.estimators.prop_net import (
    PropNetEstimator,
    get_proposal_requires_grad_fn,
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
exp_dir = "output/exp3"
root_fp = "/home/ccl/Datasets/NeRF/aizu-student-hall/output/processed"
test_chunk_size=8192

# Create a module spec
spec = importlib.util.spec_from_file_location('ngp_appearance', f'{exp_dir}/ngp_appearance.py')
ngp_appearance = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ngp_appearance)
NGPDensityField = ngp_appearance.NGPDensityField
NGPRadianceField = ngp_appearance.NGPRadianceField

# spec = importlib.util.spec_from_file_location('nerf_colmap', f'{exp_dir}/nerf_colmap.py')
# nerf_colmap = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(nerf_colmap)
# SubjectLoader = nerf_colmap.SubjectLoader
from datasets.nerf_colmap import SubjectLoader

device = "cuda:0"
# scene parameters
unbounded = True
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.08  # TODO: Try 0.02
far_plane = 1e3
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 2}
test_dataset_kwargs = {"factor": 4}
# model parameters
proposal_networks = [
    NGPDensityField(
        aabb=aabb,
        unbounded=unbounded,
        n_levels=5,
        max_resolution=128,
    ).to(device),
    NGPDensityField(
        aabb=aabb,
        unbounded=unbounded,
        n_levels=5,
        max_resolution=256,
    ).to(device),
]

estimator = PropNetEstimator().to(device)
# radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*2, n_levels=16, log2_hashmap_size=17).to(device)
# radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*4, n_levels=18, log2_hashmap_size=19).to(device)
# radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*8, n_levels=19, log2_hashmap_size=20).to(device)
radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded, max_resolution=4096*16, n_levels=20, log2_hashmap_size=21).to(device)

In [3]:
radiance_field.load_state_dict(torch.load(os.path.join(exp_dir, 'radiance_field.pth')))
# estimator.load_state_dict(torch.load(os.path.join(exp_dir, 'estimator.pth')))

for i, net in enumerate(proposal_networks):
    state_dict = torch.load(os.path.join(exp_dir, f'proposal_network_{i}.pth'))
    print(state_dict.keys())
    net.load_state_dict(state_dict)


radiance_field.eval()
for p in proposal_networks:
    p.eval()
estimator.eval()


odict_keys(['aabb', 'mlp_base.params'])
odict_keys(['aabb', 'mlp_base.params'])


PropNetEstimator()

In [4]:

from datasets.utils import Rays
from utils import (
    render_image_with_propnet,
)
from datasets.nerf_colmap import _load_colmap, similarity_from_cameras

import torch.nn.functional as F


# render parameters
num_samples = 128
num_samples_per_prop = [512, 256]
sampling_type = "lindisp"
opaque_bkgd = True
factor = 2

images, camtoworlds, K, split_indices = _load_colmap(
    root_fp, 0, factor
)
# normalize the scene
T, sscale = similarity_from_cameras(
    camtoworlds, strict_scaling=True
)
camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, T)
camtoworlds[:, :3, 3] *= sscale

images = torch.from_numpy(images).to(torch.uint8).to(device)
camtoworlds = (
    torch.from_numpy(camtoworlds).to(torch.float32).to(device)
)
K = torch.tensor(K).to(torch.float32).to(device)

1
loading images


100%|██████████| 616/616 [00:05<00:00, 103.86it/s]


In [5]:
def generate_rays(images, c2w, K):
    height, width = images.shape[1:3]

    x, y = torch.meshgrid(
        torch.arange(width, device=images.device),
        torch.arange(height, device=images.device),
        indexing="xy",
    )
    x = x.flatten()
    y = y.flatten()

    camera_dirs = F.pad(
        torch.stack(
            [
                (x - K[0, 2] + 0.5) / K[0, 0],
                (y - K[1, 2] + 0.5)
                / K[1, 1]
                * (1.0),
            ],
            dim=-1,
        ),
        (0, 1),
        value=(1.0),
    ) 
    directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
    origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
    viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
    origins = torch.reshape(origins, (height, width, 3))
    viewdirs = torch.reshape(viewdirs, (height, width, 3))

    height, width = images.shape[1:3]

    return Rays(origins=origins, viewdirs=viewdirs)


def render_and_save_image(rays, images, index, a_vec, radiance_field, proposal_networks, estimator, 
                          num_samples, num_samples_per_prop, near_plane, far_plane, 
                          sampling_type, opaque_bkgd, test_chunk_size, exp_dir):
    color_bkgd = torch.ones(3, device=images.device)

    with torch.no_grad():
        rgb, acc, depth, _ = render_image_with_propnet(
            radiance_field,
            proposal_networks,
            estimator,
            rays,
            num_samples=num_samples,
            num_samples_per_prop=num_samples_per_prop,
            near_plane=near_plane,
            far_plane=far_plane,
            sampling_type=sampling_type,
            opaque_bkgd=opaque_bkgd,
            render_bkgd=color_bkgd,
            test_chunk_size=test_chunk_size,
            img=None,
            a_vec=a_vec,
        )
        renders_dir = os.path.join(exp_dir, "renders")
        os.makedirs(renders_dir, exist_ok=True)

        imageio.imwrite(
            os.path.join(renders_dir, f"rgb_render_{index:08}.png"),
            (rgb.cpu().numpy() * 255).astype(np.uint8),
        )
        vis_depth = torch.log(depth)
        vis_depth -= torch.min(vis_depth)
        vis_depth /= torch.max(vis_depth)
        imageio.imwrite(
            os.path.join(renders_dir, f"rgb_depth_{index:08}.png"),
            (vis_depth.cpu().numpy() * 255).astype(np.uint8),
        )

# for index in tqdm.tqdm(range(images.shape[0])):
#     image_id = [index]
#     c2w = camtoworlds[image_id]  # (1, 4, 4)
#     rays = generate_rays(images, c2w, K)
#     img = images[index]
#     render_and_save_image(rays, images, index, img, radiance_field, proposal_networks, estimator, 
#                         num_samples, num_samples_per_prop, near_plane, far_plane, 
#                         sampling_type, opaque_bkgd, test_chunk_size, exp_dir)


In [6]:
import matplotlib.pyplot as plt

# image_ids = [413, 581, 70, 1]  # (N)

# fig, axs = plt.subplots(1, len(image_ids), figsize=(15,15))

# for i, image_id in enumerate(image_ids):
#     axs[i].imshow(images[image_id].cpu())
#     axs[i].set_title(f"Image {image_id}")
#     axs[i].axis('off')

# plt.show()


In [7]:
# import matplotlib.pyplot as plt
# W = 10
# H = 10
# fig, axs = plt.subplots(H,W, figsize=(15,15))

# base_id=450
# for i in range(H):
#     for j in range(W):
#         image_id = base_id + i*H+j
#         axs[i,j].imshow(images[image_id].cpu())
#         axs[i,j].set_title(f"Image {image_id}")
#         axs[i,j].axis('off')

# plt.show()

In [8]:


# color_bkgd = torch.ones(3, device=images.device)
# c2w = camtoworlds[[37]] 
# image_ids = [453, 457, 593, 199, 95, 154, 37, 457, 385, 335]
# img = images[image_ids]
# img = torch.permute(img, (0,3,1,2))
# img = (img / 255.0).cuda()


# fig, axs = plt.subplots(2, len(image_ids), figsize=(15,4))

# with torch.no_grad():
#     rays = generate_rays(images, c2w, K)
#     print(rays.origins.shape)
#     a_vec = radiance_field.appearance_encoding(img)  # (N, 48)

# for index, image_id in enumerate(image_ids):
#     with torch.no_grad():
#         rgb, acc, depth, _ = render_image_with_propnet(
#             radiance_field,
#             proposal_networks,
#             estimator,
#             rays,
#             num_samples=num_samples,
#             num_samples_per_prop=num_samples_per_prop,
#             near_plane=near_plane,
#             far_plane=far_plane,
#             sampling_type=sampling_type,
#             opaque_bkgd=opaque_bkgd,
#             render_bkgd=color_bkgd,
#             test_chunk_size=test_chunk_size,
#             img=None,
#             a_vec=a_vec[index],
#         )
#     axs[0,index].imshow(rgb.cpu())
#     axs[0,index].set_title(f"Render {image_id}")
#     axs[0,index].axis('off')
#     axs[1,index].imshow(img[[index]].permute((0,2,3,1)).squeeze().cpu())
#     axs[1,index].set_title(f"Image {image_id}")
#     axs[1,index].axis('off')

In [9]:

def interpolate_appearance(a_vec, M):
    # Determine the number of segments and frames per segment
    segments = a_vec.shape[0] - 1
    frames_per_segment = M // segments

    new_a_vecs = []

    for i in range(segments):
        # Prepare the target number of frames for this segment
        if i == segments - 1:  # last segment - can be slightly longer
            frames = frames_per_segment + M % segments
        else:
            frames = frames_per_segment

        output_frames = torch.linspace(0, 1, frames, device=a_vec.device)

        # Interpolate this segment
        for t in output_frames:
            new_a_vecs.append(torch.lerp(a_vec[i], a_vec[i+1], t))

    # Concatenate all segments
    new_a_vecs = torch.stack(new_a_vecs)

    return new_a_vecs

def slerp(q1, q2, t):
    """Spherical linear interpolation between two quaternions."""
    dot = torch.dot(q1 / torch.norm(q1), q2 / torch.norm(q2))
    dot = torch.clamp(dot, -1, 1)  # Avoid invalid values due to numerical errors
    theta = torch.acos(dot) * t
    q3 = (q2 - q1 * dot)
    q3 = q3 / torch.norm(q3)
    return torch.cos(theta) * q1 + torch.sin(theta) * q3



def interpolate_transforms(c2w, M):
    # Separate rotation and translation
    rotations = c2w[:, :3, :3]  # shape: (N, 3, 3)
    translations = c2w[:, :3, 3]  # shape: (N, 3)

    # Convert rotations to quaternions for smooth interpolation
    quaternions = matrix_to_quaternion(rotations)

    segments = len(quaternions) - 1
    frames_per_segment = M // segments

    new_quaternions = []
    new_translations = []

    for i in range(segments):
        # Prepare the target number of frames for this segment
        if i == segments - 1:  # last segment - can be slightly longer
            frames = frames_per_segment + M % segments
        else:
            frames = frames_per_segment

        output_frames = torch.linspace(0, 1, frames, device=c2w.device)

        # Interpolate this segment
        for t in output_frames:
            new_quaternions.append(slerp(quaternions[i], quaternions[i+1], t))
            new_translations.append(torch.lerp(translations[i], translations[i+1], t))


    # Concatenate all segments
    new_quaternions = torch.cat(new_quaternions).view(M, -1)
    new_translations = torch.cat(new_translations).view(M, -1)
    # Convert quaternions back to rotation matrices
    new_rotations = quaternion_to_matrix(new_quaternions)  # shape: (M, 3, 3)

    # Combine new rotations and translations
    new_c2w = torch.zeros((M, 4, 4))
    new_c2w[:, :3, :3] = new_rotations
    new_c2w[:, :3, 3] = new_translations
    new_c2w[:, 3, 3] = 1

    return new_c2w


In [10]:

M = 1000

# image_ids = [413, 581, 70, 1]# (N)
# image_ids = [37, 95, 70, 1]# (N)
image_ids = [453, 457, 593, 199, 95, 154, 37, 457, 385, 335] * 2 + [453]# (N)
c2w_ids = [1, 513, 95, 51, 596, 488, 483, 522, 546, 1] # (N)
c2w = camtoworlds[c2w_ids]  # (N, 4, 4)

# M = len(c2w_ids)

# manual modifications
c2w[0,1,3] += 0.03
c2w[0,0,3] += 0.04
c2w[-1,1,3] += 0.03
c2w[-1,0,3] += 0.04

img = images[image_ids]
img = torch.permute(img, (0,3,1,2))
img = (img / 255.0).cuda()

with torch.no_grad():
    a_vec = radiance_field.appearance_encoding(img)  # (N, 48)

new_c2w = interpolate_transforms(c2w, M).cuda()
new_a_vec = interpolate_appearance(a_vec, M)  # (M, 48)

for index in tqdm.tqdm(range(M)):
    rays = generate_rays(images, new_c2w[[index]], K)
    render_and_save_image(rays, images, index, new_a_vec[[index]], radiance_field, proposal_networks, estimator, 
                        num_samples, num_samples_per_prop, near_plane, far_plane, 
                        sampling_type, opaque_bkgd, test_chunk_size, exp_dir)

100%|██████████| 1000/1000 [2:43:53<00:00,  9.83s/it] 
