diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 57c02c9a35ba3..6f68f83ed05f9 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp : let hasFolder = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } def Vector_OuterProductOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c3b4d1e13de47..3edd23fef6242 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() { return success(); } +namespace { +/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type, +/// SplatOp(X):dst_type) to SplatOp(X):dst_type. +class FoldInsertStridedSliceSplat final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, + PatternRewriter &rewriter) const override { + auto srcSplatOp = + insertStridedSliceOp.getSource().getDefiningOp(); + auto destSplatOp = + insertStridedSliceOp.getDest().getDefiningOp(); + + if (!srcSplatOp || !destSplatOp) + return failure(); + + if (srcSplatOp.getInput() != destSplatOp.getInput()) + return failure(); + + rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); + return success(); + } +}; +} // namespace + +void vector::InsertStridedSliceOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + OpFoldResult InsertStridedSliceOp::fold(ArrayRef operands) { if (getSourceVectorType() == getDestVectorType()) return getSource(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index d16a6fa2c7e11..7f50d90380452 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> { %1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16> return %1 : vector<4x16xi16> } + +// ----- + +// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32> +// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> +func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.splat %x : vector<4x4xf32> + %splat1 = vector.splat %x : vector<8x16xf32> + %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + return %0 : vector<8x16xf32> +}