Skip to content

Commit

Permalink
[mlir][Vector] Fold transpose splat to splat with transposed type.
Browse files Browse the repository at this point in the history
This revision folds transpose splat to a new splat with the transposed vector type. For a splat, there is no need to actually do transpose for it, it would be more effective to just build a new splat as the result.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123765
  • Loading branch information
jacquesguan authored and jacquesguan committed Apr 18, 2022
1 parent c105bcb commit 5479044
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
21 changes: 20 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -4397,11 +4397,30 @@ struct FoldTransposedScalarBroadcast final
}
};

// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern<TransposeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
if (!splatOp)
return failure();

rewriter.replaceOpWithNewOp<vector::SplatOp>(
transposeOp, transposeOp.getResultType(), splatOp.getInput());
return success();
}
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposedScalarBroadcast, TransposeFolder>(context);
results
.add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
context);
}

void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Expand Up @@ -1483,6 +1483,17 @@ func @transpose_splat_constant() -> vector<8x4xf32> {
return %0 : vector<8x4xf32>
}

// CHECK-LABEL: func @transpose_splat2(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
// CHECK: }
func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
%splat = vector.splat %arg : vector<4x3xf32>
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}

// -----

// CHECK-LABEL: func @insert_element_fold
Expand Down
Expand Up @@ -281,22 +281,22 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index

// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
%mask0 = vector.splat %m : vector<7x14xi1>
%0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>

// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1>
%mask1 = vector.splat %m : vector<14x16xi1>
%1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>

// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1>
%mask2 = vector.splat %m : vector<7x14xi1>
%2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>

Expand Down Expand Up @@ -328,17 +328,20 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
// CHECK-LABEL: func @transfer_write_permutations
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[ARG2:.*]]: vector<7x14x8x16xf32>
// CHECK-SAME: %[[ARG3:.*]]: vector<8x16xf32>
// CHECK-SAME: %[[M:.*]]: i1
func @transfer_write_permutations(
%arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
%v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index

// CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1>
%mask0 = vector.splat %m : vector<7x14x8x16xi1>
%0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
// CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>

vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
// CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
Expand Down

0 comments on commit 5479044

Please sign in to comment.