In [None]:
import pathlib as pl
import torch
from torchvision import io
import json

data_source_name = "NeRF_Data"
dataset_name = "nerf_synthetic"
scene_name = "lego"

root_data_dir = pl.Path('./data/')
data_path = root_data_dir / data_source_name / dataset_name / scene_name

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset


digit_pattern = re.compile(r"\d+")


def load_transforms(tfs_path: pl.Path) -> tuple[float, list[float], list[torch.FloatTensor]]:
    with open(tfs_path, "r") as f:
        transforms = json.load(f)

    cam_angle_x = float(transforms["camera_angle_x"])
    rotations = []
    transform_matrixes = []

    for frame in transforms["frames"]:       
        # Assume ordered
        rotation = float(frame["rotation"])
        transform_matrix = torch.FloatTensor(frame["transform_matrix"])

        rotations.append(rotation)
        transform_matrixes.append(transform_matrix)

    return cam_angle_x, rotations, transform_matrixes
        

def extract_digit_from_path_name(path: pl.Path) -> int:
    match = digit_pattern.search(path.name)

    if not match:
        return None

    return int(match.group(0))


def load_img_paths(imgs_path: pl.Path):
    paths = imgs_path.iterdir() # Ordered lexagraphically
    paths = sorted(paths, key=extract_digit_from_path_name) # Ordered numerically
    return paths


def load_frame(
        imgs_path: pl.Path,
        rotations: list[float],
        tf_matrixes: list[torch.FloatTensor],
        idx: int
    ) -> torch.Tensor:
    img_path = imgs_path[idx]
    rotation = rotations[idx]
    tf_matrix = tf_matrixes[idx]
    img = io.read_image(str(img_path), mode=io.ImageReadMode.RGB_ALPHA)

    return img, rotation, tf_matrix


class FrameDataset(Dataset):
    def __init__(
            self,
            data_path: pl.Path,
            data_mode: str,  # 'train', 'val', 'test'
            ex_idx: int = 5
        ) -> None:
        super().__init__()

        self.imgs_path = data_path / data_mode
        self.tfs_path = data_path / f"transforms_{data_mode}.json"

        self.cam_angle_x, self.rotations, self.transform_matrixes = load_transforms(self.tfs_path)
        self.img_paths = load_img_paths(self.imgs_path)

        self.ex_img, *_ = self[ex_idx]
        self.C, self.H, self.W = self.ex_img.shape

        self.focal = 0.5 * self.W / np.tan(0.5 * self.cam_angle_x)


    @property
    def shape(self) -> tuple[float, float]:
        return self.H, self.W

    def __len__(self):
        return len(self.tfs_path)

    def __getitem__(self, idx: int):
        return load_frame(self.img_paths, self.rotations, self.transform_matrixes, idx)


train_dataset = FrameDataset(data_path, "train")

print(train_dataset.ex_img.shape)
print(train_dataset.ex_img[0].max(), train_dataset.ex_img[3].max())
plt.imshow(train_dataset.ex_img.T.swapaxes(0, 1))

In [None]:
def get_rays(
    H: int, W: int, focal: float, c2w: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    # Ported from https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L123

    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing="xy",
    )
    ds = torch.stack(
        [(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -torch.ones_like(i)], dim=-1
    )
    rays_d = ds @ c2w[:3, :3].T
    rays_o = torch.broadcast_to(c2w[:3, -1], rays_d.shape)

    return rays_d, rays_o

In [None]:
# https://github.com/yenchenlin/nerf-pytorch/issues/41
H, W = train_dataset.shape
focal = train_dataset.focal
c2w = train_dataset.transform_matrixes[5]


rays_d, rays_o = get_rays(H, W, focal, c2w)
rays_d.shape, rays_o.shape

In [None]:
torch.rand(3, 2)

In [None]:
def strat_sampling(t_near, t_far, n):
    samples = (torch.arange(n) + torch.rand(n)) * (t_far - t_near) / n
    return samples

samples = strat_sampling(0, 10, 5)

gap = torch.diff(samples)
samples, gap