## Setup

In [1]:
# Import required modules.
import os
import cv2
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow as tf2
import tabulate
from pathlib import Path
from tqdm import tqdm
from pydnet.data import KITTI
from pydnet.models import Pydnet
from IPython.display import HTML, display


# Path to KITTI dataset.
KITTI_PATH = Path("../")\
    / "data"\
    / "mount"\
    / "KITTI"\
    / "raw_data"

# Path to dataset slice index.
SLICE_PATH = Path("../")\
    / "data"\
    / "slices"\
    / "test_files.txt"

# Path to ground truth file.
GROUND_PATH = Path("../")\
    / "data"\
    / "slices"\
    / "depths.npz"

# Path to PyDnet pretrained checkpoint.
CHECK_PATH = Path("../")\
    / "data"\
    / "checkpoint"\
    / "pydnet"\
    / "pydnet"

# Path to output.
DEST_PATH = Path("../")\
    / "reports"\
    / "evaluation"\
    / "KITTI"

# Maximum depth value.
MAX_DEPTH = 80.0


# Disable Tensorflow warning messages.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 

# Run Tensorflow in earger mode.
if tf1.executing_eagerly():
   tf1.disable_eager_execution()


# Utility functions.
def read_lines(path: Path) -> list:
    """Read test files from path.
    """
    assert path.exists()
    with open(path, "r") as f:
        lines = f.readlines()
    return [l.strip() for l in lines]

def compute_errors(
    ground: np.ndarray,
    pred: np.ndarray
) -> [float, float, float, float, float, float, float]:
    """Compute error metrics using predicted and ground truth depths.
    From https://github.com/mrharicot/monodepth/blob/master/utils/evaluation_utils.py
    """
    thresh = np.maximum((ground / pred), (pred / ground))
    a1 = (thresh < 1.25).mean()
    a2 = (thresh < 1.25 ** 2).mean()
    a3 = (thresh < 1.25 ** 3).mean()
    rmse = (ground - pred) ** 2
    rmse = np.sqrt(rmse.mean())
    rmse_log = (np.log(ground) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rmse_log.mean())
    abs_rel = np.mean(np.abs(ground - pred) / ground)
    sq_rel = np.mean(((ground - pred) ** 2) / ground)
    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3

def compute_scale_and_shift(
    pred: np.ndarray,
    target: np.ndarray,
    mask: np.ndarray
) -> [float, float]:
    """From https://gist.github.com/ranftlr/a1c7a24ebb24ce0e2f2ace5bce917022
    """
    # system matrix: A = [[a_00, a_01], [a_10, a_11]]
    a_00 = np.sum(mask * pred * pred)
    a_01 = np.sum(mask * pred)
    a_11 = np.sum(mask)
    # right hand side: b = [b_0, b_1]
    b_0 = np.sum(mask * pred * target)
    b_1 = np.sum(mask * target)
    x_0 = np.zeros_like(b_0)
    x_1 = np.zeros_like(b_1)
    det = a_00 * a_11 - a_01 * a_01
    # A needs to be a positive definite matrix.
    valid = det > 0
    x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
    x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
    return x_0, x_1

## Load KITTI and PydNet

In [2]:
# Load KITTI dataset.
dataset = KITTI({
    "h": 320,
    "w": 640,
    "path": KITTI_PATH,
    "slice": SLICE_PATH
})

# Build PydNet.
network = Pydnet({
    "h": 320,
    "w": 640,
    "is_training": False
})

## Start Tensorflow Session

In [3]:
# Setup KITTI dataset feeding placeholder.
pred = network.forward(dataset.batch)
pred = tf1.nn.relu(pred)

# Setup Tensorflow session and restore checkpoint.
save = tf1.train.Saver()
sess = tf1.Session()
sess.run(tf1.global_variables_initializer())
sess.run(dataset.initializer)
save.restore(sess, str(CHECK_PATH))

INFO:tensorflow:Restoring parameters from ../data/checkpoint/pydnet/pydnet


## Run inference on KITTI dataset

In [4]:
# Makedir output directory.
DEST_PATH.mkdir(
    parents=True,
    exist_ok=True)

# Read test file indices.
tests = read_lines(SLICE_PATH)
    
# Run inference on KITTI dataset.
with tqdm(total=len(tests)) as pbar:
    for i in range(len(tests)):
        dep = sess.run(pred)
        dep = np.squeeze(dep)
        min_dep = dep.min()
        max_dep = dep.max()
        norm_dep = (dep - min_dep) / (max_dep - min_dep)
        norm_dep *= 255.0
        target = cv2.imread(str(KITTI_PATH / f"{tests[i]}.png"))
        h, w = target.shape[:2]
        norm_dep = cv2.resize(norm_dep, (w, h))
        cv2.imwrite(
            str(DEST_PATH / f"{str(i).zfill(4)}.png"),
            (norm_dep * 256.0).astype(np.uint16))
        pbar.update(1)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 697/697 [03:36<00:00,  3.22it/s]


## Evaluate model

In [5]:
# Resulting errors.
errors = []

# Read test file indices.
tests = read_lines(SLICE_PATH)

# Read ground truth file.
ground = np.load(
    GROUND_PATH,
    fix_imports=True,
    encoding="latin1",
    allow_pickle=True)["data"]

# Run inference on KITTI dataset.
with tqdm(total=len(tests)) as pbar:
    for i in range(len(tests)):
        target = ground[i]
        pred = cv2.imread(
            str(DEST_PATH / f"{str(i).zfill(4)}.png"),
            -1) / 256.0
        mask = (target > 1e-3) & (target < MAX_DEPTH)
        target_dep = np.zeros_like(target)
        target_dep[mask == 1] = 1.0 / target[mask == 1]
        scale, shift = compute_scale_and_shift(pred, target_dep, mask)
        pred_aligned = scale * pred + shift
        disparity_cap = 1.0 / MAX_DEPTH
        pred_aligned[pred_aligned < disparity_cap] = disparity_cap
        pred_aligned = 1.0 / pred_aligned
        pred_aligned = pred_aligned[mask == 1]
        target = target[mask == 1]
        errors.append(compute_errors(target, pred_aligned))
        pbar.update(1)

# Print result.
mean = np.array(errors).mean(0)
table = [
    ["abs_rel",  mean[0]],
    ["sq_rel",   mean[1]],
    ["rmse",     mean[2]],
    ["rmse_log", mean[3]],
    ["a1",       mean[4]], 
    ["a2",       mean[5]], 
    ["a3",       mean[6]]]
display(HTML(tabulate.tabulate(table, tablefmt='html')))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 697/697 [00:22<00:00, 30.93it/s]


0,1
abs_rel,0.161777
sq_rel,1.28003
rmse,6.1434
rmse_log,0.238976
a1,0.759639
a2,0.926482
a3,0.973382
