Skip to content

Commit

Permalink
Allow empty shapes and do not validate them for inputs/outputs (micro…
Browse files Browse the repository at this point in the history
…soft#18442)

### Description
Allow empty shapes and do not validate them for inputs/outputs at the
InferenceSession::ValidateInputsOutputs().

### Motivation and Context
microsoft#17301 disallowed empty
shapes.
However, many models depend on them as a way to pass shapes of different
ranks.
  • Loading branch information
yuslepukhin authored and pull[bot] committed Feb 27, 2024
1 parent 454f1ae commit 980e0bf
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2025,9 +2025,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::span<const std::stri
expected_element_type, "tensor", input_output_moniker));

// check for shape
if (iter->second.tensor_shape.has_value()) {
const auto& opt_shape = iter->second.tensor_shape;
if (opt_shape.has_value() && !opt_shape->GetDims().empty()) {
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, input_output_tensor.Shape(),
*iter->second.tensor_shape, input_output_moniker));
*opt_shape, input_output_moniker));
}
} else if (input_output_ml_value.IsSparseTensor()) {
#if !defined(DISABLE_SPARSE_TENSORS)
Expand All @@ -2038,9 +2039,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::span<const std::stri
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(sparse_tensor.DataType(), expected_element_type,
"sparse_tensor", input_output_moniker));
// Check shape
if (iter->second.tensor_shape.has_value()) {
const auto& opt_shape = iter->second.tensor_shape;
if (opt_shape.has_value() && !opt_shape->GetDims().empty()) {
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(),
*iter->second.tensor_shape, input_output_moniker));
*opt_shape, input_output_moniker));
}
} else if (is_sparse_initializer(name) &&
expected_type->IsTensorType()) {
Expand All @@ -2049,9 +2051,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::span<const std::stri
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(sparse_tensor.DataType(), expected_element_type,
"sparse_tensor", input_output_moniker));
// Check shape
if (iter->second.tensor_shape.has_value()) {
const auto& opt_shape = iter->second.tensor_shape;
if (opt_shape.has_value() && !opt_shape->GetDims().empty()) {
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(),
*iter->second.tensor_shape, input_output_moniker));
*opt_shape, input_output_moniker));
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name,
Expand All @@ -2061,7 +2064,6 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::span<const std::stri
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name ", name,
" is a sparse tensor, which is not supported in this build.");
#endif

} else if (input_output_ml_value.IsTensorSequence()) {
if (!expected_type->IsTensorSequenceType()
#if !defined(DISABLE_OPTIONAL_TYPE)
Expand Down

0 comments on commit 980e0bf

Please sign in to comment.