Skip to content

Commit

Permalink
[mlir] [VectorOps] Improve vector.create_mask lowering
Browse files Browse the repository at this point in the history
Use vector compares for the 1-D case. This approach scales much better
than generating insertion operations, and exposes SIMD directly to backend.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D82402
  • Loading branch information
aartbik committed Jun 23, 2020
1 parent 433c9ad commit 55d09df
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
42 changes: 27 additions & 15 deletions mlir/lib/Dialect/Vector/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();

if (rank == 1) {
// Express constant 1-D case in explicit vector form:
// [T,..,T,F,..,F].
SmallVector<bool, 4> values(dstType.getDimSize(0));
for (int64_t d = 0; d < trueDim; d++)
values[d] = true;
Expand Down Expand Up @@ -1364,8 +1366,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
/// %1 = select %0, %l, %zeroes |
/// %r = vector.insert %1, %pr [i] | d-times
/// %x = ....
/// When rank == 1, the selection operator is not needed,
/// and we can assign the true/false value right away.
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
Expand All @@ -1375,30 +1376,41 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
auto loc = op.getLoc();
auto dstType = op.getResult().getType().cast<VectorType>();
auto eltType = dstType.getElementType();
int64_t dim = dstType.getDimSize(0);
int64_t rank = dstType.getRank();
Value idx = op.getOperand(0);

Value trueVal;
Value falseVal;
if (rank > 1) {
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
falseVal = rewriter.create<ConstantOp>(loc, lowType,
rewriter.getZeroAttr(lowType));
if (rank == 1) {
// Express dynamic 1-D case in explicit vector form:
// mask = [0,1,..,n-1] < [a,a,..,a]
SmallVector<int64_t, 4> values(dim);
for (int64_t d = 0; d < dim; d++)
values[d] = d;
Value indices =
rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
Value bound =
rewriter.create<IndexCastOp>(loc, rewriter.getI64Type(), idx);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, indices,
bounds);
return success();
}

VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
rewriter.getZeroAttr(lowType));
Value result = rewriter.create<ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) {
for (int64_t d = 0; d < dim; d++) {
Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
if (rank > 1)
val = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
auto pos = rewriter.getI64ArrayAttr(d);
result =
rewriter.create<vector::InsertOp>(loc, dstType, val, result, pos);
rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
}
rewriter.replaceOp(op, result);
return success();
Expand Down
51 changes: 21 additions & 30 deletions mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -710,18 +710,12 @@ func @genbool_3d() -> vector<2x3x4xi1> {
}

// CHECK-LABEL: func @genbool_var_1d
// CHECK-SAME: %[[A:.*0]]: index
// CHECK-DAG: %[[VF:.*]] = constant dense<false> : vector<3xi1>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1>
// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
// CHECK: return %[[T5]] : vector<3xi1>
// CHECK-SAME: %[[A:.*]]: index
// CHECK: %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
// CHECK: %[[T0:.*]] = index_cast %[[A]] : index to i64
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64>
// CHECK: return %[[T2]] : vector<3xi1>

func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
%0 = vector.create_mask %arg0 : vector<3xi1>
Expand All @@ -731,24 +725,21 @@ func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
// CHECK-LABEL: func @genbool_var_2d
// CHECK-SAME: %[[A:.*0]]: index
// CHECK-SAME: %[[B:.*1]]: index
// CHECK-DAG: %[[Z1:.*]] = constant dense<false> : vector<3xi1>
// CHECK-DAG: %[[Z2:.*]] = constant dense<false> : vector<2x3xi1>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1>
// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
// CHECK: %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
// CHECK: %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1>
// CHECK: %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1>
// CHECK: %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
// CHECK: %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T11]] : vector<2x3xi1>
// CHECK: %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
// CHECK: %[[CF:.*]] = constant dense<false> : vector<3xi1>
// CHECK: %[[C2:.*]] = constant dense<false> : vector<2x3xi1>
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[c1:.*]] = constant 1 : index
// CHECK: %[[T0:.*]] = index_cast %[[B]] : index to i64
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
// CHECK: %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64>
// CHECK: %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
// CHECK: %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
// CHECK: %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
// CHECK: %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1>
// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T8]] : vector<2x3xi1>

func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
%0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
Expand Down

0 comments on commit 55d09df

Please sign in to comment.