Skip to content

Commit

Permalink
Merge pull request #18 from fabio-sim/fabio/tensorrt-eval
Browse files Browse the repository at this point in the history
TensorRT eval
  • Loading branch information
fabio-sim committed Jul 20, 2023
2 parents c7c34f7 + 727f38a commit e82a1a4
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 21 deletions.
24 changes: 13 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
*.egg-info
*.pyc
/.idea/
/data/
/outputs/
__pycache__
/lightglue/weights/
lightglue/_flash/
*-checkpoint.ipynb
*.pth
*.onnx
*.egg-info
*.pyc
/.idea/
/data/
/outputs/
__pycache__
/lightglue/weights/
lightglue/_flash/
*-checkpoint.ipynb
*.pth
*.onnx
*.engine
*.profile
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# LightGlue ONNX

Open Neural Network Exchange (ONNX) compatible implementation of [LightGlue: Local Feature Matching at Light Speed](https://github.com/cvg/LightGlue). The ONNX model format allows for interoperability across different platforms with support for multiple execution providers, and removes Python-specific dependencies such as PyTorch. Experimental support for TensorRT.
Open Neural Network Exchange (ONNX) compatible implementation of [LightGlue: Local Feature Matching at Light Speed](https://github.com/cvg/LightGlue). The ONNX model format allows for interoperability across different platforms with support for multiple execution providers, and removes Python-specific dependencies such as PyTorch. Supports TensorRT and OpenVINO.

<p align="center"><a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="LightGlue figure" width=80%></a>

Expand Down Expand Up @@ -57,6 +57,7 @@ runner = LightGlueRunner(
extractor_path="weights/superpoint.onnx",
lightglue_path="weights/superpoint_lightglue.onnx",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
# TensorrtExecutionProvider, OpenVINOExecutionProvider
)

# Run inference
Expand All @@ -77,9 +78,9 @@ python infer.py \
--viz
```

## TensorRT Support (Experimental)
## TensorRT Support

TensorRT inference is supported via the TensorRT Execution Provider in ONNXRuntime.
TensorRT inference is supported via the TensorRT Execution Provider in ONNXRuntime. Please follow the [official documentation](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) to install TensorRT. The exported ONNX models (whether standalone or end-to-end) must undergo [shape inference](/tools/symbolic_shape_infer.py) for compatibility with TensorRT:

```bash
python tools/symbolic_shape_infer.py \
Expand All @@ -101,7 +102,7 @@ CUDA_MODULE_LOADING=LAZY && python infer.py \
--viz
```

The first run will take longer because TensorRT needs to initialise the `.engine` and `.profile` files. Subsequent runs should use the cached files. Note that the ONNX models should not be exported with `--mp` or `--flash`. Only the SuperPoint extractor type is supported. Note that you might want to export with static input image shapes and `--max_num_keypoints` for better runtime optimisation. The same methodology can be applied to end-to-end models.
The first run will take longer because TensorRT needs to initialise the `.engine` and `.profile` files. Subsequent runs should use the cached files. Note that the ONNX models should not be exported with `--mp` or `--flash`. Only the SuperPoint extractor type is supported.

## Inference Time Comparison

Expand Down
Binary file modified assets/latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions docs/README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# LightGlue ONNX

支持Open Neural Network Exchange (ONNX)的[LightGlue: Local Feature Matching at Light Speed](https://github.com/cvg/LightGlue)实施。ONNX格式支持不同平台之间的互操作性,并支持多个执行提供程序,同时消除了Python特定的依赖项,比如PyTorch。支持TensorRT(实验性)
支持Open Neural Network Exchange (ONNX)的[LightGlue: Local Feature Matching at Light Speed](https://github.com/cvg/LightGlue)实施。ONNX格式支持不同平台之间的互操作性,并支持多个执行提供程序,同时消除了Python特定的依赖项,比如PyTorch。支持TensorRT和OpenVINO

<p align="center"><a href="https://arxiv.org/abs/2306.13643"><img src="../assets/easy_hard.jpg" alt="LightGlue figure" width=80%></a>

Expand Down Expand Up @@ -54,6 +54,7 @@ runner = LightGlueRunner(
extractor_path="weights/superpoint.onnx",
lightglue_path="weights/superpoint_lightglue.onnx",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
# TensorrtExecutionProvider, OpenVINOExecutionProvider
)

# Run inference
Expand All @@ -74,9 +75,9 @@ python infer.py \
--viz
```

## TensorRT (实验性)
## TensorRT

TensorRT推理使用ONNXRuntime的TensorRT Execution Provider。
TensorRT推理使用ONNXRuntime的TensorRT Execution Provider。请先安装[TensorRT](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html)

```bash
python tools/symbolic_shape_infer.py \
Expand Down
27 changes: 27 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Whether to use Flash Attention (CUDA only). Flash Attention must be installed.",
)
parser.add_argument(
"--trt",
action="store_true",
help="Whether to use TensorRT (experimental).",
)

# ONNXRuntime-specific args
parser.add_argument(
Expand Down Expand Up @@ -91,6 +96,7 @@ def create_models(
device="cuda",
mp=False,
flash=False,
trt=False,
extractor_path=None,
lightglue_path=None,
):
Expand All @@ -111,12 +117,14 @@ def create_models(
if device == "cuda"
else ["CPUExecutionProvider"]
)

if extractor_path is None:
extractor_path = (
f"weights/{extractor_type}_{max_num_keypoints}"
f"{'_mp' if mp else ''}"
".onnx"
)

extractor = ort.InferenceSession(
extractor_path,
sess_options=sess_opts,
Expand All @@ -130,6 +138,23 @@ def create_models(
f"{'_flash' if flash else ''}"
".onnx"
)

if trt:
assert device == "cuda", "TensorRT is only supported on CUDA devices."
providers = [
(
"TensorrtExecutionProvider",
{
"trt_fp16_enable": True,
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "weights/cache",
"trt_profile_min_shapes": f"kpts0:1x1x2,kpts1:1x1x2,desc0:1x1x256,desc1:1x1x256",
"trt_profile_opt_shapes": f"kpts0:1x{max_num_keypoints}x2,kpts1:1x{max_num_keypoints}x2,desc0:1x{max_num_keypoints}x256,desc1:1x{max_num_keypoints}x256",
"trt_profile_max_shapes": f"kpts0:1x{max_num_keypoints}x2,kpts1:1x{max_num_keypoints}x2,desc0:1x{max_num_keypoints}x256,desc1:1x{max_num_keypoints}x256",
},
)
] + providers

lightglue = ort.InferenceSession(
lightglue_path,
sess_options=sess_opts,
Expand Down Expand Up @@ -211,6 +236,7 @@ def evaluate(
device="cuda",
mp=False,
flash=False,
trt=False,
extractor_path=None,
lightglue_path=None,
):
Expand All @@ -224,6 +250,7 @@ def evaluate(
device=device,
mp=mp,
flash=flash,
trt=trt,
extractor_path=extractor_path,
lightglue_path=lightglue_path,
)
Expand Down
32 changes: 30 additions & 2 deletions evaluation/EVALUATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,42 @@ Following the implementation details of the [LightGlue paper](https://arxiv.org/

Each image is resized such that its longer side is 1024 before being fed into the feature extractor. The average inference time of the LightGlue matcher is then measured for different values of the extractor's `max_num_keypoints` parameter: 512, 1024, 2048, and 4096. The [SuperPoint](http://arxiv.org/abs/1712.07629) extractor is used.

All experiments are conducted on a [Google Colab](lightglue-onnx.ipynb) GPU Runtime (Tesla T4).
All experiments are conducted on a [Google Colab](lightglue-onnx.ipynb) GPU Runtime (Tesla T4) with `CUDA==11.8.1` and `TensorRT==8.6.1`.

## Results

The measured run times are plotted in the figure below.

<p align="center"><a href="https://github.com/fabio-sim/LightGlue-ONNX/blob/main/evaluation/EVALUATION.md"><img src="../assets/latency.png" alt="Latency Comparison" width=100%></a>

<table align="center"><thead><tr><th>Number of Keypoints</th><th></th><th>512</th><th>1024</th><th>2048</th><th>4096</th></tr><tr><th>Model</th><th>Device</th><th colspan="4">Latency (ms)</th></tr></thead><tbody><tr><td>LightGlue</td><td>CUDA</td><td>35.42</td><td>47.36</td><td>112.87</td><td>187.51</td></tr><tr><td>LightGlue-ONNX</td><td>CUDA</td><td>30.44</td><td>82.24</td><td>269.39</td><td>519.41</td></tr><tr><td>LightGlue-MP</td><td>CUDA</td><td>36.32</td><td>37.10</td><td>61.58</td><td>127.59</td></tr><tr><td>LightGlue-ONNX-MP</td><td>CUDA</td><td>24.2</td><td>66.27</td><td>227.91</td><td>473.71</td></tr><tr><td>LightGlue-MP-Flash</td><td>CUDA</td><td>38.3</td><td>38.8</td><td>42.9</td><td>55.9</td></tr><tr><td>LightGlue-ONNX-MP-Flash</td><td>CUDA</td><td>21.2</td><td>57.4</td><td>191.1</td><td>368.9</td></tr><tr><td>LightGlue</td><td>CPU</td><td>1121</td><td>3818</td><td>15968</td><td>37587</td></tr><tr><td>LightGlue-ONNX</td><td>CPU</td><td>759</td><td>2961</td><td>10493</td><td>20143</td></tr></tbody></table>
<table align="center"><thead><tr><th>Number of Keypoints</th><th></th><th>512</th><th>1024</th><th>2048</th><th>4096</th></tr><tr><th>Model</th><th>Device</th><th colspan="4">Latency (ms)</th></tr></thead><tbody><tr><td>LightGlue</td><td>CUDA</td><td>35.42</td><td>47.36</td><td>112.87</td><td>187.51</td></tr><tr><td>LightGlue-ONNX</td><td>CUDA</td><td>30.44</td><td>82.24</td><td>269.39</td><td>519.41</td></tr><tr><td>LightGlue-MP</td><td>CUDA</td><td>36.32</td><td>37.10</td><td>61.58</td><td>127.59</td></tr><tr><td>LightGlue-ONNX-MP</td><td>CUDA</td><td>24.2</td><td>66.27</td><td>227.91</td><td>473.71</td></tr><tr><td>LightGlue-MP-Flash</td><td>CUDA</td><td>38.3</td><td>38.8</td><td>42.9</td><td>55.9</td></tr><tr><td>LightGlue-ONNX-MP-Flash</td><td>CUDA</td><td>21.2</td><td>57.4</td><td>191.1</td><td>368.9</td></tr><tr><td>LightGlue-ONNX-TRT</td><td>TensorRT-CUDA</td><td>7.08</td><td>15.88</td><td>47.04</td><td>107.89</td></tr><tr><td>LightGlue</td><td>CPU</td><td>1121</td><td>3818</td><td>15968</td><td>37587</td></tr><tr><td>LightGlue-ONNX</td><td>CPU</td><td>759</td><td>2961</td><td>10493</td><td>20143</td></tr></tbody></table>

At smaller numbers of keypoints, the difference between the CUDA ONNX and PyTorch latencies are small; however, this becomes much more noticeable at higher keypoint numbers, where PyTorch is faster. The cause remains to be investigated (different operator implementations?). On the other hand, ONNX is faster overall for CPU inference.

## TensorRT

Note that TensorRT incurs an upfront initialisation cost in order to build the `.engine` and `.profile` files during the first run (subsequent runs can use the cached files). Depending on the machine, this build time can take more than 10 minutes to complete. When using dynamic axes with the TensorRT Execution Provider, it is recommended to pass the min-opt-max shape range options in order to prevent TensorRT from having to rebuild a new runtime profile whenever an unexpected shape is encountered. Corresponding snippet from [`eval.py`](/eval.py):

```python
import onnxruntime as ort

max_num_keypoints = 512 # 1024, 2048

trt_ep_options = {
"trt_fp16_enable": True,
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "weights/cache",
"trt_profile_min_shapes": f"kpts0:1x1x2,kpts1:1x1x2,desc0:1x1x256,desc1:1x1x256",
"trt_profile_opt_shapes": f"kpts0:1x{max_num_keypoints}x2,kpts1:1x{max_num_keypoints}x2,desc0:1x{max_num_keypoints}x256,desc1:1x{max_num_keypoints}x256",
"trt_profile_max_shapes": f"kpts0:1x{max_num_keypoints}x2,kpts1:1x{max_num_keypoints}x2,desc0:1x{max_num_keypoints}x256,desc1:1x{max_num_keypoints}x256"
}

sess = ort.InferenceSession(
"superpoint_lightglue.onnx",
providers=[
("TensorrtExecutionProvider", trt_ep_options),
"CUDAExecutionProvider",
"CPUExecutionProvider"
]
)
```
32 changes: 31 additions & 1 deletion evaluation/lightglue-onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"outputs": [],
"source": [
"!nvidia-smi\n",
"!nvcc --version"
"!nvcc --version\n",
"!lsb_release -a"
]
},
{
Expand All @@ -51,6 +52,22 @@
"# !pip install -q flash-attn==1.0.8 --no-build-isolation # Time-consuming (~30 minutes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install TensorRT\n",
"!wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/8.6.1/local_repos/nv-tensorrt-local-repo-ubuntu2204-8.6.1-cuda-11.8_1.0-1_amd64.deb\n",
"!sudo dpkg -i nv-tensorrt-local-repo-ubuntu2204-8.6.1-cuda-11.8_1.0-1_amd64.deb\n",
"!sudo cp /var/nv-tensorrt-local-repo-ubuntu2204-8.6.1-cuda-11.8/*-keyring.gpg /usr/share/keyrings/\n",
"!sudo cp /var/nv-tensorrt-local-repo-ubuntu2204-8.6.1-cuda-11.8/nv-tensorrt-local-0628887B-keyring.gpg /usr/share/keyrings/\n",
"!sudo apt-get update\n",
"!sudo apt-get install tensorrt python3-libnvinfer-dev\n",
"!dpkg-query -W tensorrt"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -134,6 +151,19 @@
" --dynamic --max_num_keypoints 512"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For TensorRT\n",
"!cd LightGlue-ONNX && python tools/symbolic_shape_infer.py \\\n",
" --input weights/superpoint_lightglue.onnx \\\n",
" --output weights/superpoint_lightglue.onnx \\\n",
" --auto_merge"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down

0 comments on commit e82a1a4

Please sign in to comment.