-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector][gpu] Use makeArithReduction
in lowering patterns. NFC.
#75952
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Use the `vector::makeArithReduction` helper as the source-of-truth of reduction to arith ops lowering.
kuhar
requested review from
antiagainst,
banach-space,
Hardcode84,
qedawkins and
unterumarmung
December 19, 2023 16:54
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Jakub Kuderski (kuhar) ChangesUse the Full diff: https://github.com/llvm/llvm-project/pull/75952.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ecee9a7b45e32b..a9f903e696dfb1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -16,15 +16,44 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
namespace {
+static vector::CombiningKind
+convertReductionKind(gpu::AllReduceOperation mode) {
+ switch (mode) {
+#define MAP_CASE(X) \
+ case gpu::AllReduceOperation::X: \
+ return vector::CombiningKind::X
+
+ MAP_CASE(ADD);
+ MAP_CASE(MUL);
+ MAP_CASE(MINUI);
+ MAP_CASE(MINSI);
+ MAP_CASE(MINF);
+ MAP_CASE(MAXSI);
+ MAP_CASE(MAXUI);
+ MAP_CASE(MAXF);
+ MAP_CASE(AND);
+ MAP_CASE(OR);
+ MAP_CASE(XOR);
+ MAP_CASE(MINIMUMF);
+ MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+ }
+
+ llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;
@@ -181,7 +210,7 @@ struct GpuAllReduceRewriter {
/// block is expected to have 2 arguments. The gpu.yield return the
/// accumulated value of the same type.
AccumulatorFactory getFactory(Region &body) {
- return AccumulatorFactory([&](Value lhs, Value rhs) {
+ return [&body, this](Value lhs, Value rhs) -> Value {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
@@ -209,51 +238,14 @@ struct GpuAllReduceRewriter {
// Return accumulator result.
rewriter.setInsertionPointToStart(split);
return split->addArgument(lhs.getType(), lhs.getLoc());
- });
+ };
}
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
- using Kind = gpu::AllReduceOperation;
- bool isFloatingPoint = isa<FloatType>(valueType);
- switch (opName) {
- case Kind::ADD:
- return isFloatingPoint ? getFactory<arith::AddFOp>()
- : getFactory<arith::AddIOp>();
- case Kind::MUL:
- return isFloatingPoint ? getFactory<arith::MulFOp>()
- : getFactory<arith::MulIOp>();
- case Kind::MINSI:
- return getFactory<arith::MinSIOp>();
- case Kind::MINUI:
- return getFactory<arith::MinUIOp>();
- case Kind::MINF:
- return getFactory<arith::MinNumFOp>();
- case Kind::MAXSI:
- return getFactory<arith::MaxSIOp>();
- case Kind::MAXUI:
- return getFactory<arith::MaxUIOp>();
- case Kind::MAXF:
- return getFactory<arith::MaxNumFOp>();
- case Kind::AND:
- return getFactory<arith::AndIOp>();
- case Kind::OR:
- return getFactory<arith::OrIOp>();
- case Kind::XOR:
- return getFactory<arith::XOrIOp>();
- case Kind::MINIMUMF:
- return getFactory<arith::MinimumFOp>();
- case Kind::MAXIMUMF:
- return getFactory<arith::MaximumFOp>();
- }
- llvm_unreachable("unknown GPU AllReduceOperation");
- }
-
- /// Returns an accumulator factory that creates an op of type T.
- template <typename T>
- AccumulatorFactory getFactory() {
- return [this](Value lhs, Value rhs) {
- return create<T>(lhs.getType(), lhs, rhs);
+ return [opName, this](Value lhs, Value rhs) {
+ return vector::makeArithReduction(rewriter, loc,
+ convertReductionKind(opName), lhs, rhs);
};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index ef6e6f5264a221..c3ae7e74693cdd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -38,66 +38,6 @@
using namespace mlir;
using namespace mlir::vector;
-/// This function constructs the appropriate integer or float
-/// operation given the vector combining kind and operands. The
-/// supported int operations are : add, mul, min (signed/unsigned),
-/// max(signed/unsigned), and, or, xor. The supported float
-/// operations are : add, mul, min and max.
-static Value genOperator(Location loc, Value x, Value y,
- vector::CombiningKind kind,
- PatternRewriter &rewriter) {
- using vector::CombiningKind;
-
- auto elType = cast<VectorType>(x.getType()).getElementType();
- bool isInt = elType.isIntOrIndex();
-
- Value combinedResult{nullptr};
- switch (kind) {
- case CombiningKind::ADD:
- if (isInt)
- combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
- else
- combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
- break;
- case CombiningKind::MUL:
- if (isInt)
- combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
- else
- combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
- break;
- case CombiningKind::MINUI:
- combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
- break;
- case CombiningKind::MINSI:
- combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
- break;
- case CombiningKind::MAXUI:
- combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
- break;
- case CombiningKind::MAXSI:
- combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
- break;
- case CombiningKind::AND:
- combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
- break;
- case CombiningKind::OR:
- combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
- break;
- case CombiningKind::XOR:
- combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
- break;
- case CombiningKind::MINF:
- case CombiningKind::MINIMUMF:
- combinedResult = rewriter.create<arith::MinimumFOp>(loc, x, y);
- break;
- case CombiningKind::MAXF:
- case CombiningKind::MAXIMUMF:
- combinedResult = rewriter.create<arith::MaximumFOp>(loc, x, y);
- break;
- }
- return combinedResult;
-}
-
/// This function checks to see if the vector combining kind
/// is consistent with the integer or float element type.
static bool isValidKind(bool isInt, vector::CombiningKind kind) {
@@ -224,8 +164,8 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
}
} else {
Value y = inclusive ? input : lastInput;
- output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
- assert(output != nullptr);
+ output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
+ lastOutput, y);
}
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, output, result, offsets, strides);
|
antiagainst
approved these changes
Dec 19, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Use the
vector::makeArithReduction
helper as the source-of-truth of reduction to arith ops lowering.