Skip to content

Commit

Permalink
fix depth warper and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Sep 26, 2018
1 parent a5a9cc8 commit 4be19b5
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
15 changes: 11 additions & 4 deletions test/test_depth_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,37 @@ def test_depth_warper(self):
fx, fy = 1., 1.
rx, ry, rz = 0., 0., 0.
tx, ty, tz = 0., 0., 0.
offset_x = 0. # we will apply a 10units offset to `i` camera
offset = 1. # we will apply a 1unit offset to `i` camera

pinhole_src = utils.create_pinhole(fx, fy, cx, cy, \
height, width, rx, ry, rx, tx, ty, tz)
pinhole_src = pinhole_src.expand(batch_size, -1)

pinhole_dst = utils.create_pinhole(fx, fy, cx, cy, \
height, width, rx, ry, rx, tx + offset_x, ty, tz)
height, width, rx, ry, rx, tx + offset, ty + offset, tz)
pinhole_dst = pinhole_dst.expand(batch_size, -1)

# create checkerboard
board = utils.create_checkerboard(height, width, 4)
patch_src = torch.from_numpy(board).view(
1, 1, height, width).expand(batch_size, 1, height, width)

# instantiate warper
# instantiate warper and compute relative homographies
warper = dgm.DepthWarper(pinhole_src, height, width)
warper.compute_homographies(pinhole_dst, scale=torch.ones(batch_size, 1))

# generate synthetic inverse depth
inv_depth_src = torch.ones(batch_size, 1, height, width)

import pdb;pdb.set_trace()
# warpd source patch by depth
patch_dst = warper(inv_depth_src, patch_src)

# compute error
res = utils.check_equal_torch( \
patch_src[..., :-int(offset), :-int(offset)], \
patch_dst[..., int(offset):, int(offset):])
self.assertTrue(res)

pass


Expand Down
2 changes: 1 addition & 1 deletion torchgeometry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .functional import *
from .homography_warper import HomographyWarper
#from .depth_warper import DepthWarper
from .depth_warper import DepthWarper

from torchgeometry import utils
from torchgeometry import transforms
Expand Down
27 changes: 17 additions & 10 deletions torchgeometry/depth_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@


class DepthWarper(nn.Module):
"""Warps a patch by inverse depth.
"""
def __init__(self, pinholes, width=None, height=None):
super(DepthWarper, self).__init__()
# TODO: add type and value checkings
self.width = width
self.height = height
self._pinholes = pinholes
self._i_Hs_ref = None # to be filled later
self._pinhole_ref = None # to be filled later

def compute_homographies(self, pinhole, scale):
# TODO: add type and value checkings
pinhole_ref = scale_pinhole(pinhole, scale)
if self.width is None:
self.width = pinhole_ref[..., 4]
Expand All @@ -26,7 +30,7 @@ def compute_homographies(self, pinhole, scale):
self._i_Hs_ref = homography_i_H_ref(pinhole_i, pinhole_ref)

def _compute_projection(self, x, y, invd):
point = torch.FloatTensor([[x], [y], [1.0], [invd]])).to(x.device)
point = torch.FloatTensor([[x], [y], [1.0], [invd]]).to(x.device)
flow = torch.matmul(self._i_Hs_ref, point)
z = 1. / flow[:, :, 2]
x = (flow[:, :, 0] * z)
Expand All @@ -52,6 +56,7 @@ def compute_subpixel_step(self):
# compute grids

def warp(self, inv_depth_ref, roi=None):
# TODO: add type and value checkings
assert self._i_Hs_ref is not None, 'call compute_homographies'
if roi == None:
roi = (0, self.height, 0, self.width)
Expand All @@ -62,9 +67,6 @@ def warp(self, inv_depth_ref, roi=None):
area = width * height

# take sub region
#inv_depth_ref = inv_depth_ref.squeeze(0)[start_row:end_row, start_col:
# end_col].contiguous()
import pdb;pdb.set_trace()
inv_depth_ref = inv_depth_ref[..., start_row:end_row, \
start_col:end_col].contiguous()

Expand All @@ -79,15 +81,20 @@ def warp(self, inv_depth_ref, roi=None):
xv = torch.ger(ones_x, x).view(area)
yv = torch.ger(y, ones_y).view(area)

flow = [xv, yv, ones, inv_depth_ref.view(area)]
flow = torch.stack(flow, 0)
grid = [xv, yv, ones, inv_depth_ref.view(area)]
grid = torch.stack(grid, 0)
batch_size = inv_depth_ref.shape[0]
grid = grid.unsqueeze(0).expand(batch_size, -1, -1)

flow = torch.matmul(self._i_Hs_ref, flow)
flow = torch.matmul(self._i_Hs_ref, grid)
assert len(flow.shape) == 3, flow.shape

factor_x = (self.width - 1) / 2
factor_y = (self.height - 1) / 2

z = 1. / flow[:, 2] # Nx(H*W)
x = (flow[:, 0] * z - self.width / 2) / (self.width / 2)
y = (flow[:, 1] * z - self.height / 2) / (self.height / 2)
x = (flow[:, 0] * z - factor_x) / factor_x
y = (flow[:, 1] * z - factor_y) / factor_y

flow = torch.stack([x, y], 1) # Nx2x(H*W)

Expand All @@ -96,5 +103,5 @@ def warp(self, inv_depth_ref, roi=None):
return flows.permute(0, 2, 3, 1) # NxHxWx2

def forward(self, inv_depth_ref, data):
# be aware that grid_sample only supports float or double type
# TODO: add type and value checkings
return torch.nn.functional.grid_sample(data, self.warp(inv_depth_ref))
4 changes: 2 additions & 2 deletions torchgeometry/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def inv_pinhole_matrix(pinhole, eps=1e-6):
def scale_pinhole(pinhole, scale):
"""Scales the pinhole matrix from a pinhole model.
"""
assert len(pinhole) == 2 and pinhole.shape[1] == 12, pinhole.shape
assert len(scale) == 2 and scale.shape[1] == 1, scale.shape
assert len(pinhole.shape) == 2 and pinhole.shape[1] == 12, pinhole.shape
assert len(scale.shape) == 2 and scale.shape[1] == 1, scale.shape
pinhole_scaled = pinhole.clone()
pinhole_scaled[..., :6] = pinhole[..., :6] * scale
return pinhole_scaled
Expand Down

0 comments on commit 4be19b5

Please sign in to comment.