# PlenOctrees for Real-time Rendering of Neural Radiance Fields

- Ref: 
    - https://arxiv.org/abs/2103.14024
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/PlenOctrees_for_Real_time_Rendering_of_Neural_Radiance_Fields

- Conda env : [gsplat](../gsplat/README.md#setup-a-conda-environment)


In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torchshow as ts

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/04_PlenoxelsNeRF_models/'
g_video_path = './temp/04_PlenoxelsNeRF_novel_views/'
fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))


In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
else:
    g_device = torch.device("cpu")
print(g_device)

cuda


In [5]:
# Model

def eval_spherical_function(k, d):
    x, y, z = d[..., 0:1], d[..., 1:2], d[..., 2:3]

    # Modified from https://github.com/google/spherical-harmonics/blob/master/sh/spherical_harmonics.cc
    return 0.282095 * k[..., 0] + \
        - 0.488603 * y * k[..., 1] + 0.488603 * z * k[..., 2] - 0.488603 * x * k[..., 3] + \
        (1.092548 * x * y * k[..., 4] - 1.092548 * y * z * k[..., 5] + 0.315392 * (2.0 * z * z - x * x - y * y) * k[
               ..., 6] + -1.092548 * x * z * k[..., 7] + 0.546274 * (x * x - y * y) * k[..., 8])


class NerfModel(nn.Module):
    def __init__(self, N=256, scale=1.5):
        """
        :param N
        :param scale: The maximum absolute value among all coordinates for objects in the scene
        """
        super(NerfModel, self).__init__()

        self.voxel_grid = nn.Parameter(torch.ones((N, N, N, 27 + 1)) / 100)
        self.scale = scale
        self.N = N

    def forward(self, x, d):
        color = torch.zeros_like(x)
        sigma = torch.zeros((x.shape[0]), device=x.device)
        mask = (x[:, 0].abs() < self.scale) & (x[:, 1].abs() < self.scale) & (x[:, 2].abs() < self.scale)

        idx = (x[mask] / (2 * self.scale / self.N) + self.N / 2).long().clip(0, self.N - 1)
        tmp = self.voxel_grid[idx[:, 0], idx[:, 1], idx[:, 2]]
        sigma[mask], k = torch.nn.functional.relu(tmp[:, 0]), tmp[:, 1:]
        color[mask] = eval_spherical_function(k.reshape(-1, 3, 9), d[mask])
        return color, sigma

In [6]:
# Rendering

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    _device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=_device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=_device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=_device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    # Regularization for white background
    weight_sum = weights.sum(-1).sum(-1)
    return c + 1 - weight_sum.unsqueeze(-1)

In [7]:
@torch.no_grad()
def test(nerf_model, hn, hf, dataset, chunk_size=5, img_index=0, nb_bins=192, H=400, W=400):
    """
    Args:
        hn: near plane distance
        hf: far plane distance
        dataset: dataset to render
        chunk_size (int, optional): chunk size for memory efficiency. Defaults to 10.
        img_index (int, optional): image index to render. Defaults to 0.
        nb_bins (int, optional): number of bins for density estimation. Defaults to 192.
        H (int, optional): image height. Defaults to 400.
        W (int, optional): image width. Defaults to 400.

    Returns:
        None: None
    """
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())
    img_t = torch.cat(data).reshape(H, W, 3)
    return img_t

def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, start_epoch = 0, nb_epochs=int(1e5), nb_bins=192, H=400, W=400, eval_steps = 5):
    training_loss = []
    for e in (range(start_epoch, nb_epochs)):
        print(f"{e:03d}-epoch")
        for ep, batch in enumerate(tqdm(data_loader, total = len(data_loader), desc="Training")):
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            # loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
            loss = torch.nn.functional.mse_loss(ground_truth_px_values, regenerated_px_values)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()
        _model_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pth')
        torch.save(nerf_model.state_dict(), _model_path)
        # Export in TorchScript Format
        # _model_script_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pt')
        # _model_scripted = torch.jit.script(nerf_model)
        # _model_scripted.save(_model_script_path)

        if e % eval_steps == 0 or e == nb_epochs - 1:
            nerf_model.eval()
            imgT_lst = []
            for idx in tqdm(range(200), desc="Validation"):
                imgT_lst.append(test(nerf_model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192).unsqueeze(0))
            img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
            _video_path = os.path.join(g_video_path,  f'nerf_model_{e:03d}.mp4')
            torchvision.io.write_video(_video_path, img_t, fps)
            nerf_model.train()

    return training_loss

In [8]:


# model = NerfModel(N=256).to(g_device)

