Skip to content

Commit

Permalink
fix a skipped shape infer code (#6049)
Browse files Browse the repository at this point in the history
### Description
<!-- - Describe your changes. -->

### Motivation and Context
<!-- - Why is this change required? What problem does it solve? -->
<!-- - If it fixes an open issue, please link to the issue here. -->

Signed-off-by: Liqun Fu <liqfu@microsoft.com>
  • Loading branch information
liqunfu committed Apr 3, 2024
1 parent 366ac64 commit fa0b899
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,29 +488,29 @@ class ShapeInferenceImplBase {
ProcessCall(n, *(iter->second), ctx);
} else {
has_unsupported_op = true;
return;
}
} else {
has_unsupported_op = true;
return;
}
if (!has_unsupported_op) {
for (int i = 0; i < n.output_size(); ++i) {
// skip type and shape propagation for missing optional outputs.
if (!n.output(i).empty())
UpdateType(n.output(i), ctx.getOutputType(i));
}
// Constant values are tracked to improve inference/checking for subsequent nodes.
ProcessConstant(n);
// If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed
// to improve inference/checking for subsequent nodes.
if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
if (generated_shape_data_by_name == nullptr) {
fail_shape_inference(
"Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
}
DataPropagationContextImpl data_propagation_ctx(
n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
schema->GetDataPropagationFunction()(data_propagation_ctx);
for (int i = 0; i < n.output_size(); ++i) {
// skip type and shape propagation for missing optional outputs.
if (!n.output(i).empty())
UpdateType(n.output(i), ctx.getOutputType(i));
}
// Constant values are tracked to improve inference/checking for subsequent nodes.
ProcessConstant(n);
// If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed
// to improve inference/checking for subsequent nodes.
if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
if (generated_shape_data_by_name == nullptr) {
fail_shape_inference(
"Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
}
DataPropagationContextImpl data_propagation_ctx(
n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
schema->GetDataPropagationFunction()(data_propagation_ctx);
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
Expand Down

0 comments on commit fa0b899

Please sign in to comment.