In [None]:
import os
import sys

def set_root_path():
    if os.getcwd().endswith('figures'):
        os.chdir('../')
set_root_path()
sys.path.append('python/')

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import gridspec
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import mitsuba as mi
import drjit as dr
import numpy as np

from practical_reconstruction import figutils
from practical_reconstruction import scene_configuration
from practical_reconstruction import io_utils

mi.set_variant("cuda_ad_rgb")

from core import integrators
from core import bsdfs

integrators.register()
bsdfs.register()

In [None]:
def load_scene_config():
  scene_name = 'dragon'
  technique = 'mipmap_pyramid'

  result_folder = f'results/{scene_name}/{technique}'
  scene_folder = f'third_party/{scene_name}'

  result_folder += '_test'

  override_bindings = []
  override_bindings.append(f"SceneConfig.result_folder='{result_folder}'")
  override_bindings.append(f"SceneConfig.scene_folder='{scene_folder}'")

  override_bindings.append(f'SceneConfig.base_learning_rate={0.1}')

  scene_config = scene_configuration.SceneConfig.get_instance(
      f'{scene_name}/{technique}', override_bindings, sss_config=True
  )
  return scene_config


def load_modified_scene(scene_config, res=[400, 240]):
  os.makedirs(scene_config.result_folder,exist_ok=True)

  tmp_mitsuba_xml = io_utils.mitsuba_remote_to_local(scene_config,override=True)
  scene = mi.load_file(tmp_mitsuba_xml)
  params = mi.traverse(scene)

  params['Camera.film.size'] = res
  params['Camera.film.crop_size'] = res
  params.update()
  return scene

In [None]:
def compute_finite_difference(
    scene,
    integrator,
    param_type,
    mat_key,
    epsilon,
    channel=-1,
    spp=4096,
    q=0.99,
):
  if param_type == 'albedo':
    key = f'{mat_key}.single_scattering_albedo.value'
  elif param_type == 'extinction':
    key = f'{mat_key}.extinction_coefficient.value'
  elif param_type == 'hg':
    key = f'{mat_key}.hg_coefficient.value'

  params = mi.traverse(scene)
  original_param = mi.TensorXf(params[key])
  epsilon_tensor = np.ones_like(original_param) * epsilon

  # Set per-channel epsilon if needed
  for c in range(3):
    if c != channel and channel != -1:
      epsilon_tensor[c, :] = 0.0

  # Apply - epsilon
  params[key] = mi.TensorXf(original_param) - epsilon_tensor
  params.update()

  img1 = mi.render(scene, integrator=integrator, spp=spp, seed=0)

  # Apply + epsilon
  params[key] = mi.TensorXf(original_param) + epsilon_tensor
  params.update()

  img2 = mi.render(scene, integrator=integrator, spp=spp, seed=0)

  params[key] = mi.TensorXf(original_param)
  params.update()

  img_fd = ((img2 - img1) / (2 * epsilon)).numpy()
  vlim = np.quantile(np.abs(img_fd), q=q)

  gradient_image = np.sum(img_fd, axis=2)
  return gradient_image, vlim

In [None]:
def compute_forward_gradient(
    scene,
    integrator,
    param_type,
    mat_key,
    channel=-1,
    spp=4096,
    compute_mean=False,
    q=0.99
):
  params = mi.traverse(scene)

  if param_type == 'albedo':
    key = f'{mat_key}.single_scattering_albedo.value'
  elif param_type == 'extinction':
    key = f'{mat_key}.extinction_coefficient.value'
  elif param_type == 'hg':
    key = f'{mat_key}.hg_coefficient.value'

  if channel != -1:
    dr.enable_grad(params[key][channel])
  else:
    dr.enable_grad(params[key])

  image = mi.render(scene, integrator=integrator, params=params, spp=spp)
  dr.forward(params[key])
  grad_image = dr.grad(image)

  if compute_mean:
    return dr.mean(dr.mean(grad_image, axis=0), axis=0).numpy()

  vlim = np.quantile(np.abs(grad_image.numpy()), q=q)

  gradient_image = np.sum(grad_image, axis=2)
  return gradient_image, vlim

