Skip to content

Commit

Permalink
[mlir][Vector] Add support for masked vector gather ops
Browse files Browse the repository at this point in the history
This patch adds support for masked vector.gather ops using the
vector.mask representation. It includes the implementation of the
MaskableOpInterface, Linalg vectorizer support and lowering to LLVM.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143939
  • Loading branch information
dcaballe committed Feb 15, 2023
1 parent 9452356 commit 1ac874c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 30 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -1846,7 +1846,7 @@ def Vector_MaskedStoreOp :
}

def Vector_GatherOp :
Vector_Op<"gather">,
Vector_Op<"gather", [DeclareOpInterfaceMethods<MaskableOpInterface>]>,
Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOf<[AnyInteger, Index]>:$index_vec,
Expand Down
58 changes: 31 additions & 27 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -416,8 +416,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
value);
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}

/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
Expand Down Expand Up @@ -532,14 +531,16 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult
vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
VectorizationState &state,
Operation *op,
LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
auto targetShape = linalgOp.computeStaticLoopSizes();
auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
// Compute a one-dimensional index vector for the index op dimension.
SmallVector<int64_t> constantSeq =
llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
Expand Down Expand Up @@ -597,32 +598,33 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
///
/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
static Value
calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
const IRMapping &bvm,
const SmallVectorImpl<int64_t> &targetShape) {
static Value calculateGatherOffset(RewriterBase &rewriter,
tensor::ExtractOp extractOp,
const IRMapping &bvm,
const ArrayRef<int64_t> targetShape) {
// The vector of indices for GatherOp should be shaped as the output vector
auto indexVecType = VectorType::get(targetShape, b.getIndexType());
auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
auto loc = extractOp.getLoc();

Value offset = b.create<vector::BroadcastOp>(
loc, indexVecType, bvm.lookup(extractOp.getIndices()[0]));
Value offset = broadcastIfNeeded(
rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());

const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
auto dimSize = broadcastIfNeeded(
b,
b.create<arith::ConstantIndexOp>(
rewriter,
rewriter.create<arith::ConstantIndexOp>(
loc,
extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
indexVecType.getShape());

offset = b.create<arith::MulIOp>(loc, offset, dimSize);
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);

auto extractOpIndex = broadcastIfNeeded(
b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
auto extractOpIndex =
broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
indexVecType.getShape());

offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
}

return offset;
Expand All @@ -632,17 +634,16 @@ calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
Operation *op,
LinalgOp linalgOp,
const IRMapping &bvm) {
static VectorizationResult
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = extractOp.getLoc();

// Compute the static loop sizes of the extract op.
auto targetShape = linalgOp.computeStaticLoopSizes();
auto targetShape = state.getCanonicalVecShape();

auto resultType =
VectorType::get(targetShape, extractOp.getResult().getType());
Expand All @@ -662,9 +663,10 @@ static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);

// Generate the gather load
auto gatherOp = rewriter.create<vector::GatherOp>(
Operation *gatherOp = rewriter.create<vector::GatherOp>(
loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);

return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
}
Expand Down Expand Up @@ -904,14 +906,14 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
return vectorizeLinalgIndex(rewriter, op, linalgOp);
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
};
hooks.push_back(vectorizeIndex);

// 4c. Register CustomVectorizationHook for extractOp.
CustomVectorizationHook vectorizeExtract =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
return vectorizeTensorExtract(rewriter, op, linalgOp, bvm);
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);

Expand Down Expand Up @@ -1007,8 +1009,10 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return failure();

if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(linalgOp)))
failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
return failure();
}

SmallVector<CustomVectorizationPrecondition> customPreconditions;

Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -4597,6 +4597,16 @@ LogicalResult GatherOp::verify() {
return success();
}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type GatherOp::getExpectedMaskType() {
auto vecType = this->getIndexVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1));
}

namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
Expand Down
28 changes: 26 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Expand Up @@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
Expand Down Expand Up @@ -109,6 +110,29 @@ struct MaskedTransferWriteOpPattern
}
};

/// Lowers a masked `vector.gather` operation.
struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
public:
using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;

LogicalResult
matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
Value passthru = maskingOp.hasPassthru()
? maskingOp.getPassthru()
: rewriter.create<arith::ConstantOp>(
gatherOp.getLoc(),
rewriter.getZeroAttr(gatherOp.getVectorType()));

// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
passthru);
return success();
}
};

struct LowerVectorMaskPass
: public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
using Base::Base;
Expand Down Expand Up @@ -136,8 +160,8 @@ struct LowerVectorMaskPass
/// not its nested `MaskableOpInterface`.
void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns) {
patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
patterns.getContext());
patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
MaskedGatherOpPattern>(patterns.getContext());
}

std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Vector/lower-vector-mask.mlir
Expand Up @@ -48,3 +48,32 @@ func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor<?xf
// CHECK: return %[[VAL_4]] : tensor<?xf32>
// CHECK: }

// -----

func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%c3 = arith.constant 3 : index
%0 = vector.create_mask %c3 : vector<4xi1>
%1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
%cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
%cst_1 = arith.constant dense<true> : vector<4xi1>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0_3 = arith.constant 0 : index
%2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32>
%c0_4 = arith.constant 0 : index
%3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32>
return %3 : tensor<3xf32>
}

// CHECK-LABEL: func.func @vector_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1>
// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>

0 comments on commit 1ac874c

Please sign in to comment.