<a href="https://colab.research.google.com/github/jonbarron/planar_filter/blob/main/planar_filter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

A reimplementation of the "Planar Filter" from Section 7.1 (Algorithms 1 and 2) of [Depth from motion for smartphone AR
](https://research.google/pubs/pub48288/), Valentin et al., SIGGRAPH Asia 2018

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def solve_image_ldl3(A11, A12, A13, A22, A23, A33, b1, b2, b3):
  # An unrolled LDL solver for a 3x3 symmetric linear system.
  d1 = A11
  L12 = A12/d1
  d2 = A22 - L12*A12
  L13 = A13/d1
  L23 = (A23 - L13*A12)/d2
  d3 = A33 - L13*A13 - L23*L23*d2
  y1 = b1
  y2 = b2 - L12*y1
  y3 = b3 - L13*y1 - L23*y2
  x3 = y3/d3
  x2 = y2/d2 - L23*x3
  x1 = y1/d1 - L12*x2 - L13*x3
  return x1, x2, x3

def planar_filter(Z, filt, eps):
  # Solve for the plane at each pixel in `Z`, where the plane fit is computed
  # by using `filt` (a function that blurs something of the same size and shape
  # as `Z` by taking a linear non-negative combination of inputs) to weight
  # pixels in Z, and `eps` regularizes the output to be fronto-parallel.
  # Returns (Zx, Zy, Zz), which is a plane parameterization for each pixel:
  # the derivative wrt x and y, and the offset (which can itself be used as
  # "the" filtered output).

  # Note: This isn't the same code as in the paper. I flipped x and y to match
  # a more pythonic (x, y) convention, and I had to flip a sign on the output
  # slopes to make the unit tests pass(this may be a bug in the paper's math).
  # Also, I decided to not regularize the "offset" component of the plane fit,
  # which means that setting eps -> infinity gives the output (0, 0, filt(Z)).
  xy_shape = np.array(Z.shape[-2:])
  xy_scale = 2 / np.mean(xy_shape-1)  # Scaling the x, y coords to be in ~[0, 1]
  x, y = np.meshgrid(*[(np.arange(s) - (s-1)/2) * xy_scale for s in xy_shape], indexing='ij')
  [F1, Fx, Fy, Fz, Fxx, Fxy, Fxz, Fyy, Fyz] = [
    filt(t) for t in [
    np.ones_like(x), x, y, Z, x**2, x*y, x*Z, y**2, y*Z]]
  A11 = F1*x**2 - 2*x*Fx + Fxx + eps**2
  A22 = F1*y**2 - 2*y*Fy + Fyy + eps**2
  A12 = F1*y*x - x*Fy - y*Fx + Fxy
  A13 = F1*x - Fx
  A23 = F1*y - Fy
  A33 = F1# + eps**2
  b1 = Fz*x - Fxz
  b2 = Fz*y - Fyz
  b3 = Fz
  Zx, Zy, Zz = solve_image_ldl3(A11, A12, A13, A22, A23, A33, b1, b2, b3)
  return -Zx*xy_scale, -Zy*xy_scale, Zz

In [None]:
# A simple linear blur filter. This can be whatever, provided it averages the
# input images by averaging its inputs with non-negative weights.
def blur(X, alpha):
  # Do an exponential decay filter on the outermost two dimensions of X.
  # Equivalent to convolving an image with a Laplacian blur.
  Y = X.copy()
  for i in range(Y.shape[-1]-1):
    Y[...,i+1] += alpha * Y[...,i]

  for i in range(Y.shape[-1]-1)[::-1]:
    Y[...,i] += alpha * Y[...,i+1]

  for i in range(Y.shape[-2]-1):
    Y[...,i+1,:] += alpha * Y[...,i,:]

  for i in range(Y.shape[-2]-1)[::-1]:
    Y[...,i,:] += alpha * Y[...,i+1,:]
  return Y

In [None]:
# Test that planar_filter'ing correctly recovers planes on single images.
np.random.seed(0)
for i_test in range(10):

  # Make a random plane.
  x, y = np.meshgrid(range(int(32 + 32*np.random.uniform())), range(int(32 + 32*np.random.uniform())), indexing='ij')
  sx, sy, shift = np.random.normal(size=(3))
  Z_true = sx * x + sy * y + shift

  # Mask out most of the pixels
  mask = (np.mod(x, 4) == 0) & (np.mod(y, 4) == 0)
  Z = mask * Z_true
  W = np.float32(mask)

  # Define a blur function.
  alpha = 0.2
  filt = lambda x : blur(x * W, alpha) / blur(W, alpha)

  # normal filteirng, and planar_filter'ing
  Zf = filt(Z)
  Zx, Zy, Zz = planar_filter(Z, filt, 1e-4)

  basic_max_error = np.max(np.abs(Zf - Z_true))
  planar_max_error = np.max(np.abs(Zz - Z_true))
  print(f'Errors = {basic_max_error:0.5f} | {planar_max_error:0.5f}')

  # Plane fitting correctly recovers the true plane values.
  assert(planar_max_error < 0.01)

  # Plane fitting correctly recovers the slope of the plane.
  assert(np.max(np.abs(np.median(Zx) - sx)) < 0.001)
  assert(np.max(np.abs(np.median(Zy) - sy)) < 0.001)

  # Setting `eps` -> infinity behaves as expected.
  Zx0, Zy0, Zf_recon = planar_filter(Z, filt, 1e10)
  assert(np.max(np.abs(Zx0)) < 0.001)
  assert(np.max(np.abs(Zy0)) < 0.001)
  assert(np.max(np.abs(Zf_recon - Zf)) < 0.001)

  plt.figure(i_test)
  _, ax = plt.subplots(1, 3, figsize=(12, 4))
  ax[0].imshow(Zf)
  ax[1].imshow(Zz)
  ax[2].imshow(Z_true)

In [None]:
# Test that planar_filter'ing works correctly on batches of data.
np.random.seed(0)
x, y = np.meshgrid(range(32), range(48), indexing='ij')
mask = (np.mod(x, 4) == 0) & (np.mod(y, 4) == 0)
W = np.float32(mask)

Zs = []
Zs_true = []
s_true = []
for i_test in range(10):

  sx, sy, shift = np.random.normal(size=(3))

  Z_true = sx * x + sy * y + shift
  Z = mask * Z_true

  Zs_true.append(Z_true)
  Zs.append(Z)
  s_true.append((sx, sy, shift))

Zs = np.stack(Zs, 0)
Zs_true = np.stack(Zs_true, 0)

alpha = 0.2
filt = lambda x : blur(x * W, alpha) / blur(W, alpha)

Zsf = filt(Zs)
Zsx, Zsy, Zsz = planar_filter(Zs, filt, 1e-4)

basic_max_error = np.max(np.abs(Zsf - Zs_true))
planar_max_error = np.max(np.abs(Zsz - Zs_true))
print(f'Errors = {basic_max_error:0.5f} | {planar_max_error:0.5f}')
assert(planar_max_error < 0.01)

assert np.all(np.abs(np.array([s[0] for s in s_true])[:,None,None] - Zsx) < 1e-3)
assert np.all(np.abs(np.array([s[1] for s in s_true])[:,None,None] - Zsy) < 1e-3)

Zsx0, Zsy0, Zsf_recon = planar_filter(Zs, filt, 1e10)
assert(np.max(np.abs(Zsx0)) < 0.001)
assert(np.max(np.abs(Zsy0)) < 0.001)
assert(np.max(np.abs(Zsf_recon - Zsf)) < 0.001)

plt.figure(figsize=(20,20))
plt.imshow(np.concatenate([np.reshape(Zsf, [-1, Zsf.shape[-1]]), np.reshape(Zsz, [-1, Zsz.shape[-1]]), np.reshape(Zs_true, [-1, Zs_true.shape[-1]])], 1))