Skip to content

Commit

Permalink
Omit _check_valid_rotation_matrix by default
Browse files Browse the repository at this point in the history
Summary:
According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1.

Comparison after applying the change:
```
Profiling/Function    get_world_to_view (ms)   Transform_points(ms)    specular(ms)
before                12.751                    18.577                  21.384
after                 4.432 (34.7%)             9.248 (49.8%)           11.507 (53.8%)
```

Profiling trace:
https://pxl.cl/2h687
More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing

Reviewed By: kjchalup

Differential Revision: D40442503

fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0
  • Loading branch information
Jiali Duan authored and facebook-github-bot committed Oct 20, 2022
1 parent 8339cf2 commit 46cb5aa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
6 changes: 5 additions & 1 deletion pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import os
import warnings
from typing import List, Optional, Union

Expand Down Expand Up @@ -636,7 +637,10 @@ def __init__(
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape))
R = R.to(device=device_, dtype=dtype)
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1":
# Note: aten::all_close in the check is computationally slow, so we
# only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on.
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device_)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
Expand Down
23 changes: 21 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import math
import os
import unittest
from unittest import mock

import torch
from pytorch3d.transforms import random_rotations
Expand Down Expand Up @@ -191,7 +192,25 @@ def test_translate(self):
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))

def test_rotate(self):
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
def test_rotate_check_rot_valid_on(self):
R = so3_exp_map(torch.randn((1, 3)))
t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
1, 3, 3
)
normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3)
points_out = t.transform_points(points)
normals_out = t.transform_normals(normals)
points_out_expected = torch.bmm(points, R)
normals_out_expected = torch.bmm(normals, R)
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))

@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
def test_rotate_check_rot_valid_off(self):
R = so3_exp_map(torch.randn((1, 3)))
t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
Expand Down

0 comments on commit 46cb5aa

Please sign in to comment.