Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LightGlue #2436

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def add_doctest_deps(doctest_namespace):

# the commit hash for the data version
sha: str = 'cb8f42bf28b9f347df6afba5558738f62a11f28a'
sha2: str = '824ff1518870864644df6842a4ec964040f64504'
sha2: str = 'f7d8da661701424babb64850e03c5e8faec7ea62'
sha3: str = '8b98f44abbe92b7a84631ed06613b08fee7dae14'


Expand All @@ -122,6 +122,7 @@ def data(request):
'loftr_homo': f'https://github.com/kornia/data_test/blob/{sha}/loftr_outdoor_and_homography_data.pt?raw=true',
'loftr_fund': f'https://github.com/kornia/data_test/blob/{sha}/loftr_indoor_and_fundamental_data.pt?raw=true',
'adalam_idxs': f'https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true',
'lightglue_idxs': f'https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true',
'disk_outdoor': f'https://github.com/kornia/data_test/blob/{sha3}/knchurch_disk.pt?raw=true',
'dexined': 'https://cmp.felk.cvut.cz/~mishkdmy/models/DexiNed_BIPED_10.pth',
}
Expand Down
6 changes: 6 additions & 0 deletions docs/source/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ Matching
.. autoclass:: GeometryAwareDescriptorMatcher
:members: forward

.. autoclass:: LightGlueMatcher
:members: forward

.. autoclass:: LightGlue
:members: forward

.. autoclass:: LocalFeature
:members: forward

Expand Down
9 changes: 9 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ @article{HardNet2020
year={2020}
}

@article{LightGlue2023,
author={Philipp Lindenberger and
Paul-Edouard Sarlin and
Marc Pollefeys},
title = {LightGlue: Local Feature Matching at Light Speed},
journal={arXiv ePrint 2306.13643},
year={2023}
}

@inproceedings{FRN2019,
title={Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networks},
author={Saurabh Singh and Shankar Krishnan},
Expand Down
4 changes: 4 additions & 0 deletions kornia/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
KeyNetAffNetHardNet,
KeyNetHardNet,
LAFDescriptor,
LightGlueMatcher,
LocalFeature,
LocalFeatureMatcher,
SIFTFeature,
Expand Down Expand Up @@ -37,6 +38,7 @@
scale_laf,
set_laf_orientation,
)
from .lightglue.lightglue import LightGlue
from .loftr import LoFTR
from .matching import (
DescriptorMatcher,
Expand Down Expand Up @@ -105,6 +107,8 @@
"laf_to_boundary_points",
"ellipse_to_laf",
"make_upright",
"LightGlue",
"LightGlueMatcher",
"get_laf_scale",
"get_laf_center",
"get_laf_orientation",
Expand Down
66 changes: 66 additions & 0 deletions kornia/feature/integrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .hardnet import HardNet
from .keynet import KeyNetDetector
from .laf import extract_patches_from_pyramid, get_laf_center, scale_laf
from .lightglue.lightglue import LightGlue
from .matching import GeometryAwareDescriptorMatcher, _no_match
from .orientation import LAFOrienter, OriNet, PassLAF
from .responses import BlobDoG, BlobDoGSingle, BlobHessian, CornerGFTT
from .scale_space_detector import (
Expand Down Expand Up @@ -393,3 +395,67 @@ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
'confidence': concatenate(out_confidence, dim=0).view(-1),
'batch_indexes': concatenate(out_batch_indexes, dim=0).view(-1),
}


class LightGlueMatcher(GeometryAwareDescriptorMatcher):
"""LightGlue-based matcher in kornia API.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""LightGlue-based matcher in kornia API.
"""LightGlue-based matcher in kornia API.

This is based on the original code from paper "LightGlue: Local Feature Matching at Light Speed". See :cite:`LightGlue2023` for
more details.

Args:
feature_name: type of matching, can be `disk` or `superpoint`.
params: LightGlue params.
"""

known_modes = ['superpoint', 'disk']

def __init__(self, feature_name: str = 'disk', params: Dict = {}) -> None:
feature_name_: str = feature_name.lower()
super().__init__(feature_name_)
self.feature_name = feature_name_
self.params = params
self.matcher = LightGlue(self.feature_name, **params)

def forward(
self,
desc1: Tensor,
desc2: Tensor,
lafs1: Tensor,
lafs2: Tensor,
hw1: Optional[Tuple[int, int]] = None,
hw2: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
desc1: Batch of descriptors of a shape :math:`(B1, D)`.
desc2: Batch of descriptors of a shape :math:`(B2, D)`.
lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
lafs2: LAFs of a shape :math:`(1, B1, 2, 3)`.

Return:
- Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- Long tensor indexes of matching descriptors in desc1 and desc2,
shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
"""
if (desc1.shape[0] < 2) or (desc2.shape[0] < 2):
return _no_match(desc1)

input_dict = {
"keypoints0": get_laf_center(lafs1),
"keypoints1": get_laf_center(lafs2),
"descriptors0": desc1[None],
"descriptors1": desc2[None],
}
if hw1 is None:
hw1 = input_dict['keypoints0'].max(dim=1)[0].squeeze().flip(0)
if hw2 is None:
hw2 = input_dict['keypoints1'].max(dim=1)[0].squeeze().flip(0)
input_dict['image_size0'] = hw1.flip(0)[None].to(lafs1.device)
input_dict['image_size1'] = hw2.flip(0)[None].to(lafs1.device)
for k, v in input_dict.items():
print(k, v.shape)
pred = self.matcher(input_dict)
matches0, mscores0 = pred['matches0'], pred['matching_scores0']
valid = matches0 > -1
matches = torch.stack([torch.where(valid)[1], matches0[valid]], -1)
return mscores0[valid].reshape(-1, 1), matches
Empty file.
Loading
Loading