In [None]:
#| default_exp nerf

In [None]:
#| include: false
from fastcore.all import *

In [None]:
#| exporti
from dataclasses import dataclass
import json
import numpy as np
import torch.nn.functional as F

import PIL
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
# | exporti
WHITE = torch.full((3,), 1.0, dtype=torch.float)
BLACK = torch.full((3,), 0.0, dtype=torch.float)
DEVICE = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"Using device: {DEVICE}")

# Neural Radiance Fields

Most of the code is defined in the book as well, but here we more thoroughly test it.


### Sampling from Rays

Given a point $P$ on the ray at a distance $t$ from the origin $O$, in the direction $D$ is given as

$$
P(t,O,D) = O + t  D
$$

In [None]:
def sample_along_ray(t_values, origins, directions):
    """Sample points along rays defined by origins and (unit-norm) directions."""
    return origins[..., None, :] + t_values[:, None] * directions[..., None, :]

Notice that the way we implemented `sample_along_ray` takes care to handle *arbitrary* batches of origin/direction pairs, as long as their last dimensions is 3:

In [None]:
t_values = torch.tensor([1, 2, 3, 4, 5])
origins = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
directions = torch.tensor([[1.0, 0.0, 0.0], [1.0/ np.sqrt(2), 1.0/ np.sqrt(2), 0.0]])

samples = sample_along_ray(t_values, origins, directions)

test_eq(samples.shape, torch.Size([2, 5, 3]))

The last line above asserts that we sampled 2 rays for 5 different $t$-values, each of them being 3-dimensional points as expected.

### Integration along Rays

Assuming that we are given the densities $\sigma_i$ and colors $c_i$ at $N$ sampled points $P_i$ on a ray corresponding to a given pixel, then we can calculate the color for the ray using the equation below,

$$
C = \sum_{i=1}^N T_i \alpha_i c_i
$$

where $T_i$ is the **transmittance**:

$$
T_i \doteq \exp ( - \sum_{j=1}^{i-1} \sigma_j)
$$

The transmittance $T_i$ measures the *lack* of occlusion in the space between the $i^th$ sample and the ray origin. The quantity $\alpha_i$, on the other hand, is the alpha value or **opacity** at the $i^th$ sample, defined as

$$
\alpha_i \doteq 1 - \exp(-\sigma_i).
$$

In [None]:
def render_along_ray(density, rgb, background=WHITE):
    """Compute the final rendered color given the density and RGB values."""
    alpha = 1 - torch.exp(-density)
    cumulative_density = torch.cumsum(density, dim=-1)
    trans = torch.exp(-cumulative_density)
    trans = torch.cat([torch.ones_like(density[..., :1]), trans[..., :-1]], dim=-1)
    
    weights = alpha * trans
    color_acc = torch.einsum('...i,...ij->...j', weights, rgb)
    acc = weights.sum(dim=-1, keepdim=True)

    return color_acc + (1.0 - acc) * background

Test using randomly generated `density` and `rgb` inputs that have the same shape as our sampled rays from above, asserting that we indeed get *two* RGB colors as the end-result:

In [None]:
density = torch.rand(2, 5) # Random density
rgb = torch.rand(2, 5, 3) # Random colors (between 0 and 1)
rendered = render_along_ray(density, rgb)
test_eq( rendered.shape, torch.Size([2, 3]))
print(rendered.detach().numpy())

## A Differentiable Voxel Grid

In [None]:
def bracket(x, n):
    """Return the indices of the nearest grid points to x, as well as weights."""
    x0 = torch.floor(x).long()
    X0 = torch.clamp(x0, min=0, max=(n - 1))
    X1 = torch.clamp(x0 + 1, min=0, max=(n - 1))
    return X0, X1, torch.clamp(x - x0.float(), min=0.0, max=1.0)

In [None]:
def interpolate(v0, v1, alpha):
    """Interpolate between v0 and v1 using alpha, using unsqueeze to properly handle batches."""
    return v0 * (1 - alpha.unsqueeze(-1)) + v1 * alpha.unsqueeze(-1)

