In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import torch

from common import load_network, predict, crop, export_point_cloud

def disable_ticks(ax):
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])

device = torch.device('cuda:0')

In [None]:
# Load the network from a checkpoint
checkpoint_path = Path("denoising.ckpt")
network, checkpoint, config = load_network(checkpoint_path, device=device)

In [None]:
# Define the calibration (fixed atm)
K = 0.5 * np.array([[ 14700.7978515625, 0.0, 3230.5765901904233],
                    [ 0.0, 14711.431640625, 2422.6457405752153],
                    [ 0.0, 0.0, 1.0]])
K[2, 2] = 1.0

R = np.eye(3)
t = np.zeros(3)

In [None]:
# Load the Rook sample
rook_dir = Path("./data/rook")
sample = {}
gt = np.load(rook_dir / "dm_gt.npy")[:, :, np.newaxis]
color = plt.imread(rook_dir / "rook_match_white.png")
scan_sim = np.load(rook_dir / "dm_sim.npy")[:, :, np.newaxis]
scan_real = np.load(rook_dir / "dm_scan.npy")[:, :, np.newaxis]

# Perform three normalization steps
# 1) The current Rook ground truth is too wide (usually two pixels), so crop it (fixed with newer data)
margin = gt.shape[1] % 3232
assert(margin % 2 == 0)
gt = gt[:, margin//2:-margin//2]
# 2) Make all images/maps the same size
color = cv2.resize(color, dsize=(gt.shape[:2][::-1]))
scan_sim = cv2.resize(scan_sim, dsize=(gt.shape[:2][::-1]))[:, :, np.newaxis]
scan_real = cv2.resize(scan_real, dsize=(gt.shape[:2][::-1]))[:, :, np.newaxis]
# 3) Scale the depth to meters if it is in millimiters
if scan_sim.max() > 100:
    scan_sim /= 1000.0
if scan_real.max() > 100:
    scan_real /= 1000.0

In [None]:
# Denoise the real scan
denoised_real = predict(network, {'depth': scan_real}, device)

In [None]:
# Plot a figure

scale = (1000.0, "mm")
fontsize = 46
roi = [0, 2000, 1000, 2250]

error_threshold = 0.0015 * scale[0]
error_min = -error_threshold
error_max = error_threshold

mask_reconstructed = scan_real > 0

error_reconstructed = (scan_real - gt) * scale[0]
error_reconstructed[~mask_reconstructed] = 0

error_denoised = (denoised_real.numpy() - gt) * scale[0]
error_denoised[~mask_reconstructed] = 0

error_ratio = np.abs(error_denoised) / np.abs(error_reconstructed)
error_ratio[error_reconstructed == 0] = 0

fig, axs = plt.subplots(1, 4, figsize=(32, 11), constrained_layout=True)

axs[0].set_title("Reconstructed", fontsize=fontsize)
axs[0].get_xaxis().set_ticks([])
axs[0].get_yaxis().set_ticks([])
im = axs[0].imshow(crop(error_reconstructed, roi), vmin=error_min, vmax=error_max, cmap='bwr')

cbar = fig.colorbar(im, ax=[axs[0]], location='left', aspect=30, pad=0.04)
cbar.ax.set_ylabel(scale[1], fontsize=fontsize)
cbar.ax.tick_params(labelsize=fontsize)

#axs[1].axis('off')
axs[1].set_title("Denoised", fontsize=fontsize)
axs[1].get_xaxis().set_ticks([])
axs[1].get_yaxis().set_ticks([])
im = axs[1].imshow(crop(error_denoised, roi), vmin=error_min, vmax=error_max, cmap='bwr')

axs[2].set_title("Improved Regions", fontsize=fontsize)
axs[2].get_xaxis().set_ticks([])
axs[2].get_yaxis().set_ticks([])
im = axs[2].imshow(crop(np.abs(error_denoised) < np.abs(error_reconstructed), roi), cmap='viridis')

bins = np.linspace(-error_threshold, error_threshold, 100)
axs[3].hist(error_reconstructed[mask_reconstructed & (np.abs(error_reconstructed) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Reconstruction')
axs[3].hist(error_denoised[mask_reconstructed & (np.abs(error_denoised) < error_threshold)].ravel(), bins=bins, alpha=0.5, label='Denoised')
axs[3].set_xlabel(scale[1], fontsize=fontsize)
axs[3].axvline(x=0, color='green', linestyle='--')
axs[3].tick_params(labelsize=fontsize)
axs[3].get_yaxis().set_ticks([])
axs[3].legend(fontsize=fontsize-10, loc='upper left', bbox_to_anchor=(0.5, 1.0))

plt.savefig("generalization_rook.png", facecolor='white')
plt.show()

In [None]:
# Write the point clouds
output_dir = Path("./out/generalization/rook")
output_dir.mkdir(parents=True, exist_ok=True)
export_point_cloud(str(output_dir / "gt.ply"), gt, K, R, t)
export_point_cloud(str(output_dir / "scan.ply"), scan_real, K, R, t)
export_point_cloud(str(output_dir / "denoised.ply"), denoised_real.numpy(), K, R, t)