In [None]:
# Display reference image
scene_config = load_scene_config()
scene = load_modified_scene(scene_config,res=[800,480])
integrator = mi.load_dict({'type': 'prb_path_volume','max_sss_depth':256,'max_path_depth':5,'dwivedi_guiding': False})
ref_image = figutils.tonemap(mi.render(scene,integrator=integrator, spp=512, seed=0))

In [None]:
# Compute all gradient and finite difference images

configs = [
    ('albedo', 0.01, -1, 1024, 1024),
    ('extinction', 0.5, -1, 8096*2, 4096),
    ('hg', 0.1, -1, 8096*2, 4096),
]
mat_key = 'mat-Material.001'

recompute = True
if recompute:
  forward_images = [None] * len(configs)
  forward_vlims = [None] * len(configs)
  fd_images = [None] * len(configs)
  fd_vlims = [None] * len(configs)
quantile = 0.95

skip_indices = []
for i, (param_type, epsilon, channel, spp_fd, spp_grad) in enumerate(configs):
  if i in skip_indices:
    continue
  print(f"Computing gradients for {param_type}")

  scene = load_modified_scene(scene_config)

  forward_grad, forward_vlim = compute_forward_gradient(
      scene,
      integrator=integrator,
      param_type=param_type,
      mat_key=mat_key,
      spp=spp_grad,
      channel=channel,
      q=quantile,
  )
  forward_vlims[i] = forward_vlim
  forward_images[i] = forward_grad

  fd_grad, fd_vlim = compute_finite_difference(
      scene,
      integrator=mi.load_dict({'type': 'prb_path_volume','max_sss_depth':256,'max_path_depth':5,'dwivedi_guiding': False}),
      param_type=param_type,
      mat_key=mat_key,
      spp=spp_fd,
      epsilon=epsilon,
      channel=channel,
      q=quantile,
  )
  fd_images[i] = fd_grad
  fd_vlims[i] = fd_vlim

In [None]:
# Figure setup

def figure_grid_setup(
    fig_width,
    image_shape,
    image_crop_shape,
    inner_hspace=0.0,
    inner_wspace=0.0,
    outer_space=0.1,
):
  # Image aspect ratios
  h, w = image_shape
  h_crop, w_crop = image_crop_shape
  r = w / h
  r_crop = w_crop / h_crop

  # Spacing in the inner gridspec
  inner_wspace = inner_hspace
  # same vertical spacing as horizontal spacing
  inner_hspace = inner_wspace
  inner_rows = 2
  inner_cols = 3
  inner_height_ratios = [h_crop, h_crop]
  inner_width_ratios = [w_crop, w_crop, w_crop]
  inner_aspect = figutils.gridspec_aspect(
      n_rows=inner_rows,
      n_cols=inner_cols,
      w=inner_width_ratios,
      h=inner_height_ratios,
      wspace=inner_wspace,
      hspace=inner_hspace,
  )
  # Spacing in the main griddpec
  outer_rows = 2
  outer_cols = 1
  outer_wspace = 0
  outer_hspace = outer_space
  # If width is 1, we need the sum of the inverses for the height (single column)
  # If height is 1, we need the sum for the width (single row)
  outer_aspect = figutils.gridspec_aspect(
      n_rows=outer_rows,
      n_cols=outer_cols,
      w=[1],
      h=[1 / inner_aspect, 1 / r],
      wspace=outer_wspace,
      hspace=outer_hspace,
  )
  outer_aspect *= 0.98

  fig = plt.figure(1, figsize=(fig_width, fig_width / outer_aspect))

  outer_gs = fig.add_gridspec(
      outer_rows,
      outer_cols,
      hspace=outer_hspace,
      wspace=outer_wspace,
      height_ratios=[1 / inner_aspect, 1 / r],
      width_ratios=[1],
  )

  inner_gs = gridspec.GridSpecFromSubplotSpec(
      inner_rows,
      inner_cols,
      subplot_spec=outer_gs[0],
      wspace=inner_wspace,
      hspace=inner_hspace,
      width_ratios=inner_width_ratios,
      height_ratios=inner_height_ratios,
  )
  return (
      fig,
      outer_gs,
      inner_gs,
      inner_rows,
      inner_cols,
  )

In [None]:
FIGURE_DIR = "figures/pdfs"
FIGURE_NAME = "volume_gradients"

ref_crop_offset = (26, 22)
ref_crop_size = (739, 432)

grad_crop_offset = (18, 13)
grad_crop_size = (363, 215)

