In [None]:
import os
import sys

def set_root_path():
    if os.getcwd().endswith('figures'):
        os.chdir('../')
set_root_path()
sys.path.append('python/')

In [None]:
from pathlib import Path
import mitsuba as mi

mi.set_variant('cuda_ad_rgb')

from practical_reconstruction import optimization_cli
from core import integrators
from core import bsdfs
from core import textures

integrators.register()
bsdfs.register()
textures.register()

In [None]:
def format_float(f):
  """Formats a float such that 0.1 becomes "0_1", 10.0 becomes "10_0", etc."""
  return str(f).replace('.', '_')

scene_name = 'painting'
techniques = ['gradient_filtering', 'mipmap_pyramid']

skip_existing = True

gradient_filtering_params = [
    {'lr': 0.1,'sigma_d':0.1},
    {'lr': 0.025,'sigma_d':0.01},
    {'lr': 0.25,'sigma_d':0.5},
]
mipmap_pyramid_params = [
    {'lr': 0.005}
]

technique_configs = {
    'gradient_filtering': gradient_filtering_params,
    'mipmap_pyramid': mipmap_pyramid_params,
}

for technique in techniques:
    technique_params = technique_configs[technique]
    use_gradient_filtering = technique == 'gradient_filtering'

    for technique_param in technique_params:
        base_learning_rate = technique_param['lr'] 
        if use_gradient_filtering:
            sigma_d = technique_param['sigma_d']
            # For gradient filtering, recomended by authors
            filtering_steps = 4
        else:
            # unused
            sigma_d = 0.0
            filtering_steps = 0

        print(
            f'******** Running {technique} with base learning rate'
            f' {base_learning_rate} ********'
        )

        override_bindings = []
        result_folder = f'results/{scene_name}/{technique}'

        result_folder += f'_lr_{format_float(base_learning_rate)}'
        override_bindings.append(
            f'SceneConfig.base_learning_rate={base_learning_rate}'
        )

        # Ensure that the default tmp folder is used
        override_bindings.append("SceneConfig.tmp_folder=''")
        override_bindings.append(
            f'SceneConfig.use_gradient_filtering={use_gradient_filtering}'
        )

        if use_gradient_filtering:
            result_folder += f'_sigma_d_{format_float(sigma_d)}'
            result_folder += f'_F{filtering_steps}'
            override_bindings.append(
                f'SceneConfig.filtering_sigma_d={sigma_d}'
            )
            override_bindings.append(
                f'SceneConfig.a_trous_filtering_steps={filtering_steps}'
            )
            override_bindings.append(
                f'SceneConfig.log_domain_filtering={True}'
            )

        override_bindings.append(
            f"SceneConfig.result_folder='{result_folder}'"
        )

        if technique == 'gradient_filtering':
            gin_config_name = f'{scene_name}/naive'
        else:
            gin_config_name = f'{scene_name}/{technique}'

        print(f'Next result location: {result_folder}')
        if skip_existing and Path(result_folder).exists():
            print('Skipping, already present')
            continue

        # Run the config
        optimization_cli.run_config(gin_config_name, override_bindings)

# Figure starts here

In [None]:
from core import image_util
from core import mitsuba_io

from practical_reconstruction import figutils

import drjit as dr
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

In [None]:
# @title The real figure starts here

FIGURE_DIR = "figures/pdfs"
FIGURE_NAME = "gradient_filtering_limitation"

def l2_error(ref,img):
  return dr.mean(dr.square(ref-img)).array[0]

