Trains a NeRF using perspective projection and then renders out images using orthographic projection.

In [None]:
import os

import torch
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

from dataset import diff_rendering_dataset
from rendering import VolumeRenderer
from fields import RadianceField
from training import fit_inverse_graphics_representation
from utils import to_gpu

In [None]:
# Set device

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
print(f"Device: {device}")

In [None]:
# Load data

data_path = "./data/bunny"
cam2world = np.load(os.path.join(data_path, "cam2world.npy"))
images = np.load(os.path.join(data_path, "images.npy"))

cam2world = torch.Tensor(cam2world).to(device)
images = torch.tensor(images).to(device)
intrinsics = torch.tensor([[0.7, 0.0, 0.5],
                            [0.0, 0.7, 0.5],
                            [0.0, 0.0, 1.0]]).to(device)
print(cam2world.shape, images.shape)

In [None]:
# Generate dataset

bunny_dataset = diff_rendering_dataset(images, cam2world, device=device)
model_input, gt = next(bunny_dataset)

plt.imshow(gt.view(images.shape[1], images.shape[2], 3).detach().cpu())
plt.show()

In [None]:
# Train

radiance_field = RadianceField(scene_rep_name="HybridVoxelNeuralField", device=device).to(device)
renderer = VolumeRenderer(near=1.5, far=4.5, n_samples=128, white_back=True, rand=False).to(device)
img_resolution = (128, 128, 3)
fit_inverse_graphics_representation(
    representation=radiance_field,
    renderer=renderer,
    data_iter=bunny_dataset,
    img_resolution=img_resolution,
    lr=1e-3,
    total_steps=500,
)

In [None]:
# Render both perspective and orthographic images

orthographic_renderer = VolumeRenderer(
    near=1.5, far=4.5, n_samples=128, white_back=True, rand=False, orthographic=True
).to(device)

fig, axes = plt.subplots(2, 3, figsize=(18, 12), squeeze=False)

num_images = 3
for i in range(num_images):
    # Get next camera params
    cam_params, _ = next(bunny_dataset)
    cam_params = to_gpu(cam_params)

    # Render image with both perspective and orthographic projection
    rgb, _ = renderer(cam_params, radiance_field)
    orthographic_rgb, _ = orthographic_renderer(cam_params, radiance_field)
    
    axes[0, i].imshow(rgb.cpu().view(*img_resolution).detach().numpy())
    axes[0, i].set_axis_off()
    axes[0, i].set_title(f"Perspective {i}")

    axes[1, i].imshow(orthographic_rgb.cpu().view(*img_resolution).detach().numpy())
    axes[1, i].set_axis_off()
    axes[1, i].set_title(f"Orthographic {i}")
    
plt.show()
