Copyright 2018 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import functools
import jax.experimental.optimizers
import time
import flax
import flax.linen as nn
from typing import Sequence, Callable
from IPython.display import clear_output
import cv2
import imageio
import mediapy as media

import os



In [None]:
def linear_to_srgb(linear):
  """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
  eps = jnp.finfo(jnp.float32).eps
  srgb0 = 323 / 25 * linear
  srgb1 = (211 * jnp.maximum(eps, linear)**(5 / 12) - 11) / 200
  return jnp.where(linear <= 0.0031308, srgb0, srgb1)

def read_envmap(filename):
  with open(filename, 'rb') as f:
    return imageio.imread(f, 'exr')

#envmap_linear = read_envmap(f'{DIRECTORY}}/ninomaru_teien_4k.exr')
#envmap_linear = read_envmap(f'{DIRECTORY}/spruit_sunrise_4k.exr')
#envmap_linear = read_envmap(f'{DIRECTORY}/hotel_room_4k.exr')
#envmap_linear = read_envmap(f'{DIRECTORY}/spruit_sunrise_50x99.exr')
envmap_linear = read_envmap(f'{DIRECTORY}/hotel_room_50x99.exr')

envmap_linear = np.fliplr(envmap_linear)  # Blender flips this for some reason
#envmap_linear = np.roll(envmap_linear, envmap_linear.shape[1]//2, axis=1)
envmap_srgb = linear_to_srgb(envmap_linear)

plt.imshow(envmap_srgb)

In [None]:
envmap_H = 50
envmap_W = envmap_H * 2 - 1
#envmap_H, envmap_W = envmap_linear.shape[:2]

envmap_gt = envmap_linear
#envmap_gt = linear_to_srgb(cv2.resize(envmap_linear, dsize=(envmap_W, envmap_H), interpolation=cv2.INTER_AREA))
#envmap_gt = cv2.resize(envmap_linear, dsize=(envmap_W, envmap_H), interpolation=cv2.INTER_AREA)
plt.imshow(envmap_gt)
plt.axis('off')

In [None]:
#with open(f'{DIRECTORY}/hotel_room_{envmap_H}x{envmap_W}.exr', 'wb') as f:
#  imageio.imsave(f, envmap_gt, 'exr')


In [None]:
#envmap_H, envmap_W = envmap_linear.shape[:2]
#omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1],
#                                      jnp.linspace(0.0,     jnp.pi, envmap_H+1)[:-1])
omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1] + 2.0 * jnp.pi / (2.0 * envmap_W),
                                      jnp.linspace(0.0,     jnp.pi, envmap_H+1)[:-1] +       jnp.pi / (2.0 * envmap_H))

dtheta_dphi = (omega_theta[1, 1] - omega_theta[0, 0]) * (omega_phi[1, 1] - omega_phi[0, 0])

omega_theta = omega_theta.flatten()
omega_phi = omega_phi.flatten()

omega_x = jnp.sin(omega_theta) * jnp.cos(omega_phi)
omega_y = jnp.sin(omega_theta) * jnp.sin(omega_phi)
omega_z = jnp.cos(omega_theta)
omega_xyz = jnp.stack([omega_x,
                       omega_y,
                       omega_z], axis=-1)



In [None]:
def mse_to_psnr(mse):
  """Compute PSNR given an MSE (we assume the maximum pixel value is 1)."""
  return -10. / jnp.log(10.) * jnp.log(mse)
  
def get_rays(H, W, focal, c2w, rand_ort=False, key=None):
  """
  c2w: 4x4 matrix
  output: two arrays of shape [H, W, 3]
  """
  j, i = jnp.meshgrid(jnp.arange(W, dtype=jnp.float32),
                      jnp.arange(H, dtype=jnp.float32))
  
  if rand_ort:
    k1, k2 = random.split(key)

    i += jax.random.uniform(k1, shape=(H, W)) - 0.5
    j += jax.random.uniform(k2, shape=(H, W)) - 0.5
      
  dirs = jnp.stack([ (j.flatten()-0.5*W)/focal,
                    -(i.flatten()-0.5*H)/focal,
                    -jnp.ones((H*W,), dtype=jnp.float32)], -1)  # shape [HW, 3]
  
  rays_d = dirs @ c2w[:3, :3].T  # shape [HW, 3]
  rays_o = c2w[:3,-1:].T.repeat(H*W, 0)
  return rays_o.reshape(H, W, 3), rays_d.reshape(H, W, 3)


In [None]:
def parse_bin(s):
  return int(s[1:], 2) / 2.**(len(s) - 1)


def phi2(i):
  return parse_bin('.' + f'{i:b}'[::-1])

def nice_uniform(N):
  u = []
  v = []
  for i in range(N):
    u.append(i / float(N))
    v.append(phi2(i))
    #pts.append((i/float(N), phi2(i)))

  return u, v

def nice_uniform_spherical(N, hemisphere=True):
  """implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html"""
  u, v = nice_uniform(N)

  theta = np.arccos(1.0 - np.array(u)) * (2.0 - int(hemisphere))
  phi   = 2.0 * np.pi * np.array(v)

  return theta, phi
    
hemisphere = True
def get_all_camera_rays(N_cameras, camera_dist, H, W, focal):
  theta, phi = nice_uniform_spherical(N_cameras, hemisphere)

  camera_x_vec = np.sin(theta) * np.cos(phi)
  camera_y_vec = np.sin(theta) * np.sin(phi)
  camera_z_vec = np.cos(theta)

  rays_o_vec = []
  rays_d_vec = []
  cameras = []
  for i in range(N_cameras):
    camera = np.eye(4)
    camera[0, 3] = camera_x_vec[i] * camera_dist
    camera[1, 3] = camera_y_vec[i] * camera_dist
    camera[2, 3] = camera_z_vec[i] * camera_dist

    zdir = np.array([camera_x_vec[i], camera_y_vec[i], camera_z_vec[i]])
    zdir /= np.linalg.norm(zdir)

    ydir = np.array([0.0, 0.0, 1.0])
    ydir -= zdir * zdir.dot(ydir)
    ydir[0] += 1e-10  # make sure that cameras pointing straight down/up have a defined ydir
    ydir /= np.linalg.norm(ydir)

    xdir = np.cross(ydir, zdir)


    camera[:3, 0] = xdir
    camera[:3, 1] = ydir
    camera[:3, 2] = zdir

    cameras.append(camera)

    rays_o, rays_d = get_rays(H, W, focal, camera)

    rays_o_vec.append(rays_o)
    rays_d_vec.append(rays_d)

  rays_o_vec = jnp.stack(rays_o_vec, 0)
  rays_d_vec = jnp.stack(rays_d_vec, 0)

  return rays_o_vec, rays_d_vec

In [None]:
def render_pixel(normal, lobe, envmap, mask):
  masked_envmap = envmap * mask[:, :, None]
  return (masked_envmap * lobe * jnp.sin(omega_theta).reshape(envmap_H, envmap_W, 1)).sum(0).sum(0) * dtheta_dphi / jnp.pi


def render(envmap, mask, materials, normals, rays_d, alpha, shading='lambertian'):
  """
  envmap:     shape [h, w, 3]
  mask:       shape [h, w]
  materials:  dictionary with entries of shape [N, 3]
  normals:    shape [N, 3]
  rays_d:     shape [N, 3]
  alpha:      shape [N, 1]
  
  output: rendered colors, shape [N, 3]
  """
  
  assert shading in ['lambertian', 'phong', 'blinnphong']

  if shading in ['lambertian', 'phong', 'blinnphong']:
    # TODO: Feed in only pixels where alpha = 1
    lobes = jnp.maximum(0.0, (omega_xyz.reshape(1, envmap_H, envmap_W, 3) * normals[:, None, None, :]).sum(-1, keepdims=True)) * materials['albedo'][:, None, None, :]  # [HW, envmap_H, envmap_W, 3]

  if shading == 'blinnphong':
    assert 'specular_albedo' in materials.keys()
    specular_albedo = materials['specular_albedo'][:, None, None, :]
    exponent = materials['specular_exponent'][:, None, None, :]

    d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)
    rays_d_norm = -rays_d / jnp.sqrt(d_norm_sq + 1e-10)

    halfvectors = omega_xyz.reshape(1, envmap_H, envmap_W, 3) + rays_d_norm[:, None, None, :]
    halfvectors /= (jnp.linalg.norm(halfvectors, axis=-1, keepdims=True) + 1e-10)  # [N, envmap_H, envmap_W, 3]

    lobes += jnp.maximum(0.0, (halfvectors * normals[:, None, None, :]).sum(-1, keepdims=True)) ** exponent * specular_albedo

  if shading == 'phong':
    assert 'specular_albedo' in materials.keys()
    specular_albedo = materials['specular_albedo'][:, None, None, :]
    exponent = materials['specular_exponent'][:, None, None, :]

    d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)
    rays_d_norm = -rays_d / jnp.sqrt(d_norm_sq + 1e-10)  # [N, 3]

    refdirs = 2.0 * (normals * rays_d_norm).sum(-1, keepdims=True) * normals - rays_d_norm   # [N, 3]
    refdirs = refdirs[:, None, None, :]

    # No need to normalize because ||n|| = 1 and ||d|| = 1, so ||2(n.d)n - d|| = 1.
    print("Not normalizing here (because unnecessary, at least theoretically).")
    #refdirs /= (jnp.linalg.norm(refdirs, axis=-1, keepdims=True) + 1e-10)  # [N, HW, envmap_H, envmap_W, 3]

    lobes += jnp.maximum(0.0, (refdirs * omega_xyz.reshape(1, envmap_H, envmap_W, 3)).sum(-1, keepdims=True)) ** exponent * specular_albedo
     
  colors = jax.vmap(render_pixel, in_axes=(0, 0, None, None))(normals, lobes, envmap, mask)     
  
  return colors * alpha

In [None]:
def load_img(pth: str, is_16bit: bool=False) -> np.ndarray:
  """Load an image and cast to float32."""
  with utils.open_file(pth, 'rb') as f:
    if is_16bit:
      bytes_ = np.asarray(bytearray(f.read()), dtype=np.uint8)  # Read bytes
      image = np.array(
          cv2.imdecode(bytes_, cv2.IMREAD_UNCHANGED), dtype=np.float32)
    else:
      image = np.array(Image.open(f), dtype=np.float32)
  return image

#disp = load_img(os.path.join(data_dir, 'test', 'r_0_disp.tiff'), is_16bit=True)[:, :, :1] / 255.0
#plt.imshow(1/disp-1)

In [None]:
"""
disp = load_img('{DIRECTORY}/r_4_disp_0029.tif', True) / 65535.0
depth = 1.0 / disp[:, :, 0] - 1.0
plt.imshow(depth)
print(np.nanmin(depth), np.sqrt(4.0 ** 2 + 0.5 ** 2) - 1.0)
""";

In [None]:
Config = configs.Config()
Config.dataset_loader = 'Blender'
Config.near = 6
Config.far = 2
Config.factor = 1
Config.disp_tiff = True

# Force loading disparities and normals
Config.compute_disp_metrics = True
Config.compute_normal_metrics = True
Config.semantic_dir = None
import queue
import jax
import json
import os

LOCAL_COLMAP_DIR = '/tmp/colmap/'

#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_srgb_128x128'
#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_lambertian_srgb_128x128'; Config.disp_tiff = False
#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_occlusions_linear_128x128'
#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_occlusions_uniform_linear_128x128'
#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_farther_occlusions_uniform_linear_128x128'
#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_farfield_occluder_uniform_linear_128x128'
#data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_occlusions_lambertian_new_uniform_linear_128x128'
data_dir = '{DATA_DIRECTORY}/nerf/nerf_synthetic/hotdog_farfield_occlusions_lambertian_new_no_self_occ_uniform_linear_128x128'
#materials_gt = {'albedo': jnp.ones_like(imgs_gt)*0.38}

#data_loader = Blender('test', data_dir, Config)
data_loader = Blender('occlusions', data_dir, Config)
#data_loader._next_fn = data_loader._next_test

In [None]:
imgs_gt = []
normals_gt = []
disps_gt = []
alpha_gt = []
rays_o_ = []
rays_d_ = []

N_cameras = data_loader.size
for i in range(N_cameras):
  batch = next(data_loader)
  imgs_gt.append(batch.rgb)
  normals_gt.append(batch.normals)
  alpha_gt.append(batch.alphas)
  disps_gt.append(batch.disps)
  rays_o_.append(batch.rays.origins)
  rays_d_.append(batch.rays.directions)


In [None]:
imgs_gt = jnp.stack(imgs_gt, axis=0).reshape(N_cameras, -1, 3)
normals_gt = jnp.stack(normals_gt, axis=0).reshape(N_cameras, -1, 3)
alpha_gt = jnp.float32(jnp.stack(alpha_gt, axis=0).reshape(N_cameras, -1, 1) > 0.99)
disps_gt = jnp.stack(disps_gt, axis=0)[..., :1].reshape(N_cameras, -1, 1)
rays_o_vec = jnp.stack(rays_o_, axis=0).reshape(N_cameras, -1, 3)
rays_d_vec = jnp.stack(rays_d_, axis=0).reshape(N_cameras, -1, 3)

t_surface_gt = 1.0 / disps_gt - 1.0
materials_gt = {'albedo': jnp.ones_like(imgs_gt)*0.15}


In [None]:
#pts = rays_o_vec + rays_d_vec * t_surface_gt

#normals_gt = pts / jnp.linalg.norm(pts, axis=-1, keepdims=True)



In [None]:
H, W = data_loader.images[0].shape[:2]
print(f"There are {imgs_gt.shape[0]} images of size {H}x{W}")

In [None]:
ind = 10
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.figure()
plt.imshow(normals_gt[ind].reshape(H, W, 3) * 0.5 + 0.5)
plt.figure()
plt.imshow(np.where(alpha_gt[ind] < 0.99, np.nan, t_surface_gt[0]).reshape(H, W, 1))
plt.figure()
plt.imshow(alpha_gt[ind].reshape(H, W, 1))


In [None]:
plt.imshow(imgs_gt[0].reshape(H, W, 3))
plt.figure()
plt.imshow(normals_gt[0].reshape(H, W, 3) * 0.5 + 0.5)
plt.figure()
plt.imshow(t_surface_gt[0].reshape(H, W, 1))
plt.figure()
plt.imshow(alpha_gt[0].reshape(H, W, 1))


In [None]:
rays_d_vec.shape

In [None]:

rays_d_r = rays_d_vec.reshape(-1, H, W, 3)

occluder_relative_size = 0.1 #0.03  # Ratio of the unit sphere occluder by the occluder
th = 1.0 - 2 * occluder_relative_size
# Make masks
mask_shape = 'circle'
if mask_shape == 'circle' or mask_shape == 'two_circles' or mask_shape == 'three_circles':
  sdfs_gt = th - jnp.sum(-omega_xyz[None, :, :] * rays_d_r[:, H//2, W//2, :][:, None, :], axis=-1)
if mask_shape == 'two_circles' or mask_shape == 'three_circles':
  camera_dirs = rays_d_r[:, H//2, W//2, :]
  up_dirs = rays_d_r[:, H//2+1, W//2, :] - camera_dirs
  up_dirs = up_dirs / jnp.linalg.norm(up_dirs, axis=-1, keepdims=True)
  rot_dirs = jnp.cross(up_dirs, camera_dirs, axis=-1) # direction of rotation is camera direction cross up direction

  rotation_angle = jnp.pi / 7
  dir2 = camera_dirs * jnp.cos(rotation_angle) + jnp.cross(rot_dirs, camera_dirs, axis=-1) * jnp.sin(rotation_angle) + rot_dirs * (rot_dirs * camera_dirs).sum(-1, keepdims=True) * (1 - jnp.cos(rotation_angle))
  
  sdfs_gt = jnp.minimum(sdfs_gt, th - jnp.sum(-omega_xyz[None, :, :] * dir2[:, None, :], axis=-1))  # Occluder is aligned with the camera
if mask_shape == 'three_circles':
  camera_dirs = rays_d_r[:, H//2, W//2, :]
  up_dirs = rays_d_r[:, H//2+1, W//2, :] - camera_dirs
  up_dirs = up_dirs / jnp.linalg.norm(up_dirs, axis=-1, keepdims=True)
  rot_dirs = up_dirs # direction of rotation is up direction

  rotation_angle = jnp.pi / 6
  dir2 = camera_dirs * jnp.cos(rotation_angle) + jnp.cross(rot_dirs, camera_dirs, axis=-1) * jnp.sin(rotation_angle) + rot_dirs * (rot_dirs * camera_dirs).sum(-1, keepdims=True) * (1 - jnp.cos(rotation_angle))
  sdfs_gt = jnp.minimum(sdfs_gt, th - jnp.sum(-omega_xyz[None, :, :] * dir2[:, None, :], axis=-1))  # Occluder is aligned with the camera

sdfs_gt = sdfs_gt.reshape(-1, envmap_H, envmap_W)
masks_gt = jnp.float32(sdfs_gt > 0.0)



In [None]:
shading = 'lambertian'
exposure = 1.0

# Render dataset
num_devices = jax.local_device_count()
imgs_gt = []

render_gt_partial = functools.partial(render, shading=shading)
for i in range(N_cameras // num_devices):
  i0 = i * num_devices
  i1 = i0 + num_devices

  imgs_gt_ = jax.pmap(render_gt_partial, in_axes=(None, 0, 0, 0, 0, 0))(  # here pmap is faster
      envmap_gt * exposure,
      masks_gt[i0:i1],
      jax.tree.map(lambda x: x[i0:i1], materials_gt),
      normals_gt[i0:i1],
      rays_d_vec.reshape(-1, H*W, 3)[i0:i1],
      alpha_gt[i0:i1],
      )
  imgs_gt.append(imgs_gt_)

imgs_gt = linear_to_srgb(jnp.concatenate(imgs_gt, axis=0))

print(jnp.nanmax(imgs_gt))


In [None]:
ind = 50
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))

In [None]:
def mfbrdf_map(viewdir, normal, albedo, roughness, eps=1e-15):
  half_vecs = omega_xyz + viewdir[None, :]
  half_vecs /= (jnp.linalg.norm(half_vecs, axis=-1, keepdims=True) + eps)

  n_dot_v = jnp.abs(jnp.sum(viewdir * normal)) + 1e-5
  n_dot_l = jnp.maximum(jnp.sum(omega_xyz * normal[None, :], axis=-1), 0.0)
  n_dot_h = jnp.maximum(jnp.sum(normal[None, :] * half_vecs, axis=-1), 0.0)
  l_dot_h = jnp.maximum(jnp.sum(omega_xyz * half_vecs, axis=-1), 0.0)

  print(n_dot_v.shape, n_dot_l.shape)

  F_0 = 0.04
  a = roughness**2

  D = a / (jnp.pi * ((a - 1.0) * n_dot_h ** 2 + 1.)**2)
  F = F_0 + (1. - F_0) * jnp.power(1. - l_dot_h, 5)
  #V = 0.5 / ((n_dot_v * jnp.sqrt((-1. * n_dot_l * a + n_dot_l) * n_dot_l + a)) + (n_dot_l * jnp.sqrt((n_dot_v * (1 - a) * n_dot_v + a))))
  V = 0.5 / ((n_dot_v * jnp.sqrt((-1. * n_dot_l * a + n_dot_l) * n_dot_l + a)) + (n_dot_l * jnp.sqrt((-1. * n_dot_v * a + n_dot_v) * n_dot_v + a)))
  brdf = D * F * V

  print(brdf.shape)
  brdf = brdf + (1. - F) * albedo / jnp.pi
  #brdf = jnp.reshape(brdf, [mapres[0], mapres[1], 3])
  return brdf

brdf = mfbrdf_map(jnp.array([0.0, 0.0, 1.0]), jnp.array([0.0, 1.0, 1.0])/jnp.sqrt(2), 0.7, 0.7)
plt.imshow(brdf.reshape(envmap_H, envmap_W))
plt.colorbar()

In [None]:
num_devices = jax.local_device_count()

shading = 'lambertian'
exposure = 1.0

render_partial = functools.partial(render, shading=shading)

gt_list = []
#gt_list = ['materials']
#gt_list = ['masks', 'materials']
#gt_list = ['sdfs', 'materials']
#gt_list = ['envmap', 'materials']
#gt_list = ['envmap', 'masks']
#gt_list = ['masks']
#gt_list = ['masks', 'materials', 'envmap']

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x, y=None):
    if y is not None:
      x = jnp.concatenate([x, y], axis=-1)
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

def get_pyramid_params(rng, num_pyramids, height, width, pyramid_num_scales, pyramid_resize_scale, global_std, global_bias):
  pyramid_params = []
  for i in range(pyramid_num_scales):
    gsh, gsw = [sz // pyramid_resize_scale ** i for sz in [height, width]]
    key, rng = jax.random.split(rng)
    features = jax.random.normal(key, (num_pyramids, gsh, gsw)) * global_std + global_bias
    pyramid_params.append(features)
  return pyramid_params


def pyramids_to_imgs(pyramid_params, pyramid_mult, img_inds):
  def pyramid_to_img(pyramid_params, pyramid_mult):
    acc_val = pyramid_params[-1]
    for i, curr_val in enumerate(pyramid_params[-2::-1], start=1):
      # upsample
      acc_val = jax.image.resize(acc_val * pyramid_mult, shape=curr_val.shape, method='linear')
      
      # accumulate
      acc_val += curr_val
    return acc_val

  # Select all pyramid parameters at given indices
  sub_pyramid_params = jax.tree.map(lambda t: t[img_inds], pyramid_params)
  return jax.vmap(pyramid_to_img, in_axes=(0, None))(sub_pyramid_params, pyramid_mult)


def grad_norm_spherical(dirs, grad):
  """
  Compute gradient norm restricted to the sphere.
  Assume dim 0 is elevation and 1 is azimuth.

  dirs: (N, 3) array of directions on the sphere
  grad: (N, 3) array of Cartesian gradients of points on dirs
  """

  norm_spherical = grad - (dirs * grad).sum(axis=-1, keepdims=True) * dirs
  return jnp.sqrt((norm_spherical ** 2).sum(-1) + 1e-5)



rays_o_r = rays_o_vec.reshape(-1, H*W, 3)
rays_d_r = rays_d_vec.reshape(-1, H*W, 3)



append_identity = True
def posenc(x, L_encoding):
  if L_encoding <= 0:
    return x
  else:
    scales = 2**jnp.arange(L_encoding)
    #shape = x.shape[:-1] + (-1,)
    #scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)

    #four_feat = jnp.sin(
    #    jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))
    shape = x.shape[:-1] + (-1,)
    scaled_x = x[..., None, :] * scales[:, None] # [..., L, D]

    four_feat = jnp.sin(
        jnp.stack([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) # [..., L, D, 2]

    four_feat = jnp.reshape(four_feat / scales[:, None, None], shape)
    print("Using Lipschitz posenc")
    if append_identity:
      return jnp.concatenate([x] + [four_feat], axis=-1)
    else:
      return four_feat


def params_to_sdf(params_sdf, img_inds):
  if sdf_representation == 'mlp':
    sdf, sdf_grad = jax.vmap(jax.value_and_grad(lambda x, y: sdf_mlp.apply(params_sdf, x, y)[0]))(
        posenc(omega_xyz[None, :, :], L_encoding_sdf).repeat(img_inds.shape[0], 0).reshape(-1, mlp_input_features),
        rays_o_vec[img_inds, 0, 0, :][:, None, :].repeat(envmap_H * envmap_W, 1).reshape(-1, 3)
        )
  elif sdf_representation == 'grid':
    sdf = params_sdf[img_inds[:, None], ...]
  else:
    sdf = pyramids_to_imgs(params_sdf, pyramid_mult, img_inds)
  return sdf


def sdf_to_mask(x, width, curve='sigmoid'):
  if curve == 'sigmoid':
    return jax.nn.sigmoid(x * width)
  elif curve == 'laplace_cdf':
    return 0.5 + 0.5 * jnp.sign(x) * (1.0 - jnp.exp(-jnp.abs(x) * width))
  else:
    raise NotImplementedError('Only sigmoid and laplace_cdf for now.')
  
def params_to_materials(params_materials, pts):
  mlp_res = material_mlp.apply(params_materials, posenc(pts, L_encoding_materials))
  materials = {}
  materials['albedo'] = jax.nn.sigmoid(mlp_res[..., 0:3])
  if shading in ['phong', 'blinnphong']:
    #materials['specular_albedo'] = 20.0 * jax.nn.sigmoid(mlp_res[..., 3:6])
    if is_dielectric:
      materials['specular_albedo'] = jax.nn.softplus(mlp_res[..., 3:4])
      materials['specular_exponent'] = jax.nn.softplus(mlp_res[..., 4:5])
    else:
      materials['specular_albedo'] = jax.nn.softplus(mlp_res[..., 3:6])
      materials['specular_exponent'] = jax.nn.softplus(mlp_res[..., 6:7])

  return materials

R = jnp.array([[ 0.0, 1.0, 0.0],
               [-1.0, 0.0, 0.0],
               [ 0.0, 0.0, 1.0]])

@jax.jit
def get_loss(params_envmap, params_sdf, params_materials, gt, spatial_inds, img_inds, i, rng):
  rng, key = jax.random.split(rng)

  if 'envmap' in gt_list:
    envmap = envmap_gt * exposure
  else:
    envmap = params_to_envmap(params_envmap)

  normals = normals_gt[img_inds[:, None], spatial_inds, ...]
  t_surface = t_surface_gt[img_inds[:, None], spatial_inds, ...]
  alpha = alpha_gt[img_inds[:, None], spatial_inds, ...]
  rays_d = rays_d_r[img_inds[:, None], spatial_inds, ...]
  rays_o = rays_o_r[img_inds[:, None], spatial_inds, ...]

  if 'materials' in gt_list:
    materials = jax.tree.map(lambda x: x[img_inds[:, None], spatial_inds[img_inds], ...], materials_gt)
  else:
    pts = rays_o + rays_d * t_surface
    materials = params_to_materials(params_materials, pts)

  sdf = params_to_sdf(params_sdf, img_inds)
  sdf = sdf.reshape(img_inds.shape[0], envmap_H * envmap_W)
  #mask_width = 200.0 * (i / num_iters) + 10.0
  #masks = sdf_to_mask(sdf.reshape(-1, envmap_H, envmap_W), mask_width, 'sigmoid')

  sdf_curve = 'sigmoid'

  if 'masks' in gt_list or 'sdfs' in gt_list:
    if 'masks' in gt_list:
      masks = masks_gt[img_inds]
    else:
      masks = sdf_to_mask(sdfs_gt[img_inds], mask_width, sdf_curve)
    #masks = masks * 0.0 + 1.0
    #print("Setting masks to 1!!!!!!!!!!!!!!!!!")
    eikonal_loss = 0.0
    length_loss = 0.0
    mask_area_loss = 0.0
  else:
    #mask_width = 10.0 #200.0 * (i / num_iters) + 10.0
    #mask_width = 20.0 * (i / num_iters) + 10.0
    #mask_width = 0.1
    #mask_width = 10.0 #150.0 * (i / num_iters) + 6.0
    #mask_width = 0.1 * jnp.exp(jnp.log(100) * i / num_iters)
    if straight_through_mode == 'hard':
      masks = sdf_to_mask(sdf, mask_width, sdf_curve)
      masks = (masks + jax.lax.stop_gradient(jnp.float32(sdf > 0.0) - masks))
    elif straight_through_mode == 'soft':
      print("I think this might actually be 'none' straight_through_mode instead of 'soft' because mask_width is 0.1")
      soft_masks = sdf_to_mask(sdf, 0.1, sdf_curve)
      hard_masks = sdf_to_mask(sdf, mask_width, sdf_curve)
      # Define masks with value of `hard_masks` but gradients of `soft_masks`
      masks = (soft_masks + jax.lax.stop_gradient(hard_masks - soft_masks))
    elif straight_through_mode == 'none':
      masks = sdf_to_mask(sdf, mask_width, sdf_curve)



    if sdf_representation == 'mlp':
      # TODO: Only compute grad w.r.t. x, not w.r.t. other posenc components
      sdf_grad = sdf_grad[..., :3].reshape(img_inds.shape[0], envmap_H * envmap_W, 3)
      sdf_grad_norm = grad_norm_spherical(omega_xyz, sdf_grad.reshape(img_inds.shape[0], envmap_H * envmap_W, 3))
      eikonal_loss = (jnp.sin(omega_theta) * (sdf_grad_norm - 1) ** 2).sum() * dtheta_dphi
    else:
      eikonal_loss = 0.0

    # Compute entropy assuming mask = sigmoid(sdf)
    entropy_loss = jax.nn.softplus(-sdf) + sdf * (1.0 - jax.nn.sigmoid(sdf))
    mask_area_loss = 1.0 - masks  # Try to make the occluder as small as possible
    #delta_sdf = jax.vmap(jax.grad(lambda x: sdf_to_mask(x, mask_width, sdf_curve)))(sdf.flatten()).reshape(img_inds.shape[0], envmap_H * envmap_W)

    #length_loss = (jnp.sin(omega_theta) * sdf_grad_norm * delta_sdf).sum() * dtheta_dphi

  res = jax.vmap(render_partial, in_axes=(None, 0, 0, 0, 0, 0))(
      envmap,
      masks.reshape(img_inds.shape[0], envmap_H, envmap_W),
      materials,
      normals,
      rays_d,
      alpha,
      )

  #diff = gt[img_inds[:, None], spatial_inds[img_inds], :] - res
  diff = gt[img_inds[:, None], spatial_inds, :] - linear_to_srgb(res)
  #loss_per_element = (diff ** 2).sum(-1)
  #loss_per_element = jnp.abs(diff).sum(-1)
  if False:
    p = 2 - 1.5 * i / num_iters
    data_loss = jnp.power((jnp.abs(diff + 1e-10) ** p).sum(), 1/p)
    print("Using graduated nonconvexity in the loss")
  else:
    data_loss = (jnp.abs(diff) ** 2).sum()
    print("Using L2 loss")

  data_loss = data_loss / img_inds.shape[0] / spatial_inds.shape[-1]

  loss = data_loss
  loss += 0.1 * eikonal_loss / img_inds.shape[0]
  loss += 1e-5 * (mask_area_loss * jnp.sin(omega_theta)).sum() * dtheta_dphi / img_inds.shape[0] / 4.0 / jnp.pi
  #loss += 1e-1 * ((jnp.abs(params_envmap) ** 2) * ell.flatten()[:, None, None]).mean()
  #loss += 0.01 * (entropy_loss * jnp.sin(omega_theta)).sum() * dtheta_dphi / img_inds.shape[0] / 4.0 / jnp.pi
  #loss += 0.01 * length_loss / img_inds.shape[0]

  # Environment map TV loss
  loss += 1e-7 * (((envmap[:, 1:] - envmap[:, :-1]) ** 2).sum() + ((envmap[1:, :] - envmap[:-1, :]) ** 2).sum())  * dtheta_dphi / 4.0 / jnp.pi

  return loss, (data_loss, eikonal_loss)

def safe_exp(x):
  return jnp.exp(jnp.minimum(x, 80.0))

def tonemap_and_clip(x):
  return np.clip(linear_to_srgb(x), 0.0, 1.0)

def params_to_envmap(params_envmap):
  #envmap = jax.nn.sigmoid(params_envmap)
  #envmap = jax.nn.softplus(params_envmap)
  if envmap_representation == 'SH':
    #params_envmap = jnp.where(ell.flatten()[:, None, None] < 5, params_envmap, 0.0)
    envmap = jax.nn.softplus(jax.vmap(isht, in_axes=-1, out_axes=-1)(params_envmap))
  else:
    envmap = safe_exp(params_envmap)
  return envmap


@jax.jit
def update_params(i, rng, state_envmap, state_sdf, state_materials, gt, spatial_inds, img_inds):
  params_envmap = get_params_envmap(state_envmap)
  params_sdf = get_params_sdf(state_sdf)
  params_materials = get_params_materials(state_materials)

  (loss, (data_loss, eikonal_loss)), g = jax.value_and_grad(get_loss, argnums=(0, 1, 2), has_aux=True)(params_envmap, params_sdf, params_materials,
                                                                                                       gt, spatial_inds, img_inds, i, rng)

  grad_envmap = jax.lax.pmean(g[0], axis_name='batch')
  grad_sdf = jax.lax.pmean(g[1], axis_name='batch')
  grad_materials = jax.lax.pmean(g[2], axis_name='batch')
  eikonal_loss = jax.lax.pmean(eikonal_loss, axis_name='batch')
  data_loss = jax.lax.pmean(data_loss, axis_name='batch')
  loss = jax.lax.pmean(loss, axis_name='batch')

  return update_envmap(i, grad_envmap, state_envmap), update_sdf(i, grad_sdf, state_sdf), update_materials(i, grad_materials, state_materials), loss, data_loss, eikonal_loss


#slow_optimization_mode = ('masks' not in gt_list and 'sdfs' not in gt_list) or 'materials' not in gt_list
#if slow_optimization_mode:
#  print("Slow optimization")
#else:
#  print("Fast optimization")

num_iters = 150000 #50000 if slow_optimization_mode else 10000
#num_iters = 50000
straight_through_mode = 'soft'
assert straight_through_mode in ['none', 'hard', 'soft']

# TODO:
# 1. It looks like using straight-through on the masks (with constant width 10) makes them be a little off,
#    but improves the envmap (making it a little noisier because of the bad masks). Why?
# 2. 

envmap_representation = 'direct'
if envmap_representation == 'SH':
  #params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_H, 3)) - 0.5) * 0.01
  params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_H, 3, 2)) - 0.5) / (1 + ell.flatten()[:, None, None, None]) * 0.01
  params_envmap = params_envmap[..., 0] + 1j * params_envmap[..., 1]
  init_lr_envmap = 0.0003 #if slow_optimization_mode else 0.001

elif envmap_representation == 'direct':
  params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_W, 3)) - 0.5) * 0.1
  #print("TODO: Get rid of this annoying -4.0")

  init_lr_envmap = 0.003 #if slow_optimization_mode else 0.03
else:
  raise ValueError('')
  
init_envmap, update_envmap, get_params_envmap = jax.experimental.optimizers.adam(init_lr_envmap)
state_envmap = init_envmap(params_envmap)


sdf_representation = 'pyramid'  # 'pyramid', 'mlp', 'pyramid'

if sdf_representation == 'mlp':
  L_encoding_sdf = 0 # 4
  mlp_input_features = 3 + 6 * L_encoding_sdf

  sdf_mlp = MLP([128]*4 + [1])

  params_sdf = sdf_mlp.init(jax.random.PRNGKey(0),
                            np.zeros([1, mlp_input_features]),
                            np.zeros([1, 3]))
  init_lr_sdf = 0.0001
elif sdf_representation == 'pyramid':
  pyramid_num_scales = 5
  pyramid_resize_scale = 2
  pyramid_mult = 2.0
  global_std = 1.0 #0.1
  global_bias = 4.0

  rng = jax.random.PRNGKey(0)
  params_sdf = get_pyramid_params(rng, N_cameras, envmap_H, envmap_W, pyramid_num_scales, pyramid_resize_scale, global_std, global_bias)
  init_lr_sdf = 0.1 #* 100
  mask_width = 0.1

elif sdf_representation == 'grid':
  params_sdf = jax.random.normal(jax.random.PRNGKey(111), shape=(N_cameras, envmap_H, envmap_W))
  init_lr_sdf = 0.01
else:
  raise ValueError('')

#for _ in range(20):
#  print("TODO: Use xmanager to optimize: learning rates, biases, global std for params_sdf in pyramid mode, number of iterations (longer!), etc.")
#  print("TODO: Find out what happens if for masks we use the ground truth SDFs passed through a 0.1 sigmoid, instead of the GT masks. This is the best case scenario when using such a soft sigmoid!")
#  print("TODO: Replace initialization as soft ~0.5ish masks and dark envmap with good init. of envmap and ~1 masks (no occluders). Currently things just go there anyway...")

init_sdf, update_sdf, get_params_sdf = jax.experimental.optimizers.adam(init_lr_sdf)
state_sdf = init_sdf(params_sdf)

# Initialize material MLP
L_encoding_materials = 4
mlp_input_features = 3 + 6 * L_encoding_materials

is_dielectric = True
num_components = 3  # 3 for diffuse
if shading in ['phong', 'blinnphong']:
  if is_dielectric:
    num_components += 2  # 1 for specular albedo, 1 for exponent
  else:
    num_components += 4  # 3 for specular albedo, 1 for exponent
material_mlp = MLP([128]*4 + [num_components])

params_materials = material_mlp.init(jax.random.PRNGKey(0),
                          np.zeros([1, mlp_input_features]))
init_lr_materials = 0.003
init_materials, update_materials, get_params_materials = jax.experimental.optimizers.adam(init_lr_materials)
state_materials = init_materials(params_materials)


np_rng = np.random.default_rng(12345)
jax_rng = jax.random.PRNGKey(3948)

spatial_batch_size = 64

img_batch_size = N_cameras #1024
#img_batch_size = 64
losses = []
data_losses = []
eikonal_losses = []
envmap_psnrs = []
tonemapped_envmap_psnrs = []
envmaps = []

t = 0.0
training_progress_bar = ProgressBar()
training_progress_bar.Publish()

replicated_state_envmap = flax.jax_utils.replicate(state_envmap)
replicated_state_sdf = flax.jax_utils.replicate(state_sdf)
replicated_state_materials = flax.jax_utils.replicate(state_materials)
replicated_imgs_gt = flax.jax_utils.replicate(imgs_gt)


for iteration in range(num_iters):
  #t0 = time.time()
  # Generate B1 image indices
  img_inds = np_rng.choice(imgs_gt.shape[0], size=img_batch_size, replace=False) #* 0
  # Now generate B2 pixel indices for each image. The total batch size is B1 * B2.

  keys = jax.random.split(jax_rng, num=img_batch_size+1)
  jax_rng, keys = keys[0], keys[1:]
  spatial_inds = jax.vmap(jax.random.choice, in_axes=(0, None, None, None, 0))(keys, H*W, (spatial_batch_size,), False, alpha_gt[img_inds, :, 0])

  assert jnp.all(alpha_gt[img_inds[:, None], spatial_inds, :] > 0.99)

  replicated_state_envmap, replicated_state_sdf, replicated_state_materials, loss, data_loss, eikonal_loss = jax.pmap(update_params, in_axes=(None, None, 0, 0, 0, 0, 0, 0), axis_name='batch')(
      iteration,
      jax_rng,
      replicated_state_envmap,
      replicated_state_sdf,
      replicated_state_materials,
      replicated_imgs_gt,
      spatial_inds.reshape(num_devices, -1, spatial_batch_size),
      img_inds.reshape(num_devices, -1)
      )

  if iteration % 100 == 0 or iteration == num_iters - 1:
    envmap = params_to_envmap(get_params_envmap(replicated_state_envmap)[0])
    envmaps.append(envmap)
    mse = (jnp.sin(omega_theta)[:, None] * (exposure * envmap_gt - envmap).reshape(-1, 3) ** 2).sum() * dtheta_dphi / 4.0 / jnp.pi / 3.0
    envmap_psnrs.append(mse_to_psnr(mse))
    mse = (jnp.sin(omega_theta)[:, None] * (tonemap_and_clip(exposure * envmap_gt) - tonemap_and_clip(envmap)).reshape(-1, 3) ** 2).sum() * dtheta_dphi / 4.0 / jnp.pi / 3.0
    tonemapped_envmap_psnrs.append(mse_to_psnr(mse))


  losses.append(loss[0])
  data_losses.append(data_loss[0])
  eikonal_losses.append(eikonal_loss[0])

  if iteration in [500] or iteration % 1000 == 0 and iteration > 0:
    clear_output(wait=True)
    training_progress_bar.Publish()
    training_progress_bar.SetProgress(100.0 * (iteration + 1) / num_iters)

    plt.figure()
    plt.semilogy(data_losses)
    plt.semilogy(eikonal_losses)

    # Plot envmaps
    plt.figure(figsize=[16, 8])
    plt.subplot(221)
    envmap = params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))
    plt.imshow(envmap)
    plt.axis('off')
    plt.subplot(222)
    plt.imshow(exposure * envmap_gt)
    plt.axis('off')
    plt.subplot(223)
    plt.imshow(linear_to_srgb(params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))))
    plt.axis('off')
    plt.subplot(224)
    plt.imshow(linear_to_srgb(exposure * envmap_gt))
    plt.axis('off')

    # Plot PSNRs
    plt.figure(figsize=[8, 8])
    plt.subplot(121)
    plt.plot(np.linspace(0, iteration, len(envmap_psnrs)), envmap_psnrs)
    plt.subplot(122)
    plt.plot(np.linspace(0, iteration, len(envmap_psnrs)), tonemapped_envmap_psnrs)

    # Plot masks
    num_rows = 6
    cameras_to_plot = [int(x) for x in np.linspace(0, N_cameras-1, num_rows)]
    sdf_params = get_params_sdf(flax.jax_utils.unreplicate(replicated_state_sdf))
    sdfs = params_to_sdf(sdf_params, jnp.array(cameras_to_plot))

    plt.figure(figsize=[15, 12])
    for row, i in enumerate(cameras_to_plot):
      sdf = sdfs[row].reshape(envmap_H, envmap_W)
      for col, img_to_plot in enumerate([sdf, sdf_to_mask(sdf, mask_width), sdf > 0]):
        plt.subplot(num_rows, 3, row * 3 + col + 1)
        if img_to_plot.shape[-1] != 3:
          if col == 0:
            plt.imshow(img_to_plot, cmap='gray')
          else:
            plt.imshow(img_to_plot, cmap='gray', vmin=0.0, vmax=1.0)
        else:
          plt.imshow(img_to_plot)
        plt.axis('off')
        plt.title(f'{i}')


    # Plot materials
    ind = 10
    pts = rays_o_r[ind] + t_surface_gt[ind] * rays_d_r[ind]
    materials = params_to_materials(get_params_materials(flax.jax_utils.unreplicate(replicated_state_materials)), pts)
    for k in materials.keys():
      plt.figure()
      plt.imshow(materials[k].reshape(H, W, -1))
      plt.axis('off')
    
    # Plot rendered image and GT image
    sdf = params_to_sdf(sdf_params, jnp.array([ind]))
    mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)
    if 'envmap' in gt_list:
      envmap = exposure * envmap_gt
    if 'mask' in gt_list:
      mask = masks_gt[ind]
    if 'materials' in gt_list:
      materials = jax.tree.map(lambda x: x[ind], materials_gt)
    res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))
    plt.figure()
    plt.subplot(121)
    plt.imshow(res)
    plt.axis('off')
    plt.subplot(122)
    plt.imshow(imgs_gt[ind].reshape(H, W, 3))
    plt.axis('off')
    plt.show()



In [None]:
plt.imshow(jnp.log(1e-5 + params_to_envmap(get_params_envmap(flax.jax_utils.unreplicate(replicated_state_envmap)))))
plt.figure()
plt.imshow(jnp.log(1e-5 + exposure * envmap_gt))


In [None]:
ind = 3
pts = rays_o_r[ind] + t_surface_gt[ind] * rays_d_r[ind]
materials = params_to_materials(get_params_materials(flax.jax_utils.unreplicate(replicated_state_materials)), pts)

sdf = params_to_sdf(sdf_params, jnp.array([ind]))
mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)

res = linear_to_srgb(render_partial(envmap_gt, mask[0], {'albedo': materials['albedo'] * 0.0 + 0.15}, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))
plt.imshow(res)
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))

In [None]:
ind = 0

sdf = params_to_sdf(sdf_params, jnp.array([ind]))
mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)

#envmap_top = np.zeros_like(envmap_gt)
#envmap_top[0, :, :] = 1000
#envmap_top[25, 49, :] = 10000
res = render_partial(envmap_gt, mask[0]*0.0+1.0,
                                    {'albedo': materials['albedo'] * 0.0 + 0.15},
                                    normals_gt[ind], -rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3)
plt.imshow(linear_to_srgb(res), interpolation='nearest')
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))

In [None]:
omega_phi.reshape(envmap_H, envmap_W)[0, 49]

In [None]:
plt.imshow(omega_x.reshape(envmap_H, envmap_W))

In [None]:
plt.plot(res[64])
print(jnp.nanmin(res) / jnp.nanmax(res))
print(jnp.nanmax(res) * 0.27823895, res[18, 64, 0])

In [None]:
print(alpha_gt[0].reshape(H, W)[18, :].sum())
print(alpha_gt[0].reshape(H, W)[18, 64])

In [None]:
normals_gt[0].reshape(H, W, 3)[18, 64, :]

In [None]:
def jon(x, eps=1e-7):
  denom_sq = x ** 2
  normal = x / jnp.sqrt(jnp.maximum(denom_sq, eps))
  return jnp.where(denom_sq < eps, jnp.zeros_like(normal), normal)

def dor(x, eps=1e-7):
  return x / jnp.sqrt(jnp.maximum(x**2, eps))


x = jnp.linspace(-0.001, 0.001, 10000)
plt.plot(x, jon(x), x, dor(x))

In [None]:
ind = 0

print(rays_d_vec[ind].reshape(H, W, 3)[0, W//2, :])  # Top is x
print(rays_d_vec[ind].reshape(H, W, 3)[H//2, 1, :])  # Left is y

n = (normals_gt[ind].reshape(H, W, 3) @ R.T) * 0.5 + 0.5
plt.imshow(n)

In [None]:
plt.imshow(envmap_gt)
plt.figure()
sdf = params_to_sdf(sdf_params, jnp.array([0]))
mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)

plt.imshow(mask[0])

In [None]:
plt.imshow(masks_gt[-1])

In [None]:
#plt.figure(figsize=[12, 12])
#plt.plot(masks_gt[-1][34, :], '.')
plt.plot((1-masks_gt[-1]).sum(1), '.')

In [None]:
# Plot rendered image and GT image
sdf = params_to_sdf(sdf_params, jnp.array([ind]))
mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)
if 'envmap' in gt_list:
  envmap = exposure * envmap_gt
if 'mask' in gt_list:
  mask = masks_gt[ind]
if 'materials' in gt_list:
  materials = jax.tree.map(lambda x: x[ind], materials_gt)
res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))
plt.figure()
plt.subplot(121)
plt.imshow(res)
plt.axis('off')
plt.subplot(122)
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.axis('off')
plt.show()

diff = imgs_gt[ind].reshape(H, W, 3) - res
#loss_per_element = (diff ** 2).sum(-1)
#loss_per_element = jnp.abs(diff).sum(-1)
if False:
  p = 2 - 1.5 * i / num_iters
  data_loss = jnp.power((jnp.abs(diff + 1e-10) ** p).sum(), 1/p)
  print("Using graduated nonconvexity in the loss")
else:
  data_loss = (jnp.abs(diff) ** 2).sum()
  print("Using L2 loss")

data_loss = data_loss / 1 / 128 / 128
print(data_loss)

In [None]:
diff = res - imgs_gt[ind].reshape(H, W, 3)
diff.min(), diff.max()

In [None]:
plt.imshow(mask[0])

In [None]:
 media.show_video(envmaps, height=300)

In [None]:
num_rows = 6
cameras_to_plot = [int(x) for x in np.linspace(0, N_cameras-1, num_rows)]
sdf_params = get_params_sdf(flax.jax_utils.unreplicate(replicated_state_sdf))
sdfs = params_to_sdf(sdf_params, jnp.array(cameras_to_plot))
plt.imshow(jnp.float32(sdfs[2] > 0) - jnp.float32(sdfs[1] > 0))

In [None]:
plt.imshow(sdfs[3] > 0)

In [None]:
jnp.where(alpha_gt == 1.0, imgs_gt, 0.0).max()

In [None]:
plt.imshow(np.float32(imgs_gt[5].reshape(H, W, 3) == 1))

In [None]:
media.show_video(envmaps, height=200)

In [None]:
ind = 11
sdf = params_to_sdf(sdf_params, jnp.array([ind]))
mask = sdf_to_mask(sdf, mask_width).reshape(1, envmap_H, envmap_W)
envmap = envmap_gt
res = linear_to_srgb(render_partial(envmap, mask[0], materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))
plt.figure()
plt.subplot(121)
plt.imshow(res)
plt.axis('off')
plt.subplot(122)
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.axis('off')
plt.show()


In [None]:
omega_xyz.shape

In [None]:
ind

materials = jax.tree.map(lambda x: x[ind], materials_gt)
envmap = envmap_gt

mask = jnp.sum(-omega_xyz * rays_d_vec[ind, H//2*W+W//2, :][None, :], axis=-1) < 0.76649692  # Occluder is aligned with the camera

res = linear_to_srgb(render_partial(envmap, mask.reshape(envmap_H, envmap_W), materials, normals_gt[ind], rays_d_r[ind], alpha_gt[ind]).reshape(H, W, 3))
#diff = res - imgs_gt[ind].reshape(H, W, 3)
plt.figure()
plt.imshow(res)
"""
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.figure()
err = np.abs(diff).sum(-1)
plt.imshow(err, cmap='gray')
plt.colorbar()
""";

In [None]:
plt.imshow(normals_gt[ind].reshape(H, W, 3))

In [None]:
envmap = envmap_gt

rows = []
for i in range(H):
  inds = np.arange(W) + i * W
  materials = jax.tree.map(lambda x: x[ind, inds], materials_gt)
  res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, materials, normals_gt[ind, inds], rays_d_r[ind, inds], alpha_gt[ind, inds]))
  rows.append(res)

res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)
diff = res - imgs_gt[ind].reshape(H, W, 3)
plt.figure()
plt.imshow(res)
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.figure()
err = np.abs(diff).sum(-1)
plt.imshow(err, cmap='gray')
plt.colorbar()


In [None]:
envmap = envmap_linear

rows = []
for i in range(H):
  cols = []
  for j in range(W):
    inds = np.arange(1) + i * W + j
    materials = jax.tree.map(lambda x: x[ind, inds], materials_gt)
    res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, materials, normals_gt[ind, inds], rays_d_r[ind, inds], alpha_gt[ind, inds]))
    cols.append(res)
  row = jnp.concatenate(cols, axis=0)
  rows.append(row)

res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)
diff = res - imgs_gt[ind].reshape(H, W, 3)
plt.figure()
plt.imshow(res)
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.figure()
err = np.abs(diff).sum(-1)
plt.imshow(err, cmap='gray')
plt.colorbar()


In [None]:
res = jnp.concatenate(rows, axis=0).reshape(-1, W, 3)
#diff = res - imgs_gt[ind].reshape(H, W, 3)
plt.figure()
plt.imshow(res)
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.figure()
err = np.abs(diff).sum(-1)
plt.imshow(err, cmap='gray')
plt.colorbar()


In [None]:
materials['albedo'].shape

In [None]:
plt.plot(err[:, 64])

In [None]:
def interp2d(grids, inds):
  """
  grids is [H, W, d]
  inds is [2, ...], with the 0th dim being elevation and 1st azimuth
  """

  results = []
  for grid in [grids[:, :, d] for d in range(grids.shape[-1])]:
    res = jax.scipy.ndimage.map_coordinates(grid, inds, order=1, mode='wrap')
    results.append(res)

  return jnp.stack(results, axis=-1)


def render_pixel_mirror(refdir, envmap, mask):
  x, y, z = refdir
  theta = jnp.arctan2(jnp.sqrt(x ** 2 + y ** 2), z)
  phi   = jnp.arctan2(y, x)
  # Quantize to get index
  theta_ind = jnp.floor(envmap_H * theta / jnp.pi).astype(jnp.int32)
  phi_ind = jnp.round(envmap_W * phi / 2.0 / jnp.pi).astype(jnp.int32)
  return (envmap * mask[:, :, None])[theta_ind, phi_ind]
  #return interp2d(envmap * mask[:, :, None], [theta*(envmap_H-1)/jnp.pi, phi*(envmap_W-1)/2/jnp.pi])

def render_mirror(envmap, mask, normals, rays_d, alpha, oxyz, rad, shape='sphere'):
  """
  envmap:     shape [h, w, 3]
  mask:       shape [h, w]
  materials:  dictionary with entries of shape [N, 3]
  normals:    shape [N, 3]
  rays_d:     shape [N, 3]
  alpha:      shape [N, 1]
  oxyz:       shape [1, 3]  (shape center)
  rad:        float         (shape radius)
  
  output: rendered colors, shape [N, 3]
  """
  d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)
  rays_d_norm = rays_d / jnp.sqrt(d_norm_sq + 1e-10)  # [N, 3]

  refdirs = 2.0 * (normals * rays_d_norm).sum(-1, keepdims=True) * normals - rays_d_norm   # [N, 3]
  print(refdirs.shape, envmap.shape, mask.shape)
  colors = jax.vmap(render_pixel_mirror, in_axes=(0, None, None))(refdirs, envmap, mask)     
  
  return colors * alpha

img_ind = 0
d = -rays_d_vec[img_ind] #* jnp.array([1.0, -1.0, 1.0])
R = jnp.array([[ 0.0, 1.0, 0.0],
               [-1.0, 0.0, 0.0],
               [ 0.0, 0.0, 1.0]])
n = normals_gt[img_ind] @ R

res = render_mirror(jnp.fliplr(envmap_gt), envmap_gt[:, :, 0]*0.0+1.0,
                    n, d, alpha_gt[img_ind], jnp.zeros((3,)), 1.0)
res_srgb = linear_to_srgb(res.reshape(H, W, 3))
plt.figure(figsize=[12, 12])
plt.imshow(res_srgb, interpolation='nearest')
plt.axis('off')

In [None]:
plt.imshow(normals_gt[0].reshape(H, W, 3) * 0.5 + 0.5)

In [None]:
rays_d_vec[0].reshape(H, W, 3)[0, W//2, :]

In [None]:
plt.imshow(envmap_gt)

In [None]:
with open('{DIRECTORY}/r_0_lamb.png', 'rb') as f:
  img_lamb = np.array(Image.open(f))[:, :, :3] / 255.0


In [None]:
plt.imshow(alpha_gt[ind].reshape(H, W))

In [None]:
ind = 20

with open(f'{DIRECTORY}/r_{ind}_larger.png', 'rb') as f:
  res_gt = np.float32(Image.open(f)) / 255.0
  alpha = res_gt[:, :, 3:]
  res_gt = res_gt[:, :, :3] * alpha

#elevation = omega_theta.reshape(envmap_H, envmap_W) / jnp.pi * (envmap_H - 1 + 1)
#azimuth = omega_phi.reshape(envmap_H, envmap_W) / (2.0 * jnp.pi) * (envmap_W - 1) + 0.5 # + 4.0
#inds = jnp.stack([jnp.mod(elevation, envmap_H), jnp.mod(azimuth, envmap_W)], axis=0)
#envmap = interp2d(envmap_gt, inds)
#print(envmap.shape)
envmap = envmap_gt

rows = []
for i in range(H):
  inds = np.arange(W) + i * W
  materials = jax.tree.map(lambda x: x[ind, inds], materials_gt)
  res = linear_to_srgb(render_partial(envmap, envmap[..., 0]*0.0+1.0, {'albedo': jnp.ones((W, 3))*0.15}, normals_gt[ind, inds], rays_d_r[ind, inds], alpha.reshape(-1, 1)[inds]))
  rows.append(res)

res = jnp.concatenate(rows, axis=0).reshape(H, W, 3)
plt.imshow(res)
plt.figure()

plt.imshow(res_gt)

plt.figure()
diff = res - res_gt
plt.imshow(jnp.abs(diff).sum(-1)/3, cmap='gray')
plt.colorbar()

print(jnp.abs(diff[30:50, 80:90]).max())

In [None]:
ind = 0
res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind] @ R, -rays_d_vec[ind], alpha_gt[ind]).reshape(H, W, 3)
#res = render(jnp.fliplr(envmap_gt), envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind] @ R.T, -rays_d_vec[ind], alpha_gt[ind]).reshape(H, W, 3)
# ???????????
plt.imshow(linear_to_srgb(res))
plt.figure()
plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.figure()

In [None]:
ind = 20
#with open(f'{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_uniform_linear_128x128/test/r_{ind}.png', 'rb') as f:
with open(f'{DIRECTORY}/r_{ind}_envmap_nn.png', 'rb') as f:
  res_gt = np.float32(Image.open(f)) / 255.0
  alpha = res_gt[:, :, 3:]
  res_gt = res_gt[:, :, :3] * alpha
res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)
#print(res[60:70, 60:70, :])
res = linear_to_srgb(res)
plt.subplot(121)
plt.imshow(res)
plt.axis('off')
plt.subplot(122)
#plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.imshow(res_gt)
plt.axis('off')


In [None]:
#plt.imshow(3*(res - res_gt))
plt.imshow(0.5 * srgb_to_linear(res_gt) / srgb_to_linear(res))
plt.axis('off')

In [None]:
plt.scatter(res[:, :, 0], res[:, :, 1])
plt.figure()
plt.scatter(res_gt[:, :, 0], res_gt[:, :, 1])

In [None]:
plt.imshow(res - res_gt, cmap='gray')


In [None]:
plt.imshow((res - res_gt).sum(-1) / 3.0, cmap='gray')
plt.colorbar()

In [None]:
ind = 70
#with open(f'{DATA_DIRECTORY}/nerf/nerf_synthetic/sphere_lowres_envmap_uniform_linear_128x128/test/r_{ind}.png', 'rb') as f:
with open(f'{DIRECTORY}/r_{ind}.png', 'rb') as f:
  res_gt = np.float32(Image.open(f)) / 255.0
  alpha = res_gt[:, :, 3:]
  res_gt = res_gt[:, :, :3] * alpha
res = render(envmap_gt*0.0 + jnp.where(jnp.cos(omega_phi.reshape(envmap_H, envmap_W, 1) + 0.0) < 0, 1.0, 0.0),
             envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*1.0}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)
#print(res[60:70, 60:70, :])
res = linear_to_srgb(res)
plt.subplot(121)
plt.imshow(res, vmin=0.0, vmax=1.0)
plt.axis('off')
plt.subplot(122)
#plt.imshow(imgs_gt[ind].reshape(H, W, 3))
plt.imshow(res_gt, vmin=0.0, vmax=1.0)
plt.axis('off')
plt.figure()
diff = res - res_gt
plt.imshow(-diff[:, :, 0], cmap='gray')
plt.colorbar()

In [None]:
res[70:90, 70:90, :].min(), res_gt[70:90, 70:90, :].min()

In [None]:
diff[60, 60, :]

In [None]:
a = alpha_gt[img_ind].reshape(H, W, 1)
plt.figure(figsize=[12, 12])
plt.subplot(121)
plt.imshow(res_srgb * a + (1.0 - a))
plt.axis('off')
plt.subplot(122)
plt.imshow(img)
plt.axis('off')

In [None]:
media.show_video([img_ggx, res_srgb], fps=2, height=200)

In [None]:
#with open('{DIRECTORY}/r_0.png', 'rb') as f:
#  img = np.array(Image.open(f))[:, :, :3] / 255.0

with open('{DIRECTORY}/r_0_true_mirror.png', 'rb') as f:
  img_tm = np.array(Image.open(f))[:, :, :3] / 255.0

with open('{DIRECTORY}/r_0_ggx_mirror.png', 'rb') as f:
  img_ggx = np.array(Image.open(f))[:, :, :3] / 255.0


In [None]:
plt.imshow(jnp.log10(jnp.abs(img - res_srgb).sum(-1)/3), cmap='gray')
plt.colorbar()

In [None]:
# Strength: 0.5
img_lamb[64, 64]

In [None]:
# Strength: 0.25
img_lamb[64, 64]

In [None]:
# Strength: 0.2
img_lamb[64, 64]

In [None]:
def linear_to_srgb(linear):
  srgb0 = 323 / 25 * linear
  srgb1 = (211 * linear**(5 / 12) - 11) / 200
  return np.where(linear <= 0.0031308, srgb0, srgb1)


def srgb_to_linear(srgb):
  linear0 = srgb * 25 / 323
  linear1 = ((200 * srgb + 11) / 211) ** (12 / 5)
  return np.where(srgb <= 0.0031308 * 25 / 323, linear0, linear1)

srgb_to_linear(0.66666667), srgb_to_linear(0.48235294), srgb_to_linear(0.43529412)


In [None]:
def linear_to_srgb(linear):
  srgb0 = 323 / 25 * linear
  srgb1 = (211 * linear**(5 / 12) - 11) / 200
  return np.where(linear <= 0.0031308, srgb0, srgb1)

print(linear_to_srgb(0.5))

In [None]:
ind = 0
with open(f'{DIRECTORY}/r_{ind}.exr', 'rb') as f:
  #res_gt = np.float32(Image.open(f)) / 255.0
  res_gt = imageio.imread(f, 'exr')

  alpha = res_gt[:, :, 3:]
  res_gt = res_gt[:, :, :3] * alpha

plt.imshow(res_gt, cmap='gray', interpolation='nearest')


res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)
res = res * alpha #res.repeat(3, 2) * alpha
plt.figure()
plt.imshow(res, cmap='gray', interpolation='nearest')

plt.figure()
plt.imshow(jnp.abs(res - res_gt).sum(-1))
#plt.imshow(res[..., 1] - res_gt[..., 1])
plt.colorbar()

In [None]:
ind = 0
with open(f'{DIRECTORY}/r_{ind}_bg.exr', 'rb') as f:
  #res_gt = np.float32(Image.open(f)) / 255.0
  res_gt = imageio.imread(f, 'exr')

  alpha = res_gt[:, :, 3:]
  res_gt = res_gt[:, :, :3] * alpha

res_gt_srgb = linear_to_srgb(res_gt)
plt.imshow(res_gt_srgb, interpolation='nearest')


#omega_phi, omega_theta
dirs = rays_d_vec[ind].reshape(H, W, 3)
#elevation = omega_theta.reshape(envmap_H, envmap_W) / jnp.pi * (envmap_H - 1 + 1)
#azimuth = omega_phi.reshape(envmap_H, envmap_W) / (2.0 * jnp.pi) * (envmap_W - 1) + 0.5
#inds = jnp.stack([elevation, azimuth], axis=0)
#env = interp2d(envmap_gt, inds)
dirs_azimuth = np.arctan2(dirs[..., 1], dirs[..., 0])
nr = np.sqrt(dirs[..., 1] ** 2 + dirs[..., 0] ** 2)
dirs_elevation = np.arctan2(nr, dirs[..., 2])
import scipy
plt.figure()
channels = []
for i in range(3):
  interp = scipy.interpolate.interp2d(omega_phi.reshape(envmap_H, envmap_W)[0, :], omega_theta.reshape(envmap_H, envmap_W)[:, 0], envmap_gt[:, :, i])
  #ch = interp(omega_phi.reshape(envmap_H, envmap_W)[0, :], omega_theta.reshape(envmap_H, envmap_W)[:, 0])
  ch = interp(dirs_azimuth[0, :], dirs_elevation[:, 0])
  channels.append(ch)
#plt.imshow(interp(omega_phi, omega_theta).reshape(envmap_H, envmap_W))
plt.imshow(linear_to_srgb(jnp.stack(channels, axis=-1)))
#res = render(envmap_gt, envmap_gt[:, :, 0]*0.0+1.0, {'albedo': jnp.ones((H*W, 3))*0.15}, normals_gt[ind], rays_d_vec[ind], alpha.reshape(-1, 1)).reshape(H, W, 3)
#res = res * alpha #res.repeat(3, 2) * alpha
#plt.figure()
#plt.imshow(res, cmap='gray', interpolation='nearest')

#plt.figure()
#plt.imshow(jnp.abs(res - res_gt).sum(-1))
#plt.imshow(res[..., 1] - res_gt[..., 1])
#plt.colorbar()

In [None]:
# When shifting by half a pixel this is the error

In [None]:
plt.imshow(envmap_gt)

In [None]:
res.min(), res.mean(), res.max()

In [None]:
res_gt.min(), res_gt.mean(), res_gt.max()

In [None]:
media.show_video([res, res_gt], fps=2, height=256)

In [None]:
envmap_gt.min(), envmap_gt.max()

In [None]:
envmap_gt.min(), envmap_gt.max()

In [None]:
ind = 1
with open(f'{DIRECTORY}/r_{ind}.exr', 'rb') as f:
  #res_gt = np.float32(Image.open(f)) / 255.0
  res_gt = imageio.imread(f, 'exr')

  alpha = res_gt[:, :, 3:]
  res_gt = res_gt[:, :, :1] * alpha

plt.imshow(res_gt, cmap='gray')

res = jnp.maximum(0.0, (normals_gt[ind].reshape(H, W, 3) * jnp.array([0.0, 0.0, 1.0])[None, None, :]).sum(-1, keepdims=True)) / jnp.pi
res = res * alpha #res.repeat(3, 2) * alpha
plt.figure()
plt.imshow(res, cmap='gray')

plt.figure()
plt.imshow(res - res_gt)
plt.colorbar()

In [None]:
plt.imshow(res_gt - res)

In [None]:
jnp.linalg.norm(normals_gt[0].reshape(H, W, 3), axis=-1).max()

In [None]:
def foo(rays_d, rays_o, rad=1.0):
  d_norm_sq = (rays_d ** 2).sum(-1, keepdims=True)
  o_norm_sq = (rays_o ** 2).sum(-1, keepdims=True)
  d_dot_o = (rays_o * rays_d).sum(-1, keepdims=True)
  disc = d_norm_sq * (rad ** 2  - o_norm_sq) + d_dot_o ** 2
  alpha = jnp.float32(disc > 0)
  t_surface = jnp.where(disc > 0, - jnp.sqrt(disc) - d_dot_o, jnp.inf)  # [H, W, 1]

  pts = rays_o + rays_d * t_surface

  normals = pts / jnp.linalg.norm(pts, axis=-1, keepdims=True)

  plt.imshow(normals.reshape(H, W, 3) * 0.5 + 0.5)

ind = 111
foo(rays_d_vec[ind], rays_o_vec[ind])
plt.figure()
#R = 
n = normals_gt[ind]
plt.imshow(n.reshape(H, W, 3) * 0.5 + 0.5)

In [None]:
rays_d_vec[-1].reshape(H, W, 3)[64, -1, :]

In [None]:
n = normals_gt[0].reshape(H, W, 3) #@ R.T
plt.imshow(n * 0.5 + 0.5)

In [None]:
focal = .5 * W / np.tan(0.5 * 0.691111147403717)
pixtocams = camera_utils.get_pixtocam(focal, W, H)
c2w = jnp.array([[ 0.0, 1.0, 0.0, 0.0],
                 [-1.0, 0.0, 0.0, 0.0],
                 [ 0.0, 0.0, 1.0, 4.0],
                 [ 0.0, 0.0, 0.0, 1.0]])
origins, directions, _, _, _ = camera_utils.pixels_to_rays(
    jnp.array([W/2.0]),
    jnp.array([0]),
    pixtocams,
    c2w)
print(origins)
print(directions)  # 'up' is x

In [None]:
rays_d_vec[0].reshape(H, W, 3)[0, 64, :]

In [None]:
surface_pts = rays_o_vec + t_surface_gt * rays_d_vec
print(surface_pts.shape)

In [None]:
finite_surface_points.shape

In [None]:
@jax.jit
def get_dists_sq(x, y):
  return ((x - y) ** 2).sum(-1)

def subsample_point_cloud(points, num_points):
  new_points = [points[0]]
  inds = []
  for i in range(num_points - 1):
    points_to_use_indices = np.random.choice(points.shape[0], size=(1000,), replace=False)
    if i % 100 == 0:
      print(i)
    dists = get_dists_sq(points[points_to_use_indices, None, :], jnp.stack(new_points, axis=0)[None, :, :])
    new_point_ind = jnp.argmax(dists.min(axis=1))
    new_points.append(points[points_to_use_indices[new_point_ind]])
    inds.append(points_to_use_indices[new_point_ind])
  return new_points, inds

surface_pts = surface_pts.reshape(-1, 3)
finite_surface_points = surface_pts[jnp.all(jnp.isfinite(surface_pts), axis=-1)]
surface_pts_subsampled = subsample_point_cloud(finite_surface_points, 10000)

surface_pts_subsampled = jnp.stack(surface_pts_subsampled, axis=0)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.scatter(surface_pts_subsampled[..., 0],
           surface_pts_subsampled[..., 1],
           surface_pts_subsampled[..., 2])


In [None]:
with open('{DIRECTORY}/hotdog_surface_pts_subsampled.npy', 'wb') as f:
  np.save(f, surface_pts_subsampled)


In [None]:
plt.plot(surface_pts_subsampled[..., 0], surface_pts_subsampled[..., 1], '.')

In [None]:
surface_pts_subsampled.shape

# Cache visibility using MLP

Optimize an MLP mapping from position and direction to visibility:
$$(\mathbf{x}, \boldsymbol{\omega}) \mapsto v$$
where $v$ is a scalar visibility in $[0, 1]$ (constrained by a sigmoid), and the position and direction are 3-vectors with positional encoding.

In [None]:
#with open('{DIRECTORY}/hotdog_surface_pts.npy', 'rb') as f:
#  surface_pts = np.load(f)


In [None]:
surface_normals_subsampled.shape, omega_xyz.shape, occlusion_masks.shape

In [None]:
with open('{DIRECTORY}/hotdog_surface_pts_subsampled.npy', 'rb') as f:
  surface_pts_subsampled = np.load(f)

with open('{DIRECTORY}/subsampling_indices.npy', 'rb') as f:
  indices = np.load(f)

with open('{DIRECTORY}/visibility_images.npy', 'rb') as f:
  occlusion_masks = jnp.float32(np.load(f)) / 255.0

with open('{DIRECTORY}/hotdog_surface_normals_subsampled.npy', 'rb') as f:
  surface_normals_subsampled = np.load(f)

envmap_H, envmap_W = occlusion_masks.shape[1:]

omega_phi, omega_theta = jnp.meshgrid(jnp.linspace(-jnp.pi, jnp.pi, envmap_W+1)[:-1] + 2.0 * jnp.pi / (2.0 * envmap_W),
                                      jnp.linspace(0.0,     jnp.pi, envmap_H+1)[:-1] +       jnp.pi / (2.0 * envmap_H))

dtheta_dphi = (omega_theta[1, 1] - omega_theta[0, 0]) * (omega_phi[1, 1] - omega_phi[0, 0])

omega_theta = omega_theta.flatten()
omega_phi = omega_phi.flatten()

omega_x = jnp.sin(omega_theta) * jnp.cos(omega_phi)
omega_y = jnp.sin(omega_theta) * jnp.sin(omega_phi)
omega_z = jnp.cos(omega_theta)
omega_xyz = jnp.stack([omega_x,
                       omega_y,
                       omega_z], axis=-1)


# Turn the negative hemisphere into nans
occlusion_masks = jnp.where(jnp.sum(surface_normals_subsampled[:, None, :] * omega_xyz[None, :, :], axis=-1).reshape(-1, envmap_H, envmap_W) > 0.0,
                            occlusion_masks, jnp.nan)
#                            occlusion_masks, 0.0)


In [None]:
plt.imshow(occlusion_masks[70])

In [None]:
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x, y=None):
    if y is not None:
      x = jnp.concatenate([x, y], axis=-1)
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x


append_identity = True
def posenc(x, L_encoding):
  if L_encoding <= 0:
    return x
  else:
    scales = 2**jnp.arange(L_encoding)
    #shape = x.shape[:-1] + (-1,)
    #scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape)

    #four_feat = jnp.sin(
    #    jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1))
    shape = x.shape[:-1] + (-1,)
    scaled_x = x[..., None, :] * scales[:, None] # [..., L, D]

    four_feat = jnp.sin(
        jnp.stack([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) # [..., L, D, 2]

    #four_feat = jnp.reshape(four_feat / scales[:, None, None], shape)
    #print("Using Lipschitz posenc")
    four_feat = jnp.reshape(four_feat, shape)
    if append_identity:
      return jnp.concatenate([x] + [four_feat], axis=-1)
    else:
      return four_feat


# Initialize material MLP
L_encoding_vis_x = 3
L_encoding_vis_dir = 6
mlp_input_features = 6 + 6 * (L_encoding_vis_x + L_encoding_vis_dir)

num_components = 1
mlp_vis = MLP([128]*4 + [num_components])

params_vis = mlp_vis.init(jax.random.PRNGKey(0),
                          np.zeros([1, mlp_input_features]))

init_lr_vis = 1e-2

init_vis, update_vis, get_params_vis = jax.experimental.optimizers.adam(init_lr_vis)
state_vis = init_vis(params_vis)

#x_input = surface_pts_subsampled.reshape(-1, 1, 3).repeat(envmap_H * envmap_W, axis=1).reshape(-1, 3)
#dir_input = omega_xyz.reshape(1, -1, 3).repeat(surface_pts_subsampled.shape[0], axis=0).reshape(-1, 3)
#vis_gt = jnp.zeros(())


def get_vis_loss(params, x_input, dir_input, vis_gt):

  vis_pred = jax.nn.sigmoid(mlp_vis.apply(params,
                                          posenc(x_input, L_encoding_vis_x),
                                          posenc(dir_input, L_encoding_vis_dir)))

  loss = jnp.nansum((vis_pred - vis_gt) ** 2) / x_batch_size / envmap_H / envmap_W

  return loss

@jax.jit
def step(state, x_input, dir_input, vis_gt, i):
  params = get_params_vis(state)
  loss, grad = jax.value_and_grad(get_vis_loss)(params, x_input, dir_input, vis_gt)

  return update_vis(i, grad, state), loss

#get_vis_loss(state_vis)

x_batch_size = 50

losses = []
for i in range(500):
  image_indices = np.random.randint(0, surface_pts_subsampled.shape[0], size=(x_batch_size,))
  
  x_input = surface_pts_subsampled[image_indices].reshape(-1, 1, 3).repeat(envmap_H * envmap_W, axis=1).reshape(-1, 3)
  dir_input = omega_xyz.reshape(1, -1, 3).repeat(x_batch_size, axis=0).reshape(-1, 3)
  vis_gt = occlusion_masks[image_indices].reshape(-1, 1)

  state_vis, loss = step(state_vis, x_input, dir_input, vis_gt, i)

  losses.append(loss)

plt.semilogy(losses)

In [None]:
jnp.nansum(occlusion_masks * 2 - 1)

In [None]:
@jax.jit
def evaluate_ind(params, x_input, dir_input, ind):
  vis_pred = jax.nn.sigmoid(mlp_vis.apply(params,
                                          posenc(x_input, L_encoding_vis_x),
                                          posenc(dir_input, L_encoding_vis_dir)))

  return vis_pred

def evaluate(params):
  dir_input = omega_xyz

  total_error = 0.0
  for ind in range(surface_pts_subsampled.shape[0]):
    if (ind + 1) % 1000 == 0:
      print(ind + 1)
    x_input = surface_pts_subsampled[ind] * jnp.ones_like(omega_xyz)
    mask = evaluate_ind(params, x_input, dir_input, ind).reshape(-1)
    hard_mask = jnp.float32(mask > 0.5)
    mask_gt = occlusion_masks[ind].reshape(-1)

    error = jnp.nansum(jnp.abs(hard_mask - mask_gt))
    total_error += error

  return total_error, total_error / surface_pts_subsampled.shape[0]

ind = 70

x_input = surface_pts_subsampled[ind] * jnp.ones_like(omega_xyz)
dir_input = omega_xyz

mask = evaluate_ind(get_params_vis(state_vis), x_input, dir_input, ind).reshape(envmap_H, envmap_W)
plt.imshow(mask)
plt.figure()
plt.imshow(mask > 0.5)
plt.figure()
plt.imshow(occlusion_masks[ind])

#error, avg_error = evaluate(get_params_vis(state_vis))

print(avg_error)

In [None]:
indices = []
for ind in range(10000):
  d = ((surface_pts - surface_pts_subsampled[ind]) ** 2).sum(-1)
  i, r, c = np.unravel_index(np.argmin(d), shape=surface_pts.shape[:3])
  indices.append((i, r, c))

with open('{DIRECTORY}/subsampling_indices.npy', 'wb') as f:
  np.save(f, np.array(indices))


In [None]:
#with open('{DIRECTORY}/subsampling_indices.npy', 'wb') as f:
#  np.save(f, np.array(indices))


In [None]:
with open('{DIRECTORY}/hotdog_surface_normals_subsampled.npy', 'wb') as f:
  np.save(f, surface_normals_subsampled)


In [None]:
normals_gt.shape, indices.shape

In [None]:
surface_normals_subsampled = normals_gt.reshape(-1, 128, 128, 3)[indices[:, 0], indices[:, 1], indices[:, 2], :]
surface_normals_subsampled.shape

In [None]:
for ind in [0, 10, 20, 50, 100]:
#ind = 10
  a = jnp.where(jnp.sum(surface_normals_subsampled[ind] * omega_xyz.reshape(envmap_H, envmap_W, 3), axis=-1) > 0.0, occlusion_masks[ind], 0.0)
  
  plt.figure(); plt.imshow(a)

In [None]:
indices[100]

In [None]:
surface_normals_subsampled[109]

In [None]:
normals_gt.reshape(512, 128, 128, 3)[327, 96, 49, :]

In [None]:
plt.imshow(occlusion_masks[ind])

In [None]:
# A simple script that uses blender to render views of a single object by rotation the camera around it.
# Also produces depth map at the same time.

import argparse, sys, os
import json
import bpy
import mathutils
import numpy as np
        
def listify_matrix(matrix):
    matrix_list = []
    for row in matrix:
        matrix_list.append(list(row))
    return matrix_list

def delistify_matrix(lst):
    mat = mathutils.Matrix()
    for i in range(4):
        for j in range(4):
            mat[i][j] = lst[i][j]
    return mat


DEBUG = False
envmap_H = 50
envmap_W = 99
FORMAT = 'PNG'

# filename is /.../<model>.blend/.../<script>.py
ind_f = __file__.find('.blend')
ind_i = __file__[:ind_f].rfind('/') + 1
#model_name = __file__[ind_i:ind_f] + '_uniform'
#model_name = 'hotdog_farfield_occlusions_lambertian_new_no_self_occ_uniform_linear_128x128'
model_name = 'hotdog_occlusions_lambertian_linear_128x128'

# Read from file
#transforms_files = [f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_train.json',
#                    f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_test.json']

partitions = ['train']
transforms_files = [f'/Users/dorverbin/Downloads/blend_files/{model_name}/transforms_{partition}.json' for partition in partitions]
RESULTS_PATH = os.path.join(model_name, 'visibility')

fp = bpy.path.abspath(f"//{RESULTS_PATH}")

if not os.path.exists(fp):
    os.makedirs(fp)
for partition in partitions:
    if not os.path.exists(os.path.join(fp, partition)):
        os.makedirs(os.path.join(fp, partition))

# Data to store in JSON file
#out_data = {}


# Render Optimizations
bpy.context.scene.render.use_persistent_data = True


# Set up rendering of depth map.
bpy.context.scene.use_nodes = True
tree = bpy.context.scene.node_tree
links = tree.links

# Add passes for additionally dumping albedo and normals.
bpy.context.scene.render.image_settings.file_format = str('PNG')
bpy.context.scene.render.image_settings.color_depth = str(8)
print("Only 32 if using EXR. When I use binary forget about it")

# If using OpenEXR, set to linear color space
bpy.data.scenes['Scene'].display_settings.display_device = 'None'
bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'Linear'  

# Remove all tree nodes
for node in tree.nodes:
    tree.nodes.remove(node)

if 'Custom Outputs' not in tree.nodes:
    # Create input render layer node.
    render_layers = tree.nodes.new('CompositorNodeRLayers')
    render_layers.label = 'Custom Outputs'
    render_layers.name = 'Custom Outputs'

# Background
bpy.context.scene.render.dither_intensity = 0.0
bpy.context.scene.render.film_transparent = False

# Create collection for objects not to render with background


scene = bpy.context.scene
scene.render.resolution_x = envmap_W
scene.render.resolution_y = envmap_H
scene.render.resolution_percentage = 100

cam = scene.objects['Camera']

# Define equirect camera
cam.data.type = 'PANO'
cam.data.cycles.panorama_type = 'EQUIRECTANGULAR'
#cam.data.cycles.latitude_min = np.pi / (2.0 * envmap_H) - np.pi / 2.0
#cam.data.cycles.latitude_max = np.pi / 2.0 - np.pi / (2.0 * envmap_H)


#cam.location = (0, 4.0, 0.5)

#cam_constraint = cam.constraints.new(type='TRACK_TO')
#cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
#cam_constraint.up_axis = 'UP_Y'
#b_empty = parent_obj_to_camera(cam)
#cam_constraint.target = b_empty


#scene.render.image_settings.file_format = 'PNG'  # set output format to .png
scene.render.image_settings.file_format = FORMAT  # set output format to .png

canonical_mat = [[ 0.0, 0.0, -1.0, 0.0],
                 [ 1.0, 0.0,  0.0, 0.0],
                 [ 0.0, 1.0,  0.0, 0.0],
                 [ 0.0, 0.0,  0.0, 1.0]]

#hotdog_points = np.load('/Users/dorverbin/Downloads/hotdog_surface_pts.npy')
hotdog_points = np.load('/Users/dorverbin/Downloads/hotdog_surface_pts_subsampled.npy')

all_object_names_except_camera = [k for k in scene.objects.keys() if 'Camera' not in k]

def toggle_object_visibility(do_hide_objects):
    for o in all_object_names_except_camera:
        scene.objects[o].hide_render = do_hide_objects
    
def toggle_mask_visibility(do_hide_mask):
    bpy.data.worlds["World"].node_tree.nodes["Math"].inputs[1].default_value = 1.01 if do_hide_mask else 0.8


for transforms_file, partition in zip(transforms_files, partitions):
    if transforms_file is not None:
        with open(transforms_file) as in_file:
            transforms_data = json.load(in_file)
        
        VIEWS = len(transforms_data['frames'])
    else:
        raise RuntimeError('Must specify transforms file')

    #out_data['frames'] = []

    #for i in range(0, VIEWS, 20):
    #if partition == 'train':
    #    continue
    """
    for i in range(1):
        #cam.matrix_world = delistify_matrix(transforms_data['frames'][i]['transform_matrix'])    
        #print(cam.matrix_world)

        # Start by rendering mask
        cam.matrix_world = delistify_matrix(canonical_mat)

        toggle_object_visibility(False)
        toggle_mask_visibility(True)


        for r in [40]:
            for c in range(30, 40):
                p = hotdog_points[i, r, c, :]
                if not np.all(np.isfinite(p)):
                    continue
                
                point_x, point_y, point_z = p
            
                cam.matrix_world[0][3] = point_x
                cam.matrix_world[1][3] = point_y
                cam.matrix_world[2][3] = point_z
                
                scene.render.filepath = os.path.join(fp, partition, f'r_{i}_{r}_{c}')          
                bpy.ops.render.render(write_still=True)  # render still
                
            
        toggle_object_visibility(True)
        toggle_mask_visibility(False)
        bpy.data.worlds["World"].node_tree.nodes["Value"].outputs[0].default_value = cam.matrix_world[0][3]
        bpy.data.worlds["World"].node_tree.nodes["Value.001"].outputs[0].default_value = cam.matrix_world[1][3]
        bpy.data.worlds["World"].node_tree.nodes["Value.002"].outputs[0].default_value = cam.matrix_world[2][3]                
        
        
        # Render mask
        scene.render.filepath = os.path.join(fp, partition, f'r_{i}_mask')
        if DEBUG:
            break
        else:
            bpy.ops.render.render(write_still=True)  # render still
    """

    toggle_object_visibility(False)
    toggle_mask_visibility(True)

    #for ind in range(hotdog_points.shape[0]):
    for ind in [100]:    
        p = hotdog_points[ind, :]
        if not np.all(np.isfinite(p)):
            continue
        
        point_x, point_y, point_z = p
    
        cam.matrix_world[0][3] = point_x
        cam.matrix_world[1][3] = point_y
        cam.matrix_world[2][3] = point_z
        
        #scene.render.filepath = os.path.join(fp, partition, f'r_{ind}')          
        #bpy.ops.render.render(write_still=True)  # render still
        



        #frame_data = {
        #    #'file_path': scene.render.filepath,
        #    'file_path': f'./{partition}/r_{i}',
        #    'transform_matrix': listify_matrix(cam.matrix_world)
        #}
        #out_data['frames'].append(frame_data)

        #if transforms_file is None:
        #    b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff
        #    b_empty.rotation_euler[2] += radians(2*stepsize)

    #if not DEBUG:
    #    with open(fp + '/' + f'transforms_{partition}.json', 'w') as out_file:
    #        json.dump(out_data, out_file, indent=4)




In [None]:
# A simple script that uses blender to render views of a single object by rotation the camera around it.
# Also produces depth map at the same time.

import argparse, sys, os
import json
import bpy
import mathutils
import numpy as np
        

def listify_matrix(matrix):
    matrix_list = []
    for row in matrix:
        matrix_list.append(list(row))
    return matrix_list

def delistify_matrix(lst):
    mat = mathutils.Matrix()
    for i in range(4):
        for j in range(4):
            mat[i][j] = lst[i][j]
    return mat
      
def parse_bin(s):
  return int(s[1:], 2) / 2.**(len(s) - 1)


def phi2(i):
  return parse_bin('.' + f'{i:b}'[::-1])

def nice_uniform(N):
  u = []
  v = []
  for i in range(N):
    u.append(i / float(N))
    v.append(phi2(i))
    #pts.append((i/float(N), phi2(i)))

  return u, v

def nice_uniform_spherical(N, hemisphere=True):
  """implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html"""
  u, v = nice_uniform(N)

  theta = np.arccos(1.0 - np.array(u)) * (2.0 - int(hemisphere))
  phi   = 2.0 * np.pi * np.array(v)

  return theta, phi
    
    
    
hemisphere = True
camera_dist = np.sqrt(4.0**2 + 0.5**2)
def get_all_camera_matrices(N_cameras, camera_dist=camera_dist):
  theta, phi = nice_uniform_spherical(N_cameras, hemisphere)

  camera_x_vec = np.sin(theta) * np.cos(phi)
  camera_y_vec = np.sin(theta) * np.sin(phi)
  camera_z_vec = np.cos(theta)

  cameras = []
  for i in range(N_cameras):
    camera = np.eye(4)
    camera[0, 3] = camera_x_vec[i] * camera_dist
    camera[1, 3] = camera_y_vec[i] * camera_dist
    camera[2, 3] = camera_z_vec[i] * camera_dist

    zdir = np.array([camera_x_vec[i], camera_y_vec[i], camera_z_vec[i]])
    zdir /= np.linalg.norm(zdir)

    ydir = np.array([0.0, 0.0, 1.0])
    ydir -= zdir * zdir.dot(ydir)
    ydir[0] += 1e-10  # make sure that cameras pointing straight down/up have a defined ydir
    ydir /= np.linalg.norm(ydir)

    xdir = np.cross(ydir, zdir)

    camera[:3, 0] = xdir
    camera[:3, 1] = ydir
    camera[:3, 2] = zdir
    
    cameras.append(camera)
  return cameras
         
DEBUG = False
VIEWS = 512  # Only used if not specifying transforms_file
RESOLUTION = 128 #800
DEPTH_SCALE = 1.4
FORMAT = 'OPEN_EXR'
COLOR_DEPTH = 8 if FORMAT == 'PNG' else 32
 
       
# filename is /.../<model>.blend/.../<script>.py
ind_f = __file__.find('.blend')
ind_i = __file__[:ind_f].rfind('/') + 1
model_name = __file__[ind_i:ind_f] + '_uniform'


if FORMAT == 'OPEN_EXR':
    RESULTS_PATH = f'{model_name}_linear'
elif FORMAT == 'PNG':
    RESULTS_PATH = model_name
else:
    raise RuntimeError('format unknown')

if RESOLUTION != 800:
    RESULTS_PATH += f'_{RESOLUTION}x{RESOLUTION}'

# Read from file
#transforms_files = [f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_train.json',
#                    f'/Users/dorverbin/Downloads/nerf_synthetic/hotdog/transforms_test.json']
#partitions = ['train', 'test']
transforms_files = [None]
partitions = ['occlusions']


fp = bpy.path.abspath(f"//{RESULTS_PATH}")

if not os.path.exists(fp):
    os.makedirs(fp)
for partition in partitions:
    if not os.path.exists(os.path.join(fp, partition)):
        os.makedirs(os.path.join(fp, partition))

# Data to store in JSON file
out_data = {
    'camera_angle_x': bpy.data.objects['Camera'].data.angle_x,
}

# Render Optimizations
bpy.context.scene.render.use_persistent_data = True


# Set up rendering of depth map.
bpy.context.scene.use_nodes = True
tree = bpy.context.scene.node_tree
links = tree.links

# Add passes for additionally dumping albedo and normals.
bpy.context.scene.view_layers["RenderLayer"].use_pass_normal = True
bpy.context.scene.render.image_settings.file_format = str(FORMAT)
bpy.context.scene.render.image_settings.color_depth = str(COLOR_DEPTH)

# If using OpenEXR, set to linear color space
if FORMAT == 'OPEN_EXR':
    bpy.data.scenes['Scene'].display_settings.display_device = 'None'
    bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'Linear'
else:
    bpy.data.scenes['Scene'].display_settings.display_device = 'sRGB'
    bpy.data.scenes['Scene'].sequencer_colorspace_settings.name = 'sRGB'    

# Remove all tree nodes
for node in tree.nodes:
    tree.nodes.remove(node)

if 'Custom Outputs' not in tree.nodes:
    # Create input render layer node.
    render_layers = tree.nodes.new('CompositorNodeRLayers')
    render_layers.label = 'Custom Outputs'
    render_layers.name = 'Custom Outputs'
    
    depth_file_output = tree.nodes.new(type="CompositorNodeOutputFile")
    depth_file_output.label = 'Depth Output'
    depth_file_output.name = 'Depth Output'
    if FORMAT == 'OPEN_EXR':
      add_one = tree.nodes.new('CompositorNodeMath')
      add_one.operation = 'ADD'
      add_one.inputs[1].default_value = 1.0
      links.new(render_layers.outputs['Depth'], add_one.inputs[0])
      
      recip = tree.nodes.new('CompositorNodeMath')
      recip.operation = 'DIVIDE'
      recip.inputs[0].default_value = 1.0
      links.new(add_one.outputs[0], recip.inputs[1])
      
      links.new(recip.outputs[0], depth_file_output.inputs[0])
      
    else:
      # Remap as other types can not represent the full range of depth.
      map = tree.nodes.new(type="CompositorNodeMapRange")
      # Size is chosen kind of arbitrarily, try out until you're satisfied with resulting depth map.
      map.inputs['From Min'].default_value = 0
      map.inputs['From Max'].default_value = 8
      map.inputs['To Min'].default_value = 1
      map.inputs['To Max'].default_value = 0
      links.new(render_layers.outputs['Depth'], map.inputs[0])

      links.new(map.outputs[0], depth_file_output.inputs[0])
    
    normal_file_output = tree.nodes.new(type="CompositorNodeOutputFile")
    normal_file_output.label = 'Normal Output'
    normal_file_output.name = 'Normal Output'
    normal_file_output.format.file_format = 'PNG'
    
    # Separate normals into channels, transform (x+1)/2 and combine
    sep_rgba = tree.nodes.new('CompositorNodeSepRGBA')
    links.new(render_layers.outputs['Normal'], sep_rgba.inputs[0])
    
    comb_rgba = tree.nodes.new('CompositorNodeCombRGBA')
    add_ones = []
    divide_by_twos = []
    for i in range(3):
      add_ones.append(tree.nodes.new('CompositorNodeMath'))
      add_ones[i].operation = 'ADD'
      add_ones[i].inputs[1].default_value = 1.0
      links.new(sep_rgba.outputs[i], add_ones[i].inputs[0])        
    
      divide_by_twos.append(tree.nodes.new('CompositorNodeMath'))
      divide_by_twos[i].operation = 'DIVIDE'
      divide_by_twos[i].inputs[1].default_value = 2.0
      links.new(add_ones[i].outputs[0], divide_by_twos[i].inputs[0])        
        
      links.new(divide_by_twos[i].outputs[0], comb_rgba.inputs[i])
      
    # Connect alpha
    links.new(sep_rgba.outputs[3], comb_rgba.inputs[3])
    
    links.new(comb_rgba.outputs[0], normal_file_output.inputs[0])

# Background
bpy.context.scene.render.dither_intensity = 0.0
bpy.context.scene.render.film_transparent = True

# Create collection for objects not to render with background

    
objs = [ob for ob in bpy.context.scene.objects if ob.type in ('EMPTY') and 'Empty' in ob.name]
bpy.ops.object.delete({"selected_objects": objs})

def parent_obj_to_camera(b_camera):
    origin = (0, 0, 0)
    b_empty = bpy.data.objects.new("Empty", None)
    b_empty.location = origin
    b_camera.parent = b_empty  # setup parenting

    scn = bpy.context.scene
    scn.collection.objects.link(b_empty)
    bpy.context.view_layer.objects.active = b_empty
    # scn.objects.active = b_empty
    return b_empty


scene = bpy.context.scene
scene.render.resolution_x = RESOLUTION
scene.render.resolution_y = RESOLUTION
scene.render.resolution_percentage = 100

cam = scene.objects['Camera']
#cam.location = (0, 4.0, 0.5)

#cam_constraint = cam.constraints.new(type='TRACK_TO')
#cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
#cam_constraint.up_axis = 'UP_Y'
#b_empty = parent_obj_to_camera(cam)
#cam_constraint.target = b_empty


#scene.render.image_settings.file_format = 'PNG'  # set output format to .png
scene.render.image_settings.file_format = FORMAT  # set output format to .png

from math import radians

stepsize = 360.0 / VIEWS
rotation_mode = 'XYZ'




if not DEBUG:
    for output_node in [tree.nodes['Depth Output'], tree.nodes['Normal Output']]:
        output_node.base_path = ''

for transforms_file, partition in zip(transforms_files, partitions):
    if transforms_file is not None:
        with open(transforms_file) as in_file:
            transforms_data = json.load(in_file)
        
        VIEWS = len(transforms_data['frames'])
    else:
        cameras = get_all_camera_matrices(VIEWS)

    out_data['frames'] = []

    #for i in range(0, VIEWS, 20):
    #if partition == 'train':
    #    continue
    print(VIEWS)
    for i in [350]:
        if transforms_file is None:
            for a in range(4):
                for b in range(4):
                    cam.matrix_world[a][b] = cameras[i][a][b]
            print(cameras[i], i)
            #cam.location.x = cameras[i][0][3]
            #cam.location.y = cameras[i][1][3]
            #cam.location.z = cameras[i][2][3]
            #if RANDOM_VIEWS:
            #    scene.render.filepath = fp + '/r_' + str(i)
            #    b_empty.rotation_euler = np.random.uniform(0, 2*np.pi, size=3)
            #else:
            #    print("Rotation {}, {}".format((stepsize * i), radians(stepsize * i)))
            #    scene.render.filepath = fp + '/r_{0:03d}'.format(int(i * stepsize))
        else:
            cam.matrix_world = delistify_matrix(transforms_data['frames'][i]['transform_matrix'])    
            print(cam.matrix_world)
            #peter.matrix_world = cam.matrix_world
        if DEBUG:
            i = np.random.randint(0,VIEWS)
            b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff
            b_empty.rotation_euler[2] += radians(2*stepsize*i)
       
        bpy.data.worlds["World"].node_tree.nodes["Value"].outputs[0].default_value = cam.matrix_world[0][3]
        bpy.data.worlds["World"].node_tree.nodes["Value.001"].outputs[0].default_value = cam.matrix_world[1][3]
        bpy.data.worlds["World"].node_tree.nodes["Value.002"].outputs[0].default_value = cam.matrix_world[2][3]                
        
        print("Rotation {}, {}".format((stepsize * i), radians(stepsize * i)))
        scene.render.filepath = os.path.join(fp, partition, f'r_{i}')

        tree.nodes['Depth Output'].file_slots[0].path = scene.render.filepath + "_disp_"
        tree.nodes['Normal Output'].file_slots[0].path = scene.render.filepath + "_normal_"

        break
        if DEBUG:
            break
        else:
            bpy.ops.render.render(write_still=True)  # render still

        frame_data = {
            #'file_path': scene.render.filepath,
            'file_path': f'./{partition}/r_{i}',
            'rotation': radians(stepsize),
            'transform_matrix': listify_matrix(cam.matrix_world)
        }
        out_data['frames'].append(frame_data)

        #if transforms_file is None:
        #    b_empty.rotation_euler[0] = CIRCLE_FIXED_START[0] + (np.cos(radians(stepsize*i))+1)/2 * vertical_diff
        #    b_empty.rotation_euler[2] += radians(2*stepsize)

    if not DEBUG:
        with open(fp + '/' + f'transforms_{partition}.json', 'w') as out_file:
            json.dump(out_data, out_file, indent=4)
