Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector][gpu] Use makeArithReduction in lowering patterns. NFC. #75952

Merged
merged 1 commit into from
Dec 20, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Dec 19, 2023

Use the vector::makeArithReduction helper as the source-of-truth of reduction to arith ops lowering.

Use the `vector::makeArithReduction` helper as the source-of-truth of
reduction to arith ops lowering.
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 19, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-vector

Author: Jakub Kuderski (kuhar)

Changes

Use the vector::makeArithReduction helper as the source-of-truth of reduction to arith ops lowering.


Full diff: https://github.com/llvm/llvm-project/pull/75952.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp (+34-42)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp (+2-62)
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ecee9a7b45e32b..a9f903e696dfb1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -16,15 +16,44 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 
 namespace {
 
+static vector::CombiningKind
+convertReductionKind(gpu::AllReduceOperation mode) {
+  switch (mode) {
+#define MAP_CASE(X)                                                            \
+  case gpu::AllReduceOperation::X:                                             \
+    return vector::CombiningKind::X
+
+    MAP_CASE(ADD);
+    MAP_CASE(MUL);
+    MAP_CASE(MINUI);
+    MAP_CASE(MINSI);
+    MAP_CASE(MINF);
+    MAP_CASE(MAXSI);
+    MAP_CASE(MAXUI);
+    MAP_CASE(MAXF);
+    MAP_CASE(AND);
+    MAP_CASE(OR);
+    MAP_CASE(XOR);
+    MAP_CASE(MINIMUMF);
+    MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+  }
+
+  llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
 struct GpuAllReduceRewriter {
   using AccumulatorFactory = std::function<Value(Value, Value)>;
 
@@ -181,7 +210,7 @@ struct GpuAllReduceRewriter {
   /// block is expected to have 2 arguments. The gpu.yield return the
   /// accumulated value of the same type.
   AccumulatorFactory getFactory(Region &body) {
-    return AccumulatorFactory([&](Value lhs, Value rhs) {
+    return [&body, this](Value lhs, Value rhs) -> Value {
       Block *block = rewriter.getInsertionBlock();
       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
 
@@ -209,51 +238,14 @@ struct GpuAllReduceRewriter {
       // Return accumulator result.
       rewriter.setInsertionPointToStart(split);
       return split->addArgument(lhs.getType(), lhs.getLoc());
-    });
+    };
   }
 
   /// Returns an accumulator factory that creates an op specified by opName.
   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
-    using Kind = gpu::AllReduceOperation;
-    bool isFloatingPoint = isa<FloatType>(valueType);
-    switch (opName) {
-    case Kind::ADD:
-      return isFloatingPoint ? getFactory<arith::AddFOp>()
-                             : getFactory<arith::AddIOp>();
-    case Kind::MUL:
-      return isFloatingPoint ? getFactory<arith::MulFOp>()
-                             : getFactory<arith::MulIOp>();
-    case Kind::MINSI:
-      return getFactory<arith::MinSIOp>();
-    case Kind::MINUI:
-      return getFactory<arith::MinUIOp>();
-    case Kind::MINF:
-      return getFactory<arith::MinNumFOp>();
-    case Kind::MAXSI:
-      return getFactory<arith::MaxSIOp>();
-    case Kind::MAXUI:
-      return getFactory<arith::MaxUIOp>();
-    case Kind::MAXF:
-      return getFactory<arith::MaxNumFOp>();
-    case Kind::AND:
-      return getFactory<arith::AndIOp>();
-    case Kind::OR:
-      return getFactory<arith::OrIOp>();
-    case Kind::XOR:
-      return getFactory<arith::XOrIOp>();
-    case Kind::MINIMUMF:
-      return getFactory<arith::MinimumFOp>();
-    case Kind::MAXIMUMF:
-      return getFactory<arith::MaximumFOp>();
-    }
-    llvm_unreachable("unknown GPU AllReduceOperation");
-  }
-
-  /// Returns an accumulator factory that creates an op of type T.
-  template <typename T>
-  AccumulatorFactory getFactory() {
-    return [this](Value lhs, Value rhs) {
-      return create<T>(lhs.getType(), lhs, rhs);
+    return [opName, this](Value lhs, Value rhs) {
+      return vector::makeArithReduction(rewriter, loc,
+                                        convertReductionKind(opName), lhs, rhs);
     };
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index ef6e6f5264a221..c3ae7e74693cdd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -38,66 +38,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-/// This function constructs the appropriate integer or float
-/// operation given the vector combining kind and operands. The
-/// supported int operations are : add, mul, min (signed/unsigned),
-/// max(signed/unsigned), and, or, xor. The supported float
-/// operations are : add, mul, min and max.
-static Value genOperator(Location loc, Value x, Value y,
-                         vector::CombiningKind kind,
-                         PatternRewriter &rewriter) {
-  using vector::CombiningKind;
-
-  auto elType = cast<VectorType>(x.getType()).getElementType();
-  bool isInt = elType.isIntOrIndex();
-
-  Value combinedResult{nullptr};
-  switch (kind) {
-  case CombiningKind::ADD:
-    if (isInt)
-      combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
-    else
-      combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
-    break;
-  case CombiningKind::MUL:
-    if (isInt)
-      combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
-    else
-      combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
-    break;
-  case CombiningKind::MINUI:
-    combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
-    break;
-  case CombiningKind::MINSI:
-    combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXUI:
-    combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXSI:
-    combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
-    break;
-  case CombiningKind::AND:
-    combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
-    break;
-  case CombiningKind::OR:
-    combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
-    break;
-  case CombiningKind::XOR:
-    combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
-    break;
-  case CombiningKind::MINF:
-  case CombiningKind::MINIMUMF:
-    combinedResult = rewriter.create<arith::MinimumFOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXF:
-  case CombiningKind::MAXIMUMF:
-    combinedResult = rewriter.create<arith::MaximumFOp>(loc, x, y);
-    break;
-  }
-  return combinedResult;
-}
-
 /// This function checks to see if the vector combining kind
 /// is consistent with the integer or float element type.
 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
@@ -224,8 +164,8 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
         }
       } else {
         Value y = inclusive ? input : lastInput;
-        output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
-        assert(output != nullptr);
+        output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
+                                            lastOutput, y);
       }
       result = rewriter.create<vector::InsertStridedSliceOp>(
           loc, output, result, offsets, strides);

@kuhar kuhar merged commit 9f74e6e into llvm:main Dec 20, 2023
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants