In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import json
import cv2
import matplotlib.cm as cm
import math
from PIL import Image


# load test frame,
def load_test_frame_files(file):
    with open(file, 'r') as f:
        meta = json.load(f)
    fs = meta['frames']
    fs = sorted(fs, key=lambda d: d['file_path'])
    
    frames = []
    for frame in fs:
        frames.append(frame['file_path'])
    return frames

def format_axes(axes):
    for ax in axes:
        if type(ax) is np.ndarray:
            format_axes(ax)
        else:
            ax.set_xticks([])
            ax.set_yticks([])
            
def weighted_percentile(x, w, ps, assume_sorted=False):
    """Compute the weighted percentile(s) of a single vector."""
    x = x.reshape([-1])
    w = w.reshape([-1])
    if not assume_sorted:
        sortidx = np.argsort(x)
    x, w = x[sortidx], w[sortidx]
    acc_w = np.cumsum(w)
    return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x)
            
def visualize_cmap(value,
                   weight,
                   colormap,
                   lo=None,
                   hi=None,
                   percentile=99.,
                   curve_fn=lambda x: x,
                   modulus=None,
                   matte_background=True):
    """Visualize a 1D image and a 1D weighting according to some colormap.

    Args:
    value: A 1D image.
    weight: A weight map, in [0, 1].
    colormap: A colormap function.
    lo: The lower bound to use when rendering, if None then use a percentile.
    hi: The upper bound to use when rendering, if None then use a percentile.
    percentile: What percentile of the value map to crop to when automatically
      generating `lo` and `hi`. Depends on `weight` as well as `value'.
    curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If
      `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.
    matte_background: If True, matte the image over a checkerboard.

    Returns:
    A colormap rendering.
    """
    # Identify the values that bound the middle of `value' according to `weight`.
    lo_auto, hi_auto = weighted_percentile(
      value, weight, [50 - percentile / 2, 50 + percentile / 2])

    # If `lo` or `hi` are None, use the automatically-computed bounds above.
    eps = np.finfo(np.float32).eps
    lo = lo or (lo_auto - eps)
    hi = hi or (hi_auto + eps)

    # Curve all values.
    value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]

    # Wrap the values around if requested.
    if modulus:
        value = np.mod(value, modulus) / modulus
    else:
        # Otherwise, just scale to [0, 1].
        value = np.nan_to_num(
        np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1))

    if colormap:
        colorized = colormap(value)[:, :, :3]
    else:
        assert len(value.shape) == 3 and value.shape[-1] == 3
        colorized = value

    return colorized

depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps)


In [None]:
# set path to experiments and groundtruth data here
EXPERIMENT_DIR = '/mnt/res_nas/silvanweder/experiments' # set this to where you saved your experiments to
GROUNDTRUTH_DIR = '/mnt/res_nas/silvanweder/datasets/object-removal-custom-clean' # set this to where you saved your data

In [None]:
print('Available Experiments:')
for exp in os.listdir(EXPERIMENT_DIR):
    if exp.startswith('.'):
        continue
    print('\t -', exp)

In [None]:
# select experiments from available experiments above
experiment = 'final_tests_real'

In [None]:
suffix = '_real' # set to either _real or _synthetic depending on what masks you used

In [None]:
print('Available Sequences:')
for sc in sorted(os.listdir(os.path.join(EXPERIMENT_DIR, experiment))):
    print(f'\t- {sc}')

In [None]:
# set available sequence here
sequence = '002'

In [None]:
# set options here
rotate = True # do we need to rotate the renderings
eval_run = 'train_test_preds' # visualizing test or training images

In [None]:
experiment_path = os.path.join(EXPERIMENT_DIR, experiment, sequence)
groundtruth_path = os.path.join(GROUNDTRUTH_DIR, sequence)

In [None]:
if eval_run == 'test_preds':
    test_frame_files = load_test_frame_files(os.path.join(experiment_path, 'transforms_test.json'))
elif eval_run == 'train_test_preds':
    test_frame_files = load_test_frame_files(os.path.join(experiment_path, 'transforms_train.json'))
else:
    raise ValueError(f'Invalid eval run {eval_run}')


for i, frame in enumerate(test_frame_files):
    
    # load groundtruth image
    image_gt = np.asarray(Image.open(os.path.join(groundtruth_path, 'images', frame.split('/')[-1])))
    image_gt = cv2.resize(image_gt, (256, 192))

    # load input image
    input_mask = np.load(os.path.join(groundtruth_path, frame.replace('images', f'masks{suffix}').replace('jpg', 'npy'))) 
    image_input = image_gt.copy()
    image_input[input_mask == 1] = (255, 255, 255)

    image_inpainted = np.asarray(Image.open(os.path.join(experiment_path, *frame.replace('images', f'lama_images_output{suffix}').replace('.jpg', '_mask001.png').split('/')[-2:])))

    image_est = np.asarray(Image.open(os.path.join(experiment_path, eval_run, f'color_{str(i).zfill(3)}.png')))
    depth_est = np.asarray(Image.open(os.path.join(experiment_path, eval_run, f'distance_mean_{str(i).zfill(3)}.tiff')))

    # colorize uncertainty and depth map
    depth_est = (visualize_cmap(depth_est, np.ones_like(depth_est), cm.get_cmap('turbo'), curve_fn=depth_curve_fn).copy() * 255).astype(np.uint8)
    
    # rotate all images
    if rotate:
        image_gt = cv2.rotate(image_gt, cv2.ROTATE_90_CLOCKWISE)
        image_inpainted = cv2.rotate(image_inpainted,  cv2.ROTATE_90_CLOCKWISE)
        image_input = cv2.rotate(image_input,  cv2.ROTATE_90_CLOCKWISE)
        image_est = cv2.rotate(image_est,  cv2.ROTATE_90_CLOCKWISE)
        depth_est = cv2.rotate(depth_est,  cv2.ROTATE_90_CLOCKWISE)


    fig, ax = plt.subplots(1, 5, figsize=(5 * 12, 16))
    ax[0].imshow(image_gt)
    ax[0].set_title('Groundtruth Image', fontsize=60)
    ax[1].imshow(image_input)
    ax[1].set_title('Masked Image', fontsize=60)
    ax[2].imshow(image_inpainted)
    ax[2].set_title('Inpainted Image', fontsize=60)
    ax[3].imshow(image_est)
    ax[3].set_title('Rendered Image', fontsize=60)
    ax[4].imshow(depth_est)
    ax[4].set_title('Rendered Depth', fontsize=60)
    plt.tight_layout()
    plt.show()
    plt.close('all')
