Skip to content

Commit

Permalink
[mlir][Vector] Add canonicalization pattern for vector.transpose(vect…
Browse files Browse the repository at this point in the history
…or.constant_mask)

We already had vector.transpose(vector.create_mask) ->
vector.create_mask. This patch adds the constant mask version of it.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D147099
  • Loading branch information
dcaballe committed Mar 29, 2023
1 parent e2f1d5c commit 1cd434d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
38 changes: 26 additions & 12 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -5269,23 +5269,37 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
LogicalResult matchAndRewrite(TransposeOp transpOp,
PatternRewriter &rewriter) const override {
auto createMaskOp =
transposeOp.getVector().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
Value transposeSrc = transpOp.getVector();
auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
if (!createMaskOp && !constantMaskOp)
return failure();

// Get the transpose permutation and apply it to the vector.create_mask
// operands.
auto maskOperands = createMaskOp.getOperands();
// Get the transpose permutation and apply it to the vector.create_mask or
// vector.constant_mask operands.
SmallVector<int64_t> permutation;
transposeOp.getTransp(permutation);
SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
applyPermutationToVector(newOperands, permutation);
transpOp.getTransp(permutation);

if (createMaskOp) {
auto maskOperands = createMaskOp.getOperands();
SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
applyPermutationToVector(newOperands, permutation);

rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
transpOp, transpOp.getResultVectorType(), newOperands);
return success();
}

// ConstantMaskOp case.
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
applyPermutationToVector(newMaskDimSizes, permutation);

rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
transposeOp, transposeOp.getResultVectorType(), newOperands);
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
transpOp, transpOp.getResultVectorType(),
ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
return success();
}
};
Expand Down
17 changes: 15 additions & 2 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Expand Up @@ -58,15 +58,28 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
func.func @create_mask_transpose_to_transposed_create_mask(
%dim0: index, %dim1: index, %dim2: index) -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
// CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
// CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
// CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
// CHECK-NOT: vector.transpose
%0 = vector.create_mask %dim0, %dim1, %dim2 : vector<2x3x4xi1>
%1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
}

// -----

// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
// CHECK: vector.constant_mask [3, 1, 2] : vector<4x2x3xi1>
// CHECK-NOT: vector.transpose
%0 = vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
%1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
}

// -----

func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%0 = vector.constant_mask [2, 2] : vector<4x3xi1>
%1 = vector.extract_strided_slice %0
Expand Down

0 comments on commit 1cd434d

Please sign in to comment.