Skip to content

Commit

Permalink
ONNX transformation to cast int64 constants to int32 when possible (#655
Browse files Browse the repository at this point in the history
)

* add function to convert int64 constants to int32

* add utils file

* add in init

* add doc

* remove onnx check

* fix transformation and test

* Update optimum/onnxruntime/modeling_decoder.py

* more explicit naming
  • Loading branch information
fxmarty committed Jan 16, 2023
1 parent 2911a91 commit 4016c17
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 147 deletions.
7 changes: 6 additions & 1 deletion optimum/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .graph_transformations import merge_decoders
from .graph_transformations import (
cast_slice_nodes_inputs_to_int32,
merge_decoders,
remove_duplicate_weights,
replace_atenops_to_gather,
)
198 changes: 54 additions & 144 deletions optimum/onnx/graph_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections import defaultdict
from typing import DefaultDict, Dict, List, Set, Tuple

import onnx
from onnx import ModelProto, ValueInfoProto


def _find_duplicate_weights(model) -> DefaultDict[Tuple[int, bytes], Set[str]]:
return _find_duplicate_initializers(model.graph.initializer)


def _find_duplicate_initializers(initializers) -> DefaultDict[Tuple[int, bytes], Set[str]]:
duplicates = defaultdict(set)
for initializer in initializers:
tensor_dims = tuple(getattr(initializer, "dims"))
for data_attr in ["raw_data", "int32_data", "int64_data", "uint64_data", "float_data", "double_data"]:
tensor_data = getattr(initializer, data_attr)
if tensor_data:
tensor_data = tuple(tensor_data)
break
duplicates[(initializer.data_type, tensor_data, tensor_dims)].add(initializer.name)
return duplicates

from onnx import ModelProto

def _create_name_sharing_dict(
duplicate_weights: DefaultDict[Tuple[int, bytes], Set[str]], suffix: str = None
) -> Dict[str, str]:
def _create_name_sharing_dict_for_duplicates(duplicates: Set[str]) -> Dict[str, str]:
common_name = duplicates.pop()
duplicates.add(common_name)
if suffix:
return {k: f"{common_name}_{suffix}" for k in duplicates}
else:
return {k: common_name for k in duplicates}
from ..utils import logging

name_sharing_dict = {}
for duplicates in duplicate_weights.values():
name_sharing_dict.update(_create_name_sharing_dict_for_duplicates(duplicates))
return name_sharing_dict

logger = logging.get_logger()

def _replace_input_names(model: ModelProto, name_sharing_dict: Dict[str, str]):
for node in model.graph.node:
for i in range(len(node.input)):
node.input[i] = name_sharing_dict.get(node.input[i], node.input[i])


def _remove_redundant_initializers(model: ModelProto, name_sharing_dict: Dict[str, str]):
to_pop = []
for idx, initializer in enumerate(model.graph.initializer):
if initializer.name != name_sharing_dict[initializer.name]:
to_pop.append(idx)

for idx in sorted(to_pop, reverse=True):
model.graph.initializer.pop(idx)
from .transformations_utils import (
_create_name_sharing_dict,
_deduplicated_cross_model_initializers,
_find_duplicate_weights,
_get_all_inputs,
_get_onnx_opset,
_remove_redundant_initializers,
_replace_input_names,
_unify_onnx_outputs,
cast_int64_tensorproto_to_int32,
)


def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelProto:
Expand All @@ -92,7 +59,7 @@ def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelP
return model


