Copyright 2018 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import numpy as np
import scipy
from PIL import Image
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import jax.experimental.optimizers
import time
import functools

In [None]:
envmap_H = 50
envmap_W = 100

with gfile.Open(f'{DIRECTORY}/ninomaru_teien_4k.jpg', 'rb') as f:
  envmap_gt = np.array(Image.open(f).resize((envmap_W, envmap_H), Image.ANTIALIAS)) / 255.0

envmap_gt[0, :, :] = envmap_gt[0, :, :].mean(axis=0, keepdims=True)

In [None]:
plt.figure(figsize=[12, 12])
plt.imshow(envmap_gt)

In [None]:
phi, theta = jnp.meshgrid(jnp.linspace(0.0, 2*jnp.pi, envmap_W+1)[:-1],
                          jnp.linspace(0.0,   jnp.pi, envmap_H+1)[:-1])

#mask = np.ones_like(img)

x = jnp.sin(theta) * jnp.cos(phi)
y = jnp.sin(theta) * jnp.sin(phi)
z = jnp.cos(theta)
xyz = jnp.stack([x.flatten(), y.flatten(), z.flatten()], axis=-1)

def thph2vec(theta, phi):
  return jnp.stack([jnp.sin(theta) * jnp.cos(phi),
                    jnp.sin(theta) * jnp.sin(phi),
                    jnp.cos(theta)], axis=-1)


def rotate(xyz, new_z_dir, flipped=True):
  new_x_dir = jnp.cross(new_z_dir, jnp.array([0.0, 1.0, 0.0]))
  new_x_dir = new_x_dir / jnp.linalg.norm(new_x_dir)
  new_y_dir = jnp.cross(new_z_dir, new_x_dir)

  R = jnp.stack([new_x_dir, new_y_dir, new_z_dir], axis=1) # columns are x, y, z

  if flipped:
    new_xyz = xyz @ R
  else:
    new_xyz = xyz @ R.T

  return new_xyz

def vec2thph(x, y, z):
  theta = jnp.arctan2(jnp.sqrt(x**2+y**2+1e-10), z)
  phi = jnp.arctan2(y, x+1e-10)
  return theta, phi

def get_mask(new_z_theta, new_z_phi, thresh=0.9):

  # rotate z -> new_z_dir
  new_z_dir = thph2vec(new_z_theta, new_z_phi)
  new_xyz = rotate(xyz, new_z_dir)

  new_x = new_xyz[:, 0].reshape((envmap_H, envmap_W))
  new_y = new_xyz[:, 1].reshape((envmap_H, envmap_W))
  new_z = new_xyz[:, 2].reshape((envmap_H, envmap_W))

  #new_theta, _ = vec2thph(new_x, new_y, new_z)

  #mask = jax.nn.sigmoid(1*(1.0 - new_z))#(new_z < 0.9)
  #mask = new_z
  mask = jnp.float32(new_z < thresh)
  #mask = jax.nn.sigmoid(30.0*(-new_z + thresh))
  return mask

new_z_theta = jnp.pi/2.0 #jnp.pi/4.0
new_z_phi = jnp.pi/2
mask = get_mask(new_z_theta, new_z_phi, 0.5)

plt.imshow(mask)

In [None]:
def get_rotated_grid(new_z_theta, new_z_phi):
  # rotate z -> new_z_dir
  new_z_dir = thph2vec(new_z_theta, new_z_phi)
  new_xyz = rotate(xyz, new_z_dir)

  new_x = new_xyz[:, 0].reshape((envmap_H, envmap_W))
  new_y = new_xyz[:, 1].reshape((envmap_H, envmap_W))
  new_z = new_xyz[:, 2].reshape((envmap_H, envmap_W))
  return new_x, new_y, new_z 

def get_mask_and_peter(new_z_theta, new_z_phi, thresh=0.9, peter_grid=None, mask_grid=None):
  new_x, new_y, new_z = get_rotated_grid(new_z_theta, new_z_phi)

  if peter_grid is None:
    peter = jnp.zeros((envmap_H, envmap_W, 3), dtype=jnp.float32)
    mask = jnp.ones((envmap_H, envmap_W), dtype=jnp.float32)
  else:
    peter, mask = peter_fn(peter_grid, mask_grid, new_x, new_y, new_z)

  return mask, peter, new_x, new_y, new_z

