Skip to content

Commit

Permalink
Fix function shape inference bug (onnx#4880)
Browse files Browse the repository at this point in the history
* Fix function shape inference bug

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Fix lintrunner issues

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
gramalingam and jcwchen committed Feb 10, 2023
1 parent b85c67c commit f65a669
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
30 changes: 24 additions & 6 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,33 +294,41 @@ class ShapeInferenceImplBase {
if (checker::check_is_experimental_op(n)) {
has_experimental_op = true;
} else if (n.op_type() == "Constant" && n.output().size() == 1) {
const std::string& output_name = n.output(0);
for (const auto& attr : n.attribute()) {
if (attr.name() == "value") {
if (attr.type() == AttributeProto::TENSOR && attr.has_t()) {
input_data_by_name[n.output(0)] = &attr.t();
if (reuse_constant_tensors) {
input_data_by_name[output_name] = &attr.t();
} else {
input_data_by_name_holder[output_name] = attr.t();
input_data_by_name[output_name] = &input_data_by_name_holder[output_name];
}
} else if (attr.type() == AttributeProto::SPARSE_TENSOR && attr.has_sparse_tensor()) {
input_sparse_data_by_name[n.output(0)] = &attr.sparse_tensor();
if (reuse_constant_tensors) {
input_sparse_data_by_name[output_name] = &attr.sparse_tensor();
}
}
} else {
switch (attr.type()) {
case AttributeProto::INTS: {
std::vector<int64_t> ints{attr.ints().begin(), attr.ints().end()};
addTemporaryConstant(n.output(0), ints);
addTemporaryConstant(output_name, ints);
break;
}
case AttributeProto::INT: {
std::vector<int64_t> ints({attr.i()});
addTemporaryConstant(n.output(0), ints);
addTemporaryConstant(output_name, ints);
break;
}
case AttributeProto::FLOATS: {
std::vector<float> floats{attr.floats().begin(), attr.floats().end()};
addTemporaryConstant(n.output(0), floats);
addTemporaryConstant(output_name, floats);
break;
}
case AttributeProto::FLOAT: {
std::vector<float> floats({attr.f()});
addTemporaryConstant(n.output(0), floats);
addTemporaryConstant(output_name, floats);
break;
}
default:
Expand Down Expand Up @@ -555,6 +563,10 @@ class ShapeInferenceImplBase {
}

void process(const FunctionProto& func_proto, InferenceContext& ctx) {
// Ensure Constant node tensor-attributes are copied
bool old_reuse_constant_tensors = reuse_constant_tensors;
reuse_constant_tensors = false;

// Get a temporary tensor-shape map
const auto num_func_inputs = func_proto.input_size();
std::vector<TypeProto> types_cache(num_func_inputs);
Expand Down Expand Up @@ -598,6 +610,8 @@ class ShapeInferenceImplBase {
type_proto->CopyFrom(*(iter->second));
}
}

reuse_constant_tensors = old_reuse_constant_tensors;
}

public:
Expand Down Expand Up @@ -659,6 +673,10 @@ class ShapeInferenceImplBase {
std::vector<std::string> inference_errors;

std::list<TypeProto> initializer_type_list;

// reuse_constant_tensors: controls whether we need to copy tensors occurring as attributes
// in Constant nodes. We avoid it for inference for graphs, but must make a copy for functions.
bool reuse_constant_tensors = true;
};

static void InferShapesImpl(
Expand Down
22 changes: 22 additions & 0 deletions onnx/test/model_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ def test_mi_constant_2(self):
"""
self._check_shape(model, [8, 4, 16])

def test_mi_constant_in_function(self):
model = """
<
ir_version: 7,
opset_import: [ "" : 17, "local" : 1]
>
main (float x) => (y, z) {
y, z = local.expand(x)
}
<
opset_import: [ "" : 17 ],
domain: "local"
>
expand (x) => (y, z) {
shape1 = Constant<value = int64[2] {4,4}>()
shape2 = Constant<value = int64[3] {8,8,8}>()
z = Expand (x, shape2)
y = Expand (x, shape1)
}
"""
self._check_shape(model, [4, 4], [8, 8, 8])


if __name__ == "__main__":
unittest.main()

0 comments on commit f65a669

Please sign in to comment.