Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix function shape inference bug #4880

Merged
merged 7 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,33 +294,41 @@ class ShapeInferenceImplBase {
if (checker::check_is_experimental_op(n)) {
has_experimental_op = true;
} else if (n.op_type() == "Constant" && n.output().size() == 1) {
const std::string& output_name = n.output(0);
for (const auto& attr : n.attribute()) {
if (attr.name() == "value") {
if (attr.type() == AttributeProto::TENSOR && attr.has_t()) {
input_data_by_name[n.output(0)] = &attr.t();
if (reuse_constant_tensors) {
input_data_by_name[output_name] = &attr.t();
} else {
input_data_by_name_holder[output_name] = attr.t();
input_data_by_name[output_name] = &input_data_by_name_holder[output_name];
}
} else if (attr.type() == AttributeProto::SPARSE_TENSOR && attr.has_sparse_tensor()) {
input_sparse_data_by_name[n.output(0)] = &attr.sparse_tensor();
if (reuse_constant_tensors) {
input_sparse_data_by_name[output_name] = &attr.sparse_tensor();
}
}
} else {
switch (attr.type()) {
case AttributeProto::INTS: {
std::vector<int64_t> ints{attr.ints().begin(), attr.ints().end()};
addTemporaryConstant(n.output(0), ints);
addTemporaryConstant(output_name, ints);
break;
}
case AttributeProto::INT: {
std::vector<int64_t> ints({attr.i()});
addTemporaryConstant(n.output(0), ints);
addTemporaryConstant(output_name, ints);
break;
}
case AttributeProto::FLOATS: {
std::vector<float> floats{attr.floats().begin(), attr.floats().end()};
addTemporaryConstant(n.output(0), floats);
addTemporaryConstant(output_name, floats);
break;
}
case AttributeProto::FLOAT: {
std::vector<float> floats({attr.f()});
addTemporaryConstant(n.output(0), floats);
addTemporaryConstant(output_name, floats);
jcwchen marked this conversation as resolved.
Show resolved Hide resolved
break;
}
default:
Expand Down Expand Up @@ -555,6 +563,10 @@ class ShapeInferenceImplBase {
}

void process(const FunctionProto& func_proto, InferenceContext& ctx) {
// Ensure Constant node tensor-attributes are copied
bool old_reuse_constant_tensors = reuse_constant_tensors;
reuse_constant_tensors = false;

// Get a temporary tensor-shape map
const auto num_func_inputs = func_proto.input_size();
std::vector<TypeProto> types_cache(num_func_inputs);
Expand Down Expand Up @@ -598,6 +610,8 @@ class ShapeInferenceImplBase {
type_proto->CopyFrom(*(iter->second));
}
}

reuse_constant_tensors = old_reuse_constant_tensors;
}

public:
Expand Down Expand Up @@ -659,6 +673,10 @@ class ShapeInferenceImplBase {
std::vector<std::string> inference_errors;

std::list<TypeProto> initializer_type_list;

// reuse_constant_tensors: controls whether we need to copy tensors occurring as attributes
// in Constant nodes. We avoid it for inference for graphs, but must make a copy for functions.
bool reuse_constant_tensors = true;
};

static void InferShapesImpl(
Expand Down
21 changes: 21 additions & 0 deletions onnx/test/model_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,27 @@ def test_mi_constant_2(self):
"""
self._check_shape(model, [8, 4, 16])

def test_mi_constant_in_function(self):
model = """
<
ir_version: 7,
opset_import: [ "" : 17, "local" : 1]
>
main (float x) => (y, z) {
y, z = local.expand(x)
}
<
opset_import: [ "" : 17 ],
domain: "local"
>
expand (x) => (y, z) {
shape1 = Constant<value = int64[2] {4,4}>()
shape2 = Constant<value = int64[3] {8,8,8}>()
z = Expand (x, shape2)
y = Expand (x, shape1)
}
"""
self._check_shape(model, [4, 4], [8, 8, 8])

if __name__ == "__main__":
unittest.main()