# Instant Neural Graphics Primitives with a Multiresolution Hash Encoding

---

- Conda env : [3dcv_playgrounds](../README.md#setup-a-conda-environment)

----

- Ref: 
    - https://arxiv.org/abs/2201.05989
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/Instant_Neural_Graphics_Primitives_with_a_Multiresolution_Hash_Encoding



In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/06_InstantNeRF_models/'
g_video_path = './temp/06_InstantNeRF_novel_views/'
fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))


In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
    !nvidia-smi
else:
    g_device = torch.device("cpu")
# g_device = torch.device("cpu")
print(g_device)

Wed Oct  1 09:15:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2080 Ti     On  |   00000000:01:00.0  On |                  N/A |
| 42%   48C    P3             52W /  250W |    1279MiB /  11264MiB |     23%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
# Model

class NGP(torch.nn.Module):

    def __init__(self, T, Nl, L, device, aabb_scale, F=2):
        super(NGP, self).__init__()
        self.T = T
        self.Nl = Nl
        self.F = F
        self.L = L  # For encoding directions
        self.aabb_scale = aabb_scale
        self.lookup_tables = torch.nn.ParameterDict(
            {str(i): torch.nn.Parameter((torch.rand(
                (T, 2), device=device) * 2 - 1) * 1e-4) for i in range(len(Nl))})
        self.pi1, self.pi2, self.pi3 = 1, 2_654_435_761, 805_459_861
        self.density_MLP = nn.Sequential(nn.Linear(self.F * len(Nl), 64),
                                         nn.ReLU(), nn.Linear(64, 16)).to(device)
        self.color_MLP = nn.Sequential(nn.Linear(27 + 16, 64), nn.ReLU(),
                                       nn.Linear(64, 64), nn.ReLU(),
                                       nn.Linear(64, 3), nn.Sigmoid()).to(device)

    def positional_encoding(self, x):
        out = [x]
        for j in range(self.L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, x, d):

        x /= self.aabb_scale
        mask = (x[:, 0].abs() < .5) & (x[:, 1].abs() < .5) & (x[:, 2].abs() < .5)
        x += 0.5  # x in [0, 1]^3

        color = torch.zeros((x.shape[0], 3), device=x.device)
        log_sigma = torch.zeros((x.shape[0]), device=x.device) - 100000
        features = torch.empty((x[mask].shape[0], self.F * len(self.Nl)), device=x.device)
        for i, N in enumerate(self.Nl):
            # Computing vertices, use nn.functional.grid_sample convention
            floor = torch.floor(x[mask] * N)
            ceil = torch.ceil(x[mask] * N)
            vertices = torch.zeros((x[mask].shape[0], 8, 3), dtype=torch.int64, device=x.device)
            vertices[:, 0] = floor
            vertices[:, 1] = torch.cat((ceil[:, 0, None], floor[:, 1, None], floor[:, 2, None]), dim=1)
            vertices[:, 2] = torch.cat((floor[:, 0, None], ceil[:, 1, None], floor[:, 2, None]), dim=1)
            vertices[:, 4] = torch.cat((floor[:, 0, None], floor[:, 1, None], ceil[:, 2, None]), dim=1)
            vertices[:, 6] = torch.cat((floor[:, 0, None], ceil[:, 1, None], ceil[:, 2, None]), dim=1)
            vertices[:, 5] = torch.cat((ceil[:, 0, None], floor[:, 1, None], ceil[:, 2, None]), dim=1)
            vertices[:, 3] = torch.cat((ceil[:, 0, None], ceil[:, 1, None], floor[:, 2, None]), dim=1)
            vertices[:, 7] = ceil

            # hashing
            a = vertices[:, :, 0] * self.pi1
            b = vertices[:, :, 1] * self.pi2
            c = vertices[:, :, 2] * self.pi3
            h_x = torch.remainder(torch.bitwise_xor(torch.bitwise_xor(a, b), c), self.T)

            # Lookup
            looked_up = self.lookup_tables[str(i)][h_x].transpose(-1, -2)
            volume = looked_up.reshape((looked_up.shape[0], 2, 2, 2, 2))
            features[:, i*2:(i+1)*2] = torch.nn.functional.grid_sample(
                volume,
                ((x[mask] * N - floor) - 0.5).unsqueeze(1).unsqueeze(1).unsqueeze(1)
                ).squeeze(-1).squeeze(-1).squeeze(-1)

        xi = self.positional_encoding(d[mask])
        h = self.density_MLP(features)
        log_sigma[mask] = h[:, 0]
        color[mask] = self.color_MLP(torch.cat((h, xi), dim=1))
        return color, torch.exp(log_sigma)

In [6]:
# Rendering

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor(
        [1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    # Compute the 3D points along each ray
    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)
    # Expand the ray_directions tensor to match the shape of x
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    alpha = 1 - torch.exp(-sigma.reshape(x.shape[:-1]) * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors.reshape(x.shape)).sum(dim=1)
    weight_sum = weights.sum(-1).sum(-1)  # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)

In [7]:
@torch.no_grad()
def test(nerf_model, hn, hf, dataset, chunk_size=5, img_index=0, nb_bins=192, H=400, W=400):
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())
    img_t = torch.cat(data).reshape(H, W, 3)
    return img_t

