In [None]:
#| default_exp nerf

In [None]:
#| include: false
from fastcore.all import *

In [None]:
#| exporti
from dataclasses import dataclass
import json
import numpy as np
import torch.nn.functional as F

import PIL
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
# | exporti
WHITE = torch.full((3,), 1.0, dtype=torch.float)
DEVICE = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"Using device: {DEVICE}")

Using device: mps


# Neural Radiance Fields

Most of the code is defined in the book as well, but here we more thoroughly test it.


### Sampling from Rays

Given a point $P$ on the ray at a distance $t$ from the origin $O$, in the direction $D$ is given as

$$
P(t,O,D) = O + t  D
$$

In [None]:
def sample_along_ray(t_values, origins, directions):
    """Sample points along rays defined by origins and (unit-norm) directions."""
    return origins[..., None, :] + t_values[:, None] * directions[..., None, :]

Notice that the way we implemented `sample_along_ray` takes care to handle *arbitrary* batches of origin/direction pairs, as long as their last dimensions is 3:

In [None]:
t_values = torch.tensor([1, 2, 3, 4, 5])
origins = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
directions = torch.tensor([[1.0, 0.0, 0.0], [1.0/ np.sqrt(2), 1.0/ np.sqrt(2), 0.0]])

samples = sample_along_ray(t_values, origins, directions)

test_eq(samples.shape, torch.Size([2, 5, 3]))

The last line above asserts that we sampled 2 rays for 5 different $t$-values, each of them being 3-dimensional points as expected.

### Integration along Rays

Assuming that we are given the densities $\sigma_i$ and colors $c_i$ at $N$ sampled points $P_i$ on a ray corresponding to a given pixel, then we can calculate the color for the ray using the equation below,

$$
C = \sum_{i=1}^N T_i \alpha_i c_i
$$

where $T_i$ is the **transmittance**:

$$
T_i \doteq \exp ( - \sum_{j=1}^{i-1} \sigma_j)
$$

The transmittance $T_i$ measures the *lack* of occlusion in the space between the $i^th$ sample and the ray origin. The quantity $\alpha_i$, on the other hand, is the alpha value or **opacity** at the $i^th$ sample, defined as

$$
\alpha_i \doteq 1 - \exp(-\sigma_i).
$$

In [None]:
def render_along_ray(density, rgb, background=WHITE):
    """Compute the final rendered color given the density and RGB values."""
    alpha = 1 - torch.exp(-density)
    cumulative_density = torch.cumsum(density, dim=-1)
    trans = torch.exp(-cumulative_density)
    trans = torch.cat([torch.ones_like(density[..., :1]), trans[..., :-1]], dim=-1)
    
    weights = alpha * trans
    color_acc = torch.einsum('...i,...ij->...j', weights, rgb)
    acc = weights.sum(dim=-1, keepdim=True)

    return color_acc + (1.0 - acc) * background

Test using randomly generated `density` and `rgb` inputs that have the same shape as our sampled rays from above, asserting that we indeed get *two* RGB colors as the end-result:

In [None]:
density = torch.rand(2, 5) # Random density
rgb = torch.rand(2, 5, 3) # Random colors (between 0 and 1)
rendered = render_along_ray(density, rgb)
test_eq( rendered.shape, torch.Size([2, 3]))
print(rendered.detach().numpy())

[[0.7174011  0.6973155  0.4935117 ]
 [0.7510586  0.77856576 0.30422565]]


## A Differentiable Voxel Grid

In [None]:
def interpolate(v0, v1, alpha):
    """Interpolate between v0 and v1 using alpha, using unsqueeze to properly handle batches."""
    return v0 * (1 - alpha.unsqueeze(-1)) + v1 * alpha.unsqueeze(-1)

class VoxelGrid(nn.Module):
    def __init__(self, shape, d=1, max=1.0):
        """A 3D voxel grid with given `shape` with learnable values at the corners of the voxels."""
        super(VoxelGrid, self).__init__()
        self.grid = nn.Parameter(torch.rand(*shape, d) * max)


    def forward(self, P):
        """Implement trilinear interpolation at the points P."""
        x, y, z = P[..., 0], P[..., 1], P[..., 2]

        # Get indices of the corners, clamping to the grid size where needed:
        X0, Y0, Z0 = torch.floor(x).long(), torch.floor(y).long(), torch.floor(z).long()
        X1 = torch.clamp(X0 + 1, max=self.grid.shape[0] - 1)
        Y1 = torch.clamp(Y0 + 1, max=self.grid.shape[1] - 1)
        Z1 = torch.clamp(Z0 + 1, max=self.grid.shape[2] - 1)

        # Get blending weights along each axis:
        a, b, c = x - X0, y - Y0, z - Z0

        # Interpolate in the x direction:
        y0z0 = interpolate(self.grid[X0, Y0, Z0, :], self.grid[X1, Y0, Z0, :], a)
        y1z0 = interpolate(self.grid[X0, Y1, Z0, :], self.grid[X1, Y1, Z0, :], a)
        y0z1 = interpolate(self.grid[X0, Y0, Z1, :], self.grid[X1, Y0, Z1, :], a)
        y1z1 = interpolate(self.grid[X0, Y1, Z1, :], self.grid[X1, Y1, Z1, :], a)

        # Interpolate in the y direction:
        z0 = interpolate(y0z0, y1z0, b)
        z1 = interpolate(y0z1, y1z1, b)
        
        # Interpolate in the z direction:
        return interpolate(z0, z1, c).squeeze(-1)

