In [None]:
import os
import sys

def set_root_path():
    if os.getcwd().endswith('figures'):
        os.chdir('../')
set_root_path()
sys.path.append('python/')

In [None]:
from pathlib import Path

import mitsuba as mi
import drjit as dr

mi.set_variant("cuda_ad_rgb")

from practical_reconstruction import scene_configuration
from practical_reconstruction import scene_preparation
from practical_reconstruction import io_utils
from practical_reconstruction import optimization
from practical_reconstruction import figutils

from core import integrators
from core import bsdfs
from core import textures

integrators.register()
bsdfs.register()
textures.register()


# Main optimization

In [None]:
# Always run without diffuse switch for the reference if not already computed!
use_diffuse_switches = [False,True]

for use_diffuse_switch in use_diffuse_switches:

  lr = 0.005

  scene_name = 'statue'
  technique = 'mipmap_pyramid'

  result_folder = f'results/{scene_name}/{technique}'
  scene_folder = f'third_party/{scene_name}'

  if use_diffuse_switch:
    result_folder += '_diffuse_switch'
  else:
    result_folder += '_rough_dielectric'
  result_folder += '_lr_' + str(lr).replace('.', '_')

  override_bindings = []
  override_bindings.append(f"SceneConfig.result_folder='{result_folder}'")
  override_bindings.append(f"SceneConfig.scene_folder='{scene_folder}'")
  override_bindings.append(f"SceneConfig.base_learning_rate={0.005}")

  # Start at a higher albedo to compensate the lower energy without switch to diffuse
  if not use_diffuse_switch:
    override_bindings.append("SceneConfig.sss_diffuse_switch=%DiffuseSwitch.NONE")
    override_bindings.append("SceneConfig.n_iter=512")

  scene_config = scene_configuration.SceneConfig.get_instance(
    f'{scene_name}/{technique}', override_bindings,sss_config=True
  )

  Path(scene_config.result_folder).mkdir(parents=True,exist_ok=True)

  print('Preparing Mitsuba scene for optimization')
  tmp_mitsuba_xml = io_utils.mitsuba_remote_to_local(scene_config)

  scene = scene_preparation.load_mitsuba_scene(scene_config, tmp_mitsuba_xml)
  params = mi.traverse(scene)

  emitter_keys = scene_preparation.get_emitter_keys(scene_config, params)

  print('Preparing references and sensors for optimization')
  # /!\ The reference uses 10'000 samples per pixel so this may take a while!
  sensors, references = (
      scene_preparation.generate_references_and_retrieve_sensors(
          scene, scene_config, emitter_keys
      )
  )

  all_sensors, all_references = scene_preparation.create_intermediate_resolution(
      scene_config, sensors, references
  )

  print('Preparing optimization variables')
  optimized_keys = scene_preparation.get_scene_keys_for_optimization(
      scene, scene_config
  )
  scene_preparation.initialize_optimized_parameter(
      scene_config, params, optimized_keys
  )
  variables = scene_preparation.create_variables(
      params, optimized_keys, scene_config
  )

  assert scene_config.sss_optimization
  integrator = mi.load_dict({
      'type': 'prb_nee_volume',
      'max_sss_depth': -1,
      'max_path_depth': scene_config.optimized_path_depth,
  })

  print('Starting optimization')
  mts_variables, loss_values, opt, frames = optimization.optimize(
      scene_config,
      scene,
      all_sensors,
      all_references,
      emitter_keys,
      integrator,
      params,
      variables,
  )

  optimization.save_texture_results(params, optimized_keys, scene_config)
  optimization.save_optimization_videos(scene_config, frames)
  optimization.save_loss_data(scene_config, loss_values, emitter_keys)

# Figure starts here

In [None]:
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.patheffects as pe

from core import image_util
from core import mitsuba_io


def figure_grid_setup(image_shape,image_crop_shape,inner_space=0.0,outer_space=0.1):
  # Image aspect ratios
  h, w = image_shape

  top_inner_rows = 1
  top_inner_cols = 3
  # Spacing in the inner gridspec
  top_inner_wspace = inner_space
  # same vertical spacing as horizontal spacing
  top_inner_hspace = top_inner_wspace * figutils.gridspec_aspect(
      n_rows=1, n_cols=1, w=[w]*top_inner_cols, h=[h]
  )
  top_height_ratios = [h]
  top_inner_aspect = figutils.gridspec_aspect(
      n_rows=top_inner_rows,
      n_cols=top_inner_cols,
      w=[w] * top_inner_cols,
      h=top_height_ratios,
      wspace=top_inner_wspace,
      hspace=top_inner_hspace,
  )
  # Spacing in the main griddpec
  outer_rows = 1
  outer_cols = 1
  outer_wspace = 0.0
  outer_hspace = outer_space
  outer_aspect = figutils.gridspec_aspect(
      n_rows=outer_rows,
      n_cols=outer_cols,
      w=1,
      h=[1 / top_inner_aspect],
      wspace=outer_wspace,
      hspace=outer_hspace,
  )

  fig = plt.figure(
      1, figsize=(figutils.COLUMN_WIDTH, figutils.COLUMN_WIDTH / outer_aspect)
  )

  outer_gs = fig.add_gridspec(
      outer_rows,
      outer_cols,
      hspace=outer_hspace,
      wspace=outer_wspace,
      height_ratios=[1 / top_inner_aspect],
      width_ratios=[1]*outer_cols,
  )

  top_inner_gs = gridspec.GridSpecFromSubplotSpec(
      top_inner_rows,
      top_inner_cols,
      subplot_spec=outer_gs[0],
      wspace=top_inner_wspace,
      hspace=top_inner_wspace,
      width_ratios=[h] * top_inner_cols,
      height_ratios=top_height_ratios,
  )

  return (
      fig,
      (top_inner_gs, top_inner_rows, top_inner_cols),
  )