def get_illumination(envmap, new_z_theta, new_z_phi, thresh, peter_grid=None, mask_grid=None):
  mask, peter, new_x, new_y, new_z = get_mask_and_peter(new_z_theta, new_z_phi, thresh, peter_grid, mask_grid)
  return mask[:, :, None] * envmap + (1.0 - mask[:, :, None]) * peter, mask, peter


def peter_fn(peter_grid, mask_grid, x, y, z):
  th, ph = vec2thph(x, y, z)
  inds = jnp.stack([th*envmap_H/jnp.pi, ph*envmap_W/2.0/jnp.pi], axis=0)
  peter = sphere_interpolate(peter_grid, inds)
  mask = sphere_interpolate(mask_grid[..., None], inds)[..., 0]
  return peter, mask

def sphere_interpolate(grids, inds):
  """
  grid is [H, W, d], parameterized by elevation (0th dim) and azimuth (1st dim)
  inds is [2, ...], with the 0th dim being elevation and 1st azimuth
  """

  #inds[1, ...] = jnp.mod(inds[1, ...], 2*jnp.pi)
  results = []
  for grid in [grids[:, :, d] for d in range(grids.shape[-1])]:
    res = jax.scipy.ndimage.map_coordinates(grid, inds, order=1, mode='nearest')
    #res = jax.scipy.ndimage.map_coordinates(grid, inds, order=1, mode='wrap')
    results.append(res)

  return jnp.stack(results, axis=-1)


In [None]:
def cosine_scheduler(m, lr_init, lr_final):
  def lr(i):
    return 0.5 * (1.0 - jnp.cos(jnp.pi*i/m)) * (lr_final - lr_init) + lr_init
  return lr

def parse_bin(s):
  return int(s[1:], 2) / 2.**(len(s) - 1)


def phi2(i):
  return parse_bin('.' + f'{i:b}'[::-1])

def nice_uniform(N):
  u = []
  v = []
  for i in range(N):
      u.append(i / float(N))
      v.append(phi2(i))
      #pts.append((i/float(N), phi2(i)))

  return u, v

def nice_uniform_hemispherical(N):
  """implementation of http://holger.dammertz.org/stuff/notes_HammersleyOnHemisphere.html"""
  u, v = nice_uniform(N)

  theta = np.arccos(1.0 - np.array(u))
  phi   = 2.0 * np.pi * np.array(v)

  return np.stack([theta, phi], axis=-1)



In [None]:
thresh = 0.85

peter_grid_gt = np.stack([x, y, z], axis=-1) * 0.5 + 0.5
peter_grid_gt *= 0.0

mask_grid_gt = np.float32(np.cos(theta) < thresh)


num_images = 1000
envmaps_gt = []
masks_gt = []
peters_gt = []
mask_thetaphis = []
for i, (mask_theta, mask_phi) in enumerate(nice_uniform_hemispherical(num_images)):
  envmap, mask, peter = get_illumination(envmap_gt, mask_theta, mask_phi, thresh, peter_grid_gt, mask_grid_gt)
  envmaps_gt.append(envmap)
  masks_gt.append(mask)
  peters_gt.append(peter)
  mask_thetaphis.append((mask_theta, mask_phi))


#print(len(mask_phis))
plt.figure()
plt.plot(*zip(*mask_thetaphis), '.')  

masks_gt = jnp.array(masks_gt)
peters_gt = jnp.array(peters_gt)
envmaps_gt = jnp.array(envmaps_gt)
orientations_gt = jnp.array(mask_thetaphis)
#valid = 1.0 - jnp.float32(invalid).flatten()[None, None, :]
M = masks_gt.shape[0]



plt.figure()
plt.imshow(masks_gt.min(axis=0))
plt.figure()
plt.imshow((1.0 - masks_gt).sum(0))
plt.colorbar()


H = W = 96

xx, yy = np.meshgrid(np.linspace(-1.01, 1.01, W),
                     np.linspace(-1.01, 1.01, H))

invalid = xx ** 2 + yy ** 2 > 1.0

zz = np.sqrt(1.0 - xx ** 2 - yy ** 2)

xx[invalid] = 0.0
yy[invalid] = 0.0
zz[invalid] = 1.0

