Skip to content

Commit

Permalink
add transform_points; inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Aug 22, 2018
1 parent ff88062 commit 1a1511c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 5 deletions.
52 changes: 50 additions & 2 deletions test/test_functional.py
Expand Up @@ -3,23 +3,71 @@
import torch
import torchgeometry as dgm

# test utilies

def create_eye_batch(batch_size):
return torch.eye(3).view(1, 3, 3).expand(batch_size, -1, -1)

def create_random_homography(batch_size, std_val=1e-1):
std = std_val * torch.rand(batch_size, 3, 3)
eye = create_eye_batch(batch_size)
return eye + std


class Tester(unittest.TestCase):

def test_convert_points_to_homogeneous(self):
points = torch.rand(1, 2, 3)
# generate input data
batch_size = 2
points = torch.rand(batch_size, 2, 3)

# to homogeneous
points_h = dgm.convert_points_to_homogeneous(points)
self.assertTrue((points_h[..., -1] == torch.ones(1, 2, 1)).all())

def test_convert_points_from_homogeneous(self):
points_h = torch.rand(1, 2, 3)
# generate input data
batch_size = 2
points_h = torch.rand(batch_size, 2, 3)
points_h[..., -1] = 1.0

# to euclidean
points = dgm.convert_points_from_homogeneous(points_h)

error = torch.sum((points_h[..., :2] - points) ** 2)
self.assertAlmostEqual(error, 0.0)

def test_inverse(self):
# generate input data
batch_size = 2
homographies = create_random_homography(batch_size)
homographies_inv = dgm.inverse(homographies)

# H_inv * H == I
res = torch.matmul(homographies_inv, homographies)
eye = create_eye_batch(batch_size)
error = torch.sum((res - eye) ** 2)
self.assertAlmostEqual(error, 0.0)

def test_transform_points(self):
# generate input data
batch_size = 2
num_points = 2
num_dims = 2
points_src = torch.rand(batch_size, 2, num_dims)
dst_homo_src = create_random_homography(batch_size)

# transform the points from dst to ref
points_dst = dgm.transform_points(dst_homo_src, points_src)

# transform the points from ref to dst
src_homo_dst = dgm.inverse(dst_homo_src)
points_dst_to_src = dgm.transform_points(src_homo_dst, points_dst)

# projected should be equal as initial
error = torch.sum((points_src - points_dst_to_src) ** 2)
self.assertAlmostEqual(error, 0.0)


if __name__ == '__main__':
unittest.main()
42 changes: 39 additions & 3 deletions torchgeometry/functional.py
Expand Up @@ -14,7 +14,7 @@ def convert_points_from_homogeneous(points, eps=1e-6):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(points)))

if len(points.shape) != 3:
if not len(points.shape) == 3:
raise ValueError("Input size must be a three dimensional tensor. Got {}"
.format(points.shape))

Expand All @@ -31,11 +31,47 @@ def convert_points_to_homogeneous(points):
Tensor: tensor of N+1-dimensional points of size (B, D, N+1).
"""
if not torch.is_tensor(points):
raise TypeError("Input ype is not a torch.Tensor. Got {}"
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(points)))

if len(points.shape) != 3:
if not len(points.shape) == 3:
raise ValueError("Input size must be a three dimensional tensor. Got {}"
.format(points.shape))

return torch.cat([points, torch.ones_like(points)[..., :1]], dim=-1)


def transform_points(dst_homo_src, points_src):
# TODO: add documentation
"""Applies Transformation to points.
"""
if not torch.is_tensor(dst_homo_src) or not torch.is_tensor(points_src):
raise TypeError("Input type is not a torch.Tensor")
if not dst_homo_src.device == points_src.device:
raise TypeError("Tensor must be in the same device")
if not len(dst_homo_src.shape) == 3 or not len(points_src.shape) == 3:
raise ValueError("Input size must be a three dimensional tensor")
if not dst_homo_src.shape[0] == points_src.shape[0]:
raise ValueError("Input batch size must be the same for both tensors")
if not dst_homo_src.shape[1] == (points_src.shape[1] + 1):
raise ValueError("Input dimensions must differe by one unit")
# to homogeneous
points_src_h = convert_points_to_homogeneous(points_src) # BxNx3
# transform coordinates
points_dst_h = torch.matmul(dst_homo_src, points_src_h.transpose(1, 2)) # Bx3xN
points_dst_h = points_dst_h.permute(0, 2, 1) # BxNx3
# to euclidean
points_dst = convert_points_from_homogeneous(points_dst_h) # BxNx2
return points_dst


def inverse(homography):
# TODO: add documentation
# NOTE: we expect in the future to have a native Pytorch function
"""Batched version of torch.inverse(...)
"""
if not len(homography.shape) == 3:
raise ValueError("Input size must be a three dimensional tensor. Got {}"
.format(points.shape))
# iterate, compute inverse and stack tensors
return torch.stack([torch.inverse(homo) for homo in homography])

0 comments on commit 1a1511c

Please sign in to comment.