diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index cd6da35582469..89f956a5e7017 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -55,9 +55,8 @@ TensorType inferReshapeExpandedType(TensorType inputType, // Check if the input is static, and if so, get its total size bool inputIsStatic = inputType.hasStaticShape(); int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; - + // Compute result shape - bool resultIsStatic = true; auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t { // If this is not a placeholder, do not change it if (size >= 0) @@ -65,10 +64,8 @@ TensorType inferReshapeExpandedType(TensorType inputType, // If we do not know the total size of the tensor, keep this dimension // dynamic in the result shape. - if (!inputIsStatic) { - resultIsStatic = false; + if (!inputIsStatic) return ShapedType::kDynamic; - } // Calculate the product of all elements in 'newShape' except for the -1 // placeholder, which we discard by negating the result. @@ -84,12 +81,14 @@ TensorType inferReshapeExpandedType(TensorType inputType, return totalSize / totalSizeNoPlaceholder; }); + bool resultIsStatic = !ShapedType::isDynamicShape(resultShape); + // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically // shaped input from being reshaped into a statically shaped result. We may // simply turn the first result dimension dynamic to address this. if (!inputIsStatic && resultIsStatic) resultShape[0] = ShapedType::kDynamic; - + // The 'tensor.expand_shape' op also forbids a statically shaped input from // being reshaped into a dynamically shaped result, but the placeholder // inference algorithm above guarantees that this will never be the case.