def train(nerf_model, optimizer, data_loader, device='cpu', hn=0, hf=1, start_epoch = 0, nb_epochs=int(1e5), nb_bins=192, H=400, W=400, eval_steps = 5):
    training_loss = []
    for e in (range(start_epoch, nb_epochs)):
        print(f"{e:03d}-epoch")
        for ep, batch in enumerate(tqdm(data_loader, total = len(data_loader), desc="Training")):
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        _model_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pth')
        torch.save(nerf_model.state_dict(), _model_path)
        # Export in TorchScript Format
        # _model_script_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pt')
        # _model_scripted = torch.jit.script(nerf_model)
        # _model_scripted.save(_model_script_path)

        if e % eval_steps == 0 or e == nb_epochs - 1:
            nerf_model.eval()
            imgT_lst = []
            for idx in tqdm(range(200), desc="Validation"):
                imgT_lst.append(test(nerf_model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192, H = H, W = W ).unsqueeze(0))
            img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
            _video_path = os.path.join(g_video_path,  f'nerf_model_{e:03d}.mp4')
            torchvision.io.write_video(_video_path, img_t, fps)
            nerf_model.train()

    return training_loss

In [8]:

L = 16
F = 2
T = 2**19
N_min = 16
N_max = 2048
b = np.exp((np.log(N_max) - np.log(N_min)) / (L - 1))
Nl = [int(np.floor(N_min * b**l)) for l in range(L)]
model = NGP(T, Nl, 4, g_device, 3)
model.to(g_device)
model_optimizer = torch.optim.Adam(
        [{"params": model.lookup_tables.parameters(), "lr": 1e-2, "betas": (0.9, 0.99), "eps": 1e-15, "weight_decay": 0.},
         {"params": model.density_MLP.parameters(), "lr": 1e-2,  "betas": (0.9, 0.99), "eps": 1e-15, "weight_decay": 10**-6},
         {"params": model.color_MLP.parameters(), "lr": 1e-2,  "betas": (0.9, 0.99), "eps": 1e-15, "weight_decay": 10**-6}])
# scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6, nb_bins=192, H=400, W=400)
# for img_index in range(200):
#     test(model, 2, 6, testing_dataset, img_index=img_index, nb_bins=192, H=800, W=800)
imgT_lst = []
for idx in tqdm(range(200), desc="Validation"):
        imgT_lst.append(test(model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192, H = 400, W = 400 ).unsqueeze(0))
img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
_video_path = os.path.join(g_video_path,  f'nerf_model_test.mp4')
torchvision.io.write_video(_video_path, img_t, fps)


000-epoch


Training: 100%|██████████| 15625/15625 [22:18<00:00, 11.68it/s] 
Validation: 100%|██████████| 200/200 [15:21<00:00,  4.61s/it]


001-epoch


Training: 100%|██████████| 15625/15625 [22:51<00:00, 11.39it/s]


002-epoch


Training: 100%|██████████| 15625/15625 [23:05<00:00, 11.28it/s]


003-epoch


Training: 100%|██████████| 15625/15625 [21:06<00:00, 12.34it/s]


004-epoch


Training: 100%|██████████| 15625/15625 [19:17<00:00, 13.50it/s]


005-epoch


Training: 100%|██████████| 15625/15625 [21:15<00:00, 12.25it/s]
Validation: 100%|██████████| 200/200 [15:29<00:00,  4.65s/it]


006-epoch


Training: 100%|██████████| 15625/15625 [21:23<00:00, 12.18it/s]


007-epoch


Training: 100%|██████████| 15625/15625 [18:02<00:00, 14.43it/s]


008-epoch


Training: 100%|██████████| 15625/15625 [18:37<00:00, 13.98it/s]


009-epoch


Training: 100%|██████████| 15625/15625 [19:22<00:00, 13.44it/s]


010-epoch


Training: 100%|██████████| 15625/15625 [20:38<00:00, 12.62it/s]
Validation: 100%|██████████| 200/200 [15:00<00:00,  4.50s/it]


011-epoch


Training: 100%|██████████| 15625/15625 [21:28<00:00, 12.12it/s]


012-epoch


Training: 100%|██████████| 15625/15625 [21:29<00:00, 12.12it/s]


013-epoch


Training: 100%|██████████| 15625/15625 [20:46<00:00, 12.53it/s]


014-epoch


Training: 100%|██████████| 15625/15625 [20:59<00:00, 12.41it/s]


015-epoch


Training: 100%|██████████| 15625/15625 [20:41<00:00, 12.59it/s]
Validation: 100%|██████████| 200/200 [15:04<00:00,  4.52s/it]
Validation: 100%|██████████| 200/200 [15:03<00:00,  4.52s/it]
