In [1]:
import os

import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import skimage.io

from tqdm import tqdm

In [2]:
exp_name = 'realestate10k_dscale2_stride4ft_lowerL1_200'
n_iter = 50000
split = 'realestate10k_test'

save_fig_dir = '/private/home/ronghanghu/workspace/mmf_nr/save/paper_fig/realestate10k_psnr_new'
os.makedirs(save_fig_dir, exist_ok=True)

In [3]:
def plot_psnr(input, pred, gt, method="", compare=False):
    error = (pred - gt) ** 2
    mse_err = np.mean(error)
    psnr = 10 * np.log10(1 / mse_err)

    n_row = 2 if compare else 1
#     plt.subplot(n_row, 4, 2)
#     plt.imshow(input)
#     plt.title(f"input view", fontsize=32)
#     plt.axis("off")

    ax = plt.subplot(n_row, 4, 6)
    plt.imshow(gt)
    plt.xlabel(f"GT target view", fontsize=32, fontname='serif')
#     plt.axis('off')
    turn_off_axis_color()

    ax = plt.subplot(n_row, 4, 7)
    plt.imshow(pred)
    plt.xlabel(f"{method}", fontsize=32, fontname='serif')
#     plt.axis('off')
    turn_off_axis_color()

    ax = plt.subplot(n_row, 4, 3)
    plt.imshow(pad_error_map(1 - np.minimum(error * 4, 1)))
    plt.xlabel(f"sqr. error\n(PSNR: {psnr:.1f})", fontsize=32, fontname='serif')
#     plt.axis('off')
    turn_off_axis_color()
    

def plot_psnr_compare(input, pred, gt, method=""):
    error = (pred - gt) ** 2
    mse_err = np.mean(error)
    psnr = 10 * np.log10(1 / mse_err)

    ax = plt.subplot(2, 4, 2)
    plt.imshow(input)
    plt.xlabel(f"input view", fontsize=32, fontname='serif')
#     plt.axis('off')
    turn_off_axis_color()

#     plt.subplot(2, 4, 6)
#     plt.imshow(gt)
#     plt.title(f"GT target view", fontsize=32)
#     plt.axis("off")

    ax = plt.subplot(2, 4, 8)
    plt.imshow(pred)
    plt.xlabel(f"{method}", fontsize=32, fontname='serif')
#     plt.axis('off')
    turn_off_axis_color()

    ax = plt.subplot(2, 4, 4)
    plt.imshow(pad_error_map(1 - np.minimum(error * 4, 1)))
    plt.xlabel(f"sqr. error\n(PSNR: {psnr:.1f})", fontsize=32, fontname='serif')
#     plt.axis('off')
    turn_off_axis_color()

    
def pad_error_map(error, c=230/255):
    error[0, :] = c
    error[-1, :] = c
    error[:, 0] = c
    error[:, -1] = c
    return error


def turn_off_axis_color():
    plt.xticks([])
    plt.yticks([])
    plt.gca().spines['bottom'].set_color('#ffffff')
    plt.gca().spines['top'].set_color('#ffffff') 
    plt.gca().spines['right'].set_color('#ffffff')
    plt.gca().spines['left'].set_color('#ffffff')

In [4]:
plt.close('all')
for idx in tqdm(range(0, 1000)):
    ours_png_file = f'/private/home/ronghanghu/workspace/mmf_nr/save/prediction_synsin_realestate10k/{exp_name}/{n_iter}/{split}/{idx:04d}/output_image_.png'
    synsin_png_file = f"/private/home/ronghanghu/workspace/synsin/results_realestate10K_short/{idx:04d}/output_image_.png"
    input_png_file = f"/private/home/ronghanghu/workspace/synsin/results_realestate10K_short/{idx:04d}/input_image_.png"
    tgt_png_file = f"/private/home/ronghanghu/workspace/synsin/results_realestate10K_short/{idx:04d}/tgt_image_.png"

    im_ours = skimage.img_as_float32(skimage.io.imread(ours_png_file))
    im_synsin = skimage.img_as_float32(skimage.io.imread(synsin_png_file))
    im_input = skimage.img_as_float32(skimage.io.imread(input_png_file))
    im_tgt = skimage.img_as_float32(skimage.io.imread(tgt_png_file))

    plt.figure(figsize=(20, 11.5))
    plt.subplots_adjust(wspace=0.05)
    plt.subplots_adjust(hspace=0.2)
    plot_psnr(im_input, im_ours, im_tgt, "ours", compare=True)
    plot_psnr_compare(im_input, im_synsin, im_tgt, "SynSin")
    plt.savefig(f"{save_fig_dir}/{idx:04d}.pdf", bbox_inches="tight")
    plt.close('all')

100%|██████████| 1000/1000 [09:16<00:00,  1.80it/s]