def load_images_and_errors():
  # Best l2 errors for combined, close, far
  image_exr_paths = [
      [
          "third_party/painting/references/ref_view_000.exr",
          "results/painting/gradient_filtering_lr_0_1_sigma_d_0_1_F4/frames/CameraClose_iter_127.exr",
          "results/painting/gradient_filtering_lr_0_025_sigma_d_0_01_F4/frames/CameraClose_iter_127.exr",
          "results/painting/gradient_filtering_lr_0_25_sigma_d_0_5_F4/frames/CameraClose_iter_127.exr",
          "results/painting/mipmap_pyramid_lr_0_005/frames/CameraClose_iter_127.exr",
      ],
      [
          "third_party/painting/references/ref_view_001.exr",
          "results/painting/gradient_filtering_lr_0_1_sigma_d_0_1_F4/frames/CameraFar_iter_127.exr",
          "results/painting/gradient_filtering_lr_0_025_sigma_d_0_01_F4/frames/CameraFar_iter_127.exr",
          "results/painting/gradient_filtering_lr_0_25_sigma_d_0_5_F4/frames/CameraFar_iter_127.exr",
          "results/painting/mipmap_pyramid_lr_0_005/frames/CameraFar_iter_127.exr",
      ],
  ]

  n_rows = len(image_exr_paths)
  n_cols = len(image_exr_paths[0])
  images_exr = []
  images = []
  boosts = [float(np.sqrt(2.0)),float(2.0)]
  for row in range(n_rows):
    col_images_exr = []
    col_images = []
    for col in range(n_cols):
      # col_images.append(media.read_image(image_paths[row][col]))
      exr = mitsuba_io.read_bitmap(image_exr_paths[row][col])
      tonemapped = image_util.tonemap(boosts[row]*mi.TensorXf(exr))
      col_images_exr.append(mi.TensorXf(exr))
      col_images.append(np.array(tonemapped))
    images_exr.append(col_images_exr)
    images.append(col_images)

  errors = []
  for row in range(n_rows):
    errors_row = []
    for col in range(1, n_cols):
      ref = images_exr[row][0]
      img = images_exr[row][col]
      error = l2_error(ref, img)
      errors_row.append(error)
    errors.append(errors_row)

  images_close, images_far = images
  return images_close, images_far, errors

In [None]:
crop_size_close = (500, 250)
crop_size_far = (200, 100)
crop_offset_close = (450, 745)
crop_offset_far = (474, 755)

images_close, images_far, errors = load_images_and_errors()
crop_images_close, crop_images_far = (
    [
        figutils.crop_image(im, crop_offset_close, crop_size_close)
        for im in images_close
    ],
    [
        figutils.crop_image(im, crop_offset_far, crop_size_far)
        for im in images_far
    ],
)

In [None]:
fig = plt.figure(figsize=(figutils.COLUMN_WIDTH, figutils.COLUMN_WIDTH * 0.67))
cols = 5
rows = 6
gs = fig.add_gridspec(
    rows,
    cols,
    figure=fig,
    height_ratios=[1, 0.5, 0.12, 1, 0.5, 0.12],
    width_ratios=[1, 1, 1, 1, 1],
    hspace=0.025,
    wspace=0.025,
)
line_width = 0.5
crop_close_color = "orange"
crop_far_color = "orange"

subtitle_fontsize = r"\fontsize{6}{12}\selectfont"
# Titles for the columns
titles = [
    "Reference\n"
    + subtitle_fontsize + r"\textsl{Hyperparameters}",
    subtitle_fontsize
    + " Opt. combined\n"
    + figutils.math_label(r"\textsl{lr= 0.1 / $\sigma$= 0.1"),
    subtitle_fontsize
    + " Opt. close\n"
    + figutils.math_label(r"\textsl{lr= 0.025 / $\sigma$= 0.01}"),
    subtitle_fontsize
    + " Opt. far\n"
    + figutils.math_label(r"\textsl{lr= 0.25 / $\sigma$= 0.5}"),
    "Ours\n" + figutils.math_label(r"\textsl{lr= 0.005}"),
]

image_close_axes = [
    fig.add_subplot(gs[0, 0]),  # "Reference"
    fig.add_subplot(gs[0, 1]),  # "Combined"
    fig.add_subplot(gs[0, 2]),  # "Close"
    fig.add_subplot(gs[0, 3]),  # "Far"
    fig.add_subplot(gs[0, 4]),  # "Ours"
]

crop_image_close_axes = [
    fig.add_subplot(gs[1, 0]),  # "Reference"
    fig.add_subplot(gs[1, 1]),  # "Combined"
    fig.add_subplot(gs[1, 2]),  # "Close"
    fig.add_subplot(gs[1, 3]),  # "Far"
    fig.add_subplot(gs[1, 4]),  # "Ours"
]

# text filler axes
dummy_axes = []
for i in [2, 5]:
  dummy_axes.append(fig.add_subplot(gs[i, 0]))
  dummy_axes.append(fig.add_subplot(gs[i, 1]))
  dummy_axes.append(fig.add_subplot(gs[i, 2]))
  dummy_axes.append(fig.add_subplot(gs[i, 3]))
  dummy_axes.append(fig.add_subplot(gs[i, 4]))

for ax in dummy_axes:
  ax.axis("off")

