In this notebook, we demonstrate the basic usage of *Lightplane Renderer* and *Splatter*.

In [None]:
# import all needed packages
import sys
import os
import torchvision.transforms as transforms
from IPython.display import display, Markdown

import torch
import lightplane
from lightplane import LightplaneRenderer, LightplaneSplatter, LightplaneMLPSplatter, Rays

# make sure we import the correct utils directory by adding the "examples" dir to path
examples_dir = os.path.join(os.path.dirname(lightplane.__file__), "..", "examples")
sys.path.insert(0, examples_dir)
from utils.util.camera_util import generate_random_cameras, get_rays_for_extrinsic, get_sphere_cameras
from utils.util.grid_util import random_grid

## 1. Neural 3D Field Representation

*Lightplane* uses hybrid representations for rendering and splatting, where neural 3D fields are represented by the combination of 3D grids (we generalize it to 3D hash structure in paper) and tiny MLPs. 

The 3D grid is a list of 5-dim tensors, with shape $[ [ B, D_1, H_1, W_1, C ], [ B, D_2, H_2, W_2, C], ... [B, D_S, H_S, W_S, C] ]$, $S$ is the number of grids, $B$ is the batch size and $C$ is the feature dimension.

For voxel grid, $S = 1$, and $D_1, H_1, W_1$ is the Depth, Height and width of the voxel grid.

For TriPlanes, $S = 3$, and $D_1 = H_2 = W_3 = 1$, which is three planes. 

Easily, this design could support a mixture of arbitrary numbers voxel grids and planes. 

## 2. Lightplane Renderer

We now show how *Lightplane Renderer* works by rendering images from a feature grid corresponding to an RGB sphere.

We first initialize the rendering rays via the `Rays` object, which structure contains `directions`(with shape $[N,3]$), `origins`(with shape $[N,3]$), `near`(with shape $[N]$), `far`(with shape $[N]$) and `grid_idx`(with shape $[N]$) for rays.

In particularly, `grid_idx` is a tensor in `torch.long` format, indicating which `grid` this `ray` will render from. The value of `grid_idx` ranges from `[0, B)`. 

*Lightplane Renderer* samples `num_samples` points on the rays, and marches over them with the Emission Absorption algorithm to render the ray colors.

In [None]:
n_images = 16 # number of rendered images
image_size = 128 # size of the rendered images
B = 2 # number of scenes to render
near = 1.0 # near rendering plane
far = 5.0 # far rendering plane

device = torch.device("cuda")

R, T = get_sphere_cameras(
    n_cameras=n_images,
    elevation=30,
    distance=2.0,
    device=device,
) # get N cameras on a sphere
ray_dir, ray_org = get_rays_for_extrinsic(image_size, R, T) # get ray origins and ray directions

near_t = torch.ones(n_images, image_size, image_size, device=device).view(-1) * near
far_t = torch.ones(n_images, image_size, image_size, device=device).view(-1) * far

# We use grid_idx to indicate the correspondence between rays and different grids (batch-wise).
# grid_idx is a tensor (n_images * image_size * image_size, ), whose value range in [0, B)
grid_idx = torch.linspace(
    0, B-1, n_images, device=device
).round().int()[:, None, None].repeat(1, image_size, image_size)

rays = Rays(
    directions=ray_dir.reshape(-1, 3).contiguous(),
    origins=ray_org.reshape(-1, 3).contiguous(),
    grid_idx=grid_idx.reshape(-1).contiguous(),
    near=near_t.contiguous(),
    far=far_t.contiguous(),
)

*Lightplane Renderer* follows the popular practice of volumetric rendering. 
After sampling features from the provided grid-list `grid`, it regresses the points opacities and colors given sampling features and viewdirections. 

There are three MLPs inside the renderer: `trunk_mlp`, `opacity_mlp` and `color_mlp`:

