# Comparison of SuperGlue and LightGlue Demo
In this notebook we match two pairs of images using SuperGlue and LightGlue

In [None]:
# Import necessary dependecies
from pathlib import Path
from lightglue import LightGlue, LightGlue_custom, SuperPoint, SuperGlue
from lightglue.utils import load_image, rbd
from lightglue import viz2d
import torch

torch.set_grad_enabled(False)
images = Path("assets")

## Load SuperPoint Extractor and Images for Matching
As of now, the SuperGlue matcher only supports SuperPoint. Changes needed to support different input dimensional 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 'mps', 'cpu'

# Configs
matcher_features = "superpoint"
extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)  # load the extractor
image0 = load_image(images / "sacre_coeur1.jpg")
image1 = load_image(images / "sacre_coeur2.jpg")


## Run LightGlue Matcher
The top image shows the matches, while the bottom image shows the detected points pruned across layers.
For pairs with significant viewpoint- and illumination changes, LightGlue can exclude a lot of points early in the matching process (red points), which significantly reduces the inference time.

In [None]:
setup = {
            'matcher_model': LightGlue(features=matcher_features).eval().to(device),
            'matcher_name': "LightGlue",
            }

# Load
matcher = setup['matcher_model']
matcher_name = setup['matcher_name']

# Extract features
feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))

# Match features with matcher
matches01 = matcher({"image0": feats0, "image1": feats1})
feats0, feats1, matches01 = [
    rbd(x) for x in [feats0, feats1, matches01]
]  # remove batch dimension

# Identify matched keypoints
kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
conf_threshold = matcher.default_conf["filter_threshold"] # Min confidence threshold of matcher
# conf_threshold = 0
valid0 = (matches01['matching_scores0'] >  conf_threshold)
valid1 = (matches01['matching_scores1'] >  conf_threshold)
matching_num0 = sum(valid0.long())
matching_num1 = sum(valid1.long())
mconf0 = matches01['matching_scores0'][valid0]
mconf1 = matches01['matching_scores1'][valid1]

# Ensure consistency of matches
sum_mconf0 = sum(mconf0)
sum_mconf1 = sum(mconf1)
try:
    assert torch.round(sum_mconf0, decimals=3) == torch.round(sum_mconf1, decimals=3)
except:
    print("0 points met confidence threshold!")
assert matching_num0 == matching_num1

# Calculate norm-score and match-prop
num_kpts0 = len(kpts0)
num_kpts1 = len(kpts1)
matching_score = sum_mconf0 / matching_num0
match_prop = matching_num0 / min(num_kpts0, num_kpts1)

# Plot primary
axes = viz2d.plot_images([image0, image1])
viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
label_text = [
                matcher_name + " with " + list(matcher.features.keys())[0],
                'Keypoints: {}:{}'.format(num_kpts0, num_kpts1),
                'Matches: {}'.format(matching_num0),
                'norm-score: {:.4f}'.format(matching_score),
                'match-prop: {:.4f}'.format(match_prop), 
                'matching-num: {:4f}'.format(matching_num0),
                'conf-thresh: {:4f}'.format(conf_threshold)
            ]
text_pos = [0.01, 0.99]
for labels in label_text:
    viz2d.add_text(0, text=labels, pos=text_pos, fs=15)
    text_pos[1] = text_pos[1] - 0.05

# Plot secondary
kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"])
viz2d.plot_images([image0, image1])
viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)
label_text_sec =    [
                    'Detected Points',
                    'Stop after {} layers'.format(matches01["stop"]),
                    'Colors indicate respective layers'
                    ]
text_pos = [0.01, 0.99]
for labels in label_text_sec:
    viz2d.add_text(0, text=labels, pos=text_pos, fs=15)
    text_pos[1] = text_pos[1] - 0.05



## Run SuperGlue Matcher
The top image shows the matches, while the bottom image shows the detected SuperPoints

In [None]:
# Load Matcher
setup = {
            'matcher_model': SuperGlue(features=matcher_features).eval().to(device),
            'matcher_name': "SuperGlue",
            }

# Load
matcher = setup['matcher_model']
matcher_name = setup['matcher_name']

# Extract features
feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))

# Match features with matcher
matches01 = matcher({"image0": feats0, "image1": feats1})
feats0, feats1, matches01 = [
    rbd(x) for x in [feats0, feats1, matches01]
]  # remove batch dimension

# Identify matched keypoints
kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
conf_threshold = matcher.default_conf["filter_threshold"] # Min confidence threshold of matcher
valid0 = (matches01['matching_scores0'] >  conf_threshold)
valid1 = (matches01['matching_scores1'] >  conf_threshold)
matching_num0 = sum(valid0.long())
matching_num1 = sum(valid1.long())
mconf0 = matches01['matching_scores0'][valid0]
mconf1 = matches01['matching_scores1'][valid1]

# Ensure consistency of matches
sum_mconf0 = sum(mconf0)
sum_mconf1 = sum(mconf1)
try:
    assert torch.round(sum_mconf0, decimals=3) == torch.round(sum_mconf1, decimals=3)
except:
    print("0 points met confidence threshold!")
assert matching_num0 == matching_num1

# Calculate norm-score and match-prop
num_kpts0 = len(kpts0)
num_kpts1 = len(kpts1)
matching_score = sum_mconf0 / matching_num0
match_prop = matching_num0 / min(num_kpts0, num_kpts1)

# Plot primary
axes = viz2d.plot_images([image0, image1])
viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
label_text = [
                matcher_name + " with " + list(matcher.features.keys())[0],
                'Keypoints: {}:{}'.format(num_kpts0, num_kpts1),
                'Matches: {}'.format(matching_num0),
                'norm-score: {:.4f}'.format(matching_score),
                'match-prop: {:.4f}'.format(match_prop), 
                'matching-num: {:4f}'.format(matching_num0),
                'conf-thresh: {:4f}'.format(conf_threshold)
            ]
text_pos = [0.01, 0.99]
for labels in label_text:
    viz2d.add_text(0, text=labels, pos=text_pos, fs=15)
    text_pos[1] = text_pos[1] - 0.05

# Plot secondary
viz2d.plot_images([image0, image1])
viz2d.plot_keypoints([kpts0, kpts1], colors="b", ps=15)
viz2d.add_text(0, f'Detected Points', fs=15)