normals_gt = jnp.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=-1)

plt.figure()
plt.imshow(normals_gt.reshape(H, W, 3)*0.5+0.5)

num_eqs  = (1 - invalid).sum() * masks_gt.shape[0]
num_vars = envmap_H * envmap_W * 3 + masks_gt.shape[0] * envmap_H * envmap_W

#assert (1 - masks_gt).sum(0).min() >= 1, 'Some pixels are not covered by any mask'
assert num_eqs > num_vars, f"Need enough pixels ({num_eqs} eqs, {num_vars} vars)"

In [None]:
plt.imshow(get_mask_and_peter(1, 4.0, peter_grid=peter_grid_gt, mask_grid=mask_grid_gt)[0])

In [None]:
@jax.custom_jvp
def render_pixels(normals, envmap, peter, mask, use_jacobian=True):
  """
  normals:      [3]
  envmap:       [H, W, 3]
  peter:        [H, W, 3]  (grid, interpolated using `peter_fn`)
  mask:         [H, W]

  returns [3]
  """
  # xyz is [HW, 3]
  lobe = jnp.maximum(0.0, (xyz * normals[None, :]).sum(-1).reshape(envmap_H, envmap_W))  # [H, W, 1]
  #print(normals.shape, envmap.shape, peter.shape, mask.shape) (3,) (100, 200, 3) (200, 3) (200,)

  masked_envmap = envmap * mask[:, :, None] + (1.0 - mask[:, :, None]) * peter
  if use_jacobian:
    return (masked_envmap * lobe[:, :, None] * jnp.sin(theta)[:, :, None]).sum(0).sum(0) * (phi[0, 1] - phi[0, 0]) * (theta[1, 0] - theta[0, 0]) / jnp.pi
  else:
    return (masked_envmap * lobe[:, :, None]).sum(0).sum(0) * (phi[0, 1] - phi[0, 0]) * (theta[1, 0] - theta[0, 0]) / jnp.pi


@render_pixels.defjvp
def render_pixels_jvp(primals, tangents):
  a, b, x, filt = primals
  a_dot, b_dot, x_dot, filt_dot = tangents

  primal_out = params2signal(a, b, x, filt)
  xa = periodic_interp(a, t, x)
  xb = periodic_interp(b, t, x)

  ftma = periodic_interp(t - a, t, filt)
  ftmb = periodic_interp(t - b, t, filt)

  # TODO: figure out how to reuse masks from the main function instead of
  #       recomputing them here.
  mask = (b > a) * (t > a) * (t < b) \
        + (1.0 - (b > a)) * (1.0 - (t > b) * (t < a))

  tangent_out = -xa * ftma * a_dot + xb * ftmb * b_dot + \
    jnp.real(jnp.fft.ifft(jnp.fft.fft(filt) * jnp.fft.fft(mask * x_dot, axis=-1), axis=-1)) + \
    jnp.real(jnp.fft.ifft(jnp.fft.fft(filt_dot) * jnp.fft.fft(mask * x, axis=-1), axis=-1))
  return primal_out, tangent_out


@jax.jit
def make_data(envmap, peters, masks, normals):
  """
  normals:       [   n,       3]
  envmap:        [      H, W, 3]
  peters:        [N,    H, W, 3]
  masks:         [N,    H, W   ]
  """
  return jax.vmap(jax.vmap(render_pixels, in_axes=(0, None, None, None)), in_axes=(None, None, 0, 0))(normals, envmap, peters, masks)


