In [None]:
import pathlib as pl
import torch
from torch import nn
from torchvision import io
import json

data_source_name = "NeRF_Data"
dataset_name = "nerf_synthetic"
scene_name = "lego"

root_data_dir = pl.Path('./data/')
data_path = root_data_dir / data_source_name / dataset_name / scene_name

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset


digit_pattern = re.compile(r"\d+")


def load_transforms(tfs_path: pl.Path) -> tuple[float, list[float], list[torch.FloatTensor]]:
    with open(tfs_path, "r") as f:
        transforms = json.load(f)

    cam_angle_x = float(transforms["camera_angle_x"])
    rotations = []
    transform_matrixes = []

    for frame in transforms["frames"]:       
        # Assume ordered
        rotation = float(frame["rotation"])
        transform_matrix = torch.FloatTensor(frame["transform_matrix"])

        rotations.append(rotation)
        transform_matrixes.append(transform_matrix)

    return cam_angle_x, rotations, transform_matrixes
        

def extract_digit_from_path_name(path: pl.Path) -> int:
    match = digit_pattern.search(path.name)

    if not match:
        return None

    return int(match.group(0))


def load_img_paths(imgs_path: pl.Path):
    paths = imgs_path.iterdir() # Ordered lexagraphically
    paths = sorted(paths, key=extract_digit_from_path_name) # Ordered numerically
    return paths


def load_frame(
        imgs_path: pl.Path,
        rotations: list[float],
        tf_matrixes: list[torch.FloatTensor],
        idx: int
    ) -> torch.Tensor:
    img_path = imgs_path[idx]
    rotation = rotations[idx]
    tf_matrix = tf_matrixes[idx]
    img = io.read_image(str(img_path), mode=io.ImageReadMode.RGB_ALPHA)

    return img, rotation, tf_matrix


class FrameDataset(Dataset):
    def __init__(
            self,
            data_path: pl.Path,
            data_mode: str,  # 'train', 'val', 'test'
            ex_idx: int = 5
        ) -> None:
        super().__init__()

        self.imgs_path = data_path / data_mode
        self.tfs_path = data_path / f"transforms_{data_mode}.json"

        self.cam_angle_x, self.rotations, self.transform_matrixes = load_transforms(self.tfs_path)
        self.img_paths = load_img_paths(self.imgs_path)

        self.ex_img, *_ = self[ex_idx]
        self.C, self.H, self.W = self.ex_img.shape

        self.focal = 0.5 * self.W / np.tan(0.5 * self.cam_angle_x)


    @property
    def shape(self) -> tuple[float, float]:
        return self.H, self.W

    def __len__(self):
        return len(self.tfs_path)

    def __getitem__(self, idx: int):
        return load_frame(self.img_paths, self.rotations, self.transform_matrixes, idx)


train_dataset = FrameDataset(data_path, "train")

print(train_dataset.ex_img.shape)
print(train_dataset.ex_img[0].max(), train_dataset.ex_img[3].max())
plt.imshow(train_dataset.ex_img.T.swapaxes(0, 1))

In [None]:
len(train_dataset.img_paths)