class VoxelGrid(nn.Module):
    def __init__(self, shape, d=1, max=1.0):
        """A 3D voxel grid with given `shape` with learnable values at the middle of the voxels."""
        super(VoxelGrid, self).__init__()
        # Note that we store *corner* values, so we need one more point in each dimension:
        storage_shape = tuple(s + 1 for s in shape)
        self.grid = nn.Parameter(torch.rand(*storage_shape, d) * max)

    def forward(self, P):
        """Implement trilinear interpolation at the points P."""
        x, y, z = P[..., 0], P[..., 1], P[..., 2]

        # Get indices of the corners, clamping to the grid size where needed:
        X0, X1, a = bracket(x, self.grid.shape[0])
        Y0, Y1, b = bracket(y, self.grid.shape[1])
        Z0, Z1, c = bracket(z, self.grid.shape[2])

        # Interpolate in the x direction:
        y0z0 = interpolate(self.grid[X0, Y0, Z0, :], self.grid[X1, Y0, Z0, :], a)
        y1z0 = interpolate(self.grid[X0, Y1, Z0, :], self.grid[X1, Y1, Z0, :], a)
        y0z1 = interpolate(self.grid[X0, Y0, Z1, :], self.grid[X1, Y0, Z1, :], a)
        y1z1 = interpolate(self.grid[X0, Y1, Z1, :], self.grid[X1, Y1, Z1, :], a)

        # Interpolate in the y direction:
        z0 = interpolate(y0z0, y1z0, b)
        z1 = interpolate(y0z1, y1z1, b)
        
        # Interpolate in the z direction:
        return interpolate(z0, z1, c).squeeze(-1)

The code below initializes a VoxelGrid with random values, and then evaluates the a scalar function at a 3D point:

In [None]:
voxel_grid_module = VoxelGrid(shape=(6, 6, 6), d=1)
point = torch.Tensor([1.5, 2.7, 3.4])
output = voxel_grid_module(point)
print(f"Interpolated Output: {output.item():.5f}")
test_eq(output.shape, torch.Size([]))

Below we create a grid that interpolates a four-dimensional function (`d=4`), and evaluate it at a 2x2 batch `x` of 3D points:

In [None]:
voxel_grid_module = VoxelGrid(shape = (6, 6, 6), d=4)

x = torch.Tensor([[[1.5, 2.7, 3.4], [2.3, 4.6, 1.1]], [[2.3, 4.6, 1.1], [2.3, 4.6, 1.1]]])
y = voxel_grid_module(x)
test_eq(x.shape, torch.Size([2, 2, 3]))
test_eq(y.shape, torch.Size([2, 2, 4]))
print("Interpolated Output:\n", y.detach().numpy())

## DVGO

In [None]:
@dataclass
class Config:
    near: float = 0.5
    far: float = 3.5
    num_samples: int = 64
    min_corner: tuple[float] = (-1.0, -1.0, -1.0)
    max_corner: tuple[float] = (1.0, 1.0, 1.0)
    shape: tuple[int] = (16, 16, 16)
    background = WHITE

In [None]:
def sample_rays(t_values, rays, training=True):
    """Sample points along the rays, using the t_values defined in the constructor.
        During training, add a small random scalar to t_values to prevent overfitting to the
        discrete sampling locations.
    """
    # Extract ray origins and directions from rays
    origins = rays[..., :3].to(dtype=torch.float32)
    directions = rays[..., 3:].to(dtype=torch.float32)

    # Add a small random scalar to t_values during training
    if training:
        with torch.no_grad():
            n = t_values.size(0)
            random_scalar = (torch.rand(n) - 0.5) / n
            actual_ts = t_values.clone() + random_scalar
    else:
        actual_ts = t_values.clone()

    # Sample along the ray
    return sample_along_ray(actual_ts, origins, directions)

