Skip to content

Commit

Permalink
Fix dtype propagation (#1141)
Browse files Browse the repository at this point in the history
Summary:
Previously, dtypes were not propagated correctly in composed transforms, resulting in errors when different dtypes were mixed. Even specifying a dtype in the constructor does not fix this. Neither does specifying the dtype for each composition function invocation (e.g. as a `kwarg` in `rotate_axis_angle`).

With the change, I also had to modify the default dtype of `RotateAxisAngle`, which was `torch.float64`; it is now `torch.float32` like for all other transforms. This was required because the fix in propagation broke some tests due to dtype mismatches.

This change in default dtype in turn broke two tests due to precision changes (calculations that were previously done in `torch.float64` were now done in `torch.float32`), so I changed the precision tolerances to be less strict. I chose the lowest power of ten that passed the tests here.

Pull Request resolved: #1141

Reviewed By: patricklabatut

Differential Revision: D35192970

Pulled By: bottler

fbshipit-source-id: ba0293e8b3595dfc94b3cf8048e50b7a5e5ed7cf
  • Loading branch information
janEbert authored and facebook-github-bot committed Mar 29, 2022
1 parent 21262e3 commit b602edc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
32 changes: 20 additions & 12 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,24 @@ def transform_normals(self, normals) -> torch.Tensor:
return normals_out

def translate(self, *args, **kwargs) -> "Transform3d":
return self.compose(Translate(device=self.device, *args, **kwargs))
return self.compose(
Translate(device=self.device, dtype=self.dtype, *args, **kwargs)
)

def scale(self, *args, **kwargs) -> "Transform3d":
return self.compose(Scale(device=self.device, *args, **kwargs))
return self.compose(
Scale(device=self.device, dtype=self.dtype, *args, **kwargs)
)

def rotate(self, *args, **kwargs) -> "Transform3d":
return self.compose(Rotate(device=self.device, *args, **kwargs))
return self.compose(
Rotate(device=self.device, dtype=self.dtype, *args, **kwargs)
)

def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
return self.compose(
RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs)
)

def clone(self) -> "Transform3d":
"""
Expand Down Expand Up @@ -488,7 +496,7 @@ def __init__(
- A 1D torch tensor
"""
xyz = _handle_input(x, y, z, dtype, device, "Translate")
super().__init__(device=xyz.device)
super().__init__(device=xyz.device, dtype=dtype)
N = xyz.shape[0]

mat = torch.eye(4, dtype=dtype, device=self.device)
Expand Down Expand Up @@ -532,7 +540,7 @@ def __init__(
- 1D torch tensor
"""
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
super().__init__(device=xyz.device)
super().__init__(device=xyz.device, dtype=dtype)
N = xyz.shape[0]

# TODO: Can we do this all in one go somehow?
Expand Down Expand Up @@ -571,7 +579,7 @@ def __init__(
"""
device_ = get_device(R, device)
super().__init__(device=device_)
super().__init__(device=device_, dtype=dtype)
if R.dim() == 2:
R = R[None]
if R.shape[-2:] != (3, 3):
Expand All @@ -598,7 +606,7 @@ def __init__(
angle,
axis: str = "X",
degrees: bool = True,
dtype: torch.dtype = torch.float64,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
) -> None:
"""
Expand Down Expand Up @@ -629,7 +637,7 @@ def __init__(
# is for transforming column vectors. Therefore we transpose this matrix.
# R will always be of shape (N, 3, 3)
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
super().__init__(device=angle.device, R=R)
super().__init__(device=angle.device, R=R, dtype=dtype)


def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
Expand All @@ -646,8 +654,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
c = torch.tensor(c, dtype=dtype, device=device)
if c.dim() == 0:
c = c.view(1)
if c.device != device:
c = c.to(device=device)
if c.device != device or c.dtype != dtype:
c = c.to(device=device, dtype=dtype)
return c


Expand Down Expand Up @@ -696,7 +704,7 @@ def _handle_input(
if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg)
return x.to(device=device_)
return x.to(device=device_, dtype=dtype)

if allow_singleton and y is None and z is None:
y = x
Expand Down
36 changes: 33 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,36 @@ def test_to(self):
t = t.cuda()
t = t.cpu()

def test_dtype_propagation(self):
"""
Check that a given dtype is correctly passed along to child
transformations.
"""
# Use at least two dtypes so we avoid only testing on the
# default dtype.
for dtype in [torch.float32, torch.float64]:
R = torch.tensor(
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]],
dtype=dtype,
)
tf = (
Transform3d(dtype=dtype)
.rotate(R)
.rotate_axis_angle(
R[0],
"X",
)
.translate(3, 2, 1)
.scale(0.5)
)

self.assertEqual(tf.dtype, dtype)
for inner_tf in tf._transforms:
self.assertEqual(inner_tf.dtype, dtype)

transformed = tf.transform_points(R)
self.assertEqual(transformed.dtype, dtype)

def test_clone(self):
"""
Check that cloned transformations contain different _matrix objects.
Expand Down Expand Up @@ -219,8 +249,8 @@ def test_rotate_axis_angle(self):
normals_out_expected = torch.tensor(
[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
).view(1, 3, 3)
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
self.assertTrue(torch.allclose(points_out, points_out_expected, atol=1e-7))
self.assertTrue(torch.allclose(normals_out, normals_out_expected, atol=1e-7))

def test_transform_points_fail(self):
t1 = Scale(0.1, 0.1, 0.1)
Expand Down Expand Up @@ -951,7 +981,7 @@ def test_rotate_x_python_scalar(self):
self.assertTrue(
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
)
self.assertTrue(torch.allclose(t._matrix, matrix))
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))

def test_rotate_x_torch_scalar(self):
angle = torch.tensor(90.0)
Expand Down

0 comments on commit b602edc

Please sign in to comment.