## Notebook to visualize a dataset


In [None]:
from notebook_imports import *

import torch
from torch.utils.data import DataLoader

from mattport.nerf.dataset.image_dataset import ImageDataset, collate_batch
from mattport.nerf.dataset.collate import CollateIterDataset, collate_batch_size_one
from mattport.nerf.dataset.utils import get_dataset_inputs_dict
from mattport.nerf.field_modules.ray_generator import RayGenerator
from mattport.structures.rays import RayBundle
from mattport.utils.io import get_absolute_path
from mattport.viewer.plotly import get_line_segments_from_lines

from hydra import compose, initialize
from omegaconf import open_dict
import pprint
from tqdm import tqdm

import random
from tqdm import tqdm

In [None]:
with initialize(config_path="../configs"):
    cfg = compose(config_name="default.yaml")
    cfg.dataset.downscale_factor = 4
    
cfg.dataset.data_directory = get_absolute_path(cfg.dataset.data_directory)
dataset_inputs = get_dataset_inputs(**cfg.dataset)["train"]
image_dataset = ImageDataset(
            image_filenames=dataset_inputs.image_filenames, downscale_factor=dataset_inputs.downscale_factor
        )
iter_dataset = CollateIterDataset(
    image_dataset,
    collate_fn=lambda batch_list: collate_batch(
        batch_list, cfg.dataloader.num_rays_per_batch, keep_full_image=True
    ),
    num_samples_to_collate=cfg.dataloader.num_images_to_sample_from,
    num_times_to_repeat=cfg.dataloader.num_times_to_repeat_images,
)
dataloader = DataLoader(
    iter_dataset,
    batch_size=1,
    num_workers=cfg.dataloader.num_workers,
    collate_fn=collate_batch_size_one,
    pin_memory=True,
)
dataloader_iter = iter(dataloader)

ray_generator = RayGenerator(dataset_inputs.intrinsics, dataset_inputs.camera_to_world)

num_batches = 10
for _ in tqdm(range(num_batches)):
    batch = next(dataloader_iter)
    ray_bundle = ray_generator.forward(batch["indices"])
    break

In [None]:
media.show_image(image_dataset.get_image(10))

In [None]:
def show_batch(batch):
    # set the color of the sampled rays
    c, y, x = [i.flatten() for i in torch.split(batch["local_indices"], 1, dim=-1)]
    batch["image"][c,y,x] = 1.0

    # batch["image"] is num_images, h, w, 3
    images = torch.split(batch["image"], 1, dim=0)
    image_list = [image[0] for image in images]
    image = torch.cat(image_list, dim=1) # cat along the width dimension

    # the white pixels are rays
    media.show_image((image*255).to(torch.uint8))

def sample_and_show_batch():
    batch = next(iter(dataloader))
    show_batch(batch)
    return batch

In [None]:
batch = sample_and_show_batch()

In [None]:
ray_indices = batch["indices"]
ray_bundle = ray_generator(ray_indices)

In [None]:
skip = 1
size = 8
data = []
data += [go.Scatter3d(
    x=ray_generator.camera_to_world[::skip,0,3],
    y=ray_generator.camera_to_world[::skip,1,3],
    z=ray_generator.camera_to_world[::skip,2,3],
    mode='markers',
    name="origins",
    marker=dict(color='rgba(0, 0, 0, 1)', size=size)
)]

In [None]:
near_plane = 1.0
far_plane = 6.0
lines = torch.stack([ray_bundle.origins + ray_bundle.directions * near_plane,
                     ray_bundle.origins + ray_bundle.directions * far_plane
                    ], dim=1).tolist() # (num_rays, 2, 3)
lines = torch.tensor(random.sample(lines, k=100))

In [None]:
data += get_line_segments_from_lines(lines)

In [None]:
layout = go.Layout(
    autosize=False,
    width=1000,
    height=1000,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4
    ),
    scene=go.layout.Scene(
        aspectmode="data",
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.25, y=1.25, z=1.25)
        )
    )
)
fig = go.Figure(data=data, layout=layout)
fig.show()

In [None]:
fig.write_image("temp.png")
plotly_image = imageio.imread("temp.png")
