Skip to content

Commit

Permalink
Fix returning a proper rotation in levelling; supporting batches and …
Browse files Browse the repository at this point in the history
…default centroid

Summary:
`get_rotation_to_best_fit_xy` is useful to expose externally, however there was a bug (which we probably did not care about for our use case): it could return a rotation matrix with det(R) == −1.
The diff fixes that, and also makes centroid optional (it can be computed from points).

Reviewed By: bottler

Differential Revision: D39926791

fbshipit-source-id: 5120c7892815b829f3ddcc23e93d4a5ec0ca0013
  • Loading branch information
shapovalov authored and facebook-github-bot committed Sep 29, 2022
1 parent de98c9c commit 74bbd6f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
26 changes: 17 additions & 9 deletions pytorch3d/implicitron/tools/circle_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,30 @@
import torch


def _get_rotation_to_best_fit_xy(
points: torch.Tensor, centroid: torch.Tensor
def get_rotation_to_best_fit_xy(
points: torch.Tensor, centroid: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Returns a rotation r such that points @ r has a best fit plane
Returns a rotation R such that `points @ R` has a best fit plane
parallel to the xy plane
Args:
points: (N, 3) tensor of points in 3D
centroid: (3,) their centroid
points: (*, N, 3) tensor of points in 3D
centroid: (*, 1, 3), (3,) or scalar: their centroid
Returns:
(3,3) tensor rotation matrix
(*, 3, 3) tensor rotation matrix
"""
points_centered = points - centroid[None]
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
if centroid is None:
centroid = points.mean(dim=-2, keepdim=True)

points_centered = points - centroid
_, evec = torch.linalg.eigh(points_centered.transpose(-1, -2) @ points_centered)
# in general, evec can form either right- or left-handed basis,
# but we need the former to have a proper rotation (not reflection)
return torch.cat(
(evec[..., 1:], torch.cross(evec[..., 1], evec[..., 2])[..., None]), dim=-1
)


def _signed_area(path: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -191,7 +199,7 @@ def fit_circle_in_3d(
Circle3D object
"""
centroid = points.mean(0)
r = _get_rotation_to_best_fit_xy(points, centroid)
r = get_rotation_to_best_fit_xy(points, centroid)
normal = r[:, 2]
rotated_points = (points - centroid) @ r
result_2d = fit_circle_in_2d(
Expand Down
29 changes: 28 additions & 1 deletion tests/implicitron/test_circle_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
_signed_area,
fit_circle_in_2d,
fit_circle_in_3d,
get_rotation_to_best_fit_xy,
)
from pytorch3d.transforms import random_rotation
from pytorch3d.transforms import random_rotation, random_rotations
from tests.common_testing import TestCaseMixin


Expand All @@ -28,6 +29,32 @@ def _assertParallel(self, a, b, **kwargs):
"""
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)

def test_plane_levelling(self):
device = torch.device("cuda:0")
B = 16
N = 1024
random = torch.randn((B, N, 3), device=device)

# first, check that we always return a vaild rotation
rot = get_rotation_to_best_fit_xy(random)
self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))

# then, check the result is what we expect
z_squeeze = 0.1
random[..., -1] *= z_squeeze
rot_gt = random_rotations(B, device=device)
rotated = random @ rot_gt.transpose(-1, -2)
rot_hat = get_rotation_to_best_fit_xy(rotated)
self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
# covariance matrix of the levelled points is by design diag(1, 1, z_squeeze²)
self.assertClose(
(rotated @ rot_hat)[..., -1].std(dim=-1),
torch.ones_like(rot_hat[:, 0, 0]) * z_squeeze,
rtol=0.1,
)

def test_simple_3d(self):
device = torch.device("cuda:0")
for _ in range(7):
Expand Down

0 comments on commit 74bbd6f

Please sign in to comment.