In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import json
import cv2
import matplotlib.cm as cm
import math
from PIL import Image


# load test frame,
def load_test_frame_files(file):
    with open(file, 'r') as f:
        meta = json.load(f)
    fs = meta['frames']
    fs = sorted(fs, key=lambda d: d['file_path'])
    
    frames = []
    for frame in fs:
        frames.append(frame['file_path'])
    return frames

def format_axes(axes):
    for ax in axes:
        if type(ax) is np.ndarray:
            format_axes(ax)
        else:
            ax.set_xticks([])
            ax.set_yticks([])
            
def weighted_percentile(x, w, ps, assume_sorted=False):
    """Compute the weighted percentile(s) of a single vector."""
    x = x.reshape([-1])
    w = w.reshape([-1])
    if not assume_sorted:
        sortidx = np.argsort(x)
    x, w = x[sortidx], w[sortidx]
    acc_w = np.cumsum(w)
    return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x)
            
def visualize_cmap(value,
                   weight,
                   colormap,
                   lo=None,
                   hi=None,
                   percentile=99.,
                   curve_fn=lambda x: x,
                   modulus=None,
                   matte_background=True):
    """Visualize a 1D image and a 1D weighting according to some colormap.

    Args:
    value: A 1D image.
    weight: A weight map, in [0, 1].
    colormap: A colormap function.
    lo: The lower bound to use when rendering, if None then use a percentile.
    hi: The upper bound to use when rendering, if None then use a percentile.
    percentile: What percentile of the value map to crop to when automatically
      generating `lo` and `hi`. Depends on `weight` as well as `value'.
    curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If
      `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.
    matte_background: If True, matte the image over a checkerboard.

    Returns:
    A colormap rendering.
    """
    # Identify the values that bound the middle of `value' according to `weight`.
    lo_auto, hi_auto = weighted_percentile(
      value, weight, [50 - percentile / 2, 50 + percentile / 2])

    # If `lo` or `hi` are None, use the automatically-computed bounds above.
    eps = np.finfo(np.float32).eps
    lo = lo or (lo_auto - eps)
    hi = hi or (hi_auto + eps)

    # Curve all values.
    value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]

    # Wrap the values around if requested.
    if modulus:
        value = np.mod(value, modulus) / modulus
    else:
        # Otherwise, just scale to [0, 1].
        value = np.nan_to_num(
        np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1))

    if colormap:
        colorized = colormap(value)[:, :, :3]
    else:
        assert len(value.shape) == 3 and value.shape[-1] == 3
        colorized = value

    return colorized

depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps)

In [None]:
# set path to experiments and groundtruth data here
EXPERIMENT_DIR = '/mnt/res_nas/silvanweder/experiments' # set this to where you saved your experiments to
GROUNDTRUTH_DIR = '/mnt/res_nas/silvanweder/datasets/object-removal-custom-clean' # set this to where you saved your data

In [None]:
print('Available Experiments:')
for exp in os.listdir(EXPERIMENT_DIR):
    if exp.startswith('.'):
        continue
    print('\t -', exp)

In [None]:
# select experiments from available experiments above
experiment = 'final_tests_real'
suffix = '_real' # set to either _real or _synthetic depending on what masks you used

In [None]:
print('Available Sequences:')
for sc in os.listdir(os.path.join(EXPERIMENT_DIR, experiment)):
    print(f'\t- {sc}')

In [None]:
# set available sequence here
sequence = '003'

In [None]:
experiment_path = os.path.join(EXPERIMENT_DIR, experiment, sequence)
groundtruth_path = os.path.join(GROUNDTRUTH_DIR, sequence)

In [None]:
rotate = True

In [None]:
image_files = sorted(os.listdir(os.path.join(experiment_path, f'lama_images_output{suffix}')))
for imf in image_files:
    frame_id = imf.split('_')[0]
    
    img_path = os.path.join(groundtruth_path, f'images', f'{frame_id}.jpg')
    img_inp_path = os.path.join(experiment_path, f'lama_images_output{suffix}', imf)
    mask_path = os.path.join(groundtruth_path, f'masks{suffix}', f'{frame_id}.npy')
    print(mask_path)
    
    img = np.asarray(Image.open(img_path))
    img = cv2.resize(img, (256, 192))
    
    mask = np.load(mask_path)
    img_inp = np.asarray(Image.open(img_inp_path))
    
    
    
    img_masked = img.copy()
    img_masked[mask == 1] = (255, 255, 255)
    
    if rotate:
        img = np.rot90(img, k=-1)
        img_masked = np.rot90(img_masked, k=-1)
        img_inp =  np.rot90(img_inp, k=-1)
    
    fig, ax = plt.subplots(1, 3)
    ax[0].imshow(img)
    ax[1].imshow(img_masked)
    ax[2].imshow(img_inp)
    plt.axis('off')
    plt.show()
    plt.close('all')