# PlenOctrees for Real-time Rendering of Neural Radiance Fields

- Ref: 
    - https://arxiv.org/abs/2103.14024
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/PlenOctrees_for_Real_time_Rendering_of_Neural_Radiance_Fields

- Conda env : [gsplat](../gsplat/README.md#setup-a-conda-environment)


In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torchshow as ts

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/03_PlenOctreesNeRF_models/'
g_video_path = './temp/03_PlenOctreesNeRF_novel_views/'
fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))


In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
else:
    g_device = torch.device("cpu")
print(g_device)

Current memory allocated on MPS: 0 bytes
Driver memory allocated on MPS: 393216 bytes
mps


In [5]:
# Model

def to_cartesian(theta_phi):
    return torch.stack([torch.sin(theta_phi[:, 0]) * torch.cos(theta_phi[:, 1]),
                        torch.sin(theta_phi[:, 0]) * torch.sin(theta_phi[:, 1]),
                        torch.cos(theta_phi[:, 0])], axis=1)


class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=20, embedding_dim_direction=8, hidden_dim=128):
        super(NerfModel, self).__init__()

        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )

        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 3 + hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )

        self.block3 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 75), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()

        self.bandwidth = nn.Parameter(torch.zeros((1, 25)))
        self.p = nn.Parameter(torch.randn((25, 2)))

    @staticmethod
    def positional_encoding(x, L):
        out = torch.empty(x.shape[0], x.shape[1] * 2 * L, device=x.device)
        for i in range(x.shape[1]):
            for j in range(L):
                out[:, i * (2 * L) + 2 * j] = torch.sin(2 ** j * x[:, i])
                out[:, i * (2 * L) + 2 * j + 1] = torch.cos(2 ** j * x[:, i])
        return out

    def forward(self, o, d):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos // 2)
        h = self.block1(emb_x)
        tmp = self.block2(torch.cat((h, emb_x), dim=1))
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1])
        h = self.block3(h)
        k = self.block4(h).reshape(o.shape[0], 25, 3)

        c = (k * torch.exp(
            (self.bandwidth.unsqueeze(-1) * to_cartesian(self.p).unsqueeze(0) * d.unsqueeze(1)))).sum(1)

        return torch.sigmoid(c), sigma

In [6]:
# Rendering

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    _device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=_device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=_device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=_device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    # Regularization for white background
    weight_sum = weights.sum(-1).sum(-1)
    return c + 1 - weight_sum.unsqueeze(-1)

In [7]:
@torch.no_grad()
def test(nerf_model, hn, hf, dataset, chunk_size=5, img_index=0, nb_bins=192, H=400, W=400):
    """
    Args:
        hn: near plane distance
        hf: far plane distance
        dataset: dataset to render
        chunk_size (int, optional): chunk size for memory efficiency. Defaults to 10.
        img_index (int, optional): image index to render. Defaults to 0.
        nb_bins (int, optional): number of bins for density estimation. Defaults to 192.
        H (int, optional): image height. Defaults to 400.
        W (int, optional): image width. Defaults to 400.

    Returns:
        None: None
    """
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())
    img_t = torch.cat(data).reshape(H, W, 3)
    return img_t

def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, start_epoch = 0, nb_epochs=int(1e5), nb_bins=192, H=400, W=400, eval_steps = 5):
    training_loss = []
    for e in (range(start_epoch, nb_epochs)):
        print(f"{e:03d}-epoch")
        for ep, batch in enumerate(tqdm(data_loader, total = len(data_loader), desc="Training")):
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()
        _model_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pth')
        torch.save(nerf_model.state_dict(), _model_path)
        # Export in TorchScript Format
        # _model_script_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pt')
        # _model_scripted = torch.jit.script(nerf_model)
        # _model_scripted.save(_model_script_path)

        if e % eval_steps == 0 or e == nb_epochs - 1:
            nerf_model.eval()
            imgT_lst = []
            for idx in tqdm(range(200), desc="Validation"):
                imgT_lst.append(test(nerf_model, 2, 6, testing_dataset, img_index=idx, nb_bins=192).unsqueeze(0))
            img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
            _video_path = os.path.join(g_video_path,  f'nerf_model_{e:03d}.mp4')
            torchvision.io.write_video(_video_path, img_t, fps)
            nerf_model.train()

    return training_loss

