Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor homogeneous transforms module #79

Merged
merged 9 commits into from Feb 28, 2019
2 changes: 0 additions & 2 deletions docs/source/conversions.rst
Expand Up @@ -7,7 +7,6 @@ Conversions
.. autofunction:: deg2rad
.. autofunction:: convert_points_from_homogeneous
.. autofunction:: convert_points_to_homogeneous
.. autofunction:: transform_points
.. autofunction:: angle_axis_to_rotation_matrix
.. autofunction:: rotation_matrix_to_angle_axis
.. autofunction:: rotation_matrix_to_quaternion
Expand All @@ -18,4 +17,3 @@ Conversions
.. autoclass:: DegToRad
.. autoclass:: ConvertPointsFromHomogeneous
.. autoclass:: ConvertPointsToHomogeneous
.. autoclass:: TransformPoints
8 changes: 4 additions & 4 deletions docs/source/transformations.rst
Expand Up @@ -3,7 +3,7 @@ Linear Transformations

.. currentmodule:: torchgeometry

.. autofunction:: relative_pose
.. autofunction:: inverse_pose

.. autoclass:: InversePose
.. autofunction:: relative_transformation
.. autofunction:: inverse_transformation
.. autofunction:: compose_transformations
.. autofunction:: transform_points
243 changes: 141 additions & 102 deletions test/test_transformations.py
Expand Up @@ -8,111 +8,150 @@
from common import TEST_DEVICES


class TestTransformPose:
def identity_matrix(batch_size):
r"""Creates a batched homogeneous identity matrix"""
return torch.eye(4).repeat(batch_size, 1, 1) # Nx4x4


def euler_angles_to_rotation_matrix(x, y, z):
r"""Create a rotation matrix from x, y, z angles"""
assert x.dim() == 1, x.shape
assert x.shape == y.shape == z.shape
ones, zeros = torch.ones_like(x), torch.zeros_like(x)
# the rotation matrix for the x-axis
rx_tmp = [
ones, zeros, zeros, zeros,
zeros, torch.cos(x), -torch.sin(x), zeros,
zeros, torch.sin(x), torch.cos(x), zeros,
zeros, zeros, zeros, ones]
rx = torch.stack(rx_tmp, dim=-1).view(-1, 4, 4)
# the rotation matrix for the y-axis
ry_tmp = [
torch.cos(y), zeros, torch.sin(y), zeros,
zeros, ones, zeros, zeros,
-torch.sin(y), zeros, torch.cos(y), zeros,
zeros, zeros, zeros, ones]
ry = torch.stack(ry_tmp, dim=-1).view(-1, 4, 4)
# the rotation matrix for the z-axis
rz_tmp = [
torch.cos(z), -torch.sin(z), zeros, zeros,
torch.sin(z), torch.cos(z), zeros, zeros,
zeros, zeros, ones, zeros,
zeros, zeros, zeros, ones]
rz = torch.stack(rz_tmp, dim=-1).view(-1, 4, 4)
return torch.matmul(rz, torch.matmul(ry, rx)) # Bx4x4


class TestComposeTransforms:

def test_translation_4x4(self):
offset = 10
trans_01 = identity_matrix(batch_size=1)[0]
trans_12 = identity_matrix(batch_size=1)[0]
trans_12[..., :3, -1] += offset # add offset to translation vector

trans_02 = tgm.compose_transformations(trans_01, trans_12)
assert utils.check_equal_torch(trans_02, trans_12)

def _generate_identity_matrix(self, batch_size, device_type):
eye = torch.eye(4).repeat(batch_size, 1, 1) # Nx4x4
return eye.to(torch.device(device_type))
@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_translation_Bx4x4(self, batch_size):
offset = 10
trans_01 = identity_matrix(batch_size)
trans_12 = identity_matrix(batch_size)
trans_12[..., :3, -1] += offset # add offset to translation vector

