In [1]:
import torch
import matplotlib.pyplot as plt
from src.data import parse_nerf_synthetic, ImagesDataset, NerfData
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
def mip360_contract(coords: torch.Tensor) -> torch.Tensor:
    """Scene contraction from Mip-NeRF 360 https://arxiv.org/abs/2111.12077"""
    norm = torch.norm(coords, dim=-1, keepdim=True) # type: ignore
    return torch.where(norm <= 1., coords, (2. - 1./norm) * coords / norm) / 2.

def clipped_exponential_stepping(near: float, far: float, delta_min: float, delta_max: float, device: torch.device):
    """Clipped exponential stepping function as described in Instant-NGP paper"""
    t = near
    acc_ts, acc_steps = [t], []
    while t < far:
        step = min(delta_max, max(delta_min, t))
        t += step
        acc_ts.append(t)
        acc_steps.append(step)
    ts = torch.tensor(acc_ts[:-1], device=device)
    steps = torch.tensor(acc_steps, device=device)
    return ts, steps

def uniform_stepping(near: float, far: float, n_samples: float, device: torch.device):
    """uniform steps"""
    ts = torch.linspace(near, far, n_samples+1, device=device)
    steps = ts[1:] - ts[:-1]
    return ts[:-1], steps

In [11]:
data = parse_nerf_synthetic(Path("data/lego"), "test")
img_dataset = ImagesDataset(data)

In [26]:
factor = 80
rays_o = img_dataset[10]['rays_o'][::factor,::factor].reshape(-1,3)
rays_d = img_dataset[10]['rays_d'][::factor,::factor].reshape(-1,3)
# ts, _ = clipped_exponential_stepping(0., 25, 1, 1e10, torch.device('cpu'))
ts, _ = uniform_stepping(0., 25, 100, torch.device('cpu'))
samples = rays_o[:,None,:] + rays_d[:,None,:] * ts[None,:,None]
samples = samples.view(-1,3)
positions = torch.tensor([M[:3, 3].tolist() for M in data.cameras])
# samples = mip360_contract(samples)
# positions = mip360_contract(positions)

In [27]:
%matplotlib tk
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.scatter(positions[:,0], positions[:,1], positions[:,2], c='b')
ax.scatter(samples[:,0], samples[:,1], samples[:,2], c='r')
plt.show()