Skip to content

Commit

Permalink
Merge pull request #55 from arraiyopensource/feat/relative_pose
Browse files Browse the repository at this point in the history
implement relative_pose function
  • Loading branch information
edgarriba committed Jan 23, 2019
2 parents b8f3358 + 792d32f commit 6130892
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Expand Up @@ -21,6 +21,7 @@ TGM focuses on Image and tensor warping functions such as:
:caption: Package Reference

geometric
transformations
pinhole
conversions
warping
Expand Down
6 changes: 6 additions & 0 deletions docs/source/transformations.rst
@@ -0,0 +1,6 @@
Linear Transformations
----------------------

.. currentmodule:: torchgeometry

.. autofunction:: relative_pose
90 changes: 90 additions & 0 deletions test/test_transformations.py
@@ -0,0 +1,90 @@
import pytest

import torch
import torchgeometry as tgm
from torch.autograd import gradcheck

import utils # test utilities
from common import TEST_DEVICES


class TestTransformPose:

def _generate_identity_matrix(self, batch_size, device_type):
eye = torch.eye(4).repeat(batch_size, 1, 1) # Nx4x4
return eye.to(torch.device(device_type))

def _test_identity(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()
pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(pose_21, torch.eye(4).unsqueeze(0))

def _test_translation(self):
offset = 10.
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()
pose_2[..., :3, -1:] += offset # add translation

# compute relative pose
pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(pose_21[..., :3, -1:], offset)

def _test_rotation(self):
pose_1 = self.pose_1.clone()
pose_2 = torch.zeros_like(pose_1) # Rz (90deg)
pose_2[..., 0, 1] = -1.0
pose_2[..., 1, 0] = 1.0
pose_2[..., 2, 2] = 1.0
pose_2[..., 3, 3] = 1.0

# compute relative pose
pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(pose_21, pose_2)

def _test_integration(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()

# apply random rotations and translations
batch_size, device = pose_2.shape[0], pose_2.device
pose_2[..., :3, :3] = torch.rand(batch_size, 3, 3, device=device)
pose_2[..., :3, -1:] = torch.rand(batch_size, 3, 1, device=device)

pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(
torch.matmul(pose_21, pose_1), pose_2)

@pytest.mark.skip("Converting a tensor to a Python boolean ...")
def test_jit(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()

pose_21 = tgm.relative_pose(pose_1, pose_2)
pose_21_jit = torch.jit.trace(
tgm.relative_pose, (pose_1, pose_2,))(pose_1, pose_2)
assert utils.check_equal_torch(pose_21, pose_21_jit)

def _test_gradcheck(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()

pose_1 = utils.tensor_to_gradcheck_var(pose_1) # to var
pose_2 = utils.tensor_to_gradcheck_var(pose_2) # to var
assert gradcheck(tgm.relative_pose, (pose_1, pose_2,),
raise_exception=True)

@pytest.mark.parametrize("device_type", TEST_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_run_all(self, batch_size, device_type):
# generate identity matrices
self.pose_1 = self._generate_identity_matrix(
batch_size, device_type)
self.pose_2 = self.pose_1.clone()

# run tests
self._test_identity()
self._test_translation()
self._test_rotation()
self._test_integration()
self._test_gradcheck()
1 change: 1 addition & 0 deletions torchgeometry/__init__.py
Expand Up @@ -4,6 +4,7 @@
from .conversions import *
from .utils import *
from .imgwarp import *
from .transformations import *

from torchgeometry import image
from torchgeometry import losses
Expand Down
71 changes: 71 additions & 0 deletions torchgeometry/transformations.py
@@ -0,0 +1,71 @@
from typing import Optional

import torch


__all__ = [
"relative_pose",
]


def relative_pose(pose_1: torch.Tensor, pose_2: torch.Tensor,
eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Function that computes the relative transformation from a reference
pose :math:`P_1^{\{W\}} = \begin{bmatrix} R_1 & t_1 \\ \mathbf{0} & 1
\end{bmatrix}` to destination :math:`P_2^{\{W\}} = \begin{bmatrix} R_2 &
t_2 \\ \mathbf{0} & 1 \end{bmatrix}`.
The relative transformation is computed as follows:
.. math::
P_1^{2} = \begin{bmatrix} R_2 R_1^T & R_1^T (t_2 - t_1) \\ \mathbf{0} &
1\end{bmatrix}
Arguments:
pose_1 (torch.Tensor): reference pose tensor of shape
:math:`(N, 4, 4)`.
pose_2 (torch.Tensor): destination pose tensor of shape
:math:`(N, 4, 4)`.
Shape:
- Output: :math:`(N, 4, 4)`
Returns:
torch.Tensor: the relative transformation between the poses.
Example::
>>> pose_1 = torch.eye(4).unsqueeze(0) # 1x4x4
>>> pose_2 = torch.eye(4).unsqueeze(0) # 1x4x4
>>> pose_21 = tgm.relative_pose(pose_1, pose_2) # 1x4x4
"""
if not torch.is_tensor(pose_1):
raise TypeError("Input pose_1 type is not a torch.Tensor. Got {}"
.format(type(pose_1)))
if not torch.is_tensor(pose_2):
raise TypeError("Input pose_2 type is not a torch.Tensor. Got {}"
.format(type(pose_2)))
if not (len(pose_1.shape) == 3 and pose_1.shape[-2:] == (4, 4)):
raise ValueError("Input must be a of the shape Nx4x4."
" Got {}".format(pose_1.shape, pose_2.shape))
if not pose_1.shape == pose_2.shape:
raise ValueError("Input pose_1 and pose_2 must be a of the same shape."
" Got {}".format(pose_1.shape, pose_2.shape))
# unpack input data
r_mat_1 = pose_1[..., :3, :3] # Nx3x3
r_mat_2 = pose_2[..., :3, :3] # Nx3x3
t_vec_1 = pose_1[..., :3, -1:] # Nx3x1
t_vec_2 = pose_2[..., :3, -1:] # Nx3x1

# compute relative pose
r_mat_1_trans = r_mat_1.transpose(1, 2)
r_mat_21 = torch.matmul(r_mat_2, r_mat_1_trans)
t_vec_21 = torch.matmul(r_mat_1_trans, t_vec_2 - t_vec_1)

# pack output data
pose_21 = torch.zeros_like(pose_1)
pose_21[..., :3, :3] = r_mat_21
pose_21[..., :3, -1:] = t_vec_21
pose_21[..., -1, -1] += 1.0
return pose_21 + eps

2 comments on commit 6130892

@versatran01
Copy link

@versatran01 versatran01 commented on 6130892 Feb 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this math is wrong or I misunderstood some of the notation here.

P_1^{2} = \begin{bmatrix} R_2 R_1^T & R_1^T (t_2 - t_1) \ \mathbf{0} &
1\end{bmatrix}

I believe the correct forms should be (assuming when you say pose, it is wrt to some common frame, thus T_1 is actually T_1^W, T_2 is actually T_2^W. Then the relative transformation from 1 to 2 is T_2^1 which is

T_2^1 = T_1 ^ -1 * T_2 = R_1^T * R_2 | R_1^T * (t_2 - t_1)

So the rotation is flipped.

or

T_1^2 = T_2 ^ -1 * T_1 = R_2^T * R_1 | R_2^T * (t_1 - t_2)

@edgarriba
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@versatran01 you are probably right here. I have to revisit the tests and see what's happening and potentially redesign it. I'm working on a refactor for DepthWarper right now (#73) but maybe you could give a hand with this fix. I suggest to open an issue by adding your comment about the formulation. This will also help to keep tracked the things to do.

Please sign in to comment.