Skip to content

Commit

Permalink
[mlir][Vector] Fold InsertStridedSliceOp of two splat with the same i…
Browse files Browse the repository at this point in the history
…nput to splat.

This patch folds InsertStridedSliceOp(SplatOp(X):src_type, SplatOp(X):dst_type) to SplatOp(X):dst_type.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D128891
  • Loading branch information
jacquesguan authored and jacquesguan committed Jul 1, 2022
1 parent 2ceb9c3 commit 91ab4d4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp :

let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def Vector_OuterProductOp :
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() {
return success();
}

namespace {
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto srcSplatOp =
insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
auto destSplatOp =
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();

if (!srcSplatOp || !destSplatOp)
return failure();

if (srcSplatOp.getInput() != destSplatOp.getInput())
return failure();

rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
return success();
}
};
} // namespace

void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldInsertStridedSliceSplat>(context);
}

OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getSourceVectorType() == getDestVectorType())
return getSource();
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Expand Up @@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
%1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
return %1 : vector<4x16xi16>
}

// -----

// CHECK-LABEL: @insert_strided_slice_splat
// CHECK-SAME: (%[[ARG:.*]]: f32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
%splat0 = vector.splat %x : vector<4x4xf32>
%splat1 = vector.splat %x : vector<8x16xf32>
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
}

0 comments on commit 91ab4d4

Please sign in to comment.