In [None]:
import sys; sys.path.append('/private/home/ronghanghu/workspace/DATASETS/diode-devkit/')

import diode
import numpy as np
import skimage.transform
import skimage.io
import os
import tqdm

In [None]:
TYPE = 'outdoor'
SPLIT = 'val'

SAVE_IM_DIR = f'/checkpoint/ronghanghu/neural_rendering_datasets/diode_45fov_256x256_{TYPE}/{SPLIT}/images'
SAVE_DATA_DIR = f'/checkpoint/ronghanghu/neural_rendering_datasets/diode_45fov_256x256_{TYPE}/{SPLIT}/data'

dataset = diode.DIODE(
    meta_fname='/private/home/ronghanghu/workspace/DATASETS/diode-devkit/diode_meta.json',
    data_root='/checkpoint/ronghanghu/neural_rendering_datasets/diode/',
    splits=[SPLIT],
    scene_types=[TYPE]
)

os.makedirs(SAVE_IM_DIR, exist_ok=True)
os.makedirs(SAVE_DATA_DIR, exist_ok=True)

In [None]:
class DiodeProcessor:
    def __init__(self):
        self.CROP_W = 734
        self.OUT_SIZE = 256

    def crop_and_resize(self, im, order):
        H, W = im.shape[:2]
        diff = (W - self.CROP_W) // 2
        im = im[:, diff:-diff]
        out = skimage.transform.resize(im, (self.OUT_SIZE, self.OUT_SIZE), order=order)
        return out
    
    def __call__(self, im, de, de_mask):
        de = de.copy()
        de[de_mask == 0] = 0
        im = self.crop_and_resize(im, order=None)
        de = self.crop_and_resize(de, order=0)  # nearest neighbor sampling on depth map
        de = de.astype(np.float32)
        
        # also downsample the mask and use it to mask invalid regions
        de_mask = self.crop_and_resize(de_mask, order=None)
        de_mask = (1 - de_mask < 1e-8)
        de[~de_mask] = 0
        return im, de


processor = DiodeProcessor()

In [None]:
mean_rgb = 0
mean_depth = 0
count = 0
for n_sample in tqdm.tqdm(range(len(dataset))):
    im, de, de_mask = dataset[n_sample]
    im_out, depth_out = processor(im, de, de_mask)

    mean_rgb += np.mean(im_out, axis=(0, 1))
    mean_depth += np.sum(depth_out) / np.sum(depth_out > 0)
    count += 1

In [None]:
mean_rgb = mean_rgb / count
mean_depth = mean_depth / count

print(mean_rgb)
print(mean_depth)

In [None]:
import matplotlib.pyplot as plt

im, de, de_mask = dataset[10]
im_out, depth_out = processor(im, de, de_mask)

plt.figure()
plt.imshow(im_out)
plt.figure()
plt.imshow(depth_out)
plt.colorbar()
plt.title(f'mean depth: {np.sum(depth_out) / np.sum(depth_out > 0)}')