Skip to content

Commit

Permalink
avoid symeig
Browse files Browse the repository at this point in the history
Summary: Use the newer eigh to avoid deprecation warnings in newer pytorch.

Reviewed By: patricklabatut

Differential Revision: D34375784

fbshipit-source-id: 40efe0d33fdfa071fba80fc97ed008cbfd2ef249
  • Loading branch information
bottler authored and facebook-github-bot committed Feb 21, 2022
1 parent 59972b1 commit db1f7c4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
9 changes: 9 additions & 0 deletions pytorch3d/common/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
# PyTorch version >= 1.9
return torch.linalg.qr(A)
return torch.qr(A)


def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
return torch.linalg.eigh(A)
return torch.symeig(A, eigenvalues=True)
3 changes: 2 additions & 1 deletion pytorch3d/ops/perspective_n_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import torch.nn.functional as F
from pytorch3d.common.compat import eigh
from pytorch3d.ops import points_alignment, utils as oputil


Expand Down Expand Up @@ -105,7 +106,7 @@ def _null_space(m, kernel_dim):
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = torch.symeig(mTm, eigenvectors=True)
s, v = eigh(mTm)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]


Expand Down
11 changes: 6 additions & 5 deletions pytorch3d/ops/points_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing import TYPE_CHECKING, Tuple, Union

import torch
from pytorch3d.common.compat import eigh
from pytorch3d.common.workaround import symeig3x3

from ..common.workaround import symeig3x3
from .utils import convert_pointclouds_to_tensor, get_point_covariances


Expand Down Expand Up @@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames(

# get the local coord frames as principal directions of
# the per-point covariance
# this is done with torch.symeig, which returns the
# this is done with torch.symeig / torch.linalg.eigh, which returns the
# eigenvectors (=principal directions) in an ascending order of their
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
# corresponds to the normal direction
# corresponding eigenvalues, and the smallest eigenvalue's eigenvector
# corresponds to the normal direction; or with a custom equivalent.
if use_symeig_workaround:
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
else:
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
curvatures, local_coord_frames = eigh(cov)

# disambiguate the directions of individual principal vectors
if disambiguate_directions:
Expand Down

0 comments on commit db1f7c4

Please sign in to comment.