def get_diff(params, num_pts=20.0):
  """
  Compute the gradient of the image w.r.t. the vertex parameters
  """
  t = (jnp.arange(num_pts) + 0.5) / num_pts
  r = params[:, :, None].transpose(1, 0, 2) * t[None, None, :] + jnp.roll(params, 1, axis=0)[:, :, None].transpose(1, 0, 2) * (1.0 - t[None, None, :])

  fq = (0.5 * jnp.stack([x, y], axis=0)[:, None, None, :, :] + 0.5) * (H - 1) - r[:, :, :, None, None]

  f_interp = interp2d(f[:, :, None], fq)
  L_interp = interp2d(L[:, :, None], r) 

  line_diff1 = (f_interp[:, :, :, :, 0] * L_interp[:, :, None, None, 0] * (1.0 - t[None, :, None, None])).sum(1) * (t[1] - t[0])  # the ith row is the integral from v_{i} to v_{i-1}
  line_diff2 = (f_interp[:, :, :, :, 0] * L_interp[:, :, None, None, 0] * t[None, :, None, None]).sum(1) * (t[1] - t[0])          # the ith row is the integral from v_{i-1} to v_{i}
  dx = jnp.fliplr(params - jnp.roll(params, 1, axis=0)) * jnp.array([1.0, -1.0])[None, :]  # this stores y_{i} - y_{i-1} and x_{i} - x_{i-1})

  # Add integral from v_{i+1} to v_{i} multiplied by y_{i} - y_{i+1} with integral from v_{i-1} to v_{i} multiplied by y_{i-1} - y_{i}
  diff = -jnp.roll(line_diff1[:, None, :, :] * dx[:, :, None, None], -1, axis=0) - line_diff2[:, None, :, :] * dx[:, :, None, None]
  return diff


def get_mask(vertices):
  img = np.ones((H, W), dtype=np.uint8) * 255
  mask = cv2.fillPoly(img, vertices, 0, cv2.LINE_AA)
  return np.float32(mask) / 255.0

"""
@jax.custom_jvp
def params_to_data(envmap, peters, vertices, normals):
  
  #a, shape ()
  #b, shape ()
  #x, shape (T,)
  #filt, shape (T,)
  #
  #t is not an input but it also has shape (T,) 
  
  mask = get_mask(vertices)

  #shifted_peters = jax.vmap(periodic_interp, in_axes=(0, None, None))(t[None, :]-positions[:, None], t, peter)
  peter = jnp.zeros_like(mask)
  print("NO PETER!")
  
  y = mask * x + (1.0 - mask) * peter
  z = jnp.real(jnp.fft.ifft(jnp.fft.fft(filt) * jnp.fft.fft(y, axis=-1), axis=-1))

  return z

@params2signal.defjvp
def params2signal_jvp(primals, tangents):
  a, b, x, filt = primals
  a_dot, b_dot, x_dot, filt_dot = tangents

  primal_out = params2signal(a, b, x, filt)
  xa = periodic_interp(a, t, x)
  xb = periodic_interp(b, t, x)

  ftma = periodic_interp(t - a, t, filt)
  ftmb = periodic_interp(t - b, t, filt)

  # TODO: figure out how to reuse masks from the main function instead of
  #       recomputing them here.
  mask = (b > a) * (t > a) * (t < b) \
        + (1.0 - (b > a)) * (1.0 - (t > b) * (t < a))

  tangent_out = -xa * ftma * a_dot + xb * ftmb * b_dot + \
    jnp.real(jnp.fft.ifft(jnp.fft.fft(filt) * jnp.fft.fft(mask * x_dot, axis=-1), axis=-1)) + \
    jnp.real(jnp.fft.ifft(jnp.fft.fft(filt_dot) * jnp.fft.fft(mask * x, axis=-1), axis=-1))
  return primal_out, tangent_out
"""


In [None]:
res_gt = []
num_chunks = 64
assert normals_gt.shape[0] % num_chunks == 0
for chunk in range(num_chunks):
  ri = chunk * normals_gt.shape[0] // num_chunks
  rf = ri    + normals_gt.shape[0] // num_chunks
  res_gt.append(make_data(envmap_gt, peters_gt, masks_gt, normals_gt[ri:rf, :]))
res_gt = jnp.concatenate(res_gt, axis=1)


In [None]:
def entropy(p):
  return -p * jnp.log(jnp.maximum(1e-10, p)) - (1 - p) * jnp.log(jnp.maximum(1e-10, 1.0 - p))


def norm(x, p=1, w=None):
  if w is not None:
    x = x * w
  if p == 0.0:
    return jnp.float32(jnp.sum(x != 0))
  return jnp.power((jnp.abs(x + 1e-3) ** p).sum(), 1.0/p)

