From 2606afbb3c25973d5b50a9d18c33b34b80121b6d Mon Sep 17 00:00:00 2001 From: Tong Wei <73064503+weitong8591@users.noreply.github.com> Date: Wed, 22 May 2024 11:52:28 +0200 Subject: [PATCH] Speed up differentiable 5PC and fix the batch size issue (#2914) * remove one for loop to speed up the 5PC alg., and fix the issue of the inconsistent batch sizes between the input and returned E matrices, referred to [wang2023vggsfm] * remove one for loop to speed up the 5PC alg., and fix the issue of the inconsistent batch sizes, referred to [wang2023vggsfm] * typo * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove 10k-batch test * add unit test for degenarate case, change 5CP outputs 10 solutions for each image pair. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bug fix-missing type. * fix the data type issue of float32 or float64. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/geometry/epipolar/test_essential.py typo Co-authored-by: Edgar Riba --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edgar Riba --- docs/source/references.bib | 7 + kornia/geometry/epipolar/essential.py | 240 +++++++++++++--------- tests/geometry/epipolar/test_essential.py | 24 ++- 3 files changed, 165 insertions(+), 106 deletions(-) diff --git a/docs/source/references.bib b/docs/source/references.bib index fd8c4acf62..aee01394a6 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -398,3 +398,10 @@ @inproceedings{wei2023generalized booktitle = {ICCV}, year = {2023} } + +@article{wang2023vggsfm, + title={VGGSfM: Visual Geometry Grounded Deep Structure From Motion}, + author={Wang, Jianyuan and Karaev, Nikita and Rupprecht, Christian and Novotny, David}, + journal={arXiv preprint arXiv:2312.04563}, + year={2023} +} diff --git a/kornia/geometry/epipolar/essential.py b/kornia/geometry/epipolar/essential.py index 5f33467e7a..03fe926932 100644 --- a/kornia/geometry/epipolar/essential.py +++ b/kornia/geometry/epipolar/essential.py @@ -4,11 +4,10 @@ import torch -import kornia.geometry.epipolar as epi from kornia.core import eye, ones_like, stack, where, zeros from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_SHAPE, KORNIA_CHECK_SHAPE from kornia.geometry import solvers -from kornia.utils import eye_like, safe_inverse_with_mask, vec_like +from kornia.utils import eye_like, vec_like from kornia.utils.helpers import _torch_solve_cast, _torch_svd_cast from .numeric import cross_product_matrix @@ -30,7 +29,7 @@ def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[t r"""Compute the essential matrix using the 5-point algorithm from Nister. The linear system is solved by Nister's 5-point algorithm [@nister2004efficient], - and the solver implemented referred to [@barath2020magsac++][@wei2023generalized]. + and the solver implemented referred to [@barath2020magsac++][@wei2023generalized][@wang2023vggsfm]. Args: points1: A set of carlibrated points in the first image with a tensor shape :math:`(B, N, 2), N>=8`. @@ -51,7 +50,7 @@ def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[t x2, y2 = torch.chunk(points2, dim=-1, chunks=2) # Bx1xN ones = ones_like(x1) - # build equations system and find null space. + # build the equation system and find the null space. # https://www.cc.gatech.edu/~afb/classes/CS4495-Fall2013/slides/CS4495-09-TwoViews-2.pdf # [x * x', x * y', x, y * x', y * y', y, x', y', 1] # BxNx9 @@ -63,32 +62,59 @@ def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[t else: w_diag = torch.diag_embed(weights) X = X.transpose(-2, -1) @ w_diag @ X + + # use Nister's 5PC to solve essential matrix + E_Nister = null_to_Nister_solution(X, batch_size) + + return E_Nister + + +def fun_select(null_mat: torch.Tensor, i: int, j: int, ratio: int = 3) -> torch.Tensor: + return null_mat[:, ratio * j + i] + + +def null_to_Nister_solution(X: torch.Tensor, batch_size: int) -> torch.Tensor: + r"""Use Nister's 5PC to solve essential matrix. + + The linear system is solved by Nister's 5-point algorithm [@nister2004efficient], + and the solver implemented referred to [@barath2020magsac++][@wei2023generalized][@wang2023vggsfm]. + + Args: + X: Coefficients for the null space :math:`(B, N, 2), N>=8`. + batch_size: batcs size of the input, the number of image pairs :math:`B`. + + Returns: + the computed essential matrix with shape :math:`(B, 3, 3)`. + Note that the returned E matrices should be the same batch size with the input. + """ + # compute eigenvectors and retrieve the one with the smallest eigenvalue, using SVD # turn off the grad check due to the unstable gradients from SVD. # several close to zero values of eigenvalues. _, _, V = _torch_svd_cast(X) # torch.svd + null_ = V[:, :, -4:] # the last four rows nullSpace = V.transpose(-1, -2)[:, -4:, :] coeffs = zeros(batch_size, 10, 20, device=null_.device, dtype=null_.dtype) d = zeros(batch_size, 60, device=null_.device, dtype=null_.dtype) - def fun(i: int, j: int) -> torch.Tensor: - return null_[:, 3 * j + i] - # Determinant constraint coeffs[:, 9] = ( solvers.multiply_deg_two_one_poly( - solvers.multiply_deg_one_poly(fun(0, 1), fun(1, 2)) - solvers.multiply_deg_one_poly(fun(0, 2), fun(1, 1)), - fun(2, 0), + solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 2)) + - solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 1)), + fun_select(null_, 2, 0), ) + solvers.multiply_deg_two_one_poly( - solvers.multiply_deg_one_poly(fun(0, 2), fun(1, 0)) - solvers.multiply_deg_one_poly(fun(0, 0), fun(1, 2)), - fun(2, 1), + solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 0)) + - solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 2)), + fun_select(null_, 2, 1), ) + solvers.multiply_deg_two_one_poly( - solvers.multiply_deg_one_poly(fun(0, 0), fun(1, 1)) - solvers.multiply_deg_one_poly(fun(0, 1), fun(1, 0)), - fun(2, 2), + solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 1)) + - solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 0)), + fun_select(null_, 2, 2), ) ) @@ -98,9 +124,9 @@ def fun(i: int, j: int) -> torch.Tensor: for i in range(3): for j in range(3): d[:, indices[i, j] : indices[i, j] + 10] = ( - solvers.multiply_deg_one_poly(fun(i, 0), fun(j, 0)) - + solvers.multiply_deg_one_poly(fun(i, 1), fun(j, 1)) - + solvers.multiply_deg_one_poly(fun(i, 2), fun(j, 2)) + solvers.multiply_deg_one_poly(fun_select(null_, i, 0), fun_select(null_, j, 0)) + + solvers.multiply_deg_one_poly(fun_select(null_, i, 1), fun_select(null_, j, 1)) + + solvers.multiply_deg_one_poly(fun_select(null_, i, 2), fun_select(null_, j, 2)) ) for i in range(10): @@ -113,9 +139,9 @@ def fun(i: int, j: int) -> torch.Tensor: for i in range(3): for j in range(3): row = ( - solvers.multiply_deg_two_one_poly(d[:, indices[i, 0] : indices[i, 0] + 10], fun(0, j)) - + solvers.multiply_deg_two_one_poly(d[:, indices[i, 1] : indices[i, 1] + 10], fun(1, j)) - + solvers.multiply_deg_two_one_poly(d[:, indices[i, 2] : indices[i, 2] + 10], fun(2, j)) + solvers.multiply_deg_two_one_poly(d[:, indices[i, 0] : indices[i, 0] + 10], fun_select(null_, 0, j)) + + solvers.multiply_deg_two_one_poly(d[:, indices[i, 1] : indices[i, 1] + 10], fun_select(null_, 1, j)) + + solvers.multiply_deg_two_one_poly(d[:, indices[i, 2] : indices[i, 2] + 10], fun_select(null_, 2, j)) ) coeffs[:, cnt] = row cnt += 1 @@ -125,10 +151,17 @@ def fun(i: int, j: int) -> torch.Tensor: torch.linalg.matrix_rank(coeffs), ones_like(torch.linalg.matrix_rank(coeffs[:, :, :10])) * 10 ) + # check if there is no solution + if singular_filter.sum() == 0: + return torch.eye(3, dtype=coeffs.dtype, device=coeffs.device)[None].expand(batch_size, 10, -1, -1).clone() + eliminated_mat = _torch_solve_cast(coeffs[singular_filter, :, :10], b[singular_filter]) coeffs_ = torch.cat((coeffs[singular_filter, :, :10], eliminated_mat), dim=-1) + # check the batch size after singular filter, for batch operation afterwards + batch_size_filtered = coeffs_.shape[0] + A = zeros(coeffs_.shape[0], 3, 13, device=coeffs_.device, dtype=coeffs_.dtype) for i in range(3): @@ -142,66 +175,88 @@ def fun(i: int, j: int) -> torch.Tensor: A[:, i : i + 1, 9:13] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 16:20] A[:, i : i + 1, 8:12] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 16:20] + # Bx11 cs = solvers.determinant_to_polynomial(A) - E_models = [] - - # for loop because of different numbers of solutions - for bi in range(A.shape[0]): - A_i = A[bi] - null_i = nullSpace[bi] - - # companion matrix solver for polynomial - C = zeros((10, 10), device=cs.device, dtype=cs.dtype) - C[0:-1, 1:] = eye(C[0:-1, 0:-1].shape[0], device=cs.device, dtype=cs.dtype) - C[-1, :] = -cs[bi][:-1] / cs[bi][-1] - - roots = torch.real(torch.linalg.eigvals(C)) - - if roots is None: - continue - n_sols = roots.size() - if n_sols == 0: - continue - Bs = stack( - ( - A_i[:3, :1] * (roots**3) + A_i[:3, 1:2] * roots.square() + A_i[0:3, 2:3] * roots + A_i[0:3, 3:4], - A_i[0:3, 4:5] * (roots**3) + A_i[0:3, 5:6] * roots.square() + A_i[0:3, 6:7] * roots + A_i[0:3, 7:8], - ), - dim=0, - ).transpose(0, -1) - - bs = ( - A_i[0:3, 8:9] * (roots**4) - + A_i[0:3, 9:10] * (roots**3) - + A_i[0:3, 10:11] * roots.square() - + A_i[0:3, 11:12] * roots - + A_i[0:3, 12:13] - ).T.unsqueeze(-1) - - # We try to solve using top two rows, - xzs = (safe_inverse_with_mask(Bs[:, 0:2, 0:2])[0]) @ (bs[:, 0:2]) - - mask = (abs(Bs[:, 2].unsqueeze(1) @ xzs - bs[:, 2].unsqueeze(1)) > 1e-3).flatten() - if torch.sum(mask) != 0: - q, r = torch.linalg.qr(Bs[mask].clone()) # - xzs[mask] = _torch_solve_cast(r, q.transpose(-1, -2) @ bs[mask]) # [mask] - - # models - Es = null_i[0] * (-xzs[:, 0]) + null_i[1] * (-xzs[:, 1]) + null_i[2] * roots.unsqueeze(-1) + null_i[3] - - # Since the rows of N are orthogonal unit vectors, we can normalize the coefficients instead - inv = 1.0 / torch.sqrt((-xzs[:, 0]) ** 2 + (-xzs[:, 1]) ** 2 + roots.unsqueeze(-1) ** 2 + 1.0) - Es *= inv - if Es.shape[0] < 10: - Es = torch.cat( - (Es.clone(), eye(3, device=Es.device, dtype=Es.dtype).repeat(10 - Es.shape[0], 1).reshape(-1, 9)) - ) - E_models.append(Es) - # if not E_models: - # return torch.eye(3, device=cs.device, dtype=cs.dtype).unsqueeze(0) - # else: - return torch.cat(E_models).view(-1, 3, 3).transpose(-1, -2) + # A: Bx3x13 + # nullSpace: Bx4x9 + # companion matrices to solve the polynomial, in batch + C = zeros((batch_size_filtered, 10, 10), device=cs.device, dtype=cs.dtype) + eye_mat = eye(C[0, 0:-1, 0:-1].shape[0], device=cs.device, dtype=cs.dtype) + C[:, 0:-1, 1:] = eye_mat + + cs_de = cs[:, -1].unsqueeze(-1) + cs_de = torch.where(cs_de == 0, torch.tensor(1e-8, dtype=cs_de.dtype), cs_de) + C[:, -1, :] = -cs[:, :-1] / cs_de + + roots = torch.real(torch.linalg.eigvals(C)) + + roots_unsqu = roots.unsqueeze(1) + + Bs = stack( + ( + A[:, :3, :1] * (roots_unsqu**3) + + A[:, :3, 1:2] * roots_unsqu.square() + + A[:, 0:3, 2:3] * roots_unsqu + + A[:, 0:3, 3:4], + A[:, 0:3, 4:5] * (roots_unsqu**3) + + A[:, 0:3, 5:6] * roots_unsqu.square() + + A[:, 0:3, 6:7] * roots_unsqu + + A[:, 0:3, 7:8], + ), + dim=1, + ) + + Bs = Bs.transpose(1, -1) + + bs = ( + ( + A[:, 0:3, 8:9] * (roots_unsqu**4) + + A[:, 0:3, 9:10] * (roots_unsqu**3) + + A[:, 0:3, 10:11] * roots_unsqu.square() + + A[:, 0:3, 11:12] * roots_unsqu + + A[:, 0:3, 12:13] + ) + .transpose(1, 2) + .unsqueeze(-1) + ) + + xzs = torch.matmul(torch.inverse(Bs[:, :, 0:2, 0:2]), bs[:, :, 0:2]) + + mask = (abs(Bs[:, 2].unsqueeze(1) @ xzs - bs[:, 2].unsqueeze(1)) > 1e-3).flatten() + + # mask: bx10x1x1 + mask = ( + abs(torch.matmul(Bs[:, :, 2, :].unsqueeze(2), xzs) - bs[:, :, 2, :].unsqueeze(2)) > 1e-3 + ) # .flatten(start_dim=1) + + # bx10 + mask = mask.squeeze(3).squeeze(2) + + if torch.any(mask): + q_batch, r_batch = torch.linalg.qr(Bs[mask]) + xyz_to_feed = torch.linalg.solve(r_batch, torch.matmul(q_batch.transpose(-1, -2), bs[mask])) + xzs[mask] = xyz_to_feed + + nullSpace_filtered = nullSpace[singular_filter] + + Es = ( + nullSpace_filtered[:, 0:1] * (-xzs[:, :, 0]) + + nullSpace_filtered[:, 1:2] * (-xzs[:, :, 1]) + + nullSpace_filtered[:, 2:3] * roots.unsqueeze(-1) + + nullSpace_filtered[:, 3:4] + ) + + inv = 1.0 / torch.sqrt((-xzs[:, :, 0]) ** 2 + (-xzs[:, :, 1]) ** 2 + roots.unsqueeze(-1) ** 2 + 1.0) + Es *= inv + + Es = Es.view(batch_size_filtered, -1, 3, 3).transpose(-1, -2) + + # make sure the returned batch size equals to that of inputs + E_return = torch.eye(3, dtype=Es.dtype, device=Es.device)[None].expand(batch_size, 10, -1, -1).clone() + E_return[singular_filter] = Es + + return E_return def essential_from_fundamental(F_mat: torch.Tensor, K1: torch.Tensor, K2: torch.Tensor) -> torch.Tensor: @@ -482,32 +537,17 @@ def find_essential( points1: torch.Tensor, points2: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" - Args: - points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=5`. - points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=5`. - weights: Tensor containing the weights per point correspondence with a shape of :math:`(5, N)`. + Args: + points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=5`. + points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=5`. + weights: Tensor containing the weights per point correspondence with a shape of :math:`(5, N)`. Returns: - the computed essential matrix with shape :math:`(B, 3, 3)`, - one model for each batch selected out of ten solutions by Sampson distances. + the computed essential matrices with shape :math:`(B, 10, 3, 3)`. + Note that all possible solutions are returned, i.e., 10 essential matrices for each image pair. + To choose the best one out of 10, try to check the one with the lowest Sampson distance. """ E = run_5point(points1, points2, weights).to(points1.dtype) - # select one out of 10 possible solutions from 5PC Nister solver. - solution_num = 10 - batch_size = points1.shape[0] - - error = zeros((batch_size, solution_num)) - - for b in range(batch_size): - error[b] = epi.sampson_epipolar_distance(points1[b], points2[b], E.view(batch_size, solution_num, 3, 3)[b]).sum( - -1 - ) - - KORNIA_CHECK_SHAPE(error, ["f{batch_size}", "10"]) - - chosen_indices = torch.argmin(error, dim=-1) - result = stack([(E.view(-1, solution_num, 3, 3))[i, chosen_indices[i], :] for i in range(batch_size)]) - - return result + return E diff --git a/tests/geometry/epipolar/test_essential.py b/tests/geometry/epipolar/test_essential.py index 6382c4b91a..f95d4a37bf 100644 --- a/tests/geometry/epipolar/test_essential.py +++ b/tests/geometry/epipolar/test_essential.py @@ -14,16 +14,16 @@ def test_smoke(self, device, dtype): points2 = torch.rand(1, 5, 2, device=device, dtype=dtype) weights = torch.ones(1, 5, device=device, dtype=dtype) E_mat = epi.essential.find_essential(points1, points2, weights) - assert E_mat.shape == (1, 3, 3) + assert E_mat.shape == (1, 10, 3, 3) - @pytest.mark.parametrize("batch_size, num_points", [(1, 5), (2, 6), (3, 7)]) + @pytest.mark.parametrize("batch_size, num_points", [(1, 5), (2, 6), (3, 7), (1000, 5)]) def test_shape(self, batch_size, num_points, device, dtype): B, N = batch_size, num_points points1 = torch.rand(B, N, 2, device=device, dtype=dtype) points2 = torch.rand(B, N, 2, device=device, dtype=dtype) weights = torch.ones(B, N, device=device, dtype=dtype) E_mat = epi.essential.find_essential(points1, points2, weights) - assert E_mat.shape == (B, 3, 3) + assert E_mat.shape == (B, 10, 3, 3) @pytest.mark.parametrize("batch_size, num_points", [(1, 5), (2, 6), (3, 7)]) def test_shape_noweights(self, batch_size, num_points, device, dtype): @@ -32,7 +32,7 @@ def test_shape_noweights(self, batch_size, num_points, device, dtype): points2 = torch.rand(B, N, 2, device=device, dtype=dtype) weights = None E_mat = epi.essential.find_essential(points1, points2, weights) - assert E_mat.shape == (B, 3, 3) + assert E_mat.shape == (B, 10, 3, 3) def test_epipolar_constraint(self, device, dtype): calibrated_x1 = torch.tensor( @@ -49,7 +49,8 @@ def test_epipolar_constraint(self, device, dtype): E = epi.essential.find_essential(calibrated_x1, calibrated_x2) if torch.all(E != 0): distance = epi.symmetrical_epipolar_distance(calibrated_x1, calibrated_x2, E) - mean_error = distance.mean() + # Note : here we check only the best model, although all solutions are returned + mean_error = distance.mean(-1).min() self.assert_close(mean_error, torch.tensor(0.0, device=device, dtype=dtype), atol=1e-4, rtol=1e-4) def test_synthetic_sampson(self, device, dtype): @@ -68,9 +69,20 @@ def test_synthetic_sampson(self, device, dtype): E_est = epi.essential.find_essential(calibrated_x1, calibrated_x2, weights) error = epi.sampson_epipolar_distance(calibrated_x1, calibrated_x2, E_est) self.assert_close( - error, torch.zeros((calibrated_x1.shape[:2]), device=device, dtype=dtype), atol=1e-4, rtol=1e-4 + error[:, torch.argmin(error.mean(-1).min())], + torch.zeros((calibrated_x1.shape[:2]), device=device, dtype=dtype), + atol=1e-4, + rtol=1e-4, ) + @pytest.mark.parametrize("batch_size, num_points", [(5, 5), (10, 5)]) + def test_degenerate_case(self, batch_size, num_points, device, dtype): + B, N = batch_size, num_points + points1_deg = torch.rand(B, N, 2, device=device, dtype=dtype) + weights = torch.ones_like(points1_deg)[..., 0] + E_mat_deg = epi.essential.find_essential(points1_deg, points1_deg, weights) + assert E_mat_deg.shape == (B, 10, 3, 3) + class TestEssentialFromFundamental(BaseTester): def test_smoke(self, device, dtype):