def _test_identity(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()
pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(pose_21, torch.eye(4).unsqueeze(0))
trans_02 = tgm.compose_transformations(trans_01, trans_12)
assert utils.check_equal_torch(trans_02, trans_12)

def _test_translation(self):
offset = 10.
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()
pose_2[..., :3, -1:] += offset # add translation

# compute relative pose
pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(pose_21[..., :3, -1:], offset)

def _test_rotation(self):
pose_1 = self.pose_1.clone()
pose_2 = torch.zeros_like(pose_1) # Rz (90deg)
pose_2[..., 0, 1] = -1.0
pose_2[..., 1, 0] = 1.0
pose_2[..., 2, 2] = 1.0
pose_2[..., 3, 3] = 1.0

# compute relative pose
pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(pose_21, pose_2)

def _test_integration(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()

# apply random rotations and translations
batch_size, device = pose_2.shape[0], pose_2.device
pose_2[..., :3, :3] = torch.rand(batch_size, 3, 3, device=device)
pose_2[..., :3, -1:] = torch.rand(batch_size, 3, 1, device=device)

pose_21 = tgm.relative_pose(pose_1, pose_2)
assert utils.check_equal_torch(
torch.matmul(pose_21, pose_1), pose_2)

@pytest.mark.skip("Converting a tensor to a Python boolean ...")
def test_jit(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()

pose_21 = tgm.relative_pose(pose_1, pose_2)
pose_21_jit = torch.jit.trace(
tgm.relative_pose, (pose_1, pose_2,))(pose_1, pose_2)
assert utils.check_equal_torch(pose_21, pose_21_jit)

def _test_gradcheck(self):
pose_1 = self.pose_1.clone()
pose_2 = self.pose_2.clone()

pose_1 = utils.tensor_to_gradcheck_var(pose_1) # to var
pose_2 = utils.tensor_to_gradcheck_var(pose_2) # to var
assert gradcheck(tgm.relative_pose, (pose_1, pose_2,),
@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_gradcheck(self, batch_size):
trans_01 = identity_matrix(batch_size)
trans_12 = identity_matrix(batch_size)

trans_01 = utils.tensor_to_gradcheck_var(trans_01) # to var
trans_12 = utils.tensor_to_gradcheck_var(trans_12) # to var
assert gradcheck(tgm.compose_transformations, (trans_01, trans_12,),
raise_exception=True)


class TestInverseTransformation:

def test_translation_4x4(self):
offset = 10
trans_01 = identity_matrix(batch_size=1)[0]
trans_01[..., :3, -1] += offset # add offset to translation vector

trans_10 = tgm.inverse_transformation(trans_01)
trans_01_hat = tgm.inverse_transformation(trans_10)
assert utils.check_equal_torch(trans_01, trans_01_hat)

@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_translation_Bx4x4(self, batch_size):
offset = 10
trans_01 = identity_matrix(batch_size)
trans_01[..., :3, -1] += offset # add offset to translation vector

trans_10 = tgm.inverse_transformation(trans_01)
trans_01_hat = tgm.inverse_transformation(trans_10)
assert utils.check_equal_torch(trans_01, trans_01_hat)

@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_rotation_translation_Bx4x4(self, batch_size):
offset = 10
x, y, z = 0, 0, tgm.pi
ones = torch.ones(batch_size)
rmat_01 = euler_angles_to_rotation_matrix(x * ones, y * ones, z * ones)

trans_01 = identity_matrix(batch_size)
trans_01[..., :3, -1] += offset # add offset to translation vector
trans_01[..., :3, :3] = rmat_01[..., :3, :3]

trans_10 = tgm.inverse_transformation(trans_01)
trans_01_hat = tgm.inverse_transformation(trans_10)
assert utils.check_equal_torch(trans_01, trans_01_hat)

@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_gradcheck(self, batch_size):
trans_01 = identity_matrix(batch_size)
trans_01 = utils.tensor_to_gradcheck_var(trans_01) # to var
assert gradcheck(tgm.inverse_transformation, (trans_01,),
raise_exception=True)

@pytest.mark.parametrize("device_type", TEST_DEVICES)

class TestRelativeTransformation:

def test_translation_4x4(self):
offset = 10.
trans_01 = identity_matrix(batch_size=1)[0]
trans_02 = identity_matrix(batch_size=1)[0]
trans_02[..., :3, -1] += offset # add offset to translation vector

trans_12 = tgm.relative_transformation(trans_01, trans_02)
trans_02_hat = tgm.compose_transformations(trans_01, trans_12)
assert utils.check_equal_torch(trans_02_hat, trans_02)

@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_run_all(self, batch_size, device_type):
# generate identity matrices
self.pose_1 = self._generate_identity_matrix(
batch_size, device_type)
self.pose_2 = self.pose_1.clone()

# run tests
self._test_identity()
self._test_translation()
self._test_rotation()
self._test_integration()
self._test_gradcheck()


# TODO: embedd to a class
@pytest.mark.parametrize("device_type", TEST_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 5, 6])
def test_inverse_pose(batch_size, device_type):
# generate input data
eye_size = 4 # identity 4x4
dst_pose_src = utils.create_random_homography(batch_size, eye_size)
dst_pose_src = dst_pose_src.to(torch.device(device_type))
dst_pose_src[:, -1] = 0.0
dst_pose_src[:, -1, -1] = 1.0

# compute the inverse of the pose
src_pose_dst = tgm.inverse_pose(dst_pose_src)

# H_inv * H == I
eye = torch.matmul(src_pose_dst, dst_pose_src)
assert utils.check_equal_torch(eye, torch.eye(4), eps=1e-3)

# functional
eye = torch.matmul(tgm.InversePose()(dst_pose_src), dst_pose_src)
assert utils.check_equal_torch(eye, torch.eye(4), eps=1e-3)

# evaluate function gradient
dst_pose_src = utils.tensor_to_gradcheck_var(dst_pose_src) # to var
assert gradcheck(tgm.inverse_pose, (dst_pose_src,),
raise_exception=True)
def test_rotation_translation_Bx4x4(self, batch_size):
offset = 10.
x, y, z = 0., 0., tgm.pi
ones = torch.ones(batch_size)
rmat_02 = euler_angles_to_rotation_matrix(x * ones, y * ones, z * ones)

trans_01 = identity_matrix(batch_size)
trans_02 = identity_matrix(batch_size)
trans_02[..., :3, -1] += offset # add offset to translation vector
trans_02[..., :3, :3] = rmat_02[..., :3, :3]

trans_12 = tgm.relative_transformation(trans_01, trans_02)
trans_02_hat = tgm.compose_transformations(trans_01, trans_12)
assert utils.check_equal_torch(trans_02_hat, trans_02)

@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_gradcheck(self, batch_size):
trans_01 = identity_matrix(batch_size)
trans_02 = identity_matrix(batch_size)

trans_01 = utils.tensor_to_gradcheck_var(trans_01) # to var
trans_02 = utils.tensor_to_gradcheck_var(trans_02) # to var
assert gradcheck(tgm.relative_transformation, (trans_01, trans_02,),
raise_exception=True)
77 changes: 2 additions & 75 deletions torchgeometry/core/conversions.py
Expand Up @@ -8,7 +8,6 @@
"deg2rad",
"convert_points_from_homogeneous",
"convert_points_to_homogeneous",
"transform_points",
"angle_axis_to_rotation_matrix",
"rotation_matrix_to_angle_axis",
"rotation_matrix_to_quaternion",
Expand All @@ -19,7 +18,6 @@
"DegToRad",
"ConvertPointsFromHomogeneous",
"ConvertPointsToHomogeneous",
"TransformPoints",
]


Expand Down Expand Up @@ -113,47 +111,6 @@ def convert_points_to_homogeneous(points):
return nn.functional.pad(points, (0, 1), "constant", 1.0)


def transform_points(dst_pose_src, points_src):
r"""Function that applies transformations to a set of points.

See :class:`~torchgeometry.TransformPoints` for details.

Args:
dst_pose_src (Tensor): tensor for transformations.
points_src (Tensor): tensor of points.

Returns:
Tensor: tensor of N-dimensional points.

Shape:
- Input: :math:`(B, D+1, D+1)` and :math:`(B, D, N)`
- Output: :math:`(B, N, D)`

Examples::

>>> input = torch.rand(2, 4, 3) # BxNx3
>>> pose = torch.eye(4).view(1, 4, 4) # Bx4x4
>>> output = tgm.transform_points(pose, input) # BxNx3
"""
if not torch.is_tensor(dst_pose_src) or not torch.is_tensor(points_src):
raise TypeError("Input type is not a torch.Tensor")
if not dst_pose_src.device == points_src.device:
raise TypeError("Tensor must be in the same device")
if not dst_pose_src.shape[0] == points_src.shape[0]:
raise ValueError("Input batch size must be the same for both tensors")
if not dst_pose_src.shape[-1] == (points_src.shape[-1] + 1):
raise ValueError("Last input dimensions must differe by one unit")
# to homogeneous
points_src_h = convert_points_to_homogeneous(points_src) # BxNxD+1
# transform coordinates
points_dst_h = torch.matmul(
dst_pose_src.unsqueeze(1), points_src_h.unsqueeze(-1))
points_dst_h = torch.squeeze(points_dst_h, dim=-1)
# to euclidean
points_dst = convert_points_from_homogeneous(points_dst_h) # BxNxD
return points_dst


def angle_axis_to_rotation_matrix(angle_axis):
"""Convert 3d vector of axis-angle rotation to 4x4 rotation matrix

Expand Down Expand Up @@ -348,8 +305,8 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
mask_c3 = mask_c3.view(-1, 1).type_as(q3)

q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
q *= 0.5
return q

Expand Down Expand Up @@ -520,33 +477,3 @@ def __init__(self):

def forward(self, input):
return convert_points_to_homogeneous(input)


class TransformPoints(nn.Module):
r"""Creates an object to transform a set of points.

Args:
dst_pose_src (Tensor): tensor for transformations of
shape :math:`(B, D+1, D+1)`.

Returns:
Tensor: tensor of N-dimensional points.

Shape:
- Input: :math:`(B, D, N)`
- Output: :math:`(B, N, D)`

Examples::

>>> input = torch.rand(2, 4, 3) # BxNx3
>>> transform = torch.eye(4).view(1, 4, 4) # Bx4x4
>>> transform_op = tgm.TransformPoints(transform)
>>> output = transform_op(input) # BxNx3
"""

def __init__(self, dst_homo_src):
super(TransformPoints, self).__init__()
self.dst_homo_src = dst_homo_src

def forward(self, points_src):
return transform_points(self.dst_homo_src, points_src)