-
-
Notifications
You must be signed in to change notification settings - Fork 949
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import unittest | ||
|
||
import torch | ||
import torchgeometry as dgm | ||
from torch.autograd import gradcheck | ||
|
||
import utils # test utils | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
|
||
def test_depth_warper(self): | ||
# generate input data | ||
batch_size = 1 | ||
height, width = 8, 8 | ||
cx, cy = width / 2, height / 2 | ||
fx, fy = 1., 1. | ||
rx, ry, rz = 0., 0., 0. | ||
tx, ty, tz = 0., 0., 0. | ||
offset_x = 0. # we will apply a 10units offset to `i` camera | ||
|
||
pinhole_src = utils.create_pinhole(fx, fy, cx, cy, \ | ||
height, width, rx, ry, rx, tx, ty, tz) | ||
pinhole_src = pinhole_src.expand(batch_size, -1) | ||
|
||
pinhole_dst = utils.create_pinhole(fx, fy, cx, cy, \ | ||
height, width, rx, ry, rx, tx + offset_x, ty, tz) | ||
pinhole_dst = pinhole_dst.expand(batch_size, -1) | ||
|
||
# create checkerboard | ||
board = utils.create_checkerboard(height, width, 4) | ||
patch_src = torch.from_numpy(board).view( | ||
1, 1, height, width).expand(batch_size, 1, height, width) | ||
|
||
# instantiate warper | ||
warper = dgm.DepthWarper(pinhole_src, height, width) | ||
warper.compute_homographies(pinhole_dst, scale=torch.ones(batch_size, 1)) | ||
|
||
# generate synthetic inverse depth | ||
inv_depth_src = torch.ones(batch_size, 1, height, width) | ||
|
||
import pdb;pdb.set_trace() | ||
patch_dst = warper(inv_depth_src, patch_src) | ||
pass | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
|
||
from .functional import scale_pinhole, homography_i_H_ref | ||
|
||
|
||
class DepthWarper(nn.Module): | ||
def __init__(self, pinholes, width=None, height=None): | ||
super(DepthWarper, self).__init__() | ||
self.width = width | ||
self.height = height | ||
self._pinholes = pinholes | ||
self._i_Hs_ref = None # to be filled later | ||
self._pinhole_ref = None # to be filled later | ||
|
||
def compute_homographies(self, pinhole, scale): | ||
pinhole_ref = scale_pinhole(pinhole, scale) | ||
if self.width is None: | ||
self.width = pinhole_ref[..., 4] | ||
if self.height is None: | ||
self.height = pinhole_ref[..., 5] | ||
self._pinhole_ref = pinhole_ref | ||
# scale pinholes_i and compute homographies | ||
pinhole_i = scale_pinhole(self._pinholes, scale) | ||
self._i_Hs_ref = homography_i_H_ref(pinhole_i, pinhole_ref) | ||
|
||
def _compute_projection(self, x, y, invd): | ||
point = torch.FloatTensor([[x], [y], [1.0], [invd]])).to(x.device) | ||
flow = torch.matmul(self._i_Hs_ref, point) | ||
z = 1. / flow[:, :, 2] | ||
x = (flow[:, :, 0] * z) | ||
y = (flow[:, :, 1] * z) | ||
return torch.stack([x, y], 1) | ||
|
||
def compute_subpixel_step(self): | ||
"""This computes the required inverse depth step to achieve sub pixel | ||
accurate sampling of the depth cost volume, per camera. | ||
Szeliski, Richard, and Daniel Scharstein. "Symmetric sub-pixel stereo matching." European Conference on Computer Vision. Springer Berlin Heidelberg, 2002. | ||
""" | ||
delta_d = 0.01 | ||
xy_m1 = self._compute_projection(self.width / 2, self.height / 2, | ||
1.0 - delta_d) | ||
xy_p1 = self._compute_projection(self.width / 2, self.height / 2, | ||
1.0 + delta_d) | ||
dx = torch.norm((xy_p1 - xy_m1), 2, dim=2) / 2.0 | ||
dxdd = dx / (delta_d) # pixel*(1/meter) | ||
return torch.min( | ||
0.5 / | ||
dxdd) # half pixel sampling, we're interested in the min for all cameras | ||
|
||
# compute grids | ||
|
||
def warp(self, inv_depth_ref, roi=None): | ||
assert self._i_Hs_ref is not None, 'call compute_homographies' | ||
if roi == None: | ||
roi = (0, self.height, 0, self.width) | ||
start_row, end_row, start_col, end_col = roi | ||
assert start_row < end_row | ||
assert start_col < end_col | ||
height, width = end_row - start_row, end_col - start_col | ||
area = width * height | ||
|
||
# take sub region | ||
#inv_depth_ref = inv_depth_ref.squeeze(0)[start_row:end_row, start_col: | ||
# end_col].contiguous() | ||
import pdb;pdb.set_trace() | ||
inv_depth_ref = inv_depth_ref[..., start_row:end_row, \ | ||
start_col:end_col].contiguous() | ||
|
||
device = inv_depth_ref.device | ||
ones_x = torch.ones(height).to(device) | ||
ones_y = torch.ones(width).to(device) | ||
ones = torch.ones(area).to(device) | ||
|
||
x = torch.linspace(start_col, end_col - 1, width).to(device) | ||
y = torch.linspace(start_row, end_row - 1, height).to(device) | ||
|
||
xv = torch.ger(ones_x, x).view(area) | ||
yv = torch.ger(y, ones_y).view(area) | ||
|
||
flow = [xv, yv, ones, inv_depth_ref.view(area)] | ||
flow = torch.stack(flow, 0) | ||
|
||
flow = torch.matmul(self._i_Hs_ref, flow) | ||
assert len(flow.shape) == 3, flow.shape | ||
|
||
z = 1. / flow[:, 2] # Nx(H*W) | ||
x = (flow[:, 0] * z - self.width / 2) / (self.width / 2) | ||
y = (flow[:, 1] * z - self.height / 2) / (self.height / 2) | ||
|
||
flow = torch.stack([x, y], 1) # Nx2x(H*W) | ||
|
||
n, c, a = flow.shape | ||
flows = flow.view(n, c, height, width) # Nx2xHxW | ||
return flows.permute(0, 2, 3, 1) # NxHxWx2 | ||
|
||
def forward(self, inv_depth_ref, data): | ||
# be aware that grid_sample only supports float or double type | ||
return torch.nn.functional.grid_sample(data, self.warp(inv_depth_ref)) |