Skip to content

Commit

Permalink
preserve conv param names during torch onnx export (#26)
Browse files Browse the repository at this point in the history
* block automatic Conv-BN folding which causes conv param names to be dropped
  • Loading branch information
bfineran committed Jan 26, 2021
1 parent 4e57b13 commit 1596781
Showing 1 changed file with 91 additions and 1 deletion.
92 changes: 91 additions & 1 deletion src/sparseml/pytorch/utils/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from copy import deepcopy
from typing import Any, Iterable, List

import numpy
import onnx
import torch
from onnx import numpy_helper
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -51,6 +54,7 @@ def export_onnx(
sample_batch: Any,
name: str = "model.onnx",
opset: int = DEFAULT_ONNX_OPSET,
disable_bn_fusing: bool = True,
):
"""
Export an onnx file for the current module and for a sample batch.
Expand All @@ -62,6 +66,12 @@ def export_onnx(
:param name: name of the onnx file to save
:param opset: onnx opset to use for exported model. Default is 11, if torch
version is 1.2 or below, default is 9
:param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
fusing during torch export. Default and suggested setting is True. Batch
norm fusing will change the exported parameter names as well as affect
sensitivity analyses of the exported graph. Additionally, the DeepSparse
inference engine, and other engines, perform batch norm fusing at model
compilation.
"""
sample_batch = tensors_to_device(sample_batch, "cpu")
onnx_path = os.path.join(self._output_dir, name)
Expand Down Expand Up @@ -96,8 +106,21 @@ def export_onnx(
submodule.observer_enabled[0] = 0
disabled_observers.append(submodule)

is_quant_module = any(
hasattr(submodule, "qconfig") and submodule.qconfig
for submodule in self._module.modules()
)
batch_norms_wrapped = False
if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing:
# prevent batch norm fusing by adding a trivial operation before every
# batch norm layer
export_module = deepcopy(self._module)
batch_norms_wrapped = _wrap_batch_norms(export_module)
else:
export_module = self._module

torch.onnx.export(
self._module,
export_module,
sample_batch,
onnx_path,
input_names=input_names,
Expand All @@ -111,6 +134,12 @@ def export_onnx(
for submodule in disabled_observers:
submodule.observer_enabled[0] = 1

# clean up graph from any injected / wrapped operations
if batch_norms_wrapped:
onnx_model = onnx.load(onnx_path)
_delete_trivial_onnx_adds(onnx_model)
onnx.save(onnx_model, onnx_path)

def export_pytorch(
self,
optimizer: Optimizer = None,
Expand Down Expand Up @@ -198,3 +227,64 @@ def export_samples(

assert len(exported_input) == len(exported_output)
exp_counter += len(exported_input)


class _AddNoOpWrapper(Module):
# trivial wrapper to break-up Conv-BN blocks

def __init__(self, module: Module):
super().__init__()
self.module = module

def forward(self, inp):
inp = inp + 0 # no-op
return self.module(inp)


def _get_submodule(module: Module, path: List[str]) -> Module:
if not path:
return module
return _get_submodule(getattr(module, path[0]), path[1:])


def _wrap_batch_norms(module: Module) -> bool:
# wrap all batch norm layers in module with a trivial wrapper
# to prevent BN fusing during export
batch_norms_wrapped = False
for name, submodule in module.named_modules():
if (
isinstance(submodule, torch.nn.BatchNorm1d)
or isinstance(submodule, torch.nn.BatchNorm2d)
or isinstance(submodule, torch.nn.BatchNorm3d)
):
submodule_path = name.split(".")
parent_module = _get_submodule(module, submodule_path[:-1])
setattr(parent_module, submodule_path[-1], _AddNoOpWrapper(submodule))
batch_norms_wrapped = True
return batch_norms_wrapped


def _delete_trivial_onnx_adds(model: onnx.ModelProto):
# delete all add nodes in the graph with second inputs as constant nodes set to 0
add_nodes = [node for node in model.graph.node if node.op_type == "Add"]
for add_node in add_nodes:
try:
add_const_node = [
node for node in model.graph.node if node.output[0] == add_node.input[1]
][0]
add_const_val = numpy_helper.to_array(add_const_node.attribute[0].t)
if numpy.all(add_const_val == 0.0):
# update graph edges
parent_node = [
node
for node in model.graph.node
if add_node.input[0] in node.output
]
if not parent_node:
continue
parent_node[0].output[0] = add_node.output[0]
# remove node and constant
model.graph.node.remove(add_node)
model.graph.node.remove(add_const_node)
except: # skip node on any error
continue

0 comments on commit 1596781

Please sign in to comment.