# NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

- Ref: 
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/NeRF_Representing_Scenes_as_Neural_Radiance_Fields_for_View_Synthesis
    - https://arxiv.org/abs/2003.08934

- Conda env : [gsplat](../gsplat/README.md#setup-a-conda-environment)
    

In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torchshow as ts

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/00_NeRF_models/'
g_video_path = './temp/00_NeRF_novel_views/'
g_fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
g_training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
g_testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))

In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
else:
    g_device = torch.device("cpu")
print(g_device)

Current memory allocated on MPS: 0 bytes
Driver memory allocated on MPS: 393216 bytes
mps


In [5]:
# NERF Model

class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
        super(NerfModel, self).__init__()

        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        # density estimation
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        # color estimation
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, o, d):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos) # emb_x: [batch_size, embedding_dim_pos * 6]
        emb_d = self.positional_encoding(d, self.embedding_dim_direction) # emb_d: [batch_size, embedding_dim_direction * 6]
        h = self.block1(emb_x) # h: [batch_size, hidden_dim]
        tmp = self.block2(torch.cat((h, emb_x), dim=1)) # tmp: [batch_size, hidden_dim + 1]
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1]) # h: [batch_size, hidden_dim], sigma: [batch_size]
        h = self.block3(torch.cat((h, emb_d), dim=1)) # h: [batch_size, hidden_dim // 2]
        c = self.block4(h) # c: [batch_size, 3]
        return c, sigma

In [6]:
# Rendering

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)

def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device

    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    # Compute the 3D points along each ray
    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)   # [batch_size, nb_bins, 3]
    # Expand the ray_directions tensor to match the shape of x
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    # Compute the pixel values as a weighted sum of colors along each ray
    c = (weights * colors).sum(dim=1)
    weight_sum = weights.sum(-1).sum(-1)  # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)

In [7]:
# test function

@torch.no_grad()
def test(model, hn, hf, dataset, device = 'cpu', chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
    """
    Args:
        hn: near plane distance
        hf: far plane distance
        dataset: dataset to render
        chunk_size (int, optional): chunk size for memory efficiency. Defaults to 10.
        img_index (int, optional): image index to render. Defaults to 0.
        nb_bins (int, optional): number of bins for density estimation. Defaults to 192.
        H (int, optional): image height. Defaults to 400.
        W (int, optional): image width. Defaults to 400.

    Returns:
        None: None
    """
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []   # list of regenerated pixel values
    for i in range(int(np.ceil(H / chunk_size))):   # iterate over chunks
        # Get chunk of rays
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())
    img_t = torch.cat(data).reshape(H, W, 3)
    return img_t

In [8]:
# train function
def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5), nb_bins=192, H=400, W=400, eval_steps = 5):
    training_loss = []
    for e in (range(nb_epochs)):
        print(f"{e:03d}-epoch")
        for ep, batch in enumerate(tqdm(data_loader, total = len(data_loader), desc="Training")):
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())

        scheduler.step()

        _model_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pth')
        torch.save(nerf_model.state_dict(), _model_path)

        if e % eval_steps == 0 or e == nb_epochs - 1:
            nerf_model.eval()
            imgT_lst = []
            for idx in tqdm(range(200), desc="Validation"):
                imgT_lst.append(test(nerf_model, 2, 6, g_testing_dataset, device = g_device, img_index=idx, nb_bins=192).unsqueeze(0))
            img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
            _video_path = os.path.join(g_video_path,  f'nerf_model_{e:03d}.mp4')
            torchvision.io.write_video(_video_path, img_t, g_fps)
            nerf_model.train()

    return training_loss

In [9]:

model = NerfModel(hidden_dim=256).to(g_device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

data_loader = DataLoader(g_training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6, nb_bins=192)



000-epoch


Training: 100%|██████████| 15625/15625 [44:52<00:00,  5.80it/s]
Validation: 100%|██████████| 200/200 [29:20<00:00,  8.80s/it]


001-epoch


Training: 100%|██████████| 15625/15625 [43:20<00:00,  6.01it/s]


002-epoch


Training: 100%|██████████| 15625/15625 [44:33<00:00,  5.84it/s]


003-epoch


Training: 100%|██████████| 15625/15625 [42:59<00:00,  6.06it/s]


004-epoch


Training: 100%|██████████| 15625/15625 [41:41<00:00,  6.25it/s]


005-epoch


Training: 100%|██████████| 15625/15625 [40:02<00:00,  6.50it/s]
Validation: 100%|██████████| 200/200 [26:54<00:00,  8.07s/it]


006-epoch


Training: 100%|██████████| 15625/15625 [40:33<00:00,  6.42it/s]


007-epoch


Training: 100%|██████████| 15625/15625 [42:05<00:00,  6.19it/s]


008-epoch


Training: 100%|██████████| 15625/15625 [39:49<00:00,  6.54it/s]


009-epoch


Training: 100%|██████████| 15625/15625 [39:37<00:00,  6.57it/s]


010-epoch


Training: 100%|██████████| 15625/15625 [39:45<00:00,  6.55it/s]
Validation: 100%|██████████| 200/200 [26:56<00:00,  8.08s/it]


011-epoch


Training: 100%|██████████| 15625/15625 [40:14<00:00,  6.47it/s]


012-epoch


Training: 100%|██████████| 15625/15625 [39:55<00:00,  6.52it/s]


013-epoch


Training: 100%|██████████| 15625/15625 [39:43<00:00,  6.56it/s]


014-epoch


Training: 100%|██████████| 15625/15625 [39:40<00:00,  6.57it/s]


015-epoch


Training: 100%|██████████| 15625/15625 [39:40<00:00,  6.56it/s]
Validation: 100%|██████████| 200/200 [26:52<00:00,  8.06s/it]


[363.03564453125,
 394.89068603515625,
 346.5334167480469,
 365.901611328125,
 370.9471740722656,
 379.0926818847656,
 354.70904541015625,
 363.7463684082031,
 371.90594482421875,
 398.1034240722656,
 392.65472412109375,
 381.3522644042969,
 362.8458251953125,
 354.01397705078125,
 353.34075927734375,
 342.874755859375,
 337.07769775390625,
 371.3186340332031,
 358.42730712890625,
 341.6618347167969,
 386.69903564453125,
 380.487060546875,
 345.80718994140625,
 357.4534912109375,
 392.09051513671875,
 536.1439819335938,
 619.8753662109375,
 610.9814453125,
 595.3414306640625,
 578.9378662109375,
 557.20263671875,
 550.591796875,
 506.681640625,
 484.7054443359375,
 441.85345458984375,
 380.54998779296875,
 299.2256774902344,
 279.8492126464844,
 295.8323974609375,
 335.7489013671875,
 327.62933349609375,
 330.376953125,
 317.08349609375,
 294.1617736816406,
 292.7742004394531,
 300.8305969238281,
 283.602294921875,
 278.0166931152344,
 269.1240234375,
 285.2232666015625,
 275.969055175