## Notebook to visualize a dataset


In [None]:
from notebook_imports import *

import torch
from torch.utils.data import DataLoader

from pyrad.nerf.dataset.image_dataset import ImageDataset, PanopticImageDataset
from pyrad.nerf.image_sampler import CacheImageSampler
from pyrad.nerf.pixel_sampler import PixelSampler
from pyrad.nerf.dataset.utils import get_dataset_inputs
from pyrad.nerf.ray_generator import RayGenerator
from pyrad.nerf.colliders import SceneBoundsCollider, AABBBoxCollider
from pyrad.structures.rays import RayBundle
from pyrad.utils.io import get_absolute_path
from pyrad.viewer.plotly import get_line_segments_from_lines
from pyrad.structures.cameras import get_camera_model
from pyrad.utils.misc import get_dict_to_torch, instantiate_from_dict_config


from hydra import compose, initialize
from omegaconf import open_dict
import pprint
from tqdm import tqdm
import random

In [None]:
with initialize(version_base=None, config_path="../configs"):
    config = compose(config_name="default_setup.yaml")
dataset_inputs = get_dataset_inputs(**config.data.dataset, split="train")

In [None]:
device = "cpu"

train_image_dataset = instantiate_from_dict_config(
    config.data.image_dataset,
    image_filenames=dataset_inputs.image_filenames,
    downscale_factor=dataset_inputs.downscale_factor,
    semantics=dataset_inputs.semantics,
    alpha_color=dataset_inputs.alpha_color,
)  # ImageDataset
train_image_sampler = CacheImageSampler(
    train_image_dataset,
    num_samples_to_collate=len(train_image_dataset)
    if config.data.image_sampler.num_images_to_sample_from == 0
    else config.data.image_sampler.num_images_to_sample_from,
    num_times_to_repeat_images=config.data.image_sampler.num_times_to_repeat_images,
    device=device if config.data.image_sampler.move_to_graph_device else "cpu",
)  # ImageSampler
train_pixel_sampler = PixelSampler(
    num_rays_per_batch=config.data.pixel_sampler.num_rays_per_batch, keep_full_image=True
)  # PixelSampler

ray_generator = RayGenerator(dataset_inputs.intrinsics, dataset_inputs.camera_to_world)

iter_train_image_sampler = iter(train_image_sampler)
num_batches = 10
for _ in tqdm(range(num_batches)):
    image_batch = next(iter_train_image_sampler)
    batch = train_pixel_sampler.sample(image_batch)
    ray_bundle = ray_generator.forward(batch["indices"])
    break

In [None]:
media.show_image(train_image_dataset.get_image(10), height=200)

In [None]:
def show_batch(batch):
    # set the color of the sampled rays
    print(batch.keys())
    c, y, x = [i.flatten() for i in torch.split(batch["local_indices"], 1, dim=-1)]
    batch["image"][c, y, x] = 0.0

    # batch["image"] is num_images, h, w, 3
    images = torch.split(batch["image"], 1, dim=0)
    image_list = [image[0] for image in images]
    image = torch.cat(image_list, dim=1)  # cat along the width dimension

    # the black pixels are rays
    media.show_image((image * 255).to(torch.uint8))


def sample_and_show_batch():
    image_batch = next(iter_train_image_sampler)
    batch = train_pixel_sampler.sample(image_batch)
    show_batch(batch)
    return batch

In [None]:
batch = sample_and_show_batch()

In [None]:
ray_indices = batch["indices"]
ray_bundle = ray_generator(ray_indices)

In [None]:
skip = 1
size = 8
data = []
data += [
    go.Scatter3d(
        x=ray_generator.camera_to_world[::skip, 0, 3],
        y=ray_generator.camera_to_world[::skip, 1, 3],
        z=ray_generator.camera_to_world[::skip, 2, 3],
        mode="markers",
        name="origins",
        marker=dict(color="rgba(0, 0, 0, 1)", size=size),
    )
]

In [None]:
scene_bounds_collider = AABBBoxCollider(dataset_inputs.scene_bounds)
intersected_ray_bundle = scene_bounds_collider(ray_bundle)

In [None]:
lines = torch.stack(
    [
        intersected_ray_bundle.origins + intersected_ray_bundle.directions * intersected_ray_bundle.nears[..., None],
        intersected_ray_bundle.origins + intersected_ray_bundle.directions * intersected_ray_bundle.fars[..., None],
    ],
    dim=1,
).tolist()  # (num_rays, 2, 3)
lines = torch.tensor(random.sample(lines, k=100))

In [None]:
data += get_line_segments_from_lines(lines)

In [None]:
layout = go.Layout(
    autosize=False,
    width=1000,
    height=1000,
    margin=go.layout.Margin(l=50, r=50, b=100, t=100, pad=4),
    scene=go.layout.Scene(
        aspectmode="data",
        camera=dict(up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=1.25, y=1.25, z=1.25)),
    ),
)
fig = go.Figure(data=data, layout=layout)
fig.show()