In [None]:
class SimpleDVGO(nn.Module):
    def __init__(self, config: Config = Config()):
        """Initialize voxel grids and bounding box corners."""
        super().__init__()  # Calling the superclass's __init__ method

        # Initialize sampler parameters:
        self.depths = torch.linspace(
            config.near, config.far, config.num_samples + 1, dtype=torch.float32
        )
        self.t_values = 0.5 * (self.depths[1:] + self.depths[:-1])

        # Set up conversion from scene coordinates to grid coordinates:
        self.min = torch.tensor(config.min_corner, dtype=torch.float32)
        self.max = torch.tensor(config.max_corner, dtype=torch.float32)
        self.shape = torch.tensor(config.shape, dtype=torch.float32)

        # Initialize differentiable voxel grids:
        self.rgb_voxel_grid = VoxelGrid(config.shape, d=3, max=1.0)
        self.density_voxel_grid = VoxelGrid(config.shape, d=1, max=0.001)

        # Finally, record background color for rendering:
        self.background = config.background

    def forward(self, rays, training=True):
        """Perform volume rendering using the provided ray information."""
        samples = sample_rays(self.t_values, rays, training=training)

        # Rescale to fit within the grid
        rescaled = self.shape * (samples - self.min) / (self.max - self.min)

        # Query Density Voxel Grid
        density = torch.squeeze(self.density_voxel_grid(rescaled))
        density = F.relu(density)

        # Query RGB Voxel Grid
        rgb = torch.clamp(self.rgb_voxel_grid(rescaled), 0.0, 1.0)

        # Render
        return render_along_ray(density, rgb, self.background)
    
    def alpha(self):
        """return the alpha for the density voxel grid"""
        density = F.relu(self.density_voxel_grid.grid)
        return 1 - torch.exp(-density)


Below we calculate the colors for 32 random rays, each with their origin and direction stacked into a 6-vector, so the input batch size is $32 \times 6$, and we expect an output batch size of RGB colors, i.e., $32 \times 3$:

In [None]:
# Initialize renderer
dvgo = SimpleDVGO()

rays = torch.rand((32, 6))
colors = dvgo(rays)
# Verify shape of the output
test_eq(colors.shape, torch.Size([32, 3]))

## Some simple test setups

In [None]:
# Fill the rgb grid with ramps of red, green, and blue values:
X, Y, Z, _ = dvgo.rgb_voxel_grid.grid.shape

# Create ramps for each channel
# Each ramp is initially 1D, and we then unsqueeze to make it 4D with singleton dimensions where needed
red_ramp = torch.linspace(0, 1, X).unsqueeze(-1).unsqueeze(-1)  # Size: [X, 1, 1]
green_ramp = torch.linspace(0, 1, Y).unsqueeze(0).unsqueeze(-1)  # Size: [1, Y, 1]
blue_ramp = torch.linspace(0, 1, Z).unsqueeze(0).unsqueeze(0)  # Size: [1, 1, Z]

# When we assign these ramps to the grid, broadcasting will automatically expand them to the full size
dvgo.rgb_voxel_grid.grid.data[:, :, :, 0] = red_ramp
dvgo.rgb_voxel_grid.grid.data[:, :, :, 1] = green_ramp
dvgo.rgb_voxel_grid.grid.data[:, :, :, 2] = blue_ramp


In [None]:
plt.imshow(dvgo.rgb_voxel_grid.grid[4,:,:,:].detach().numpy());

In [None]:
plt.imshow(dvgo.rgb_voxel_grid.grid[:,12,:,:].detach().numpy());

In [None]:
plt.imshow(dvgo.rgb_voxel_grid.grid[:,:,15,:].detach().numpy());

In [None]:
# Let's check interpolation:
with torch.no_grad():
    P = torch.Tensor([[4, 12, 15], [4, 12, 15]])
    print(dvgo.rgb_voxel_grid(P))

In [None]:
# Fill the density with zeros, except for a cube in the middle:
dvgo.density_voxel_grid.grid.data[:, :, :, :] = 0.0

