diff --git a/docs/source/index.rst b/docs/source/index.rst index a6f4311562..ba0ca7b458 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ TGM focuses on Image and tensor warping functions such as: :caption: Package Reference geometric + transformations pinhole conversions warping diff --git a/docs/source/transformations.rst b/docs/source/transformations.rst new file mode 100644 index 0000000000..1a48949452 --- /dev/null +++ b/docs/source/transformations.rst @@ -0,0 +1,6 @@ +Linear Transformations +---------------------- + +.. currentmodule:: torchgeometry + +.. autofunction:: relative_pose diff --git a/test/test_transformations.py b/test/test_transformations.py new file mode 100644 index 0000000000..2941a00eac --- /dev/null +++ b/test/test_transformations.py @@ -0,0 +1,90 @@ +import pytest + +import torch +import torchgeometry as tgm +from torch.autograd import gradcheck + +import utils # test utilities +from common import TEST_DEVICES + + +class TestTransformPose: + + def _generate_identity_matrix(self, batch_size, device_type): + eye = torch.eye(4).repeat(batch_size, 1, 1) # Nx4x4 + return eye.to(torch.device(device_type)) + + def _test_identity(self): + pose_1 = self.pose_1.clone() + pose_2 = self.pose_2.clone() + pose_21 = tgm.relative_pose(pose_1, pose_2) + assert utils.check_equal_torch(pose_21, torch.eye(4).unsqueeze(0)) + + def _test_translation(self): + offset = 10. + pose_1 = self.pose_1.clone() + pose_2 = self.pose_2.clone() + pose_2[..., :3, -1:] += offset # add translation + + # compute relative pose + pose_21 = tgm.relative_pose(pose_1, pose_2) + assert utils.check_equal_torch(pose_21[..., :3, -1:], offset) + + def _test_rotation(self): + pose_1 = self.pose_1.clone() + pose_2 = torch.zeros_like(pose_1) # Rz (90deg) + pose_2[..., 0, 1] = -1.0 + pose_2[..., 1, 0] = 1.0 + pose_2[..., 2, 2] = 1.0 + pose_2[..., 3, 3] = 1.0 + + # compute relative pose + pose_21 = tgm.relative_pose(pose_1, pose_2) + assert utils.check_equal_torch(pose_21, pose_2) + + def _test_integration(self): + pose_1 = self.pose_1.clone() + pose_2 = self.pose_2.clone() + + # apply random rotations and translations + batch_size, device = pose_2.shape[0], pose_2.device + pose_2[..., :3, :3] = torch.rand(batch_size, 3, 3, device=device) + pose_2[..., :3, -1:] = torch.rand(batch_size, 3, 1, device=device) + + pose_21 = tgm.relative_pose(pose_1, pose_2) + assert utils.check_equal_torch( + torch.matmul(pose_21, pose_1), pose_2) + + @pytest.mark.skip("Converting a tensor to a Python boolean ...") + def test_jit(self): + pose_1 = self.pose_1.clone() + pose_2 = self.pose_2.clone() + + pose_21 = tgm.relative_pose(pose_1, pose_2) + pose_21_jit = torch.jit.trace( + tgm.relative_pose, (pose_1, pose_2,))(pose_1, pose_2) + assert utils.check_equal_torch(pose_21, pose_21_jit) + + def _test_gradcheck(self): + pose_1 = self.pose_1.clone() + pose_2 = self.pose_2.clone() + + pose_1 = utils.tensor_to_gradcheck_var(pose_1) # to var + pose_2 = utils.tensor_to_gradcheck_var(pose_2) # to var + assert gradcheck(tgm.relative_pose, (pose_1, pose_2,), + raise_exception=True) + + @pytest.mark.parametrize("device_type", TEST_DEVICES) + @pytest.mark.parametrize("batch_size", [1, 2, 5]) + def test_run_all(self, batch_size, device_type): + # generate identity matrices + self.pose_1 = self._generate_identity_matrix( + batch_size, device_type) + self.pose_2 = self.pose_1.clone() + + # run tests + self._test_identity() + self._test_translation() + self._test_rotation() + self._test_integration() + self._test_gradcheck() diff --git a/torchgeometry/__init__.py b/torchgeometry/__init__.py index 9126467b8c..645a9da739 100644 --- a/torchgeometry/__init__.py +++ b/torchgeometry/__init__.py @@ -4,6 +4,7 @@ from .conversions import * from .utils import * from .imgwarp import * +from .transformations import * from torchgeometry import image from torchgeometry import losses diff --git a/torchgeometry/transformations.py b/torchgeometry/transformations.py new file mode 100644 index 0000000000..6c46d9e454 --- /dev/null +++ b/torchgeometry/transformations.py @@ -0,0 +1,71 @@ +from typing import Optional + +import torch + + +__all__ = [ + "relative_pose", +] + + +def relative_pose(pose_1: torch.Tensor, pose_2: torch.Tensor, + eps: Optional[float] = 1e-6) -> torch.Tensor: + r"""Function that computes the relative transformation from a reference + pose :math:`P_1^{\{W\}} = \begin{bmatrix} R_1 & t_1 \\ \mathbf{0} & 1 + \end{bmatrix}` to destination :math:`P_2^{\{W\}} = \begin{bmatrix} R_2 & + t_2 \\ \mathbf{0} & 1 \end{bmatrix}`. + + The relative transformation is computed as follows: + + .. math:: + + P_1^{2} = \begin{bmatrix} R_2 R_1^T & R_1^T (t_2 - t_1) \\ \mathbf{0} & + 1\end{bmatrix} + + Arguments: + pose_1 (torch.Tensor): reference pose tensor of shape + :math:`(N, 4, 4)`. + pose_2 (torch.Tensor): destination pose tensor of shape + :math:`(N, 4, 4)`. + + Shape: + - Output: :math:`(N, 4, 4)` + + Returns: + torch.Tensor: the relative transformation between the poses. + + Example:: + + >>> pose_1 = torch.eye(4).unsqueeze(0) # 1x4x4 + >>> pose_2 = torch.eye(4).unsqueeze(0) # 1x4x4 + >>> pose_21 = tgm.relative_pose(pose_1, pose_2) # 1x4x4 + """ + if not torch.is_tensor(pose_1): + raise TypeError("Input pose_1 type is not a torch.Tensor. Got {}" + .format(type(pose_1))) + if not torch.is_tensor(pose_2): + raise TypeError("Input pose_2 type is not a torch.Tensor. Got {}" + .format(type(pose_2))) + if not (len(pose_1.shape) == 3 and pose_1.shape[-2:] == (4, 4)): + raise ValueError("Input must be a of the shape Nx4x4." + " Got {}".format(pose_1.shape, pose_2.shape)) + if not pose_1.shape == pose_2.shape: + raise ValueError("Input pose_1 and pose_2 must be a of the same shape." + " Got {}".format(pose_1.shape, pose_2.shape)) + # unpack input data + r_mat_1 = pose_1[..., :3, :3] # Nx3x3 + r_mat_2 = pose_2[..., :3, :3] # Nx3x3 + t_vec_1 = pose_1[..., :3, -1:] # Nx3x1 + t_vec_2 = pose_2[..., :3, -1:] # Nx3x1 + + # compute relative pose + r_mat_1_trans = r_mat_1.transpose(1, 2) + r_mat_21 = torch.matmul(r_mat_2, r_mat_1_trans) + t_vec_21 = torch.matmul(r_mat_1_trans, t_vec_2 - t_vec_1) + + # pack output data + pose_21 = torch.zeros_like(pose_1) + pose_21[..., :3, :3] = r_mat_21 + pose_21[..., :3, -1:] = t_vec_21 + pose_21[..., -1, -1] += 1.0 + return pose_21 + eps