diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6f68f83ed05f9..9a29825eda506 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -487,6 +487,7 @@ def Vector_ShuffleOp : }]; let assemblyFormat = "operands $mask attr-dict `:` type(operands)"; let hasVerifier = 1; + let hasCanonicalizer = 1; } def Vector_ExtractElementOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 38f38f8867059..6c1ba2161b83c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1882,6 +1882,36 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { return DenseElementsAttr::get(getVectorType(), results); } +namespace { + +/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. +class ShuffleSplat final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShuffleOp op, + PatternRewriter &rewriter) const override { + auto v1Splat = op.getV1().getDefiningOp(); + auto v2Splat = op.getV2().getDefiningOp(); + + if (!v1Splat || !v2Splat) + return failure(); + + if (v1Splat.getInput() != v2Splat.getInput()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), v1Splat.getInput()); + return success(); + } +}; + +} // namespace + +void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 515a2d1726b6f..e7747c736867f 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1655,3 +1655,17 @@ func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf3 : vector<2x4xf32> into vector<8x16xf32> return %1 : vector<8x16xf32> } + +// ----- + +// CHECK-LABEL: func @shuffle_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> +func.func @shuffle_splat(%x : i32) -> vector<4xi32> { + %v0 = vector.splat %x : vector<4xi32> + %v1 = vector.splat %x : vector<2xi32> + %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<4xi32> +} +