In [None]:
import pathlib as pl
import torch
from torchvision import io
import json

data_source_name = "NeRF_Data"
dataset_name = "nerf_synthetic"
scene_name = "lego"

root_data_dir = pl.Path('./data/')
data_path = root_data_dir / data_source_name / dataset_name / scene_name

imgs_path_train = data_path / "train"
imgs_path_val = data_path / "val"
imgs_path_train = data_path / "train"

tfs_path_train = data_path / "transforms_train.json"
tfs_path_val = data_path / "transforms_val.json"
tfs_path_test = data_path / "transforms_test.json"

In [None]:
import re
import matplotlib.pyplot as plt
from torch.utils.data import Dataset


digit_pattern = re.compile(r"\d+")


def load_transforms(tfs_path: pl.Path) -> tuple[float, list[float], list[torch.FloatTensor]]:
    with open(tfs_path, "r") as f:
        transforms = json.load(f)

    cam_angle_x = float(transforms["camera_angle_x"])
    rotations = []
    transform_matrixes = []

    for frame in transforms["frames"]:       
        # Assume ordered
        rotation = float(frame["rotation"])
        transform_matrix = torch.FloatTensor(frame["transform_matrix"])

        rotations.append(rotation)
        transform_matrixes.append(transform_matrix)

    return cam_angle_x, rotations, transform_matrixes
        

def extract_digit_from_path_name(path: pl.Path) -> int:
    match = digit_pattern.search(path.name)

    if not match:
        return None

    return int(match.group(0))


def load_img_paths(imgs_path: pl.Path):
    paths = imgs_path.iterdir() # Ordered lexagraphically
    paths = sorted(paths, key=extract_digit_from_path_name) # Ordered numerically
    return paths


def load_frame(
        imgs_path: pl.Path,
        rotations: list[float],
        tf_matrixes: list[torch.FloatTensor],
        idx: int
    ) -> torch.Tensor:
    img_path = imgs_path[idx]
    rotation = rotations[idx]
    tf_matrix = tf_matrixes[idx]
    img = io.read_image(str(img_path))

    return img, rotation, tf_matrix


class FrameDataset(Dataset):
    def __init__(
            self,
            data_path: pl.Path,
            data_mode: str  # 'train', 'val', 'test'
        ) -> None:
        super().__init__()

        self.imgs_path = data_path / data_mode
        self.tfs_path = data_path / f"transforms_{data_mode}.json"

        self.cam_angle_x, self.rotations, self.transform_matrixes = load_transforms(self.tfs_path)
        self.img_paths = load_img_paths(self.imgs_path)

    def __len__(self):
        return len(self.tfs_path)

    def __getitem__(self, idx: int):
        return load_frame(self.img_paths, self.rotations, self.transform_matrixes, idx)


train_dataset = FrameDataset(data_path, "train")
ex_img, ex_rot, ex_tf = train_dataset[5]

print(ex_img.shape)
plt.imshow(ex_img[:3].reshape(800, 800, 3))