- `trunk_mlp` is the base MLP before color (`color_mlp`) and opacity (`opacity_mlp`) regression 
- `opacity_mlp` takes the outputs of `trunk_mlp` as input and regress opacities of sampling points.
- `color_mlp` takes the outputs of `trunk_mlp` and view direction embedding as inputs, and outputs their colors. 

Additionally, *Renderer* could takes a seperate color grid for color regression, by inputing `color_feature_grid`.
Using seperate color grids requires `mlp_n_layers_trunk = 0`.

*Renderer* could also take `scaffold` as extra input, which voxel grid indicating the coarse occupancy of the 3D field. 
We will ommit the evaluation of MLPs when sampled scaffold entry is empty.

In [None]:
# Set up the parameters of the renderer
num_samples = 256  # number of sampled points along each ray
grid_chn = 16  # number of feature channels in the input grid

# configuration of the rendering MLP, we use a simple linear for all MLPs
mlp_hidden_chn = 16
mlp_n_layers_opacity = 1  # the number of layers in decoder's opacity mlp 
mlp_n_layers_trunk = 1  # the number of layers in decoder's trunk mlp
mlp_n_layers_color = 1  # the number of layers in decoder's color mlp

# we configure the renderer to have viewpoint independent colors
enable_direction_dependent_colors = False
ray_embedding_num_harmonics = None

renderer = LightplaneRenderer(
    num_samples=num_samples, 
    color_chn=3,
    grid_chn=grid_chn,
    mlp_hidden_chn=mlp_hidden_chn,
    mlp_n_layers_opacity=mlp_n_layers_opacity,
    mlp_n_layers_trunk=mlp_n_layers_trunk,
    mlp_n_layers_color=mlp_n_layers_color,
    ray_embedding_num_harmonics=ray_embedding_num_harmonics,
    opacity_init_bias=-1.0,  # the initial bias of the opacity MLP
    enable_direction_dependent_colors=enable_direction_dependent_colors,
    bg_color=1.0,
).to(device)

Here we initialize the feature grid-list `grid` to render. It corresponds to a single $64^3$ voxel grid representing a 3D sphere with random surface colors.

In [None]:
D, H, W = 64, 64, 64  # voxel grid size
g_xyz = torch.stack(  # 3D coordinates of voxel centers
    torch.meshgrid(
        torch.linspace(-1, 1, D, device=device),
        torch.linspace(-1, 1, H, device=device),
        torch.linspace(-1, 1, W, device=device),
    ),
    dim=-1,
)
# set all voxels outside a sphere radius of 0.75 to be nearly empty, and random if inside
inside_sphere_mask = g_xyz.norm(dim=-1) <= 0.75
g = 20 * torch.randn(B, D, H, W, grid_chn, device=device) * inside_sphere_mask[..., None]

# the grid-list `grid` is a list of tensors, each tensor of shape (B, D, H, W, grid_chn)
grid = [g]

In [None]:
# Render and display the results

# Render!
(
    ray_length_render,
    alpha_render,
    feature_render,
) = renderer(rays=rays, feature_grid=grid)

# reshape the rendered colors to a pil image and display
image_render = feature_render.reshape(n_images, image_size, image_size, 3).permute(0, 3, 1, 2)
display(Markdown("## Rendered colors"))
display(transforms.ToPILImage()(torch.cat(image_render.unbind(), dim=2)))

# convert the rendered alpha masks a pil image and display
mask_render = alpha_render.reshape(n_images, image_size, image_size)
display(Markdown("## Rendered alpha mask"))
display(transforms.ToPILImage()(torch.cat(mask_render.unbind(), dim=1)))

## 3. Splatter

Here we demonstrate how *Splatter* works with n_images input feature maps with shape $B\times H \times W \times F$, $F$ is feature dimension. 

Each input feature map has corresponding camera parameters, including intrinsic and extrinsic parameters, near/far values, corresponding grid idx, and a pixel-wise mask for splatting.

Similar to the ray-based design of rendering, the splatting is conducted on rays, which are from cast from pixels on input feature maps. 
We use the `Rays` class to store the information of rays.
Importantly, we pass splatting features into `rays.encoding`.