cropped_ref_image = figutils.crop_image(ref_image,ref_crop_offset,ref_crop_size)
cropped_forward_images = [figutils.crop_image(im,grad_crop_offset,grad_crop_size) for im in forward_images]
cropped_fd_images = [figutils.crop_image(im,grad_crop_offset,grad_crop_size) for im in fd_images]

titles = [
    "Albedo",
    "Extinction",
    "Phase function (HG)",
    r"\textsc{Dragon}",
]
row_titles = ["Our forward", "Finite diff."]

params_label = [
    r"$\rho=[0.2, 0.4, 0.95]$",
    r"$\sigma_t=[15,7,5]$",
    r"$g=[-0.5, -0.2, 0.1]$",
]

(
    fig,
    outer_gs,
    inner_gs,
    inner_rows,
    inner_cols,
) = figure_grid_setup(
    figutils.COLUMN_WIDTH,
    ref_image.shape[:2],
    cropped_forward_images[0].shape[:2],
    inner_hspace=0.02,
    inner_wspace=0.02,
    outer_space=0.0,
)

title_pad = 3
label_pad = 2
line_width = 0.75

strengths = [4.5, 4.5, 4.5]
# Grads
if True:
  for row in range(inner_rows):
    for col in range(inner_cols):
      ax = fig.add_subplot(inner_gs[row, col])
      figutils.disable_ticks(ax)
      if col == 0:
        ax.set_ylabel(row_titles[row], labelpad=label_pad)
      # Our
      if row == 0:
        ax.set_title(titles[col], pad=title_pad)
        im = ax.imshow(
            cropped_forward_images[col],
            cmap=cm.coolwarm,
            vmin=-forward_vlims[col] * strengths[col],
            vmax=forward_vlims[col] * strengths[col],
        )
      elif row == 1:
        im = ax.imshow(
            cropped_fd_images[col],
            cmap=cm.coolwarm,
            vmin=-forward_vlims[col] * strengths[col],
            vmax=forward_vlims[col] * strengths[col],
        )

# Reference
if True:
  ax = fig.add_subplot(outer_gs[1])
  im = ax.imshow(cropped_ref_image)
  ax.set_ylabel(titles[-1], labelpad=label_pad)
  # ax.set_title(r"\vspace{250px}" + r"\;\;\;\;\;\;\;\;\;".join(params_label),pad=0)
  ax.set_xlabel(r"\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;".join(params_label),labelpad=label_pad)
  figutils.disable_ticks(ax)

  ax = fig.add_subplot(outer_gs[0])
  ax.axis('off')
  # Create an inset axes
  inset = inset_axes(
      ax,  # Parent axes
      width="1%",  # Width of the inset in percentage of the parent axis
      height="99.19%",  # Height of the inset in percentage of the parent axis
      loc="lower left",  # Location inside the parent axis
      bbox_to_anchor=(1.005, 0.005, 1, 1),  # Bounding box adjustment
      bbox_transform=ax.transAxes,  # Use axes coordinates
      borderpad=0,
  )  # Padding

  # Create the colorbar
  cbar = fig.colorbar(
      cm.ScalarMappable(
          norm=Normalize(
              vmin=-np.max(forward_vlims), vmax=np.max(forward_vlims)
          ),
          cmap=cm.coolwarm,
      ),
      cax=inset,
      orientation="vertical",
  )
  cbar.ax.yaxis.set_label_position("left")  # Move label to the left
  # cbar.ax.yaxis.tick_left()  # Move ticks to the left side
  cbar.set_label("Parameter gradient", labelpad=-9)
  ticks = [cbar.vmin, cbar.vmax]
  cbar.set_ticks(ticks)
  # cbar.set_ticklabels([f"{x:.2f}" for x in ticks])
  cbar.set_ticklabels(["neg", "pos"])
  cbar.ax.tick_params(pad=1.8, length=0)

# fig.tight_layout()
figutils.force_post_crop_size(fig, figutils.COLUMN_WIDTH)

In [None]:
figutils.savefig(
    fig,
    name=Path(FIGURE_NAME),
    fig_directory=Path(FIGURE_DIR),
    dpi=300,
    pad_inches=0.005,
    bbox_inches="tight",
    compress=False,
    target_width=figutils.COLUMN_WIDTH,
    backend=None,
)