In [None]:
def get_rays(
    H: int, W: int, focal: float, c2w: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    # Ported from https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L123

    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing="xy",
    )
    ds = torch.stack(
        [(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -torch.ones_like(i)], dim=-1
    )
    rays_d = ds @ c2w[:3, :3].T
    rays_o = torch.broadcast_to(c2w[:3, -1], rays_d.shape)

    return rays_d, rays_o

In [None]:
# https://github.com/yenchenlin/nerf-pytorch/issues/41
H, W = train_dataset.shape
focal = train_dataset.focal
c2w = train_dataset.transform_matrixes[5]


rays_d, rays_o = get_rays(H, W, focal, c2w)
rays_d.shape, rays_o.shape

In [None]:
H, W = 8, 8
B = 7
L1 = 10
L2 = 4
D = 3

n_bins = 5
n_rays = H * W
n_samples = n_bins * n_rays

t_near = 0.1
t_far = 5.0

In [None]:
def positional_encoding(p: torch.Tensor, L: int) -> torch.Tensor:
    assert len(p.shape) == 3
    B, N, D = p.shape

    # Z denotes transformed input p
    # Z_ij becomes 2^i * p_i * p_j for each i in 0..L-1 and each component j in 1..3
    # Thus dimension is <B, D, L>
    z = (2 ** torch.arange(L).repeat(D, 1)) * (torch.pi *  p[..., None])

    # X denotes the encoded value for each transformed input
    x1 = torch.sin(z)
    x2 = torch.cos(z)
    
    # We want ordering sin(x) cos(x) sin(y) cos(y) sin(z) cos(z) repeated for each element in 1..L
    # First we stack encoding into a matrix, then we flatten the matrix to put each row side by side.
    x = torch.stack((x1, x2), dim=4)        # <B, N, D, L, 2>
    x = x.swapaxes(2, 3)                    # <B, N, L, D, 2>
    x = x.reshape(B, N, 2 * D * L)        # Finally, flatten to shape <B, N, 2*D*L>

    return x



tx = torch.randn(B * n_samples * D).reshape(B, n_samples, D)
td = torch.randn(B * n_samples * D).reshape(B, n_samples, D)
tex = positional_encoding(tx, L1)
ted = positional_encoding(tx, L2)


In [None]:
class TestNet(nn.Module):
    def __init__(self, L1, L2, n_components, n_hidden):
        super().__init__()
        
        self.d1 = n_components * 2 * L1
        self.d2 = n_components * 2 * L2

        self.lin1 = nn.Linear(self.d1, n_hidden + 1)
        self.lin2 = nn.Linear(n_hidden + self.d2, 3)

    def forward(self, o: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
        # o: <B, NS, D*2*L> 
        # d: <B, NS, D*2*L> 
        assert len(o.shape) == 3
        assert len(d.shape) == 3
        assert o.size(2) == self.d1
        assert d.size(2) == self.d2

        z1 = self.lin1(o)

        x = torch.cat((d, z1[..., 1:]), dim=2)

        sigma = z1[..., 0]  # <B, NS>
        rgb = self.lin2(x)  # <B, NS, 3>

        return rgb, sigma

tn = TestNet(L1, L2, 3, 128)
tc, tsigma = tn(tex, ted)
tc.shape, tsigma.shape

In [None]:
def strat_sampling(N: int, t_near: float, t_far: float) -> torch.Tensor:
    samples = (torch.arange(N) + torch.rand(N)) * (t_far - t_near) / N  # <N>
    return samples


def get_delta(samples: torch.Tensor, max_delta: float = 1E10) -> torch.Tensor:
    # samples: <B, NR, NB>
    B, NR, NB = samples.shape
    delta = torch.diff(samples, append=torch.ones(B, NR, 1) * max_delta, dim=-1)  # <B, NR, NB>
    return delta



samples = strat_sampling(B * n_samples, t_near, t_far).reshape(B, n_rays, n_bins)
delta = get_delta(samples)

samples.shape, delta.shape

In [None]:
def expected_color(c, sigma, delta):
    # c: <B, N, 3>
    # sigma: <B, NS>
    # delta: <B, NR, NB>
    # B: batch size
    # N: number of samples
    # C: number of components 

    assert len(sigma.shape) == 2
    assert len(delta.shape) == 3
    assert len(c.shape) == 3

    B = delta.size(0)

    delta = delta.reshape(B, -1)
    mul = delta * sigma

    T = torch.exp(-torch.cumsum(mul, dim=1))
    T = torch.cat((torch.ones(B, 1), T), dim=1)[..., :-1]

    w = T * (1 - torch.exp(-mul))

    c_hat = torch.einsum("bn,bnc->bc", w, c)
    return c_hat

expected_color(tc, tsigma, delta)