diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d8913251e56e9..db199a46e1637 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1705,6 +1705,47 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { return extractOp.getResult(); } +/// Fold extractOp coming from ShuffleOp. +/// +/// Example: +/// +/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] +/// : vector<8xf32>, vector<8xf32> +/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32> +/// -> +/// %extract = vector.extract %b[7] : f32 from vector<8xf32> +/// +static Value foldExtractFromShuffle(ExtractOp extractOp) { + // Dynamic positions are not folded as the resulting code would be more + // complex than the input code. + if (extractOp.hasDynamicPosition()) + return Value(); + + auto shuffleOp = extractOp.getVector().getDefiningOp(); + if (!shuffleOp) + return Value(); + + // TODO: 0-D or multi-dimensional vectors not supported yet. + if (shuffleOp.getResultVectorType().getRank() != 1) + return Value(); + + int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0]; + auto shuffleMask = shuffleOp.getMask(); + int64_t extractIdx = extractOp.getStaticPosition()[0]; + int64_t shuffleIdx = shuffleMask[extractIdx]; + + // Find the shuffled vector to extract from based on the shuffle index. + if (shuffleIdx < inputVecSize) { + extractOp.setOperand(0, shuffleOp.getV1()); + extractOp.setStaticPosition({shuffleIdx}); + } else { + extractOp.setOperand(0, shuffleOp.getV2()); + extractOp.setStaticPosition({shuffleIdx - inputVecSize}); + } + + return extractOp.getResult(); +} + // Fold extractOp with source coming from ShapeCast op. static Value foldExtractFromShapeCast(ExtractOp extractOp) { // TODO: Canonicalization for dynamic position not implemented yet. @@ -1953,6 +1994,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) { return res; if (auto res = foldExtractFromBroadcast(*this)) return res; + if (auto res = foldExtractFromShuffle(*this)) + return res; if (auto res = foldExtractFromShapeCast(*this)) return res; if (auto val = foldExtractFromExtractStrided(*this)) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index df87f86765a3a..5ae769090dac6 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> { %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32> return %r : vector<8xf32> } +// ----- + +// CHECK-LABEL: @fold_extract_shuffle +// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> +// CHECK-NOT: vector.shuffle +// CHECK: vector.extract %[[A]][0] : f32 from vector<8xf32> +// CHECK: vector.extract %[[B]][0] : f32 from vector<8xf32> +// CHECK: vector.extract %[[A]][7] : f32 from vector<8xf32> +// CHECK: vector.extract %[[B]][7] : f32 from vector<8xf32> +func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>) + -> (f32, f32, f32, f32) { + %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32> + %e0 = vector.extract %shuffle[0] : f32 from vector<4xf32> + %e1 = vector.extract %shuffle[1] : f32 from vector<4xf32> + %e2 = vector.extract %shuffle[2] : f32 from vector<4xf32> + %e3 = vector.extract %shuffle[3] : f32 from vector<4xf32> + return %e0, %e1, %e2, %e3 : f32, f32, f32, f32 +} // -----