# FastNeRF: High-Fidelity Neural Rendering at 200FPS

---

- Conda env : [3dcv_playgrounds](../README.md#setup-a-conda-environment)

----

- Ref: 
    - https://arxiv.org/abs/2103.10380
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/FastNeRF_High_Fidelity_Neural_Rendering_at_200FPS


In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torchshow as ts

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/01_FastNeRF_models/'
g_video_path = './temp/01_FastNeRF_novel_views/'
g_fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
g_training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
g_testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))

In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
else:
    g_device = torch.device("cpu")
print(g_device)

Current memory allocated on MPS: 0 bytes
Driver memory allocated on MPS: 393216 bytes
mps


In [5]:
# Fast NERF Model

class FastNerf(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim_pos=384, hidden_dim_dir=128, D=8):
        super(FastNerf, self).__init__()

        self.Fpos = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, 3 * D + 1), )

        self.Fdir = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + 3, hidden_dim_dir), nn.ReLU(),
                                  nn.Linear(hidden_dim_dir, hidden_dim_dir), nn.ReLU(),
                                  nn.Linear(hidden_dim_dir, hidden_dim_dir), nn.ReLU(),
                                  nn.Linear(hidden_dim_dir, D), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.D = D

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, o, d):
        sigma_uvw = self.Fpos(self.positional_encoding(o, self.embedding_dim_pos))
        sigma = torch.nn.functional.softplus(sigma_uvw[:, 0][..., None])  # [batch_size, 1]
        uvw = torch.sigmoid(sigma_uvw[:, 1:].reshape(-1, 3, self.D))  # [batch_size, 3, D]

        beta = torch.softmax(self.Fdir(self.positional_encoding(d, self.embedding_dim_direction)), -1)
        color = (beta.unsqueeze(1) * uvw).sum(-1)  # [batch_size, 3]
        return color, sigma

In [6]:
class Cache(nn.Module):
    def __init__(self, model, scale, device, Np, Nd):
        super(Cache, self).__init__()

        with torch.no_grad():
            # Position
            x, y, z = torch.meshgrid([torch.linspace(-scale / 2, scale / 2, Np).to(device),
                                      torch.linspace(-scale / 2, scale / 2, Np).to(device),
                                      torch.linspace(-scale / 2, scale / 2, Np).to(device)])
            xyz = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)), dim=1)
            sigma_uvw = model.Fpos(model.positional_encoding(xyz, model.embedding_dim_pos))
            self.sigma_uvw = sigma_uvw.reshape((Np, Np, Np, -1))
            # Direction
            xd, yd = torch.meshgrid([torch.linspace(-scale / 2, scale / 2, Nd).to(device),
                                     torch.linspace(-scale / 2, scale / 2, Nd).to(device)])
            xyz_d = torch.cat((xd.reshape(-1, 1), yd.reshape(-1, 1),
                               torch.sqrt((1 - xd ** 2 - yd ** 2).clip(0, 1)).reshape(-1, 1)), dim=1)
            beta = model.Fdir(model.positional_encoding(xyz_d, model.embedding_dim_direction))
            self.beta = beta.reshape((Nd, Nd, -1))

        self.scale = scale
        self.Np = Np
        self.Nd = Nd
        self.D = model.D

    def forward(self, x, d):
        color = torch.zeros_like(x)
        sigma = torch.zeros((x.shape[0], 1), device=x.device)

        mask = (x[:, 0].abs() < (self.scale / 2)) & (x[:, 1].abs() < (self.scale / 2)) & (x[:, 2].abs() < (self.scale / 2))
        # Position
        idx = (x[mask] / (self.scale / self.Np) + self.Np / 2).long().clip(0, self.Np - 1)
        sigma_uvw = self.sigma_uvw[idx[:, 0], idx[:, 1], idx[:, 2]]
        # Direction
        idx = (d[mask] * self.Nd).long().clip(0, self.Nd - 1)
        beta = torch.softmax(self.beta[idx[:, 0], idx[:, 1]], -1)

        sigma[mask] = torch.nn.functional.softplus(sigma_uvw[:, 0][..., None])  # [batch_size, 1]
        uvw = torch.sigmoid(sigma_uvw[:, 1:].reshape(-1, 3, self.D))  # [batch_size, 3, D]
        color[mask] = (beta.unsqueeze(1) * uvw).sum(-1)  # [batch_size, 3]
        return color, sigma

In [7]:
def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=ray_origins.device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    weight_sum = weights.sum(-1).sum(-1) # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)


