In [None]:
import copy
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import torch

from common import load_network, predict
from data import SLSDataset

def disable_ticks(ax):
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])

device = torch.device('cuda:0')

In [None]:
# Load the network from a checkpoint
checkpoint_path = Path("denoising.ckpt")
network, checkpoint, config = load_network(checkpoint_path, device=device)

In [None]:
# Load the test dataset for denoising
dataset_path = Path("./data/ABC_440/test")
dataset_config = {
    'base_dir' : dataset_path,
    'transform': None
}
dataset = SLSDataset(**dataset_config, preload=False)
print(f"The dataset contains {len(dataset)} samples")

In [None]:
# Define the calibration (fixed atm)
K = 0.5 * np.array([[ 14700.7978515625, 0.0, 3230.5765901904233],
                    [ 0.0, 14711.431640625, 2422.6457405752153],
                    [ 0.0, 0.0, 1.0]])
K[2, 2] = 1.0

R = np.eye(3)
t = np.zeros(3)

In [None]:
from baselines import bilateral_depth_filter, laplace_depth_filter

# Denoise all samples with the CNN and the classical baselines (requires some patience)
samples = []
for sample in dataset:
    sample['depth_ml'] = predict(network, sample, device)
    sample['depth_bilateral'] = bilateral_depth_filter(sample['depth'], 2.2583333333333333, 0.005) # Parameters found with a parameter sweep
    sample['depth_laplace'] = laplace_depth_filter(sample['depth'], K, R, t)
    samples += [ sample ]

In [None]:
# Plot a figure

scale = (1000.0, "mm")
fontsize = 26

num_cols = 7
num_rows = len(samples)
fig, axs = plt.subplots(num_rows, num_cols, figsize=(26.1, num_rows*3 + 0.7), constrained_layout=True)

if num_rows == 1:
    axs = [axs]

depth_min = 0.65
depth_max = 1.15

error_threshold = 0.001 * scale[0]
error_min = -error_threshold
error_max = error_threshold

bins = np.linspace(-error_threshold, error_threshold, 100)

for i, sample in enumerate(samples):
    depth_groundtruth = copy.deepcopy(sample['target'])
    mask_groundtruth = depth_groundtruth > 0
    depth_reconstructed = copy.deepcopy(sample['depth'])
    mask_reconstructed = depth_reconstructed > 0

    ax_idx = 0

    # Error map of the reconstruction/scan
    error_reconstructed = (depth_reconstructed - depth_groundtruth) * scale[0]
    error_reconstructed[~mask_reconstructed] = 0
    axs[i][ax_idx].imshow(error_reconstructed, vmin=error_min, vmax=error_max, cmap='bwr')
    disable_ticks(axs[i][ax_idx])

    if i == len(samples) - 1:
        cbar = fig.colorbar(im, ax=axs[i][ax_idx], aspect=5, pad=0.025, location='bottom')
        cbar.ax.set_xlabel(scale[1], fontsize=fontsize)
        cbar.ax.tick_params(labelsize=fontsize)

    ax_idx += 1

    # Error map of the ml denoising
    error_denoised_ml = (sample['depth_ml'].numpy() - depth_groundtruth) * scale[0]
    error_denoised_ml[~mask_reconstructed] = 0
    im = axs[i][ax_idx].imshow(error_denoised_ml, vmin=error_min, vmax=error_max, cmap='bwr')
    disable_ticks(axs[i][ax_idx])

    ax_idx += 1

    axs[i][ax_idx].axvline(x=0, color='green', linestyle='--')
    axs[i][ax_idx].hist(error_reconstructed[mask_reconstructed & (np.abs(error_reconstructed) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Scan')
    axs[i][ax_idx].hist(error_denoised_ml[mask_reconstructed & (np.abs(error_denoised_ml) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Denoised (ML)')
    if i == len(samples) - 1:
        axs[i][ax_idx].tick_params(labelsize=fontsize)
        axs[i][ax_idx].set_xlabel(scale[1], fontsize=fontsize)
    else:
        axs[i][ax_idx].get_xaxis().set_ticks([])
    axs[i][ax_idx].get_yaxis().set_ticks([])
    ax_idx += 1

    # Error map of the bilateral filter denoising
    error_denoised_bilateral = (sample['depth_bilateral'] - depth_groundtruth) * scale[0]
    error_denoised_bilateral[~mask_reconstructed] = 0
    axs[i][ax_idx].imshow(error_denoised_bilateral, vmin=error_min, vmax=error_max, cmap='bwr')
    disable_ticks(axs[i][ax_idx])

    ax_idx += 1

    axs[i][ax_idx].axvline(x=0, color='green', linestyle='--')
    axs[i][ax_idx].hist(error_reconstructed[mask_reconstructed & (np.abs(error_reconstructed) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Scan')
    axs[i][ax_idx].hist(error_denoised_bilateral[mask_reconstructed & (np.abs(error_denoised_bilateral) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Denoised (Bilateral)')
    if i == len(samples) - 1:
        axs[i][ax_idx].tick_params(labelsize=fontsize)
        axs[i][ax_idx].set_xlabel(scale[1], fontsize=fontsize)
    else:
        axs[i][ax_idx].get_xaxis().set_ticks([])
    axs[i][ax_idx].get_yaxis().set_ticks([])
    ax_idx += 1

    # Error map of the laplace denoising
    error_denoised_laplace = (sample['depth_laplace'] - depth_groundtruth) * scale[0]
    error_denoised_laplace[~mask_reconstructed] = 0
    axs[i][ax_idx].imshow(error_denoised_laplace, vmin=error_min, vmax=error_max, cmap='bwr')
    disable_ticks(axs[i][ax_idx])

    ax_idx += 1

    axs[i][ax_idx].axvline(x=0, color='green', linestyle='--')
    axs[i][ax_idx].hist(error_reconstructed[mask_reconstructed & (np.abs(error_reconstructed) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Scan')
    axs[i][ax_idx].hist(error_denoised_laplace[mask_reconstructed & (np.abs(error_denoised_laplace) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Denoised (Laplace)')
    if i == len(samples) - 1:
        axs[i][ax_idx].tick_params(labelsize=fontsize)
        axs[i][ax_idx].set_xlabel(scale[1], fontsize=fontsize)
    else:
        axs[i][ax_idx].get_xaxis().set_ticks([])
    axs[i][ax_idx].get_yaxis().set_ticks([])

    ax_idx += 1

plt.savefig("denoising_error_comparison.png", facecolor='white')
plt.show()