In [1]:
import torch
import mediapy as media
from mattport.nerf.field_modules import encoding


# NeRF Positional Encoding
First introduced in the original NeRF paper. This encoding assumes the inputs are between zero and one and can opperate on any dimensional input.

In [2]:
num_frequencies = 4
min_freq_exp = 0
max_freq_exp = 6
include_input = True
resolution = 128

encoder = encoding.NeRFEncoding(
    in_dim=2,
    num_frequencies=num_frequencies,
    min_freq_exp=min_freq_exp,
    max_freq_exp=max_freq_exp,
    include_input=include_input,
)

x_samples = torch.linspace(0, 1, resolution)
grid = torch.stack(torch.meshgrid([x_samples, x_samples], indexing="ij"), dim=-1)

encoded_values = encoder(grid)

print("Input Values:")
media.show_images(torch.moveaxis(grid, 2, 0), cmap="plasma", border=True)
print("Encoded Values:")
media.show_images(torch.moveaxis(encoded_values, 2, 0), cmap="plasma", border=True)


Input Values:


Encoded Values:


# Random Fourier Feature (RFF) Encoding
This encoding assumes the inputs are between zero and one and can opperate on any dimensional input.

In [3]:
num_frequencies = 8
scale = 10
resolution = 128

encoder = encoding.RFFEncoding(in_dim=2, num_frequencies=num_frequencies, scale=scale)

x_samples = torch.linspace(0, 1, resolution)
grid = torch.stack(torch.meshgrid([x_samples, x_samples], indexing="ij"), dim=-1)

encoded_values = encoder(grid)

print("Input Values:")
media.show_images(torch.moveaxis(grid, 2, 0), cmap="plasma", border=True)
print("Encoded Values:")
media.show_images(torch.moveaxis(encoded_values, 2, 0), cmap="plasma", border=True)


Input Values:


Encoded Values:


# Hash Encoding
The hash incoding was originally introduced in Instant-NGP. The encoding is optimized during training. This is a visualization of the initialization.

In [6]:
num_levels = 8
min_res = 2
max_res = 128
hash_table_size = 2**4  # Typically much larger tables are used

resolution = 128
slice = 0

# Fixing features_per_level to 3 for easy RGB visualization. Typical value is 2 in networks
features_per_level = 3

encoder = encoding.HashEncoding(
    num_levels=num_levels,
    min_res=min_res,
    max_res=max_res,
    hash_table_size=hash_table_size,
    features_per_level=features_per_level,
    hash_init_scale=0.001,
)

x_samples = torch.linspace(0, 1, resolution)
grid = torch.stack(torch.meshgrid([x_samples, x_samples, x_samples], indexing="ij"), dim=-1)

encoded_values = encoder(grid)

grid_slice = grid[slice, ...]
encoded_values_slice = encoded_values[slice, ...]

print("Input Values:")
media.show_images(torch.moveaxis(grid_slice, 2, 0), cmap="plasma", border=True)

print("Encoded Values:")
encoded_images = encoded_values_slice.view(resolution, resolution, num_levels, 3)
encoded_images = torch.moveaxis(encoded_images, 2, 0)
encoded_images -= torch.min(encoded_images)
encoded_images /= torch.max(encoded_images)
media.show_images(encoded_images.detach().numpy(), cmap="plasma", border=True)


Input Values:


Encoded Values:
