From 001d08f6b27c41ebb8f19af2b0e3eca9bdf341c6 Mon Sep 17 00:00:00 2001 From: Ubuntu <450283+lialan@users.noreply.github.com> Date: Sun, 17 Nov 2024 01:22:35 +0000 Subject: [PATCH] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on non-static index Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash. This patch is to make sure it return gracefully instead of crashing. --- .../Dialect/Vector/Transforms/VectorTransforms.cpp | 11 +++++------ mlir/test/Dialect/Vector/vector-transforms.mlir | 13 +++++++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 7f6b2303f86e1..20cd9cba6909a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -596,12 +596,11 @@ struct BubbleDownVectorBitCastForExtract unsigned expandRatio = castDstType.getNumElements() / castSrcType.getNumElements(); - auto getFirstIntValue = [](ArrayRef values) -> uint64_t { - assert(values[0].is() && "Unexpected non-constant index"); - return cast(values[0].get()).getInt(); - }; - - uint64_t index = getFirstIntValue(extractOp.getMixedPosition()); + // Get the first element of the mixed position as integer. + auto mixedPos = extractOp.getMixedPosition(); + if (mixedPos.size() > 0 && !mixedPos[0].is()) + return failure(); + uint64_t index = cast(mixedPos[0].get()).getInt(); // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir index 89e8ca1d93109..de12a87253a67 100644 --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector) -> vector { %0 = vector.bitcast %arg0 : vector to vector return %0 : vector } + +// Make sure not crash on dynamic index `vector.extract`: +func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 { + %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16> + %1 = vector.extract %0[%index] : i16 from vector<8xi16> + return %1 : i16 +} + +// CHECK-LABEL: func.func @vector_extract_dynamic_index +// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 { +// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16> +// CHECK: return %[[EXTRACT]]