diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index cd6da35582469..89f956a5e7017 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -55,9 +55,8 @@ TensorType inferReshapeExpandedType(TensorType inputType, // Check if the input is static, and if so, get its total size bool inputIsStatic = inputType.hasStaticShape(); int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; - + // Compute result shape - bool resultIsStatic = true; auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t { // If this is not a placeholder, do not change it if (size >= 0) @@ -65,10 +64,8 @@ TensorType inferReshapeExpandedType(TensorType inputType, // If we do not know the total size of the tensor, keep this dimension // dynamic in the result shape. - if (!inputIsStatic) { - resultIsStatic = false; + if (!inputIsStatic) return ShapedType::kDynamic; - } // Calculate the product of all elements in 'newShape' except for the -1 // placeholder, which we discard by negating the result. @@ -84,12 +81,14 @@ TensorType inferReshapeExpandedType(TensorType inputType, return totalSize / totalSizeNoPlaceholder; }); + bool resultIsStatic = !ShapedType::isDynamicShape(resultShape); + // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically // shaped input from being reshaped into a statically shaped result. We may // simply turn the first result dimension dynamic to address this. if (!inputIsStatic && resultIsStatic) resultShape[0] = ShapedType::kDynamic; - + // The 'tensor.expand_shape' op also forbids a statically shaped input from // being reshaped into a dynamically shaped result, but the placeholder // inference algorithm above guarantees that this will never be the case. diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index b8c3d56f21f10..72e7e4cc84088 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -394,6 +394,21 @@ func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> ten // ----- +// This test would previously fail on GCC with certain compiler flags. +// The GCC issue would cause invalid IR after tosa-to-tensor, so this test +// locks down that the code goes through tosa-to-tensor and verifies. +// +// See https://github.com/llvm/llvm-project/pull/91521 for a full description. + +// CHECK-LABEL: reshape_bug_fix +// CHECK: tensor.expand_shape +func.func @reshape_bug_fix(%arg0: tensor) -> tensor<1x1x1x?xf32> { + %0 = tosa.reshape %arg0 {new_shape = array} : (tensor) -> tensor<1x1x1x?xf32> + return %0 : tensor<1x1x1x?xf32> +} + +// ----- + // CHECK-LABEL: test_reshape_6d_down_s2s_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x3x5x7x11xf32> // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>