model = NerfModel(N=256)
# prev_mode_path = "/Users/hyunjae.k/110_HyunJae_Git/2025_Drills/DL_Drills/NERF/temp/05_PlenoxelsNeRF_models/nerf_model_006.pth"
# model.load_state_dict(torch.load(prev_mode_path))
model.to(g_device)


model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
# data_loader = DataLoader(training_dataset, batch_size=2048, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6, nb_bins=192)


000-epoch


Training: 100%|██████████| 15625/15625 [23:38<00:00, 11.01it/s]
Validation: 100%|██████████| 200/200 [00:31<00:00,  6.42it/s]


001-epoch


Training: 100%|██████████| 15625/15625 [23:43<00:00, 10.98it/s]


002-epoch


Training: 100%|██████████| 15625/15625 [23:31<00:00, 11.07it/s]


003-epoch


Training: 100%|██████████| 15625/15625 [23:35<00:00, 11.04it/s]


004-epoch


Training: 100%|██████████| 15625/15625 [23:33<00:00, 11.05it/s]


005-epoch


Training: 100%|██████████| 15625/15625 [23:14<00:00, 11.21it/s]
Validation: 100%|██████████| 200/200 [00:31<00:00,  6.43it/s]


006-epoch


Training: 100%|██████████| 15625/15625 [23:24<00:00, 11.13it/s]


007-epoch


Training: 100%|██████████| 15625/15625 [23:36<00:00, 11.03it/s]


008-epoch


Training: 100%|██████████| 15625/15625 [23:46<00:00, 10.95it/s]


009-epoch


Training: 100%|██████████| 15625/15625 [23:03<00:00, 11.29it/s]


010-epoch


Training: 100%|██████████| 15625/15625 [23:25<00:00, 11.12it/s]
Validation: 100%|██████████| 200/200 [00:30<00:00,  6.46it/s]


011-epoch


Training: 100%|██████████| 15625/15625 [22:58<00:00, 11.33it/s]


012-epoch


Training: 100%|██████████| 15625/15625 [23:15<00:00, 11.20it/s]


013-epoch


Training: 100%|██████████| 15625/15625 [23:32<00:00, 11.06it/s]


014-epoch


Training: 100%|██████████| 15625/15625 [23:13<00:00, 11.21it/s]


015-epoch


Training: 100%|██████████| 15625/15625 [23:16<00:00, 11.19it/s]
Validation: 100%|██████████| 200/200 [00:30<00:00,  6.49it/s]


[0.11428478360176086,
 0.12294070422649384,
 0.11439039558172226,
 0.11697125434875488,
 0.11787595599889755,
 0.1270960569381714,
 0.10222254693508148,
 0.11826828122138977,
 0.10894956439733505,
 0.11328887939453125,
 0.118632011115551,
 0.1247558444738388,
 0.1132109984755516,
 0.11685481667518616,
 0.10832562297582626,
 0.11279235780239105,
 0.10979463905096054,
 0.11694979667663574,
 0.1235223263502121,
 0.10752766579389572,
 0.10198225826025009,
 0.1278604120016098,
 0.11903528869152069,
 0.10467179864645004,
 0.1050359457731247,
 0.11044903099536896,
 0.11408694833517075,
 0.1238025650382042,
 0.12612739205360413,
 0.12009876221418381,
 0.10604840517044067,
 0.12086611241102219,
 0.10699339956045151,
 0.1180793046951294,
 0.10782232135534286,
 0.11351277679204941,
 0.09696772694587708,
 0.10897120088338852,
 0.1154244989156723,
 0.11695188283920288,
 0.11191485822200775,
 0.1065572053194046,
 0.11813395470380783,
 0.11952485889196396,
 0.11418783664703369,
 0.10835298895835876,


In [9]:
# from IPython.display import Video

# Video("./temp/02_kiloNeRF_novel_views/nerf_model_003.mp4")

In [10]:
# test(model, 2, 6, testing_dataset, chunk_size = 10000, img_index=1, nb_bins=192)

In [11]:
# model = NerfModel(N=256)
# prev_mode_path = "/Users/hyunjae.k/110_HyunJae_Git/2025_Drills/DL_Drills/NERF/temp/05_PlenoxelsNeRF_models/nerf_model_015.pth"
# model.load_state_dict(torch.load(prev_mode_path))
# model.to(g_device)

# model.eval()
# imgT_lst = []
# for idx in tqdm(range(200), desc="Validation"):
#     imgT_lst.append(test(model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192).unsqueeze(0))
# img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
# _video_path = os.path.join(g_video_path,  f'nerf_model_test.mp4')
# torchvision.io.write_video(_video_path, img_t, fps)
