Skip to content

Commit

Permalink
Add support for ONNX-only
Browse files Browse the repository at this point in the history
Summary:
This PR is composed of different fixes to enable and end-to-end ONNX export functionality for detectron2 models

* `add_export_config` API is publicly available exposed even when caffe2 is not compiled along with PyTorch (that is the new default behavior on latest PyTorch). A warning message informing users about its deprecation on future versions is also added

* `tensor.shape[0]` replaces `len(tensor)` and `for idx, img in enumerate(tensors)` replaces `for tmp_var1, tmp_var2 in zip(tensors, batched_imgs)`  so that the tracer does not lose reference to the user input on the graphs.
  * Before the changes above, the graph (see below) does not have an actual `input`. Instead, the input is exported as a model weight
![image](https://user-images.githubusercontent.com/5469809/171214657-199ca795-b0f2-4b0d-8ee7-5900db814a86.png)
  * After the fix, the user images are properly acknowledged as model's input (see below) during ONNX export
![image](https://user-images.githubusercontent.com/5469809/171227463-955092f4-bbb3-4920-8bc4-d88e79c9f687.png)

* Added unit tests (`tests/torch_export_onnx.py`) for detectron2 models

* ONNX is added as dependency for the CI to be able to run the aforementioned tests

* Added custom symbolic functions to allow CI pipelines to succeed. The symbolics are needed because PyTorch 1.8, 1.9 and 1.10 adopted by detectron2 have several bugs. They can be removed when 1.11+ is adopted by detectron2's CI infra

Fixes #3488
Fixes pytorch/pytorch#69674 (PyTorch repo)

Pull Request resolved: #4291

Reviewed By: wat3rBro

Differential Revision: D37152780

Pulled By: mcimpoi

fbshipit-source-id: edd39319fae29d9c3fb7ee907bbeda7b64c48b67
  • Loading branch information
thiagocrepaldi authored and facebook-github-bot committed Jul 16, 2022
1 parent cef4068 commit 48b598b
Show file tree
Hide file tree
Showing 19 changed files with 553 additions and 42 deletions.
10 changes: 5 additions & 5 deletions .circleci/config.yml
Expand Up @@ -10,7 +10,7 @@ cpu: &cpu

gpu: &gpu
machine:
# NOTE: use a cuda vesion that's supported by all our pytorch versions
# NOTE: use a cuda version that's supported by all our pytorch versions
image: ubuntu-1604-cuda-11.1:202012-01
resource_class: gpu.nvidia.small

Expand Down Expand Up @@ -94,7 +94,7 @@ setup_venv: &setup_venv
setup_venv_win: &setup_venv_win
- run:
name: Setup Virutal Env for Windows
name: Setup Virtual Env for Windows
command: |
pip install virtualenv
python -m virtualenv env
Expand All @@ -113,7 +113,7 @@ install_linux_dep: &install_linux_dep
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/iopath'
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
# Don't use pytest-xdist: cuda tests are unstable under multi-process workers.
pip install --progress-bar off ninja opencv-python-headless pytest tensorboard pycocotools
pip install --progress-bar off ninja opencv-python-headless pytest tensorboard pycocotools onnx
pip install --progress-bar off torch==$PYTORCH_VERSION -f $PYTORCH_INDEX
if [[ "$TORCHVISION_VERSION" == "master" ]]; then
pip install git+https://github.com/pytorch/vision.git
Expand All @@ -139,7 +139,7 @@ run_unittests: &run_unittests
- run:
name: Run Unit Tests
command: |
pytest -v --durations=15 tests # parallel causes some random failures
pytest -sv --durations=15 tests # parallel causes some random failures
uninstall_tests: &uninstall_tests
- run:
Expand Down Expand Up @@ -227,7 +227,7 @@ jobs:
command: |
pip install certifi --ignore-installed # required on windows to workaround some cert issue
pip install numpy cython # required on windows before pycocotools
pip install opencv-python-headless pytest-xdist pycocotools tensorboard
pip install opencv-python-headless pytest-xdist pycocotools tensorboard onnx
pip install -U git+https://github.com/facebookresearch/iopath
pip install -U git+https://github.com/facebookresearch/fvcore
pip install torch==$env:PYTORCH_VERSION torchvision==$env:TORCHVISION_VERSION -f $env:PYTORCH_INDEX
Expand Down
2 changes: 1 addition & 1 deletion .flake8
Expand Up @@ -2,7 +2,7 @@
# Keep in sync with setup.cfg which is used for source packages.

[flake8]
ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811
ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Expand Up @@ -78,4 +78,4 @@ jobs:
python -m detectron2.utils.collect_env
./datasets/prepare_for_tests.sh
- name: Run unittests
run: python -m pytest -n 4 --durations=15 -v tests/
run: python -m pytest -n 4 --durations=15 -sv tests/
4 changes: 3 additions & 1 deletion detectron2/export/README.md
@@ -1,6 +1,6 @@

This directory contains code to prepare a detectron2 model for deployment.
Currently it supports exporting a detectron2 model to Caffe2 format through ONNX.
Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.

Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.

Expand All @@ -11,3 +11,5 @@ Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion to

Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
help export Detectron2 models to TorchScript.

Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
18 changes: 16 additions & 2 deletions detectron2/export/__init__.py
@@ -1,5 +1,10 @@
# -*- coding: utf-8 -*-

import warnings

from .flatten import TracingAdapter
from .torchscript import dump_torchscript_IR, scripting_with_instances

try:
from caffe2.proto import caffe2_pb2 as _tmp

Expand All @@ -9,7 +14,16 @@
else:
from .api import *

from .flatten import TracingAdapter
from .torchscript import scripting_with_instances, dump_torchscript_IR

# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
STABLE_ONNX_OPSET_VERSION = 11


def add_export_config(cfg):
warnings.warn(
"add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
)
return cfg


__all__ = [k for k in globals().keys() if not k.startswith("_")]
5 changes: 0 additions & 5 deletions detectron2/export/api.py
Expand Up @@ -14,16 +14,11 @@
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph

__all__ = [
"add_export_config",
"Caffe2Model",
"Caffe2Tracer",
]


def add_export_config(cfg):
return cfg


class Caffe2Tracer:
"""
Make a detectron2 model traceable with Caffe2 operators.
Expand Down
12 changes: 7 additions & 5 deletions detectron2/layers/wrappers.py
Expand Up @@ -8,6 +8,7 @@
is implemented
"""

import warnings
from typing import List, Optional
import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -102,11 +103,12 @@ def forward(self, x):
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
with warnings.catch_warnings(record=True):
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"

x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
Expand Down
3 changes: 1 addition & 2 deletions detectron2/modeling/meta_arch/rcnn.py
Expand Up @@ -218,8 +218,7 @@ def inference(
if do_postprocess:
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
else:
return results
return results

def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
"""
Expand Down
6 changes: 4 additions & 2 deletions detectron2/structures/image_list.py
Expand Up @@ -121,7 +121,9 @@ def from_tensors(
)
batched_imgs = tensors[0].new_full(batch_shape, pad_value, device=device)
batched_imgs = move_device_like(batched_imgs, tensors[0])
for img, pad_img in zip(tensors, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
for i, img in enumerate(tensors):
# Use `batched_imgs` directly instead of `img, pad_img = zip(tensors, batched_imgs)`
# Tracing mode cannot capture `copy_()` of temporary locals
batched_imgs[i, ..., : img.shape[-2], : img.shape[-1]].copy_(img)

return ImageList(batched_imgs.contiguous(), image_sizes)
6 changes: 5 additions & 1 deletion detectron2/structures/instances.py
@@ -1,8 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import warnings
from typing import Any, Dict, List, Tuple, Union
import torch

from detectron2.structures import Boxes


class Instances:
"""
Expand Down Expand Up @@ -71,7 +74,8 @@ def set(self, name: str, value: Any) -> None:
The length of `value` must be the number of instances,
and must agree with other existing fields in this object.
"""
data_len = len(value)
with warnings.catch_warnings(record=True):
data_len = len(value)
if len(self._fields):
assert (
len(self) == data_len
Expand Down
2 changes: 1 addition & 1 deletion detectron2/utils/develop.py
Expand Up @@ -23,7 +23,7 @@ def create_dummy_class(klass, dependency, message=""):

class _DummyMetaClass(type):
# throw error on class attribute access
def __getattr__(_, __):
def __getattr__(_, __): # noqa: B902
raise ImportError(err)

class _Dummy(object, metaclass=_DummyMetaClass):
Expand Down

0 comments on commit 48b598b

Please sign in to comment.