Skip to content

Commit

Permalink
import depth warper
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Sep 25, 2018
1 parent 063b0c6 commit cd03569
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
48 changes: 48 additions & 0 deletions test/test_depth_warper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest

import torch
import torchgeometry as dgm
from torch.autograd import gradcheck

import utils # test utils


class Tester(unittest.TestCase):

def test_depth_warper(self):
# generate input data
batch_size = 1
height, width = 8, 8
cx, cy = width / 2, height / 2
fx, fy = 1., 1.
rx, ry, rz = 0., 0., 0.
tx, ty, tz = 0., 0., 0.
offset_x = 0. # we will apply a 10units offset to `i` camera

pinhole_src = utils.create_pinhole(fx, fy, cx, cy, \
height, width, rx, ry, rx, tx, ty, tz)
pinhole_src = pinhole_src.expand(batch_size, -1)

pinhole_dst = utils.create_pinhole(fx, fy, cx, cy, \
height, width, rx, ry, rx, tx + offset_x, ty, tz)
pinhole_dst = pinhole_dst.expand(batch_size, -1)

# create checkerboard
board = utils.create_checkerboard(height, width, 4)
patch_src = torch.from_numpy(board).view(
1, 1, height, width).expand(batch_size, 1, height, width)

# instantiate warper
warper = dgm.DepthWarper(pinhole_src, height, width)
warper.compute_homographies(pinhole_dst, scale=torch.ones(batch_size, 1))

# generate synthetic inverse depth
inv_depth_src = torch.ones(batch_size, 1, height, width)

import pdb;pdb.set_trace()
patch_dst = warper(inv_depth_src, patch_src)
pass


if __name__ == '__main__':
unittest.main()
100 changes: 100 additions & 0 deletions torchgeometry/depth_warper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

from .functional import scale_pinhole, homography_i_H_ref


class DepthWarper(nn.Module):
def __init__(self, pinholes, width=None, height=None):
super(DepthWarper, self).__init__()
self.width = width
self.height = height
self._pinholes = pinholes
self._i_Hs_ref = None # to be filled later
self._pinhole_ref = None # to be filled later

def compute_homographies(self, pinhole, scale):
pinhole_ref = scale_pinhole(pinhole, scale)
if self.width is None:
self.width = pinhole_ref[..., 4]
if self.height is None:
self.height = pinhole_ref[..., 5]
self._pinhole_ref = pinhole_ref
# scale pinholes_i and compute homographies
pinhole_i = scale_pinhole(self._pinholes, scale)
self._i_Hs_ref = homography_i_H_ref(pinhole_i, pinhole_ref)

def _compute_projection(self, x, y, invd):
point = torch.FloatTensor([[x], [y], [1.0], [invd]])).to(x.device)
flow = torch.matmul(self._i_Hs_ref, point)
z = 1. / flow[:, :, 2]
x = (flow[:, :, 0] * z)
y = (flow[:, :, 1] * z)
return torch.stack([x, y], 1)

def compute_subpixel_step(self):
"""This computes the required inverse depth step to achieve sub pixel
accurate sampling of the depth cost volume, per camera.
Szeliski, Richard, and Daniel Scharstein. "Symmetric sub-pixel stereo matching." European Conference on Computer Vision. Springer Berlin Heidelberg, 2002.
"""
delta_d = 0.01
xy_m1 = self._compute_projection(self.width / 2, self.height / 2,
1.0 - delta_d)
xy_p1 = self._compute_projection(self.width / 2, self.height / 2,
1.0 + delta_d)
dx = torch.norm((xy_p1 - xy_m1), 2, dim=2) / 2.0
dxdd = dx / (delta_d) # pixel*(1/meter)
return torch.min(
0.5 /
dxdd) # half pixel sampling, we're interested in the min for all cameras

# compute grids

def warp(self, inv_depth_ref, roi=None):
assert self._i_Hs_ref is not None, 'call compute_homographies'
if roi == None:
roi = (0, self.height, 0, self.width)
start_row, end_row, start_col, end_col = roi
assert start_row < end_row
assert start_col < end_col
height, width = end_row - start_row, end_col - start_col
area = width * height

# take sub region
#inv_depth_ref = inv_depth_ref.squeeze(0)[start_row:end_row, start_col:
# end_col].contiguous()
import pdb;pdb.set_trace()
inv_depth_ref = inv_depth_ref[..., start_row:end_row, \
start_col:end_col].contiguous()

device = inv_depth_ref.device
ones_x = torch.ones(height).to(device)
ones_y = torch.ones(width).to(device)
ones = torch.ones(area).to(device)

x = torch.linspace(start_col, end_col - 1, width).to(device)
y = torch.linspace(start_row, end_row - 1, height).to(device)

xv = torch.ger(ones_x, x).view(area)
yv = torch.ger(y, ones_y).view(area)

flow = [xv, yv, ones, inv_depth_ref.view(area)]
flow = torch.stack(flow, 0)

flow = torch.matmul(self._i_Hs_ref, flow)
assert len(flow.shape) == 3, flow.shape

z = 1. / flow[:, 2] # Nx(H*W)
x = (flow[:, 0] * z - self.width / 2) / (self.width / 2)
y = (flow[:, 1] * z - self.height / 2) / (self.height / 2)

flow = torch.stack([x, y], 1) # Nx2x(H*W)

n, c, a = flow.shape
flows = flow.view(n, c, height, width) # Nx2xHxW
return flows.permute(0, 2, 3, 1) # NxHxWx2

def forward(self, inv_depth_ref, data):
# be aware that grid_sample only supports float or double type
return torch.nn.functional.grid_sample(data, self.warp(inv_depth_ref))

0 comments on commit cd03569

Please sign in to comment.