@jax.jit
def get_loss(params_envmap, params_peters, params_masks, normals, gt, spatial_inds, img_inds, i):
  """
  envmap is  [   H, W, 3]
  peters is  [N, H, W, 3]
  masks is   [N, H, W]
  normals is [N,       3]
  gt is      [N, H*W,  3]
  """
  envmap, peters, masks = params2components(params_envmap, params_peters, params_masks, i)

  res = make_data(envmap, peters[img_inds], masks[img_inds], normals[spatial_inds])

  data_loss = jnp.sum(((gt[img_inds, :, :][:, spatial_inds, :] - res)**2)) / batch_size #/ gt.shape[0]
  loss = data_loss#.clone()

  # TV norm is the integral of gradient(mask) * sin(theta) dtheta dphi
  #p = 0.1
  w = masks.shape[0] * masks.shape[1] * masks.shape[2]
  #mask_tv = norm(masks[:, 1:, :] - masks[:, :-1, :], p, w=jnp.sin(theta[None, :-1, :])) / w + norm(masks[:, :, 1:] - masks[:, :, :-1], p) / w
  
  
  #norm = lambda x: lossfun(x, -cosine_scheduler(num_iters, 2.0, -10.0)(i), 0.1)
  #norm = lambda x: lossfun(x, -1.0, 0.1)
  norm = lambda x: x ** 2
  sintheta = jnp.sin(theta)
  mask_tv = norm((masks[:, 1:, :] - masks[:, :-1, :])*sintheta[None, :-1, :]).sum() + norm((masks[:, :, 1:] - masks[:, :, :-1])*sintheta[None, :, :-1]).sum()
  # Add wrap gradient in phi
  mask_tv += norm((masks[:, :, :1] - masks[:, :, -1:])*sintheta[None, :, -1:]).sum()
  mask_tv *= (phi[0, 1] - phi[0, 0]) * (theta[1, 0] - theta[0, 0])
  loss += 1e-4 * mask_tv #/ w
  #loss += 0.01 * (1.0 - masks).sum() / w
  loss += 1e-2 * entropy(masks).sum() / w

  #loss += 1e-5 * ((peters - peters.mean(0, keepdims=True)) ** 2).sum() / peters.shape[0]

  return loss, (data_loss, None, res, ((envmap-envmap_gt)**2).mean())
  




params_envmap = (jax.random.uniform(jax.random.PRNGKey(0), shape=(envmap_H, envmap_W, 3)) - 0.5) * 0.1
#params_peters = (jax.random.uniform(jax.random.PRNGKey(7065), shape=(res_gt.shape[0], envmap_H, envmap_W, 3)) - 0.5) * 0.1 #- 3.0
params_peters = (jax.random.uniform(jax.random.PRNGKey(7065), shape=(res_gt.shape[0], 1, 1, 3)) - 0.5) * 0.1
#params_peters = (jax.random.uniform(jax.random.PRNGKey(7065), shape=(1, 1, 1, 3)) - 0.5) * 0.1
if True:
  #params_masks =  (jax.random.uniform(jax.random.PRNGKey(1122), shape=(res_gt.shape[0], envmap_H, envmap_W)) - 0.5) * 10.0 + 3.0
  #params_masks =  jax.random.uniform(jax.random.PRNGKey(1122), shape=(res_gt.shape[0], envmap_H, envmap_W)) * 1.0 #+ 1.0
  params_masks =  jax.random.uniform(jax.random.PRNGKey(1122), shape=(res_gt.shape[0], envmap_H, envmap_W)) * 1.0 + 1.0
else:
  params_masks = jax.random.uniform(jax.random.PRNGKey(1122), shape=(res_gt.shape[0], 256, 2))


#params_envmap = envmap_gt
#params_peters = peters_gt
#params_masks = masks_gt

jitvalgrad = jax.jit(jax.value_and_grad(get_loss, argnums=(0, 1, 2), has_aux=True))

num_iters = 50000
spatial_batch_size = 4
img_batch_size = 128
batch_size = spatial_batch_size * img_batch_size
print(f"batch_size is {batch_size}")


init_lr_envmap = 0.003
init_envmap, update_envmap, get_params_envmap = jax.experimental.optimizers.adam(cosine_scheduler(num_iters, init_lr_envmap, init_lr_envmap*0.3))
state_envmap = init_envmap(params_envmap)

init_lr_peters = 0.01
init_peters, update_peters, get_params_peters = jax.experimental.optimizers.adam(cosine_scheduler(num_iters, init_lr_peters, init_lr_peters*0.3))
state_peters = init_peters(params_peters)