image_far_axes = [
    fig.add_subplot(gs[3, 0]),  # "Reference"
    fig.add_subplot(gs[3, 1]),  # "Combined"
    fig.add_subplot(gs[3, 2]),  # "Close"
    fig.add_subplot(gs[3, 3]),  # "Far"
    fig.add_subplot(gs[3, 4]),  # "Ours"
]

crop_image_far_axes = [
    fig.add_subplot(gs[4, 0]),  # "Reference"
    fig.add_subplot(gs[4, 1]),  # "Combined"
    fig.add_subplot(gs[4, 2]),  # "Close"
    fig.add_subplot(gs[4, 3]),  # "Far"
    fig.add_subplot(gs[4, 4]),  # "Ours"
]

scale = 1000
scale_txt = figutils.math_label(r"\text{$(\times 10^3)$}")


def error_format(error, scale):
  return f"{error*scale:.3f}"


# Close images
for i, ax_close in enumerate(image_close_axes):
  ax_close.set_title(titles[i], pad=(3.0 if i != 0 else 3))
  ax_close.imshow(images_close[i], aspect="equal")
  ax_close.imshow(images_close[i])
  figutils.disable_ticks(ax_close)
  # Close crop on reference
  if i == 0:
    rect_close = Rectangle(
        crop_offset_close,
        crop_size_close[0],
        crop_size_close[1],
        linewidth=line_width,
        edgecolor=crop_close_color,
        facecolor="none",
    )
    ax_close.add_patch(rect_close)
    ax_close.set_ylabel("Close-up view", labelpad=1.5)


# Close crops
for i, ax_close in enumerate(crop_image_close_axes):
  ax_close.imshow(crop_images_close[i], aspect="equal")
  figutils.disable_ticks(ax_close)
  ax_close.spines[:].set_color(crop_close_color)
  ax_close.spines[:].set_linewidth(line_width)
  if i == 0:
    label = r"$L_2$ error " + scale_txt
    ax_close.set_xlabel(label, labelpad=1.5)
  else:
    if i == 4:
      label = r"\textbf{" + error_format(errors[0][i - 1], scale) + r"}"
    else:
      label = error_format(errors[0][i - 1], scale)
    ax_close.set_xlabel(label, labelpad=1.5)

# Far images
for i, ax_far in enumerate(image_far_axes):
  ax_far.imshow(images_far[i], aspect="equal")
  figutils.disable_ticks(ax_far)
  # Far crop on reference
  if i == 0:
    rect_far = Rectangle(
        crop_offset_far,
        crop_size_far[0],
        crop_size_far[1],
        linewidth=line_width,
        edgecolor=crop_far_color,
        facecolor="none",
    )
    ax_far.add_patch(rect_far)
    ax_far.set_ylabel("Wide-angle view", labelpad=1.5)

# Far crops
for i, ax_far in enumerate(crop_image_far_axes):
  ax_far.imshow(crop_images_far[i], aspect="equal")
  figutils.disable_ticks(ax_far)
  ax_far.spines[:].set_color(crop_far_color)
  ax_far.spines[:].set_linewidth(line_width)
  if i == 0:
    label = "$L_2$ error " + scale_txt
    ax_far.set_xlabel(label, labelpad=1.5)
  else:
    if i == 4:
      label = r"\textbf{" + error_format(errors[1][i - 1], scale) + r"}"
    else:
      label = error_format(errors[1][i - 1], scale)
    ax_far.set_xlabel(label, labelpad=1.5)

# fig.subplots_adjust(wspace=0, hspace=0)
# left, bottom, right, top = [0.0, 0.0, 1.0,1.4]
# gs.tight_layout(fig, rect=[left, bottom, right, top], h_pad=0.0, w_pad=0.0)

line = plt.Line2D(
    [0.28, 0.745],
    [0.97, 0.97],
    color="black",
    solid_capstyle="round",
    transform=fig.transFigure,
    linewidth=0.75,
)
fig.add_artist(line)
fig.text(
    0.5,
    0.985,
    r"\fontsize{8}{12}\selectfont " + f"{figutils.GRAD_FILTERING_NAME_LONG}",
    ha="center",
    color="black",
)

figutils.force_post_crop_size(fig, figutils.COLUMN_WIDTH)

In [None]:
figutils.savefig(
    fig,
    name=Path(FIGURE_NAME),
    fig_directory=Path(FIGURE_DIR),
    dpi=300,
    pad_inches=0.005,
    bbox_inches="tight",
    compress=False,
    target_width=figutils.COLUMN_WIDTH,
    backend=None,
)