def replace_atenops_to_gather(model: ModelProto):
def replace_atenops_to_gather(model: ModelProto) -> ModelProto:
"""
Replaces broken ATenOp nodes back to Gather nodes.
Expand Down Expand Up @@ -122,101 +89,6 @@ def replace_atenops_to_gather(model: ModelProto):
return model


def _infer_output_shape(output: ValueInfoProto):
output_shape = []
for dim in output.type.tensor_type.shape.dim:
if getattr(dim, "dim_param"):
output_shape.append(getattr(dim, "dim_param"))
elif getattr(dim, "dim_value"):
output_shape.append(getattr(dim, "dim_value"))
else:
raise ValueError(f"Cannot find `dim_param` nor `dim_value` in the output dimension info.")

return output_shape


def _check_num_outputs(model1: ModelProto, model2: ModelProto):
if not len(model1.graph.output) == len(model2.graph.output):
raise ValueError(
f"Two model protos need to have the same outputs. But one has {len(model1.graph.output)} "
f"outputs while the other has {len(model2.graph.output)} outputs."
)


def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto):
"""
Unifies the outputs of two ONNX model protos. The outputs of model1 will be replaced by outputs of model2.
According to the rules of "If" op, two subgraphs must have the same number of outputs.
"""
_check_num_outputs(model1, model2)

for idx in range(len(model1.graph.output)):
model_output_1 = model1.graph.output[idx]
model_output_2 = model2.graph.output[idx]
if not model_output_1 == model_output_2:
if not (
model_output_1.name == model_output_2.name
and model_output_1.type.tensor_type.elem_type == model_output_2.type.tensor_type.elem_type
):
raise ValueError(
f"Can not match {model_output_1.name} with {model_output_2.name}. Make sure your"
f" model protos have same outputs, have same data types and are in the same order."
)
model1.graph.output.remove(model_output_1)

new_output = onnx.helper.make_tensor_value_info(
model_output_2.name,
model_output_2.type.tensor_type.elem_type,
_infer_output_shape(model_output_2),
)
model1.graph.output.insert(idx, new_output)

if not all(
model_output_1 == model_output_2
for model_output_1, model_output_2 in zip(model1.graph.output, model2.graph.output)
):
raise RuntimeError(f"Failed to unify outputs of given ONNX model protos.")


def _get_all_inputs(model_list: List[ModelProto]):
inputs = []
input_names = set()
for model in model_list:
for input in model.graph.input:
if input.name not in input_names:
input_names.add(input.name)
inputs.append(input)
return inputs


def _get_onnx_opset(model: ModelProto):
opset_import = model.opset_import[0]
return getattr(opset_import, "version")


def _deduplicated_cross_model_initializers(models: List[ModelProto], suffix: str = None):

all_initializers = []
for model in models:
all_initializers += list(model.graph.initializer)

duplicates = _find_duplicate_initializers(all_initializers)
name_sharing_dict = _create_name_sharing_dict(duplicates, suffix=suffix)
for model in models:
_replace_input_names(model, name_sharing_dict)

deduplicated_initializers = []
deduplicated_name = set()

for initializer in all_initializers:
if name_sharing_dict[initializer.name] not in deduplicated_name:
deduplicated_name.add(name_sharing_dict[initializer.name])
initializer.name = name_sharing_dict[initializer.name]
deduplicated_initializers.append(initializer)

return deduplicated_initializers


def merge_decoders(
decoder: ModelProto,
decoder_with_past: ModelProto,
Expand Down Expand Up @@ -296,3 +168,41 @@ def merge_decoders(
onnx.checker.check_model(merged_model)

return merged_model


def cast_slice_nodes_inputs_to_int32(model: ModelProto) -> ModelProto:
"""
Convert node inputs of `Slice` nodes from int64 to int32, casting the out of range values.
The constant node inputs are stored in `model.graph.node`, and the sole way to check which node
they are consumed by is to iterate over nodes and check `node.input` for a match.
Note that constant inputs to nodes as `Squeeze`, `Unsqueeze` can not be converted to int32, as the
these operators explicitely expect int64 inputs according to ONNX specifications:
https://github.com/onnx/onnx/blob/main/docs/Operators.md
"""
map_input_node = {}
map_node_inputs = {}

for node in model.graph.node:
for input_name in node.input:
map_input_node[input_name] = {"op_type": node.op_type, "node_name": node.name}
map_node_inputs[node.name] = node.input

for node in model.graph.node:
if (
node.op_type == "Constant"
and node.attribute[0].t.data_type == 7 # int64
and f"{node.name}_output_0" in map_input_node
and map_input_node[node.name + "_output_0"]["op_type"] == "Slice"
):
logger.debug(f"Converting {node.name} to int32")

# `Slice` node is homogeneous (requires parameters of same type), hence cast to int32 only if all of its inputs are constants
# refer to onnx/defs/schema.h
cast = all(
"Constant" in inp for inp in map_node_inputs[map_input_node[node.name + "_output_0"]["node_name"]][1:]
)
cast_int64_tensorproto_to_int32(node.attribute[0].t, cast=cast)

return model

0 comments on commit 4016c17

Please sign in to comment.