init_lr_masks = 0.07
init_masks, update_masks, get_params_masks = jax.experimental.optimizers.adam(cosine_scheduler(num_iters, init_lr_masks, init_lr_masks*0.3))
state_masks = init_masks(params_masks)

def params2components(params_envmap, params_peters, params_masks, iteration):
  envmap = jax.nn.sigmoid(params_envmap)# / jnp.sin(theta[:, :, None] + 1e-10))
  peters = jax.nn.sigmoid(params_peters) #* 0.0
  if True:
    #temperature = cosine_scheduler(num_iters, 1e-1, 1e-4)(iteration)
    temperature = 1.0
    masks = jax.nn.sigmoid(params_masks/temperature) #/ jnp.sin(theta[None, :, :] + 1e-10) #* 0.0 + 1.0
  else:
    width = 0.3
    distsq = (params_masks[:, :, 0][:, :, None, None] * jnp.pi - theta[None, None, :, :]) ** 2 + (params_masks[:, :, 1][:, :, None, None] * 2 * jnp.pi - phi[None, None, :, :]) ** 2
    masks = jnp.exp(-distsq/2/width).sum(1)
    masks = jnp.clip(masks, 0.0, 1.0)
  
  if True:
    #envmap = envmap_gt
    peters = peters_gt
    masks = masks_gt
    for _ in range(20):
      print("SETTING STUFF TO GT")


  return envmap, peters, masks



@jax.jit
def update_params(i, state_envmap, state_peters, state_masks, normals, gt, spatial_inds, img_inds):
  params_envmap = get_params_envmap(state_envmap)
  params_peters = get_params_peters(state_peters)
  params_masks  = get_params_masks(state_masks)

  (loss, (data_loss, _, res, envmap_error)), g = jitvalgrad(params_envmap, params_peters, params_masks, normals, gt, spatial_inds, img_inds, i)

  return update_envmap(i, g[0], state_envmap), update_peters(i, g[1], state_peters), update_masks(i, g[2], state_masks), loss, data_loss, res, envmap_error, g

from numpy.random import default_rng

rng = default_rng()

valid = 1.0 - jnp.float32(invalid).flatten()[None, None, :]
nz = np.nonzero(valid[0, 0, :])[0]  # valid element indices



t = time.time()
losses = []
envmap_errors = []
for i in range(num_iters):
  if i % 100 == 0:
    print(i)
  spatial_inds = nz[rng.choice(nz.shape[0], size=spatial_batch_size, replace=False)]
  img_inds = rng.choice(res_gt.shape[0], size=img_batch_size, replace=False)
  state_envmap, state_peters, state_masks, loss, data_loss, z, envmap_error, g = update_params(
      i, state_envmap, state_peters, state_masks, normals_gt, res_gt, spatial_inds, img_inds)
  losses.append(data_loss)
  envmap_errors.append(envmap_error)

  if (i % 5000 == 0 and i > 0) or i == num_iters - 1 or i == 2000:

    envmap, peters, masks = params2components(get_params_envmap(state_envmap), get_params_peters(state_peters), get_params_masks(state_masks), i)

    print(f"Iteration {i}")
    plt.figure()
    plt.semilogy(np.array(losses))
    plt.semilogy(np.array(envmap_errors))
    plt.figure(figsize=[18, 12])
    plt.subplot(321)
    plt.imshow(envmap) 
    plt.axis('off')
    plt.subplot(322)
    plt.imshow(envmap_gt)
    plt.axis('off')
    for ind, mask_ind in zip([3, 4, 5], [0, 30, 60]):
      plt.subplot(3, 2, ind)
      mask_vs_gt = jnp.stack([masks[mask_ind], masks_gt[mask_ind], jnp.ones_like(masks[0])], axis=-1)
      plt.imshow(mask_vs_gt)
      plt.axis('off')
    plt.subplot(326)
    #plt.imshow(jnp.where(theta[:, :, None] >= np.arccos(thresh), np.ones_like(peter_grid_gt)*0.5, peter_grid_gt))
    plt.imshow(peters[:, 0, :, :].repeat(peters.shape[0], 1))
    plt.axis('off')
    plt.show()
    #print(peters[0, 0, 0, :])

print(f"Optimization took {time.time()-t:.3f} seconds.")