The code below initializes a VoxelGrid with random values, and then evaluates the a scalar function at a 3D point:

In [None]:
voxel_grid_module = VoxelGrid(shape=(6, 6, 6), d=1)
point = torch.Tensor([1.5, 2.7, 3.4])
output = voxel_grid_module(point)
print(f"Interpolated Output: {output.item():.5f}")
test_eq(output.shape, torch.Size([]))

Interpolated Output: 0.61790


Below we create a grid that interpolates a four-dimensional function (`d=4`), and evaluate it at a 2x2 batch `x` of 3D points:

In [None]:
voxel_grid_module = VoxelGrid(shape = (6, 6, 6), d=4)

x = torch.Tensor([[[1.5, 2.7, 3.4], [2.3, 4.6, 1.1]], [[2.3, 4.6, 1.1], [2.3, 4.6, 1.1]]])
y = voxel_grid_module(x)
test_eq(x.shape, torch.Size([2, 2, 3]))
test_eq(y.shape, torch.Size([2, 2, 4]))
print("Interpolated Output:\n", y.detach().numpy())

Interpolated Output:
 [[[0.50723857 0.5354987  0.6874882  0.47449723]
  [0.38862017 0.43923682 0.55946934 0.5623423 ]]

 [[0.38862017 0.43923682 0.55946934 0.5623423 ]
  [0.38862017 0.43923682 0.55946934 0.5623423 ]]]


## DVGO

In [None]:
@dataclass
class Config:
    near: float = 1.5
    far: float = 3.5
    num_samples: int = 64
    min_corner: tuple[float] = (-1.0, -1.0, 0.0)
    max_corner: tuple[float] = (1.0, 1.0, 1.0)
    shape: tuple[int] = (128, 128, 128)
    background = WHITE

In [None]:
class SimpleDVGO(nn.Module):
    def __init__(self, config: Config = Config()):
        """Initialize voxel grids and bounding box corners."""
        super().__init__()  # Calling the superclass's __init__ method

        # Initialize sampler parameters:
        self.depths = torch.linspace(
            config.near, config.far, config.num_samples + 1, dtype=torch.float32
        )
        self.t_values = 0.5 * (self.depths[1:] + self.depths[:-1])

        # Set up conversion from scene coordinates to grid coordinates:
        self.min_corner = torch.tensor(config.min_corner, dtype=torch.float32)
        self.max_corner = torch.tensor(config.max_corner, dtype=torch.float32)
        self.scale = 1.0 / (self.max_corner - self.min_corner)
        self.float_shape = torch.tensor(config.shape, dtype=torch.float32)

        # Initialize differentiable voxel grids:
        self.rgb_voxel_grid = VoxelGrid(config.shape, d=3, max=1.0)
        self.density_voxel_grid = VoxelGrid(config.shape, d=1, max=0.1)

        # Finally, record background color for rendering:
        self.background = config.background

    def forward(self, x_samples):
        """Perform volume rendering using the provided ray information."""
        # Extract ray origins and directions from x_samples
        origins = x_samples[..., :3].to(dtype=torch.float32)
        directions = x_samples[..., 3:].to(dtype=torch.float32)

        # Sample along the ray
        samples = sample_along_ray(self.t_values, origins, directions)

        # Rescale to fit within the grid
        unclamped = (samples - self.min_corner) * self.scale
        rescaled = torch.clamp(unclamped, 0.0, 0.9999999) * self.float_shape

        # Query Density Voxel Grid
        density = F.softplus(torch.squeeze(self.density_voxel_grid(rescaled)))
        sparsity_penalty = torch.sum(torch.abs(density))

        # Query RGB Voxel Grid
        rgb = torch.sigmoid(self.rgb_voxel_grid(rescaled))

        # Render
        return render_along_ray(density, rgb, self.background), sparsity_penalty

Below we calculate the colors for 32 random rays, each with their origin and direction stacked into a 6-vector, so the input batch size is $32 \times 6$, and we expect an output batch size of RGB colors, i.e., $32 \times 3$:

In [None]:
# Initialize renderer
dvgo = SimpleDVGO()

x_samples = torch.rand((32, 6))
y_samples, sparsity_penalty = dvgo(x_samples)
# Verify shape of the output
test_eq(y_samples.shape, torch.Size([32, 3]))
test_eq(sparsity_penalty.shape, torch.Size([]))
print("sparsity_penalty:", sparsity_penalty.item())

sparsity_penalty: 131.33258056640625
