From 04d8bf6a435da136331cdb33be3f5cf85a678e2c Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Fri, 17 Apr 2020 07:42:16 -0700 Subject: [PATCH] Efficient PnP. Summary: Efficient PnP algorithm to fit 2D to 3D correspondences under perspective assumption. Benchmarked both variants of nullspace and pick one; SVD takes 7 times longer in the 100K points case. Reviewed By: davnov134, gkioxari Differential Revision: D20095754 fbshipit-source-id: 2b4519729630e6373820880272f674829eaed073 --- pytorch3d/ops/perspective_n_points.py | 401 ++++++++++++++++++++++++++ tests/bm_perspective_n_points.py | 25 ++ tests/bm_points_alignment.py | 1 - tests/common_testing.py | 76 ++++- tests/test_common_testing.py | 56 ++++ tests/test_perspective_n_points.py | 131 +++++++++ 6 files changed, 679 insertions(+), 11 deletions(-) create mode 100644 pytorch3d/ops/perspective_n_points.py create mode 100644 tests/bm_perspective_n_points.py create mode 100644 tests/test_common_testing.py create mode 100644 tests/test_perspective_n_points.py diff --git a/pytorch3d/ops/perspective_n_points.py b/pytorch3d/ops/perspective_n_points.py new file mode 100644 index 000000000..3de0c826d --- /dev/null +++ b/pytorch3d/ops/perspective_n_points.py @@ -0,0 +1,401 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +""" +This file contains Efficient PnP algorithm for Perspective-n-Points problem. +It finds a camera position (defined by rotation `R` and translation `T`) that +minimises re-projection error between the given 3D points `x` and +the corresponding uncalibrated 2D points `y`. +""" + +import warnings +from typing import NamedTuple, Optional + +import torch +import torch.nn.functional as F +from pytorch3d.ops import points_alignment, utils as oputil + + +class EpnpSolution(NamedTuple): + x_cam: torch.Tensor + R: torch.Tensor + T: torch.Tensor + err_2d: torch.Tensor + err_3d: torch.Tensor + + +def _define_control_points(x, weight, storage_opts=None): + """ + Returns control points that define barycentric coordinates + Args: + x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + storage_opts: dict of keyword arguments to the tensor constructor. + """ + storage_opts = storage_opts or {} + x_mean = oputil.wmean(x, weight) + x_std = oputil.wmean((x - x_mean) ** 2, weight) ** 0.5 + c_world = F.pad(torch.eye(3, **storage_opts), (0, 0, 0, 1), value=0.0).expand_as( + x[:, :4, :] + ) + return c_world * x_std + x_mean + + +def _compute_alphas(x, c_world): + """ + Computes barycentric coordinates of x in the frame c_world. + Args: + x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + c_world: control points in world coordinates. + """ + x = F.pad(x, (0, 1), value=1.0) + c = F.pad(c_world, (0, 1), value=1.0) + return torch.matmul(x, torch.inverse(c)) # B x N x 4 + + +def _build_M(y, alphas, weight): + """ Returns the matrix defining the reprojection equations. + Args: + y: projected points in camera coordinates of size B x N x 2 + alphas: barycentric coordinates of size B x N x 4 + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + """ + bs, n, _ = y.size() + + # prepend t with the column of v's + def prepad(t, v): + return F.pad(t, (1, 0), value=v) + + # outer left-multiply by alphas + def lm_alphas(t): + return torch.matmul(alphas[..., None], t).reshape(bs, n, 12) + + M = torch.cat( + ( + lm_alphas( + prepad(prepad(-y[:, :, 0, None, None], 0.0), 1.0) + ), # u constraints + lm_alphas( + prepad(prepad(-y[:, :, 1, None, None], 1.0), 0.0) + ), # v constraints + ), + dim=-1, + ).reshape(bs, -1, 12) + + if weight is not None: + M = M * weight.repeat(1, 2)[:, :, None] + + return M + + +def _null_space(m, kernel_dim): + """ Finds the null space (kernel) basis of the matrix + Args: + m: the batch of input matrices, B x N x 12 + kernel_dim: number of dimensions to approximate the kernel + Returns: + * a batch of null space basis vectors + of size B x 4 x 3 x kernel_dim + * a batch of spectral values where near-0s correspond to actual + kernel vectors, of size B x kernel_dim + """ + mTm = torch.bmm(m.transpose(1, 2), m) + s, v = torch.symeig(mTm, eigenvectors=True) + return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim] + + +def _reproj_error(y_hat, y, weight): + """ Projects estimated 3D points and computes the reprojection error + Args: + y_hat: a batch of predicted 2D points in homogeneous coordinates + y: a batch of ground-truth 2D points + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + Returns: + Optionally weighted RMSE of difference between y and y_hat. + """ + y_hat = y_hat / y_hat[..., 2:] + dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5 + return oputil.wmean(dist, weight)[:, 0, 0] + + +def _algebraic_error(x_w_rotated, x_cam, weight): + """ Computes the residual of Umeyama in 3D. + Args: + x_w_rotated: The given 3D points rotated with the predicted camera. + x_cam: the lifted 2D points y + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + Returns: + Optionally weighted MSE of difference between x_w_rotated and x_cam. + """ + dist = ((x_w_rotated - x_cam) ** 2).sum(dim=-1, keepdim=True) + return oputil.wmean(dist, weight)[:, 0, 0] + + +def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-9): + """ Given a solution, adjusts the scale and flip + Args: + c_cam: control points in camera coordinates + alphas: barycentric coordinates of the points + x_world: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`. + weights: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + eps: epsilon to threshold negative `z` values + """ + # position of reference points in camera coordinates + x_cam = torch.matmul(alphas, c_cam) + + x_cam = x_cam * (1.0 - 2.0 * (oputil.wmean(x_cam[..., 2:], weight) < 0).float()) + if torch.any(x_cam[..., 2:] < -eps): + neg_rate = oputil.wmean((x_cam[..., 2:] < 0).float(), weight, dim=(0, 1)).item() + warnings.warn("\nEPnP: %2.2f%% points have z<0." % (neg_rate * 100.0)) + + R, T, s = points_alignment.corresponding_points_alignment( + x_world, x_cam, weight, estimate_scale=True + ) + x_cam = x_cam / s[:, None, None] + T = T / s[:, None] + x_w_rotated = torch.matmul(x_world, R) + T[:, None, :] + err_2d = _reproj_error(x_w_rotated, y, weight) + err_3d = _algebraic_error(x_w_rotated, x_cam, weight) + + return EpnpSolution(x_cam, R, T, err_2d, err_3d) + + +def _gen_pairs(input, dim=-2, reducer=lambda l, r: ((l - r) ** 2).sum(dim=-1)): + """ Generates all pairs of different rows and then applies the reducer + Args: + input: a tensor + dim: a dimension to generate pairs across + reducer: a function of generated pair of rows to apply (beyond just concat) + Returns: + for default args, for A x B x C input, will output A x (B choose 2) + """ + n = input.size()[dim] + range = torch.arange(n) + idx = torch.combinations(range).to(input).long() + left = input.index_select(dim, idx[:, 0]) + right = input.index_select(dim, idx[:, 1]) + return reducer(left, right) + + +def _kernel_vec_distances(v): + """ Computes the coefficients for linearisation of the quadratic system + to match all pairwise distances between 4 control points (dim=1). + The last dimension corresponds to the coefficients for quadratic terms + Bij = Bi * Bj, where Bi and Bj correspond to kernel vectors. + Arg: + v: tensor of B x 4 x 3 x D, where D is dim(kernel), usually 4 + Returns: + a tensor of B x 6 x [(D choose 2) + D]; + for D=4, the last dim means [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]. + """ + dv = _gen_pairs(v, dim=-3, reducer=lambda l, r: l - r) # B x 6 x 3 x D + + # we should take dot-product of all (i,j), i < j, with coeff 2 + rows_2ij = 2.0 * _gen_pairs(dv, dim=-1, reducer=lambda l, r: (l * r).sum(dim=-2)) + # this should produce B x 6 x (D choose 2) tensor + + # we should take dot-product of all (i,i) + rows_ii = (dv ** 2).sum(dim=-2) + # this should produce B x 6 x D tensor + + return torch.cat((rows_ii, rows_2ij), dim=-1) + + +def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx): + """ Solves an over-determined linear system for selected LHS columns. + A batched version of `torch.lstsq`. + Args: + rhs: right-hand side vectors + lhs: left-hand side matrices + lhs_col_idx: a slice of columns in lhs + Returns: + a least-squares solution for lhs * X = rhs + """ + lhs = lhs.index_select(-1, torch.tensor(lhs_col_idx, device=lhs.device).long()) + return torch.matmul(torch.pinverse(lhs), rhs[:, :, None]) + + +def _find_null_space_coords_1(kernel_dsts, cw_dst): + """ Solves case 1 from the paper [1]; solve for 4 coefficients: + [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] + ^ ^ ^ ^ + Args: + kernel_dsts: distances between kernel vectors + cw_dst: distances between control points + Returns: + coefficients to weight kernel vectors + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6]) + + beta = beta * beta[:, :1, :].sign() + return beta / (beta[:, :1, :] ** 0.5) + + +def _find_null_space_coords_2(kernel_dsts, cw_dst): + """ Solves case 2 from the paper; solve for 3 coefficients: + [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] + ^ ^ ^ + Args: + kernel_dsts: distances between kernel vectors + cw_dst: distances between control points + Returns: + coefficients to weight kernel vectors + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1]) + + coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign() + coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * ( + (beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0) + ).float() + + return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1) + + +def _find_null_space_coords_3(kernel_dsts, cw_dst): + """ Solves case 3 from the paper; solve for 5 coefficients: + [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] + ^ ^ ^ ^ ^ + Args: + kernel_dsts: distances between kernel vectors + cw_dst: distances between control points + Returns: + coefficients to weight kernel vectors + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7]) + + coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign() + coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * ( + (beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0) + ).float() + coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :] + + return torch.cat( + (coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1 + ) + + +def efficient_pnp( + x: torch.Tensor, + y: torch.Tensor, + weights: Optional[torch.Tensor] = None, + skip_quadratic_eq: bool = False, +) -> EpnpSolution: + """ + Implements Efficient PnP algorithm [1] for Perspective-n-Points problem: + finds a camera position (defined by rotation `R` and translation `T`) that + minimizes re-projection error between the given 3D points `x` and + the corresponding uncalibrated 2D points `y`, i.e. solves + + `y[i] = Proj(x[i] R[i] + T[i])` + + in the least-squares sense, where `i` are indices within the batch, and + `Proj` is the perspective projection operator: `Proj([x y z]) = [x/z y/z]`. + In the noise-less case, 4 points are enough to find the solution as long + as they are not co-planar. + + Args: + x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`. + weights: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + skip_quadratic_eq: If True, assumes the solution space for the + linear system is one-dimensional, i.e. takes the scaled eigenvector + that corresponds to the smallest eigenvalue as a solution. + If False, finds the candidate coordinates in the potentially + 4D null space by approximately solving the systems of quadratic + equations. The best candidate is chosen by examining the 2D + re-projection error. While this option finds a better solution, + especially when the number of points is small or perspective + distortions are low (the points are far away), it may be more + difficult to back-propagate through. + + Returns: + `EpnpSolution` namedtuple containing elements: + **x_cam**: Batch of transformed points `x` that is used to find + the camera parameters, of shape `(minibatch, num_points, 3)`. + In the general (noisy) case, they are not exactly equal to + `x[i] R[i] + T[i]` but are some affine transform of `x[i]`s. + **R**: Batch of rotation matrices of shape `(minibatch, 3, 3)`. + **T**: Batch of translation vectors of shape `(minibatch, 3)`. + **err_2d**: Batch of mean 2D re-projection errors of shape + `(minibatch,)`. Specifically, if `yhat` is the re-projection for + the `i`-th batch element, it returns `sum_j norm(yhat_j - y_j)` + where `j` iterates over points and `norm` denotes the L2 norm. + **err_3d**: Batch of mean algebraic errors of shape `(minibatch,)`. + Specifically, those are squared distances between `x_world` and + estimated points on the rays defined by `y`. + + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + # define control points in a world coordinate system (centered on the 3d + # points centroid); 4 x 3 + # TODO: more stable when initialised with the center and eigenvectors! + c_world = _define_control_points( + x.detach(), weights, storage_opts={"dtype": x.dtype, "device": x.device} + ) + + # find the linear combination of the control points to represent the 3d points + alphas = _compute_alphas(x, c_world) + + M = _build_M(y, alphas, weights) + + # Compute kernel M + kernel, spectrum = _null_space(M, 4) + + c_world_distances = _gen_pairs(c_world) + kernel_dsts = _kernel_vec_distances(kernel) + + betas = ( + [] + if skip_quadratic_eq + else [ + fnsc(kernel_dsts, c_world_distances) + for fnsc in [ + _find_null_space_coords_1, + _find_null_space_coords_2, + _find_null_space_coords_3, + ] + ] + ) + + c_cam_variants = [kernel] + [ + torch.matmul(kernel, beta[:, None, :, :]) for beta in betas + ] + + solutions = [ + _compute_norm_sign_scaling_factor(c_cam[..., 0], alphas, x, y, weights) + for c_cam in c_cam_variants + ] + + sol_zipped = EpnpSolution(*(torch.stack(list(col)) for col in zip(*solutions))) + best = torch.argmin(sol_zipped.err_2d, dim=0) + + def gather1d(source, idx): + # reduces the dim=1 by picking the slices in a 1D tensor idx + # in other words, it is batched index_select. + return source.gather( + 0, + idx.reshape(1, -1, *([1] * (len(source.shape) - 2))).expand_as(source[:1]), + )[0] + + return EpnpSolution(*[gather1d(sol_col, best) for sol_col in sol_zipped]) diff --git a/tests/bm_perspective_n_points.py b/tests/bm_perspective_n_points.py new file mode 100644 index 000000000..75d77a374 --- /dev/null +++ b/tests/bm_perspective_n_points.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools + +from fvcore.common.benchmark import benchmark +from test_perspective_n_points import TestPerspectiveNPoints + + +def bm_perspective_n_points() -> None: + case_grid = { + "batch_size": [1, 10, 100], + "num_pts": [100, 100000], + "skip_q": [False, True], + } + + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + test = TestPerspectiveNPoints() + benchmark( + test.case_with_gaussian_points, + "PerspectiveNPoints", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/bm_points_alignment.py b/tests/bm_points_alignment.py index 942e76aae..39e5bb9a9 100644 --- a/tests/bm_points_alignment.py +++ b/tests/bm_points_alignment.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from copy import deepcopy diff --git a/tests/common_testing.py b/tests/common_testing.py index 450c3c38d..141e28af8 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -1,12 +1,15 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import unittest -from typing import Optional +from typing import Callable, Optional, Union import numpy as np import torch +TensorOrArray = Union[torch.Tensor, np.ndarray] + + class TestCaseMixin(unittest.TestCase): def assertSeparate(self, tensor1, tensor2) -> None: """ @@ -28,10 +31,61 @@ def assertAllSeparate(self, tensor_list) -> None: ptrs = [i.storage().data_ptr() for i in tensor_list] self.assertCountEqual(ptrs, set(ptrs)) + def assertNormsClose( + self, + input: TensorOrArray, + other: TensorOrArray, + norm_fn: Callable[[TensorOrArray], TensorOrArray], + *, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, + msg: Optional[str] = None, + ) -> None: + """ + Verifies that two tensors or arrays have the same shape and are close + given absolute and relative tolerance; raises AssertionError otherwise. + A custom norm function is computed before comparison. If no such pre- + processing needed, pass `torch.abs` or, equivalently, call `assertClose`. + Args: + input, other: two tensors or two arrays. + norm_fn: The function evaluates + `all(norm_fn(input - other) <= atol + rtol * norm_fn(other))`. + norm_fn is a tensor -> tensor function; the output has: + * all entries non-negative, + * shape defined by the input shape only. + rtol, atol, equal_nan: as for torch.allclose. + msg: message in case the assertion is violated. + Note: + Optional arguments here are all keyword-only, to avoid confusion + with msg arguments on other assert functions. + """ + + self.assertEqual(np.shape(input), np.shape(other)) + + diff = norm_fn(input - other) + other_ = norm_fn(other) + + # We want to generalise allclose(input, output), which is essentially + # all(diff <= atol + rtol * other) + # but with a sophisticated handling non-finite values. + # We work that around by calling allclose() with the following arguments: + # allclose(diff + other_, other_). This computes what we want because + # all(|diff + other_ - other_| <= atol + rtol * |other_|) == + # all(|norm_fn(input - other)| <= atol + rtol * |norm_fn(other)|) == + # all(norm_fn(input - other) <= atol + rtol * norm_fn(other)). + + backend = torch if torch.is_tensor(input) else np + close = backend.allclose( + diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + self.assertTrue(close, msg) + def assertClose( self, - input, - other, + input: TensorOrArray, + other: TensorOrArray, *, rtol: float = 1e-05, atol: float = 1e-08, @@ -39,7 +93,10 @@ def assertClose( msg: Optional[str] = None, ) -> None: """ - Verify that two tensors or arrays are the same shape and close. + Verifies that two tensors or arrays have the same shape and are close + given absolute and relative tolerance, i.e. checks + `all(|input - other| <= atol + rtol * |other|)`; + raises AssertionError otherwise. Args: input, other: two tensors or two arrays. rtol, atol, equal_nan: as for torch.allclose. @@ -51,10 +108,9 @@ def assertClose( self.assertEqual(np.shape(input), np.shape(other)) - if torch.is_tensor(input): - close = torch.allclose( - input, other, rtol=rtol, atol=atol, equal_nan=equal_nan - ) - else: - close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + backend = torch if torch.is_tensor(input) else np + close = backend.allclose( + input, other, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + self.assertTrue(close, msg) diff --git a/tests/test_common_testing.py b/tests/test_common_testing.py new file mode 100644 index 000000000..c8976ad50 --- /dev/null +++ b/tests/test_common_testing.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin + + +class TestOpsUtils(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + def test_all_close(self): + device = torch.device("cuda:0") + n_points = 20 + noise_std = 1e-3 + msg = "tratata" + + # test absolute tolerance + x = torch.rand(n_points, 3, device=device) + x_noise = x + noise_std * torch.rand(n_points, 3, device=device) + assert torch.allclose(x, x_noise, atol=10 * noise_std) + assert not torch.allclose(x, x_noise, atol=0.1 * noise_std) + self.assertClose(x, x_noise, atol=10 * noise_std) + with self.assertRaises(AssertionError) as context: + self.assertClose(x, x_noise, atol=0.1 * noise_std, msg=msg) + self.assertTrue(msg in str(context.exception)) + + # test numpy + def to_np(t): + return t.data.cpu().numpy() + + self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std) + with self.assertRaises(AssertionError) as context: + self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg) + self.assertTrue(msg in str(context.exception)) + + # test relative tolerance + assert torch.allclose(x, x_noise, rtol=100 * noise_std) + assert not torch.allclose(x, x_noise, rtol=noise_std) + self.assertClose(x, x_noise, rtol=100 * noise_std) + with self.assertRaises(AssertionError) as context: + self.assertClose(x, x_noise, rtol=noise_std, msg=msg) + self.assertTrue(msg in str(context.exception)) + + # test norm aggregation + # if one of the spatial dimensions is small, norm aggregation helps + x_noise[:, 0] = x_noise[:, 0] - x[:, 0] + x[:, 0] = 0.0 + assert not torch.allclose(x, x_noise, rtol=100 * noise_std) + self.assertNormsClose( + x, x_noise, rtol=100 * noise_std, norm_fn=lambda t: t.norm(dim=-1) + ) diff --git a/tests/test_perspective_n_points.py b/tests/test_perspective_n_points.py new file mode 100644 index 000000000..feaefe66b --- /dev/null +++ b/tests/test_perspective_n_points.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops import perspective_n_points +from pytorch3d.transforms import rotation_conversions + + +def reproj_error(x_world, y, R, T, weight=None): + # applies the affine transform, projects, and computes the reprojection error + y_hat = torch.matmul(x_world, R) + T[:, None, :] + y_hat = y_hat / y_hat[..., 2:] + if weight is None: + weight = y.new_ones((1, 1)) + return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean( + dim=-1 + ) + + +class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + + def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False): + sol = perspective_n_points.efficient_pnp( + x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q + ) + + err_2d = reproj_error(x_world, y, sol.R, sol.T) + R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R) + R_quat = rotation_conversions.matrix_to_quaternion(R) + + num_pts = x_world.shape[-2] + # quadratic part is more stable with fewer points + num_pts_thresh = 5 if skip_q else 4 + if check_output and num_pts > num_pts_thresh: + assert_msg = ( + f"test_perspective_n_points assertion failure for " + f"n_points={num_pts}, " + f"skip_quadratic={skip_q}, " + f"no noise." + ) + + self.assertClose(err_2d, sol.err_2d, msg=assert_msg) + self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg) + + def norm_fn(t): + return t.norm(dim=-1) + + self.assertNormsClose( + T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg + ) + self.assertNormsClose( + R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg + ) + + if print_stats: + torch.set_printoptions(precision=5, sci_mode=False) + for err_2d, err_3d, R_gt, T_gt in zip( + sol.err_2d, + sol.err_3d, + torch.cat((sol.R, R), dim=-1), + torch.stack((sol.T, T[:, 0, :]), dim=-1), + ): + print("2D Error: %1.4f" % err_2d.item()) + print("3D Error: %1.4f" % err_3d.item()) + print("R_hat | R_gt\n", R_gt) + print("T_hat | T_gt\n", T_gt) + + def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False): + x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1) + x_cam[:, :2] *= x_cam[:, 2:] # unproject + + R = rotation_conversions.random_rotations(16).to(y) + T = torch.randn_like(R[:, :1, :]) + x_world = torch.matmul(x_cam - T, R.transpose(1, 2)) + + if print_stats: + print("Run without noise") + + if benchmark: # return curried call + torch.cuda.synchronize() + + def result(): + self._run_and_print(x_world, y, R, T, False, skip_q) + torch.cuda.synchronize() + + return result + + self._run_and_print(x_world, y, R, T, print_stats, skip_q, check_output=True) + + # in the noisy case, there are no guarantees, so we check it doesn't crash + if print_stats: + print("Run with noise") + x_world += torch.randn_like(x_world) * 0.1 + self._run_and_print(x_world, y, R, T, print_stats, skip_q) + + def case_with_gaussian_points( + self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False + ): + return self._testcase_from_2d( + torch.randn((num_pts, 2)).cuda() / 3.0, + print_stats=print_stats, + benchmark=benchmark, + skip_q=skip_q, + ) + + def test_perspective_n_points(self, print_stats=False): + if print_stats: + print("RUN ON A DENSE GRID") + u = torch.linspace(-1.0, 1.0, 20) + v = torch.linspace(-1.0, 1.0, 15) + for skip_q in [False, True]: + self._testcase_from_2d( + torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q + ) + + for num_pts in range(6, 3, -1): + for skip_q in [False, True]: + if print_stats: + print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}") + + self.case_with_gaussian_points( + num_pts=num_pts, + print_stats=print_stats, + benchmark=False, + skip_q=skip_q, + )