In [8]:


model = NerfModel(hidden_dim=256).to(g_device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6, nb_bins=192)


000-epoch


Training: 100%|██████████| 15625/15625 [42:33<00:00,  6.12it/s]
Validation: 100%|██████████| 200/200 [29:59<00:00,  9.00s/it]


001-epoch


Training: 100%|██████████| 15625/15625 [42:49<00:00,  6.08it/s]


002-epoch


Training: 100%|██████████| 15625/15625 [42:57<00:00,  6.06it/s]


003-epoch


Training: 100%|██████████| 15625/15625 [42:44<00:00,  6.09it/s]


004-epoch


Training: 100%|██████████| 15625/15625 [42:24<00:00,  6.14it/s] 


005-epoch


Training: 100%|██████████| 15625/15625 [41:57<00:00,  6.21it/s]
Validation: 100%|██████████| 200/200 [29:39<00:00,  8.90s/it]


006-epoch


Training: 100%|██████████| 15625/15625 [42:28<00:00,  6.13it/s]


007-epoch


Training: 100%|██████████| 15625/15625 [42:34<00:00,  6.12it/s]


008-epoch


Training: 100%|██████████| 15625/15625 [42:22<00:00,  6.15it/s]


009-epoch


Training: 100%|██████████| 15625/15625 [42:25<00:00,  6.14it/s]


010-epoch


Training: 100%|██████████| 15625/15625 [41:59<00:00,  6.20it/s]
Validation: 100%|██████████| 200/200 [29:21<00:00,  8.81s/it]


011-epoch


Training: 100%|██████████| 15625/15625 [42:38<00:00,  6.11it/s]


012-epoch


Training: 100%|██████████| 15625/15625 [42:39<00:00,  6.10it/s]


013-epoch


Training: 100%|██████████| 15625/15625 [42:19<00:00,  6.15it/s]


014-epoch


Training: 100%|██████████| 15625/15625 [41:51<00:00,  6.22it/s]


015-epoch


Training: 100%|██████████| 15625/15625 [42:39<00:00,  6.10it/s]
Validation: 100%|██████████| 200/200 [30:16<00:00,  9.08s/it]


[488.8343505859375,
 462.5929870605469,
 425.7779235839844,
 384.8770751953125,
 344.699462890625,
 318.0718994140625,
 298.526123046875,
 276.50531005859375,
 259.2846984863281,
 287.4348449707031,
 274.85003662109375,
 339.58917236328125,
 319.3733825683594,
 286.33575439453125,
 267.37725830078125,
 242.57984924316406,
 281.8811340332031,
 272.3720703125,
 281.3577880859375,
 269.67034912109375,
 289.3138427734375,
 275.3322448730469,
 293.03814697265625,
 273.3975830078125,
 291.84478759765625,
 271.0476989746094,
 267.7910461425781,
 279.28118896484375,
 259.0105285644531,
 289.2955627441406,
 281.78717041015625,
 256.08514404296875,
 269.4617919921875,
 249.5538330078125,
 286.6634521484375,
 260.9109802246094,
 275.8546142578125,
 244.72561645507812,
 243.82485961914062,
 269.955078125,
 270.5818176269531,
 265.4052734375,
 271.59576416015625,
 245.99801635742188,
 257.4112548828125,
 261.13372802734375,
 246.929443359375,
 259.60906982421875,
 264.10546875,
 258.24969482421875,

In [9]:
# from IPython.display import Video

# Video("./temp/02_kiloNeRF_novel_views/nerf_model_003.mp4")