Skip to content

Commit

Permalink
implement inverse_pose
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Aug 28, 2018
1 parent 0d84f55 commit 001097c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
46 changes: 36 additions & 10 deletions test/test_functional.py
Expand Up @@ -7,16 +7,17 @@

# test utilies

def create_eye_batch(batch_size):
def create_eye_batch(batch_size, eye_size):
"""Creates a batch of identity matrices of shape Bx3x3
"""
return torch.eye(3).view(1, 3, 3).expand(batch_size, -1, -1)
return torch.eye(eye_size).view(
1, eye_size, eye_size).expand(batch_size, -1, -1)

def create_random_homography(batch_size, std_val=1e-1):
def create_random_homography(batch_size, eye_size, std_val=1e-1):
"""Creates a batch of random homographies of shape Bx3x3
"""
std = std_val * torch.rand(batch_size, 3, 3)
eye = create_eye_batch(batch_size)
std = std_val * torch.rand(batch_size, eye_size, eye_size)
eye = create_eye_batch(batch_size, eye_size)
return eye + std

def tensor_to_gradcheck_var(tensor):
Expand Down Expand Up @@ -73,19 +74,21 @@ def test_convert_points_from_homogeneous_gradcheck(self):
def test_inverse(self):
# generate input data
batch_size = 2
homographies = create_random_homography(batch_size)
eye_size = 3 # identity 3x3
homographies = create_random_homography(batch_size, eye_size)
homographies_inv = dgm.inverse(homographies)

# H_inv * H == I
res = torch.matmul(homographies_inv, homographies)
eye = create_eye_batch(batch_size)
eye = create_eye_batch(batch_size, eye_size)
error = torch.sum((res - eye) ** 2)
self.assertAlmostEqual(error.item(), 0.0)

def test_inverse_gradcheck(self):
# generate input data
batch_size = 2
homographies = create_random_homography(batch_size)
eye_size = 3 # identity 3x3
homographies = create_random_homography(batch_size, eye_size)
homographies = tensor_to_gradcheck_var(homographies) # to var

# evaluate function gradient
Expand All @@ -96,8 +99,9 @@ def test_transform_points(self):
batch_size = 2
num_points = 2
num_dims = 2
eye_size = 3 # identity 3x3
points_src = torch.rand(batch_size, 2, num_dims)
dst_homo_src = create_random_homography(batch_size)
dst_homo_src = create_random_homography(batch_size, eye_size)

# transform the points from dst to ref
points_dst = dgm.transform_points(dst_homo_src, points_src)
Expand All @@ -115,9 +119,10 @@ def test_transform_points_gradcheck(self):
batch_size = 2
num_points = 2
num_dims = 2
eye_size = 3 # identity 3x3
points_src = torch.rand(batch_size, 2, num_dims)
points_src = tensor_to_gradcheck_var(points_src) # to var
dst_homo_src = create_random_homography(batch_size)
dst_homo_src = create_random_homography(batch_size, eye_size)
dst_homo_src = tensor_to_gradcheck_var(dst_homo_src) # to var

# evaluate function gradient
Expand Down Expand Up @@ -167,6 +172,27 @@ def test_deg2rad_gradcheck(self):
res = gradcheck(dgm.deg2rad, (tensor_to_gradcheck_var(x_deg),),
raise_exception=True)

@unittest.skip("Need to verify output")
def test_inverse_pose(self):
# generate input data
batch_size = 2
eye_size = 4 # identity 4x4
dst_pose_src = create_random_homography(batch_size, eye_size)

# compute the inverse of the pose
src_pose_dst = dgm.inverse_pose(dst_pose_src)
# TODO: add assert with proper check

def test_inverse_pose_gradcheck(self):
# generate input data
batch_size = 2
eye_size = 4 # identity 4x4
dst_pose_src = create_random_homography(batch_size, eye_size)
dst_pose_src = tensor_to_gradcheck_var(dst_pose_src) # to var

# evaluate function gradient
res = gradcheck(dgm.inverse_pose, (dst_pose_src,),
raise_exception=True)

if __name__ == '__main__':
unittest.main()
39 changes: 36 additions & 3 deletions torchgeometry/functional.py
@@ -1,7 +1,8 @@
import torch

__all__ = ["pi", "rad2deg", "deg2rad", "convert_points_from_homogeneous",
"convert_points_to_homogeneous", "transform_points", "inverse"]
"convert_points_to_homogeneous", "transform_points", "inverse",
"inverse_pose"]


"""Constant with number pi
Expand All @@ -15,7 +16,8 @@ def rad2deg(x):
Args:
x (Tensor): tensor of unspecified size.
Returns: tensor with same size as input.
Returns:
Tensor: tensor with same size as input.
"""
if not torch.is_tensor(x):
raise TypeError("Input type is not a torch.Tensor. Got {}"
Expand All @@ -30,7 +32,8 @@ def deg2rad(x):
Args:
x (Tensor): tensor of unspecified size.
Returns: tensor with same size as input.
Returns:
Tensor: tensor with same size as input.
"""
if not torch.is_tensor(x):
raise TypeError("Input type is not a torch.Tensor. Got {}"
Expand Down Expand Up @@ -113,3 +116,33 @@ def inverse(homography):
.format(points.shape))
# iterate, compute inverse and stack tensors
return torch.stack([torch.inverse(homo) for homo in homography])


def inverse_pose(pose):
"""Inverts a 4x4 pose.
Args:
points (Tensor): tensor of either size (4, 4) or (B, 4, 4).
Returns:
Tensor: tensor with same size as input.
"""
if not torch.is_tensor(pose):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(pose)))
if not pose.shape[-2:] == (4, 4):
raise ValueError("Input size must be a 4x4 tensor. Got {}"
.format(pose.shape))
pose_shape = pose.shape
if len(pose_shape) == 2:
pose = torch.unsqueeze(pose, dim=0)

pose_inv = pose.clone()
pose_inv[..., :3, 0:3] = torch.transpose(pose[..., :3, :3], 1, 2)
pose_inv[..., :3, 2:3] = torch.matmul(
-1.0 * pose_inv[..., :3, :3], pose[..., :3, 2:3])

if len(pose_shape) == 2:
pose_inv = torch.squeeze(pose_inv, dim=0)

return pose_inv

0 comments on commit 001097c

Please sign in to comment.