@torch.no_grad()
def test(model, hn, hf, dataset, device = 'cpu', img_index=0, nb_bins=192, H=400, W=400):
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]
    regenerated_px_values = render_rays(model, ray_origins.to(device), ray_directions.to(device), hn=hn, hf=hf,
                                        nb_bins=nb_bins)

    # plt.figure()
    # plt.imshow(regenerated_px_values.data.cpu().numpy().reshape(H, W, 3).clip(0, 1))
    # plt.axis('off')
    # plt.savefig(f'./temp/01_novel_views/img_{img_index}.png', bbox_inches='tight')
    # plt.close()
    return regenerated_px_values.cpu().reshape(H, W, 3)


def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5), nb_bins=192, eval_steps = 5):
    training_loss = []
    for e in (range(nb_epochs)):
        print(f"epoch : {e}")
        for ep, batch in enumerate(tqdm(data_loader)):

            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()
        # torch.save(nerf_model.cpu(), 'nerf_model')
        torch.save(nerf_model.state_dict(), f'./temp/01_models/nerf_model_{e:03d}.pth')
        # nerf_model.to(device)
        _model_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pth')
        torch.save(nerf_model.state_dict(), _model_path)

        if e % eval_steps == 0 or e == nb_epochs - 1:
            nerf_model.eval()
            imgT_lst = []
            cache = Cache(nerf_model, 2.2, g_device, 192, 128)
            for idx in tqdm(range(200), desc="Validation"):
                imgT_lst.append(test(cache, 2, 6, g_testing_dataset, device = g_device, img_index=idx, nb_bins=192).unsqueeze(0))
            img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
            _video_path = os.path.join(g_video_path,  f'nerf_model_{e:03d}.mp4')
            torchvision.io.write_video(_video_path, img_t, g_fps)
            nerf_model.train()
    return training_loss

In [8]:


model = FastNerf().to(g_device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

data_loader = DataLoader(g_training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6)

cache = Cache(model, 2.2, g_device, 192, 128)
for idx in range(200):
    test(cache, 2., 6., g_testing_dataset, g_device, img_index=idx, nb_bins=192, H=400, W=400)

epoch : 0


100%|██████████| 15625/15625 [54:07<00:00,  4.81it/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Validation: 100%|██████████| 200/200 [01:46<00:00,  1.87it/s]


epoch : 1


100%|██████████| 15625/15625 [55:09<00:00,  4.72it/s] 


epoch : 2


100%|██████████| 15625/15625 [1:00:51<00:00,  4.28it/s]


epoch : 3


100%|██████████| 15625/15625 [55:42<00:00,  4.67it/s]


epoch : 4


100%|██████████| 15625/15625 [55:49<00:00,  4.66it/s]


epoch : 5


100%|██████████| 15625/15625 [57:57<00:00,  4.49it/s]
Validation: 100%|██████████| 200/200 [02:14<00:00,  1.49it/s]


epoch : 6


100%|██████████| 15625/15625 [1:01:32<00:00,  4.23it/s]


epoch : 7


100%|██████████| 15625/15625 [1:01:23<00:00,  4.24it/s]


epoch : 8


100%|██████████| 15625/15625 [1:00:28<00:00,  4.31it/s]


epoch : 9


100%|██████████| 15625/15625 [56:37<00:00,  4.60it/s] 


epoch : 10


100%|██████████| 15625/15625 [54:14<00:00,  4.80it/s]
Validation: 100%|██████████| 200/200 [01:46<00:00,  1.87it/s]


epoch : 11


100%|██████████| 15625/15625 [55:54<00:00,  4.66it/s] 


epoch : 12


100%|██████████| 15625/15625 [1:02:01<00:00,  4.20it/s]


epoch : 13


100%|██████████| 15625/15625 [59:34<00:00,  4.37it/s] 


epoch : 14


100%|██████████| 15625/15625 [56:57<00:00,  4.57it/s]  


epoch : 15


100%|██████████| 15625/15625 [55:46<00:00,  4.67it/s] 
Validation: 100%|██████████| 200/200 [01:31<00:00,  2.19it/s]


In [9]:
# cache = Cache(model, 2.2, g_device, 192, 128)
# imgT_lst = []
# for idx in range(3):
#     t =  test(cache, 2., 6., g_testing_dataset, g_device, img_index=idx, nb_bins=192, H=400, W=400)
#     imgT_lst.append(t)
# img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
# _video_path = os.path.join(g_video_path,  f'ntest.mp4')
# torchvision.io.write_video(_video_path, img_t, g_fps)