Skip to content

Commit

Permalink
Speed up differentiable 5PC and fix the batch size issue (#2914)
Browse files Browse the repository at this point in the history
* remove one for loop to speed up the 5PC alg., and fix the issue of the inconsistent batch sizes between the input and returned E matrices, referred to [wang2023vggsfm]

* remove one for loop to speed up the 5PC alg., and fix the issue of the inconsistent batch sizes, referred to [wang2023vggsfm]

* typo

* typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove 10k-batch test

* add unit test for degenarate case, change 5CP outputs 10 solutions for each image pair.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bug fix-missing type.

* fix the data type issue of float32 or float64.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests/geometry/epipolar/test_essential.py

typo

Co-authored-by: Edgar Riba <edgar.riba@gmail.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edgar Riba <edgar.riba@gmail.com>
  • Loading branch information
3 people committed May 22, 2024
1 parent 5b01cc9 commit 2606afb
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 106 deletions.
7 changes: 7 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,10 @@ @inproceedings{wei2023generalized
booktitle = {ICCV},
year = {2023}
}

@article{wang2023vggsfm,
title={VGGSfM: Visual Geometry Grounded Deep Structure From Motion},
author={Wang, Jianyuan and Karaev, Nikita and Rupprecht, Christian and Novotny, David},
journal={arXiv preprint arXiv:2312.04563},
year={2023}
}
240 changes: 140 additions & 100 deletions kornia/geometry/epipolar/essential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import torch

import kornia.geometry.epipolar as epi
from kornia.core import eye, ones_like, stack, where, zeros
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_SHAPE, KORNIA_CHECK_SHAPE
from kornia.geometry import solvers
from kornia.utils import eye_like, safe_inverse_with_mask, vec_like
from kornia.utils import eye_like, vec_like
from kornia.utils.helpers import _torch_solve_cast, _torch_svd_cast

from .numeric import cross_product_matrix
Expand All @@ -30,7 +29,7 @@ def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[t
r"""Compute the essential matrix using the 5-point algorithm from Nister.
The linear system is solved by Nister's 5-point algorithm [@nister2004efficient],
and the solver implemented referred to [@barath2020magsac++][@wei2023generalized].
and the solver implemented referred to [@barath2020magsac++][@wei2023generalized][@wang2023vggsfm].
Args:
points1: A set of carlibrated points in the first image with a tensor shape :math:`(B, N, 2), N>=8`.
Expand All @@ -51,7 +50,7 @@ def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[t
x2, y2 = torch.chunk(points2, dim=-1, chunks=2) # Bx1xN
ones = ones_like(x1)

# build equations system and find null space.
# build the equation system and find the null space.
# https://www.cc.gatech.edu/~afb/classes/CS4495-Fall2013/slides/CS4495-09-TwoViews-2.pdf
# [x * x', x * y', x, y * x', y * y', y, x', y', 1]
# BxNx9
Expand All @@ -63,32 +62,59 @@ def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[t
else:
w_diag = torch.diag_embed(weights)
X = X.transpose(-2, -1) @ w_diag @ X

# use Nister's 5PC to solve essential matrix
E_Nister = null_to_Nister_solution(X, batch_size)

return E_Nister


def fun_select(null_mat: torch.Tensor, i: int, j: int, ratio: int = 3) -> torch.Tensor:
return null_mat[:, ratio * j + i]


def null_to_Nister_solution(X: torch.Tensor, batch_size: int) -> torch.Tensor:
r"""Use Nister's 5PC to solve essential matrix.
The linear system is solved by Nister's 5-point algorithm [@nister2004efficient],
and the solver implemented referred to [@barath2020magsac++][@wei2023generalized][@wang2023vggsfm].
Args:
X: Coefficients for the null space :math:`(B, N, 2), N>=8`.
batch_size: batcs size of the input, the number of image pairs :math:`B`.
Returns:
the computed essential matrix with shape :math:`(B, 3, 3)`.
Note that the returned E matrices should be the same batch size with the input.
"""

# compute eigenvectors and retrieve the one with the smallest eigenvalue, using SVD
# turn off the grad check due to the unstable gradients from SVD.
# several close to zero values of eigenvalues.
_, _, V = _torch_svd_cast(X) # torch.svd

null_ = V[:, :, -4:] # the last four rows
nullSpace = V.transpose(-1, -2)[:, -4:, :]

coeffs = zeros(batch_size, 10, 20, device=null_.device, dtype=null_.dtype)
d = zeros(batch_size, 60, device=null_.device, dtype=null_.dtype)

def fun(i: int, j: int) -> torch.Tensor:
return null_[:, 3 * j + i]

# Determinant constraint
coeffs[:, 9] = (
solvers.multiply_deg_two_one_poly(
solvers.multiply_deg_one_poly(fun(0, 1), fun(1, 2)) - solvers.multiply_deg_one_poly(fun(0, 2), fun(1, 1)),
fun(2, 0),
solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 2))
- solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 1)),
fun_select(null_, 2, 0),
)
+ solvers.multiply_deg_two_one_poly(
solvers.multiply_deg_one_poly(fun(0, 2), fun(1, 0)) - solvers.multiply_deg_one_poly(fun(0, 0), fun(1, 2)),
fun(2, 1),
solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 0))
- solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 2)),
fun_select(null_, 2, 1),
)
+ solvers.multiply_deg_two_one_poly(
solvers.multiply_deg_one_poly(fun(0, 0), fun(1, 1)) - solvers.multiply_deg_one_poly(fun(0, 1), fun(1, 0)),
fun(2, 2),
solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 1))
- solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 0)),
fun_select(null_, 2, 2),
)
)

Expand All @@ -98,9 +124,9 @@ def fun(i: int, j: int) -> torch.Tensor:
for i in range(3):
for j in range(3):
d[:, indices[i, j] : indices[i, j] + 10] = (
solvers.multiply_deg_one_poly(fun(i, 0), fun(j, 0))
+ solvers.multiply_deg_one_poly(fun(i, 1), fun(j, 1))
+ solvers.multiply_deg_one_poly(fun(i, 2), fun(j, 2))
solvers.multiply_deg_one_poly(fun_select(null_, i, 0), fun_select(null_, j, 0))
+ solvers.multiply_deg_one_poly(fun_select(null_, i, 1), fun_select(null_, j, 1))
+ solvers.multiply_deg_one_poly(fun_select(null_, i, 2), fun_select(null_, j, 2))
)

for i in range(10):
Expand All @@ -113,9 +139,9 @@ def fun(i: int, j: int) -> torch.Tensor:
for i in range(3):
for j in range(3):
row = (
solvers.multiply_deg_two_one_poly(d[:, indices[i, 0] : indices[i, 0] + 10], fun(0, j))
+ solvers.multiply_deg_two_one_poly(d[:, indices[i, 1] : indices[i, 1] + 10], fun(1, j))
+ solvers.multiply_deg_two_one_poly(d[:, indices[i, 2] : indices[i, 2] + 10], fun(2, j))
solvers.multiply_deg_two_one_poly(d[:, indices[i, 0] : indices[i, 0] + 10], fun_select(null_, 0, j))
+ solvers.multiply_deg_two_one_poly(d[:, indices[i, 1] : indices[i, 1] + 10], fun_select(null_, 1, j))
+ solvers.multiply_deg_two_one_poly(d[:, indices[i, 2] : indices[i, 2] + 10], fun_select(null_, 2, j))
)
coeffs[:, cnt] = row
cnt += 1
Expand All @@ -125,10 +151,17 @@ def fun(i: int, j: int) -> torch.Tensor:
torch.linalg.matrix_rank(coeffs), ones_like(torch.linalg.matrix_rank(coeffs[:, :, :10])) * 10
)

# check if there is no solution
if singular_filter.sum() == 0:
return torch.eye(3, dtype=coeffs.dtype, device=coeffs.device)[None].expand(batch_size, 10, -1, -1).clone()

eliminated_mat = _torch_solve_cast(coeffs[singular_filter, :, :10], b[singular_filter])

coeffs_ = torch.cat((coeffs[singular_filter, :, :10], eliminated_mat), dim=-1)

# check the batch size after singular filter, for batch operation afterwards
batch_size_filtered = coeffs_.shape[0]

A = zeros(coeffs_.shape[0], 3, 13, device=coeffs_.device, dtype=coeffs_.dtype)

for i in range(3):
Expand All @@ -142,66 +175,88 @@ def fun(i: int, j: int) -> torch.Tensor:
A[:, i : i + 1, 9:13] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 16:20]
A[:, i : i + 1, 8:12] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 16:20]

# Bx11
cs = solvers.determinant_to_polynomial(A)
E_models = []

# for loop because of different numbers of solutions
for bi in range(A.shape[0]):
A_i = A[bi]
null_i = nullSpace[bi]

# companion matrix solver for polynomial
C = zeros((10, 10), device=cs.device, dtype=cs.dtype)
C[0:-1, 1:] = eye(C[0:-1, 0:-1].shape[0], device=cs.device, dtype=cs.dtype)
C[-1, :] = -cs[bi][:-1] / cs[bi][-1]

roots = torch.real(torch.linalg.eigvals(C))

if roots is None:
continue
n_sols = roots.size()
if n_sols == 0:
continue
Bs = stack(
(
A_i[:3, :1] * (roots**3) + A_i[:3, 1:2] * roots.square() + A_i[0:3, 2:3] * roots + A_i[0:3, 3:4],
A_i[0:3, 4:5] * (roots**3) + A_i[0:3, 5:6] * roots.square() + A_i[0:3, 6:7] * roots + A_i[0:3, 7:8],
),
dim=0,
).transpose(0, -1)

bs = (
A_i[0:3, 8:9] * (roots**4)
+ A_i[0:3, 9:10] * (roots**3)
+ A_i[0:3, 10:11] * roots.square()
+ A_i[0:3, 11:12] * roots
+ A_i[0:3, 12:13]
).T.unsqueeze(-1)

# We try to solve using top two rows,
xzs = (safe_inverse_with_mask(Bs[:, 0:2, 0:2])[0]) @ (bs[:, 0:2])

mask = (abs(Bs[:, 2].unsqueeze(1) @ xzs - bs[:, 2].unsqueeze(1)) > 1e-3).flatten()
if torch.sum(mask) != 0:
q, r = torch.linalg.qr(Bs[mask].clone()) #
xzs[mask] = _torch_solve_cast(r, q.transpose(-1, -2) @ bs[mask]) # [mask]

# models
Es = null_i[0] * (-xzs[:, 0]) + null_i[1] * (-xzs[:, 1]) + null_i[2] * roots.unsqueeze(-1) + null_i[3]

# Since the rows of N are orthogonal unit vectors, we can normalize the coefficients instead
inv = 1.0 / torch.sqrt((-xzs[:, 0]) ** 2 + (-xzs[:, 1]) ** 2 + roots.unsqueeze(-1) ** 2 + 1.0)
Es *= inv
if Es.shape[0] < 10:
Es = torch.cat(
(Es.clone(), eye(3, device=Es.device, dtype=Es.dtype).repeat(10 - Es.shape[0], 1).reshape(-1, 9))
)
E_models.append(Es)

# if not E_models:
# return torch.eye(3, device=cs.device, dtype=cs.dtype).unsqueeze(0)
# else:
return torch.cat(E_models).view(-1, 3, 3).transpose(-1, -2)
# A: Bx3x13
# nullSpace: Bx4x9
# companion matrices to solve the polynomial, in batch
C = zeros((batch_size_filtered, 10, 10), device=cs.device, dtype=cs.dtype)
eye_mat = eye(C[0, 0:-1, 0:-1].shape[0], device=cs.device, dtype=cs.dtype)
C[:, 0:-1, 1:] = eye_mat

cs_de = cs[:, -1].unsqueeze(-1)
cs_de = torch.where(cs_de == 0, torch.tensor(1e-8, dtype=cs_de.dtype), cs_de)
C[:, -1, :] = -cs[:, :-1] / cs_de

roots = torch.real(torch.linalg.eigvals(C))

roots_unsqu = roots.unsqueeze(1)

Bs = stack(
(
A[:, :3, :1] * (roots_unsqu**3)
+ A[:, :3, 1:2] * roots_unsqu.square()
+ A[:, 0:3, 2:3] * roots_unsqu
+ A[:, 0:3, 3:4],
A[:, 0:3, 4:5] * (roots_unsqu**3)
+ A[:, 0:3, 5:6] * roots_unsqu.square()
+ A[:, 0:3, 6:7] * roots_unsqu
+ A[:, 0:3, 7:8],
),
dim=1,
)

Bs = Bs.transpose(1, -1)

bs = (
(
A[:, 0:3, 8:9] * (roots_unsqu**4)
+ A[:, 0:3, 9:10] * (roots_unsqu**3)
+ A[:, 0:3, 10:11] * roots_unsqu.square()
+ A[:, 0:3, 11:12] * roots_unsqu
+ A[:, 0:3, 12:13]
)
.transpose(1, 2)
.unsqueeze(-1)
)

xzs = torch.matmul(torch.inverse(Bs[:, :, 0:2, 0:2]), bs[:, :, 0:2])

mask = (abs(Bs[:, 2].unsqueeze(1) @ xzs - bs[:, 2].unsqueeze(1)) > 1e-3).flatten()

# mask: bx10x1x1
mask = (
abs(torch.matmul(Bs[:, :, 2, :].unsqueeze(2), xzs) - bs[:, :, 2, :].unsqueeze(2)) > 1e-3
) # .flatten(start_dim=1)

# bx10
mask = mask.squeeze(3).squeeze(2)

if torch.any(mask):
q_batch, r_batch = torch.linalg.qr(Bs[mask])
xyz_to_feed = torch.linalg.solve(r_batch, torch.matmul(q_batch.transpose(-1, -2), bs[mask]))
xzs[mask] = xyz_to_feed

nullSpace_filtered = nullSpace[singular_filter]

Es = (
nullSpace_filtered[:, 0:1] * (-xzs[:, :, 0])
+ nullSpace_filtered[:, 1:2] * (-xzs[:, :, 1])
+ nullSpace_filtered[:, 2:3] * roots.unsqueeze(-1)
+ nullSpace_filtered[:, 3:4]
)

inv = 1.0 / torch.sqrt((-xzs[:, :, 0]) ** 2 + (-xzs[:, :, 1]) ** 2 + roots.unsqueeze(-1) ** 2 + 1.0)
Es *= inv

Es = Es.view(batch_size_filtered, -1, 3, 3).transpose(-1, -2)

# make sure the returned batch size equals to that of inputs
E_return = torch.eye(3, dtype=Es.dtype, device=Es.device)[None].expand(batch_size, 10, -1, -1).clone()
E_return[singular_filter] = Es

return E_return


def essential_from_fundamental(F_mat: torch.Tensor, K1: torch.Tensor, K2: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -482,32 +537,17 @@ def find_essential(
points1: torch.Tensor, points2: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
Args:
points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=5`.
points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=5`.
weights: Tensor containing the weights per point correspondence with a shape of :math:`(5, N)`.
Args:
points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=5`.
points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=5`.
weights: Tensor containing the weights per point correspondence with a shape of :math:`(5, N)`.
Returns:
the computed essential matrix with shape :math:`(B, 3, 3)`,
one model for each batch selected out of ten solutions by Sampson distances.
the computed essential matrices with shape :math:`(B, 10, 3, 3)`.
Note that all possible solutions are returned, i.e., 10 essential matrices for each image pair.
To choose the best one out of 10, try to check the one with the lowest Sampson distance.
"""
E = run_5point(points1, points2, weights).to(points1.dtype)

# select one out of 10 possible solutions from 5PC Nister solver.
solution_num = 10
batch_size = points1.shape[0]

error = zeros((batch_size, solution_num))

for b in range(batch_size):
error[b] = epi.sampson_epipolar_distance(points1[b], points2[b], E.view(batch_size, solution_num, 3, 3)[b]).sum(
-1
)

KORNIA_CHECK_SHAPE(error, ["f{batch_size}", "10"])

chosen_indices = torch.argmin(error, dim=-1)
result = stack([(E.view(-1, solution_num, 3, 3))[i, chosen_indices[i], :] for i in range(batch_size)])

return result
return E

0 comments on commit 2606afb

Please sign in to comment.