In [None]:
# define features that will be splatted to the grid
splatting_features = torch.randn(
    n_images, image_size, image_size, grid_chn, device=device
).view(-1, grid_chn)

rays = Rays(
    directions=(ray_dir.reshape(-1, 3)).contiguous(),
    origins=(ray_org.reshape(-1, 3)).contiguous(),
    grid_idx=(grid_idx).reshape(-1).contiguous(),
    near=near_t.reshape(-1).contiguous(),
    far=far_t.reshape(-1).contiguous(),
    encoding=splatting_features.contiguous() # splatted features are stored inside the encoding field
)

*Lightplane Splatter* samples `num_samples` points on the rays, and splats them to the target 3D Grid.

*Lightplane Splatter* can work in 2 modes (we split it into two `torch.nn.Module`): (1) without MLP and input grid (`LightplaneSplatter`); (2) with MLP and input grid (`LightplaneMLPSplatter`).

In mode (1), for each point along the ray, `LightplaneSplatter` splats ray encoding to the output grid without any MLP or input grids. 

In mode (2), for each point along the ray, `LightplaneMLPSplatter` samples a feature from an input grid, appends ray encoding, passes through mlp, and splats to the output grid.

In [None]:
# setting parameters
D, H, W = 64, 64, 64 # grid sizes
C = grid_chn # The feature dimension of grids. It is supposed to be the same as F (input feature dimension) in mode (2), while it could be different from F in mode (1).

grid_sizes = [[B, 1, H, W, C], [B, D, 1, W, C], [B, D, H, 1, C]] # a triplane grid sizes

input_grid = random_grid((1, D, H, W, grid_chn), device, requires_grad=True, is_triplane=True) # get a triplane input grid

num_samples = 128 # number of sampling points along each ray
use_input_grid = True # we use input grid

input_grid_chn = grid_chn # features from input_grid would be summed to splatting features, so they have the same feature dimension.
mlp_hidden_chn = 32 # the mlp hidden layer sizes of MLP insider *Splatter*.
mlp_n_layers = 2 # the mlp depths of MLP insider *Splatter*.

num_samples_inf = 0 # additional sampling numbers for unbounded regions
contract_coords = False # whether or not use contract coordinates

mask_out_of_bounds_samples = False # whether or not mask OOB samples.
rays_jitter_near_far = True # jitter the sampling points 


We first demonstrate the usage of *LightplaneSplatter* (mode 1), which takes splatting features and directly splates them into 3D grids without any MLPs

In [None]:
# initialize splatter
splatter = LightplaneSplatter(
    num_samples=num_samples,
    grid_chn=C,
).to(device)
# output splatting results
output_grid = splatter(
    rays=rays, 
    grid_size=grid_sizes,
    mask_out_of_bounds_samples=mask_out_of_bounds_samples,
    rays_jitter_near_far=rays_jitter_near_far,
    num_samples_inf=num_samples_inf,
    contract_coords=contract_coords
)
print(output_grid[0].shape)
print(output_grid[1].shape)
print(output_grid[2].shape)

We then demonstrate the usage of *LightplaneMLPSplatter* (mode 2), which samples additional features from input_grid, and splate feature


In [None]:
# initialize splatter
splatter = LightplaneMLPSplatter(
    num_samples=num_samples,
    grid_chn=C,
    input_grid_chn=input_grid_chn,
    mlp_hidden_chn=mlp_hidden_chn,
    mlp_n_layers=mlp_n_layers,
).to(device)

In [None]:
# output splatting results
output_grid = splatter(
    rays=rays, 
    grid_size=grid_sizes,
    input_grid=input_grid, 
    mask_out_of_bounds_samples=mask_out_of_bounds_samples,
    rays_jitter_near_far=rays_jitter_near_far,
    num_samples_inf=num_samples_inf,
    contract_coords=contract_coords
)
print(output_grid[0].shape)
print(output_grid[1].shape)
print(output_grid[2].shape)