In [None]:
scene_path = 'third_party/statue'
result_path = 'results/statue'

diffuse_folder = 'mipmap_pyramid_diffuse_switch_lr_0_005'
roughdielectric_folder = 'mipmap_pyramid_rough_dielectric_lr_0_005'

start_iter = 15
end_iter_diffuse = 127
end_iter_roughdielectric = 511


boost = np.sqrt(2)

ref_image = np.array(image_util.tonemap(boost * mi.TensorXf(mitsuba_io.read_bitmap(f"{scene_path}/references/ref_view_000.exr"))))

init_state = np.array(image_util.tonemap(boost * mi.TensorXf(mitsuba_io.read_bitmap(f"{result_path}/{diffuse_folder}/frames/Camera_iter_{0:03d}.exr"))))
start_diffuse = np.array(image_util.tonemap(boost * mi.TensorXf(mitsuba_io.read_bitmap(f"{result_path}/{diffuse_folder}/frames/Camera_iter_{start_iter:03d}.exr"))))
end_diffuse = np.array(image_util.tonemap(boost * mi.TensorXf(mitsuba_io.read_bitmap(f"{result_path}/{diffuse_folder}/frames/Camera_iter_{end_iter_diffuse:03d}_spp_2048.exr"))))
start_roughdielectric = np.array(image_util.tonemap(boost * mi.TensorXf(mitsuba_io.read_bitmap(f"{result_path}/{roughdielectric_folder}/frames/Camera_iter_{start_iter:03d}.exr"))))
end_roughdielectric = np.array(image_util.tonemap(boost * mi.TensorXf(mitsuba_io.read_bitmap(f"{result_path}/{roughdielectric_folder}/frames/Camera_iter_{end_iter_roughdielectric:03d}_spp_2048.exr"))))

column_images = [ref_image, start_diffuse, end_diffuse, start_roughdielectric, end_roughdielectric]

images_start = [start_roughdielectric, start_diffuse]
images_end = [end_roughdielectric, end_diffuse]

In [None]:
FIGURE_DIR = "figures/pdfs"
FIGURE_NAME = "diffuse_switch_comparison"

top_row_images = [ref_image, init_state]

titles = ["Reference", "Without diffuse exit", "With diffuse exit"]

labels = [f"24m 16s", f"2m 47s"]

(
    fig,
    (top_inner_gs, top_inner_rows, top_inner_cols),
) = figure_grid_setup(
    ref_image.shape[:2], ref_image.shape[:2], inner_space=0.01, outer_space=0.1
)

line_width = 0.5
aspect = ref_image.shape[1] / ref_image.shape[0]


def add_iter_texts(ax, img_shape, end_iter, time, linewidth=0.75):
  ax.text(
      img_shape[1] - 10,
      10,
      f"Iter. {end_iter}" + f"\n{time}",
      color="white",
      ha="right",
      va="top",
      fontsize=8,
      path_effects=[pe.withStroke(linewidth=linewidth, foreground="black")],
  )
  ax.text(
      10,
      img_shape[0]-10,
      f"Iter. {32}",
      color="white",
      ha="left",
      va="bottom",
      fontsize=8,
      path_effects=[pe.withStroke(linewidth=linewidth, foreground="black")],
  )

for row in range(top_inner_rows):
  for col in range(top_inner_cols):
    row_gs = top_inner_gs
    ax = fig.add_subplot(row_gs[col])
    figutils.disable_ticks(ax)

    ax.set_title(titles[col], pad=3)
    if col == 0:
      ax.imshow(top_row_images[col], aspect="equal")
    else:
      img_combined, xline, yline = figutils.diagonal_split_image(
          images_start[col-1],
          images_end[col-1],
          offset=0,
          angle=-20,
      )
      ax.plot(xline, yline, color="black", linewidth=line_width)
      ax.imshow(img_combined, aspect="equal")
      add_iter_texts(
          ax,
          img_combined.shape,
          end_iter=end_iter_diffuse + 1
          if col == 2
          else end_iter_roughdielectric + 1,
          time=labels[col-1],
      )

figutils.force_post_crop_size(fig, figutils.COLUMN_WIDTH)

In [None]:
figutils.savefig(
    fig,
    name=Path(FIGURE_NAME),
    fig_directory=Path(FIGURE_DIR),
    dpi=300,
    pad_inches=0.005,
    bbox_inches="tight",
    compress=False,
    target_width=figutils.COLUMN_WIDTH,
    backend=None,
)