Skip to content

Commit

Permalink
Merge pull request #9 from fabio-sim/fabio/runtime
Browse files Browse the repository at this point in the history
Add inference time comparison
  • Loading branch information
fabio-sim committed Jul 3, 2023
2 parents baf9b92 + c06a431 commit 1da3ab3
Show file tree
Hide file tree
Showing 8 changed files with 477 additions and 7 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Open Neural Network Exchange (ONNX) compatible implementation of [LightGlue: Loc

## Updates

- **4 July 2023**: Add inference time comparisons.
- **1 July 2023**: Add support for extractor `max_num_keypoints`.
- **30 June 2023**: Add support for DISK extractor.
- **28 June 2023**: Add end-to-end SuperPoint+LightGlue export & inference pipeline.
Expand Down Expand Up @@ -67,6 +68,12 @@ python infer.py \
--viz
```

## Inference Time Comparison

In general, for smaller numbers of keypoints the ONNX version performs similarly to the PyTorch implementation. However, as the number of keypoints increases, the PyTorch CUDA implementation is faster, whereas ONNX is faster overall for CPU inference. See [EVALUATION.md](./evaluation/EVALUATION.md) for technical details.

<p align="center"><a href="https://github.com/fabio-sim/LightGlue-ONNX/blob/main/evaluation/EVALUATION.md"><img src="assets/latency.png" alt="Latency Comparison" width=80%></a>

## Caveats

As the ONNX Runtime has limited support for features like dynamic control flow, certain configurations of the models cannot be exported to ONNX easily. These caveats are outlined below.
Expand Down
Binary file added assets/latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
227 changes: 227 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import argparse
import time
from pathlib import Path

import numpy as np
from tqdm import tqdm


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"framework",
type=str,
choices=["torch", "ort"],
help="The LightGlue framework to measure inference time. Options are 'torch' for PyTorch and 'ort' for ONNXRuntime.",
)
parser.add_argument(
"--megadepth_path",
type=Path,
default=Path("megadepth_test_1500"),
required=False,
help="Path to the root of the MegaDepth dataset.",
)
parser.add_argument(
"--img_size", type=int, default=512, required=False, help="Image size."
)
parser.add_argument(
"--extractor_type",
type=str,
choices=["superpoint", "disk"],
default="superpoint",
required=False,
help="Type of feature extractor. Supported extractors are 'superpoint' and 'disk'.",
)
parser.add_argument(
"--max_num_keypoints",
type=int,
default=512,
required=False,
help="Maximum number of keypoints to extract.",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda", "cpu"],
default="cuda",
required=False,
help="cuda or cpu",
)
return parser.parse_args()


def get_megadepth_images(path: Path):
sort_key = lambda p: int(p.stem.split("_")[0])
images = sorted(
list((path / "Undistorted_SfM/0015/images").glob("*.jpg")), key=sort_key
) + sorted(list((path / "Undistorted_SfM/0022/images").glob("*.jpg")), key=sort_key)
return images


def create_models(
framework: str, extractor_type="superpoint", max_num_keypoints=512, device="cuda"
):
if framework == "torch":
if extractor_type == "superpoint":
extractor = (
SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(device)
)
elif extractor_type == "disk":
extractor = DISK(max_num_keypoints=max_num_keypoints).eval().to(device)

lightglue = LightGlue(extractor_type).eval().to(device)
elif framework == "ort":
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if device == "cuda"
else ["CPUExecutionProvider"]
)
extractor = ort.InferenceSession(
f"weights/{extractor_type}_{max_num_keypoints}.onnx",
sess_options=sess_opts,
providers=providers,
)

lightglue = ort.InferenceSession(
f"weights/{extractor_type}_lightglue.onnx",
sess_options=sess_opts,
providers=providers,
)

return extractor, lightglue


def measure_inference(
framework: str, extractor, lightglue, image0, image1, device="cuda"
) -> float:
if framework == "torch":
# Feature extraction time is not measured
feats0, feats1 = extractor({"image": image0}), extractor({"image": image1})
pred = {
**{k + "0": v for k, v in feats0.items()},
**{k + "1": v for k, v in feats1.items()},
"image0": image0,
"image1": image1,
}

# Measure only matching time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
with torch.no_grad():
result = lightglue(pred)
end.record()
torch.cuda.synchronize()

return start.elapsed_time(end)
elif framework == "ort":
# Feature extraction time is not measured
kpts0, scores0, desc0 = extractor.run(None, {"image": image0})
kpts1, scores1, desc1 = extractor.run(None, {"image": image1})

lightglue_inputs = {
"kpts0": LightGlueRunner.normalize_keypoints(
kpts0, image0.shape[2], image0.shape[3]
),
"kpts1": LightGlueRunner.normalize_keypoints(
kpts1, image1.shape[2], image1.shape[3]
),
"desc0": desc0,
"desc1": desc1,
}
lightglue_outputs = ["matches0", "matches1", "mscores0", "mscores1"]

if device == "cuda":
# Prepare IO-Bindings
binding = lightglue.io_binding()

for name, arr in lightglue_inputs.items():
binding.bind_cpu_input(name, arr)

for name in lightglue_outputs:
binding.bind_output(name, "cuda")

# Measure only matching time
start = time.perf_counter()
result = lightglue.run_with_iobinding(binding)
end = time.perf_counter()
else:
start = time.perf_counter()
result = lightglue.run(None, lightglue_inputs)
end = time.perf_counter()

return (end - start) * 1000


def evaluate(
framework,
megadepth_path=Path("megadepth_test_1500"),
img_size=512,
extractor_type="superpoint",
max_num_keypoints=512,
device="cuda",
):
images = get_megadepth_images(megadepth_path)
image_pairs = list(zip(images[::2], images[1::2]))

extractor, lightglue = create_models(
framework=framework,
extractor_type=extractor_type,
max_num_keypoints=max_num_keypoints,
device=device,
)

# Warmup
for image0, image1 in image_pairs[:10]:
image0, _ = load_image(str(image0), resize=img_size)
image1, _ = load_image(str(image1), resize=img_size)

if framework == "torch":
image0 = image0[None].to(device)
image1 = image1[None].to(device)
elif framework == "ort" and extractor_type == "superpoint":
image0 = rgb_to_grayscale(image0)
image1 = rgb_to_grayscale(image1)

_ = measure_inference(framework, extractor, lightglue, image0, image1, device)

# Measure
timings = []
for image0, image1 in tqdm(image_pairs):
image0, _ = load_image(str(image0), resize=img_size)
image1, _ = load_image(str(image1), resize=img_size)

if framework == "torch":
image0 = image0[None].to(device)
image1 = image1[None].to(device)
elif framework == "ort" and extractor_type == "superpoint":
image0 = rgb_to_grayscale(image0)
image1 = rgb_to_grayscale(image1)

inference_time = measure_inference(
framework, extractor, lightglue, image0, image1, device
)
timings.append(inference_time)

# Results
timings = np.array(timings)
print(f"Mean inference time: {timings.mean():.2f} +/- {timings.std():.2f} ms")
print(f"Median inference time: {np.median(timings):.2f} ms")


if __name__ == "__main__":
args = parse_args()
if args.framework == "torch":
import torch

from lightglue import DISK, LightGlue, SuperPoint
from lightglue.utils import load_image
elif args.framework == "ort":
import onnxruntime as ort

from onnx_runner import LightGlueRunner, load_image, rgb_to_grayscale

evaluate(**vars(args))
21 changes: 21 additions & 0 deletions evaluation/EVALUATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Evaluation

The inference time of LightGlue-ONNX is compared to that of the original PyTorch implementation with default configuration.

## Methods

Following the implementation details of the [LightGlue paper](https://arxiv.org/abs/2306.13643), we report the inference time, or latency, of only the LightGlue matcher; that is, the time taken for feature extraction, postprocessing, copying data between the host & device, or finding inliers (e.g., CONSAC/MAGSAC) is not measured. The average inference time is defined as the mean over all samples in the [MegaDepth](https://arxiv.org/abs/1804.00607) test dataset. We use the data provided by [LoFTR](https://arxiv.org/abs/2104.00680) [here](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md) - a total of 403 image pairs.

Each image is resized such that its longer side is 1024 before being fed into the feature extractor. The average inference time of the LightGlue matcher is then measured for different values of the extractor's `max_num_keypoints` parameter: 512, 1024, 2048, and 4096. The [SuperPoint](http://arxiv.org/abs/1712.07629) extractor is used.

All experiments are conducted on a [Google Colab](https://colab.research.google.com/github/fabio-sim/LightGlue-ONNX/blob/main/evaluation/lightglue-onnx.ipynb) GPU Runtime (Tesla T4).

## Results

The measured run times are plotted in the figure below.

![Latency vs. Number of Keypoints](../assets/latency.png)

<table align="center"><thead><tr><th>Number of Keypoints</th><th></th><th>512</th><th>1024</th><th>2048</th><th>4096</th></tr><tr><th>Model</th><th>Device</th><th colspan="4">Latency (ms)</th></tr></thead><tbody><tr><td>LightGlue</td><td>CUDA</td><td>35.42</td><td>47.36</td><td>112.87</td><td>187.51</td></tr><tr><td>LightGlue-ONNX</td><td>CUDA</td><td>30.44</td><td>82.24</td><td>269.39</td><td>519.41</td></tr><tr><td>LightGlue</td><td>CPU</td><td>1121</td><td>3818</td><td>15968</td><td>37587</td></tr><tr><td>LightGlue-ONNX</td><td>CPU</td><td>759</td><td>2961</td><td>10493</td><td>20143</td></tr></tbody></table>

At smaller numbers of keypoints, the difference between the CUDA ONNX and PyTorch latencies are small; however, this becomes much more noticeable at higher keypoint numbers, where PyTorch is faster. The cause remains to be investigated (different operator implementations?). On the other hand, ONNX is faster overall for CPU inference.
Loading

0 comments on commit 1da3ab3

Please sign in to comment.