Skip to content

Commit

Permalink
[mlir][VectorOps] Don't drop scalable dims when lowering transfer_rea…
Browse files Browse the repository at this point in the history
…ds/writes (in VectorToLLVM)

This is a follow-on to D158753, and allows the lowering of a
transfer read/write of n-D vectors with a single trailing scalable dimension
to primitive vector ops.

The final conversion to LLVM depends on D158517 and D158752, without
these patches type conversion will fail (or an assert is hit in the LLVM
backend) if the final IR contains an array of scalable vectors.

This patch adds `transform.apply_patterns.vector.lower_create_mask`
which allows the lowering of vector.create_mask/constant_mask to be
tested independently of --convert-vector-to-llvm.

Reviewed By: c-rhodes, awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D159482
  • Loading branch information
MacDue committed Sep 11, 2023
1 parent 6bf923d commit ccef726
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def ApplyLowerContractionPatternsOp : Op<Transform_Dialect,
}];
}

def ApplyLowerCreateMaskPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_create_mask",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector create_mask-like operations should be lowered to
finer-grained vector primitives.
}];

let assemblyFormat = "attr-dict";
}

def ApplyLowerMasksPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_masks",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
vector::populateVectorReductionToContractPatterns(patterns);
}

void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorMaskOpLoweringPatterns(patterns);
}

void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransferDropUnitDimsPatterns(patterns);
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
return rewriter.notifyMatchFailure(
op, "0-D and 1-D vectors are handled separately");

if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");

auto loc = op.getLoc();
auto eltType = dstType.getElementType();
int64_t dim = dstType.getDimSize(0);
Value idx = op.getOperand(0);

VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<arith::ConstantOp>(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ struct TransferReadToVectorLoadLowering
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = VectorType::get(
VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
unbroadcastedVectorShape, read.getVectorType().getElementType());

// `vector.load` supports vector types as memref's elements only when the
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,28 @@ func.func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5x

// -----

// CHECK-LABEL: func @transfer_read_1d_scalable_mask
// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: return %[[r]] : vector<[4]xf32>
func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32>
return %vec : vector<[4]xf32>
}

// -----
// CHECK-LABEL: func @transfer_write_1d_scalable_mask
// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32>
return
}

// -----

func.func @genbool_0d_f() -> vector<i1> {
%0 = vector.constant_mask [0] : vector<i1>
return %0 : vector<i1>
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s

// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable(
// CHECK-SAME: %[[arg:.*]]: index) -> vector<3x[4]xi1> {
// CHECK-NEXT: %[[zero_mask_1d:.*]] = arith.constant dense<false> : vector<[4]xi1>
// CHECK-NEXT: %[[zero_mask_2d:.*]] = arith.constant dense<false> : vector<3x[4]xi1>
// CHECK-NEXT: %[[create_mask_1d:.*]] = vector.create_mask %[[arg]] : vector<[4]xi1>
// CHECK-NEXT: %[[res_0:.*]] = vector.insert %[[create_mask_1d]], %[[zero_mask_2d]] [0] : vector<[4]xi1> into vector<3x[4]xi1>
// CHECK-NEXT: %[[res_1:.*]] = vector.insert %[[create_mask_1d]], %[[res_0]] [1] : vector<[4]xi1> into vector<3x[4]xi1>
// CHECK-NEXT: %[[res_2:.*]] = vector.insert %[[zero_mask_1d]], %[[res_1]] [2] : vector<[4]xi1> into vector<3x[4]xi1>
// CHECK-NEXT: return %[[res_2]] : vector<3x[4]xi1>
func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<3x[4]xi1> {
%c2 = arith.constant 2 : index
%mask = vector.create_mask %c2, %a : vector<3x[4]xi1>
return %mask : vector<3x[4]xi1>
}

// -----

/// The following cannot be lowered as the current lowering requires unrolling
/// the leading dim.

// CHECK-LABEL: func.func @cannot_create_mask_2d_leading_scalable(
// CHECK-SAME: %[[arg:.*]]: index) -> vector<[4]x4xi1> {
// CHECK: %{{.*}} = vector.create_mask %[[arg]], %{{.*}} : vector<[4]x4xi1>
func.func @cannot_create_mask_2d_leading_scalable(%a: index) -> vector<[4]x4xi1> {
%c1 = arith.constant 1 : index
%mask = vector.create_mask %a, %c1 : vector<[4]x4xi1>
return %mask : vector<[4]x4xi1>
}

transform.sequence failures(suppress) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op

transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_create_mask
} : !transform.any_op
}

0 comments on commit ccef726

Please sign in to comment.