for i in range(X // 4, 1 + 3 * X // 4):
    for j in range(Y // 4, 1 + 3 * Y // 4):
        for k in range(Z // 4, 1 + 3 * Z // 4):
            dvgo.density_voxel_grid.grid.data[i, j, k, :] = 100.0


In [None]:
plt.imshow(torch.sum(dvgo.alpha(), axis=2).detach().numpy()/Z); plt.colorbar();

## Some orthographic renders

In [None]:
def create_rays(config: Config, face, off=1.0):
    """
    Create rays for an orthographic camera on one of the grid faces.
    It generates rays centered at the center of every pixel on the face.
    Takes as input a config and the face id "x", "-x", "y", "-y", "z", "-z".
    """
    # Get grid shape:
    n, m, p = config.shape
    def get_x(i):
        dx = (config.max_corner[0] - config.min_corner[0]) / n
        return config.min_corner[0] + (i + 0.5) * dx
    def get_y(j):
        dy = (config.max_corner[1] - config.min_corner[1]) / m
        return config.min_corner[1] + (j + 0.5) * dy
    def get_z(k):
        dz = (config.max_corner[2] - config.min_corner[2]) / p
        return config.min_corner[2] + (k + 0.5) * dz
    # Fill in the rays:
    if face == "x":
        rays = torch.zeros((m, p, 6))
        for j in range(m):
            for k in range(p):
                rays[j, k, :] = torch.tensor(
                    [config.min_corner[0]-off, get_y(j), get_z(k), 1.0, 0.0, 0.0]
                )
    elif face == "-x":
        rays = torch.zeros((m, p, 6))
        for j in range(m):
            for k in range(p):
                rays[j, k, :] = torch.tensor(
                    [config.max_corner[0]+off, get_y(j), get_z(k), -1.0, 0.0, 0.0]
                )
    elif face == "y":
        rays = torch.zeros((n, p, 6))
        for i in range(n):
            for k in range(p):
                rays[i, k, :] = torch.tensor(
                    [get_x(i), config.min_corner[1]-off, get_z(k), 0.0, 1.0, 0.0]
                )
    elif face == "-y":
        rays = torch.zeros((n, p, 6))
        for i in range(n):
            for k in range(p):
                rays[i, k, :] = torch.tensor(
                    [get_x(i), config.max_corner[1]+off, get_z(k), 0.0, -1.0, 0.0]
                )
    elif face == "z":
        rays = torch.zeros((n, m, 6))
        for i in range(n):
            for j in range(m):
                rays[i, j, :] = torch.tensor(
                    [get_x(i), get_y(j), config.min_corner[2]-off, 0.0, 0.0, 1.0]
                )
    elif face == "-z":
        rays = torch.zeros((n, m, 6))
        for i in range(n):
            for j in range(m):
                rays[i, j, :] = torch.tensor(
                    [get_x(i), get_y(j), config.max_corner[2]+off, 0.0, 0.0, -1.0]
                )
    else:
        raise ValueError("Invalid face id")
    return rays

In [None]:
# Let's create rays for the "x" face:
x_rays = create_rays(Config(), "x")
test_eq(x_rays.shape, torch.Size([16, 16, 6]))
half = 2.0/32
test_close(x_rays[0,0], torch.tensor([-2, -1+half,  -1+half, 1, 0, 0]), 1e-3)
test_close(x_rays[-1,-1], torch.tensor([-2, 1-half,  1-half, 1, 0, 0]), 1e-3)

In [None]:
# Sample from the x-face rays:
x_ray_samples = sample_rays(dvgo.t_values, x_rays, training=False)
test_eq(x_ray_samples.shape, torch.Size([16, 16, 64, 3]))
# Check that first sample is about 0.5 from the face:
test_close(x_ray_samples[0, 0, 0], np.array([-1.5, -0.9375, -0.9375]), 0.1)
# And that the last sample is about 0.5 from the back face:
test_close(x_ray_samples[0, 0, -1], np.array([1.5, -0.9375, -0.9375]), 0.1)

In [None]:
# Check scaled and bracketed coordinates:
rescaled = dvgo.shape * (x_ray_samples - dvgo.min) / (dvgo.max - dvgo.min)
test_eq(rescaled.shape, torch.Size([16, 16, 64, 3]))
middle = rescaled[8, 8]
test_eq(middle.shape, torch.Size([64, 3]))
print(middle[:, 0])
bracket(middle[:, 0],16)

Now, check the density and RGB values along this middle ray:

In [None]:
x, y, z = middle[..., 0], middle[..., 1], middle[..., 2]

# Get indices of the corners, clamping to the grid size where needed:
X0, X1, a = bracket(x, dvgo.rgb_voxel_grid.grid.shape[0])
Y0, Y1, b = bracket(y, dvgo.rgb_voxel_grid.grid.shape[1])
Z0, Z1, c = bracket(z, dvgo.rgb_voxel_grid.grid.shape[2])


In [None]:
# Interpolate in the x direction:
y0z0 = interpolate(dvgo.rgb_voxel_grid.grid[X0, Y0, Z0, :], dvgo.rgb_voxel_grid.grid[X1, Y0, Z0, :], a)
y1z0 = interpolate(dvgo.rgb_voxel_grid.grid[X0, Y1, Z0, :], dvgo.rgb_voxel_grid.grid[X1, Y1, Z0, :], a)
y0z1 = interpolate(dvgo.rgb_voxel_grid.grid[X0, Y0, Z1, :], dvgo.rgb_voxel_grid.grid[X1, Y0, Z1, :], a)
y1z1 = interpolate(dvgo.rgb_voxel_grid.grid[X0, Y1, Z1, :], dvgo.rgb_voxel_grid.grid[X1, Y1, Z1, :], a)

# Interpolate in the y direction:
z0 = interpolate(y0z0, y1z0, b)
z1 = interpolate(y0z1, y1z1, b)

# Interpolate in the z direction:
predicted_rgb = interpolate(z0, z1, c).squeeze(-1)

test_eq(predicted_rgb.shape, torch.Size([64, 3]))
test_close(predicted_rgb[0], torch.tensor([0.0, 0.5, 0.5]), 0.1)
test_close(predicted_rgb[32], torch.tensor([0.5, 0.5, 0.5]), 0.1)
test_close(predicted_rgb[-1], torch.tensor([1.0, 0.5, 0.5]), 0.1)

In [None]:
density = torch.squeeze(dvgo.density_voxel_grid(middle))
density = F.relu(density)
rgb = torch.clamp(dvgo.rgb_voxel_grid(middle), 0, 1)

In [None]:
# Check shapes and values:
test_eq(density.shape, torch.Size([64]))
test_eq(rgb.shape, torch.Size([64, 3]))
test_close(rgb[0], torch.tensor([0.0, 0.5, 0.5]), 0.1)
test_close(rgb[32], torch.tensor([0.5, 0.5, 0.5]), 0.1)
test_close(rgb[-1], torch.tensor([1.0, 0.5, 0.5]), 0.1)

In [None]:
# plot density using plotly, lines and markers:
px.line(x=dvgo.t_values, y=density.detach().numpy(), title="Density", markers="lines+markers")

In [None]:
# plot alpha using plotly, lines and markers:
px.line(x=dvgo.t_values, y=1-np.exp(-density.detach().numpy()), title="Alpha", markers="lines+markers")

In [None]:
# Plot transmittance using plotly, lines and markers:
transmittance = torch.exp(-torch.cumsum(density, dim=-1))
px.line(x=dvgo.t_values, y=transmittance.detach().numpy(), title="Transmittance", markers="lines+markers")

In [None]:
# Plot rgb the same way, one trace at a time:
fig = go.Figure()
colors =dict(zip([0,1,2], ["red", "green", "blue"]))
for i in range(3):
    fig.add_trace(
        go.Scatter(
            x=dvgo.t_values,
            y=rgb[:, i].detach().numpy(),
            mode="lines+markers",
            name=f"RGB"[i],
            marker=dict(size=5, color=colors[i]),
        )
    )
fig.update_layout(title="RGB")

In [None]:
# Make sure to use actual calculation:
with torch.no_grad():
    alpha = 1 - torch.exp(-density)
    cumulative_density = torch.cumsum(density, dim=-1)
    trans = torch.exp(-cumulative_density)
    trans = torch.cat([torch.ones_like(density[..., :1]), trans[..., :-1]], dim=-1)

    weights = alpha * trans
    color_acc = torch.einsum('...i,...ij->...j', weights, rgb)
    acc = weights.sum(dim=-1, keepdim=True)

    color = color_acc + (1.0 - acc) * WHITE

print(acc, color_acc, color)

In [None]:
px.line(x=dvgo.t_values, y=weights.detach().numpy(), title="Weights", markers="lines+markers")

Finally, render:

In [None]:
x_render = dvgo(x_rays, training=False).detach().numpy()
plt.imshow(x_render);