-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][VectorToLLVM] Add support for unrolling and lowering multi-dim… #160405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements target-independent rewrites and utilities to lower the | ||
// 'vector.scatter' operation. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Arith/Utils/Utils.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Location.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
|
||
#define DEBUG_TYPE "vector-scatter-lowering" | ||
|
||
using namespace mlir; | ||
using namespace mlir::vector; | ||
|
||
namespace { | ||
|
||
/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the | ||
/// outermost dimension. For example: | ||
/// ``` | ||
/// vector.scatter %base[%c0][%idx], %mask, %value : | ||
/// memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> | ||
/// | ||
/// ==> | ||
/// | ||
/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32> | ||
/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1> | ||
/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32> | ||
/// vector.scatter %base[%c0][%i0], %m0, %v0 : | ||
/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> | ||
/// | ||
/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32> | ||
/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1> | ||
/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32> | ||
/// vector.scatter %base[%c0][%i1], %m1, %v1 : | ||
/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> | ||
/// ``` | ||
/// | ||
/// When applied exhaustively, this will produce a sequence of 1-d scatter ops. | ||
/// | ||
/// Supports vector types with a fixed leading dimension. | ||
struct UnrollScatter : OpRewritePattern<vector::ScatterOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::ScatterOp op, | ||
PatternRewriter &rewriter) const override { | ||
Value indexVec = op.getIndices(); | ||
Value maskVec = op.getMask(); | ||
Value valueVec = op.getValueToStore(); | ||
|
||
// Get the vector type from one of the vector operands. | ||
VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType()); | ||
if (!vectorTy) | ||
return failure(); | ||
|
||
auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc, | ||
VectorType subTy, int64_t index) { | ||
int64_t thisIdx[1] = {index}; | ||
|
||
Value indexSubVec = | ||
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); | ||
Value maskSubVec = | ||
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); | ||
Value valueSubVec = | ||
vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx); | ||
|
||
rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getOffsets(), | ||
indexSubVec, maskSubVec, valueSubVec, | ||
op.getAlignmentAttr()); | ||
|
||
// Return a dummy value since unrollVectorOp expects a Value. | ||
return Value(); | ||
}; | ||
|
||
return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::vector::populateVectorScatterLoweringPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
patterns.add<UnrollScatter>(patterns.getContext(), benefit); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -431,27 +431,48 @@ vector::unrollVectorValue(TypedValue<VectorType> vector, | |
} | ||
|
||
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, | ||
vector::UnrollVectorOpFn unrollFn) { | ||
assert(op->getNumResults() == 1 && "expected single result"); | ||
assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type"); | ||
VectorType resultTy = cast<VectorType>(op->getResult(0).getType()); | ||
if (resultTy.getRank() < 2) | ||
vector::UnrollVectorOpFn unrollFn, | ||
VectorType vectorTy) { | ||
// If vector type is not provided, get it from the result. | ||
if (!vectorTy) { | ||
assert(op->getNumResults() == 1 && | ||
"expected single result when vector type not provided"); | ||
vectorTy = dyn_cast<VectorType>(op->getResult(0).getType()); | ||
assert(vectorTy && "expected result to have vector type"); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we assert that, conversely, when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But then when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I may have misspoken, but the main point is to prefer assertions in code to verbiage in documentation. |
||
|
||
if (vectorTy.getRank() < 2) | ||
return rewriter.notifyMatchFailure(op, "already 1-D"); | ||
|
||
// Unrolling doesn't take vscale into account. Pattern is disabled for | ||
// vectors with leading scalable dim(s). | ||
if (resultTy.getScalableDims().front()) | ||
if (vectorTy.getScalableDims().front()) | ||
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); | ||
|
||
Location loc = op->getLoc(); | ||
Value result = ub::PoisonOp::create(rewriter, loc, resultTy); | ||
VectorType subTy = VectorType::Builder(resultTy).dropDim(0); | ||
|
||
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { | ||
// Only create result value if the operation produces results. | ||
Value result; | ||
if (op->getNumResults() > 0) { | ||
result = ub::PoisonOp::create(rewriter, loc, vectorTy); | ||
} | ||
|
||
VectorType subTy = VectorType::Builder(vectorTy).dropDim(0); | ||
|
||
for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) { | ||
Value subVector = unrollFn(rewriter, loc, subTy, i); | ||
result = vector::InsertOp::create(rewriter, loc, subVector, result, i); | ||
|
||
// Only insert if we have a result to build. | ||
if (op->getNumResults() > 0) { | ||
result = vector::InsertOp::create(rewriter, loc, subVector, result, i); | ||
} | ||
} | ||
|
||
if (op->getNumResults() > 0) { | ||
rewriter.replaceOp(op, result); | ||
} else { | ||
rewriter.eraseOp(op); | ||
} | ||
|
||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not keen on the proliferation of
populate*Patterns
functions. Is there any scenario where one would want scatter unrolling while not simultaneously wanting gather patterns? If not, I'd suggest to merge the two inpopulateVectorGatherScatterUnrollPatterns
, which also has the benefit of being more clear about it being unrolling patterns and not all lowering.