In [None]:
import copy
import cv2
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import torch

from data import SLSDataset

def disable_ticks(ax):
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])

device = torch.device('cuda:0')

In [None]:
# Load the test dataset
dataset_path = Path("./data/ABC_440/train")
dataset_config = {
    'base_dir' : dataset_path,
    'color_dir_name': '.',
    'color_file_pattern': '*_cam_w*',
    'ambient_dir_name': '.',
    'ambient_file_pattern': '*_cam_a*',
    'transform': None
}
dataset = SLSDataset(**dataset_config, preload=False)
print(f"The dataset contains {len(dataset)} samples")

In [None]:
# Select a set of samples that we would like to inspect
sample_indices = [ 40 ]

samples = []
for sample_index in sample_indices:
    samples += [ dataset[sample_index] ]

In [None]:
# Add the Rook to the set of samples
rook_dir = Path("./data/rook")
sample = {}
sample['target'] = np.load(rook_dir / "dm_gt.npy")[:, :, np.newaxis]
sample['color'] = plt.imread(rook_dir / "rook_match_white.png")
sample['depth'] = np.load(rook_dir / "dm_scan.npy")[:, :, np.newaxis]

# Perform two normalization steps
# 1) Make all images/maps the same size
# 2) Scale the depth to meters if it is in millimiters
sample['color'] = cv2.resize(sample['color'], dsize=(sample['target'].shape[:2][::-1]))
sample['depth'] = cv2.resize(sample['depth'], dsize=(sample['target'].shape[:2][::-1]))[:, :, np.newaxis]
if sample['depth'].max() > 100:
    sample['depth'] /= 1000.0

# Make the Rook data squared
for k, v in sample.items():
    xcrop = (sample[k].shape[1] - sample[k].shape[0]) // 2
    sample[k] = sample[k][:, xcrop:-xcrop]

samples += [sample]

In [None]:
# Plot a figure

scale = (1000.0, "mm")
fontsize = 26

num_cols = 4
num_rows = len(samples)
fig, axs = plt.subplots(num_rows, num_cols, figsize=(23.5, num_rows * 5.75 + 1), constrained_layout=True)

if num_rows == 1:
    axs = [axs]

depth_min = 0.8
depth_max = 1.0

error_threshold = 0.002 * scale[0]
error_min = -error_threshold
error_max = error_threshold

for i, sample in enumerate(samples):
    depth_groundtruth = copy.deepcopy(sample['target'])
    mask_groundtruth = depth_groundtruth > 0
    depth_reconstructed = copy.deepcopy(sample['depth'])
    mask_reconstructed = depth_reconstructed > 0

    error_reconstructed = (depth_reconstructed - depth_groundtruth) * scale[0]
    error_reconstructed[~mask_reconstructed] = 0

    axs[i][0].get_xaxis().set_ticks([])
    axs[i][0].get_yaxis().set_ticks([])
    axs[i][0].imshow(sample['color'])

    axs[i][1].get_xaxis().set_ticks([])
    axs[i][1].get_yaxis().set_ticks([])
    # Ground truth goes towards 0 at the borders, so we cannot use its minimum for a reasonable value range
    im = axs[i][1].imshow(depth_reconstructed, cmap='RdBu_r', vmin=depth_min, vmax=depth_max)
    if i == len(samples) - 1:
        cbar = fig.colorbar(im, ax=axs[i][1], aspect=10, pad=0.025,  location='bottom')
        cbar.ax.set_xlabel('m', fontsize=fontsize)
        cbar.ax.tick_params(labelsize=fontsize)

    if i == 0:
        axs[i][0].set_title("Color Image", fontsize=fontsize)
        axs[i][1].set_title("Depth", fontsize=fontsize)
        axs[i][2].set_title("Depth Error", fontsize=fontsize)
        axs[i][3].set_title("Error Distribution", fontsize=fontsize)

    axs[i][2].get_xaxis().set_ticks([])
    axs[i][2].get_yaxis().set_ticks([])
    im = axs[i][2].imshow(error_reconstructed, vmin=error_min, vmax=error_max, cmap='bwr')

    if i == len(samples) - 1:
        cbar = fig.colorbar(im, ax=axs[i][2], aspect=10, pad=0.025, location="bottom")
        cbar.ax.set_xlabel(scale[1], fontsize=fontsize)
        cbar.ax.tick_params(labelsize=fontsize)
        labels = [item.get_text() for item in cbar.ax.get_yticklabels()]

    bins = np.linspace(-error_threshold, error_threshold, 100)
    axs[i][3].hist(error_reconstructed[mask_reconstructed & (np.abs(error_reconstructed) < error_threshold)].ravel(), bins=bins, alpha=0.5)
    if i == len(samples) - 1:
        axs[i][3].set_xlabel(scale[1], fontsize=fontsize)
    else:
        axs[i][3].get_xaxis().set_ticks([])
    axs[i][3].axvline(x=0, color='green', linestyle='--')
    axs[i][3].tick_params(labelsize=fontsize)
    axs[i][3].get_yaxis().set_ticks([])

plt.savefig("analysis_reconstruction_errors.png", facecolor='white')
plt.show()