Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add pattern to break down reductions into arith ops #75727

Merged
merged 1 commit into from
Dec 18, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Dec 17, 2023

The number of vector elements considered 'small' enough to extract is
parameterized.

This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.

Depends on #75846.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 17, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

The number of vector elements considered 'small' is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do.

Also update the vector::makeArithReduction helper to distinguish minf/maxf and minimumf/maximumf, and to propagate fast math flags.


Patch is 22.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75727.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+4-2)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+19)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+18-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+53)
  • (added) mlir/test/Dialect/Vector/break-down-vector-reduction.mlir (+104)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir (+8-8)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+23)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB,
                            bool testDynamicValueUsingBounds = false);
 
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value acc, Value mask = Value());
+                         Value v1, Value acc,
+                         arith::FastMathFlagsAttr fastmath = nullptr,
+                         Value mask = nullptr);
 
 /// Returns true if `attr` has "parallel" iterator type semantics.
 inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
                                                    PatternBenefit benefit = 1);
 
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+    PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
                                                 zeroIdx);
     }
 
-    Value result = vector::makeArithReduction(
-        rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+    Value result =
+        vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+                                   cast, /*fastmath=*/nullptr, mask);
     rewriter.replaceOp(rootOp, result);
     return success();
   }
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
-                                          result, acc, mask);
+                                          result, acc,
+                                          reductionOp.getFastmathAttr(), mask);
 
     rewriter.replaceOp(rootOp, result);
     return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
 
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
+                                       arith::FastMathFlagsAttr fastmath,
                                        Value mask) {
   Type t1 = getElementTypeOrSelf(v1.getType());
   Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for ADD reduction");
     break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
     break;
   case CombiningKind::MAXF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MAXIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MINF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MINIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MAXSI:
     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for MUL reduction");
     break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   if (!acc)
     return std::optional<Value>(mul);
 
-  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+  return makeArithReduction(rewriter, loc, kind, mul, acc,
+                            /*fastmath=*/nullptr, mask);
 }
 
 /// Return the positions of the reductions in the given map.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..143360079916a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <cassert>
 #include <cstdint>
 #include <functional>
 #include <optional>
@@ -44,6 +45,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,50 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+  BreakDownVectorReduction(MLIRContext *context,
+                           unsigned maxNumElementsToExtract,
+                           PatternBenefit benefit)
+      : OpRewritePattern(context, benefit),
+        maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType type = op.getSourceVectorType();
+    if (type.isScalable() || op.isMasked())
+      return failure();
+    assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+    int64_t numElems = type.getNumElements();
+    if (numElems > maxNumElementsToExtract) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("has too many vector elements ({0}) to break down "
+                            "(max allowed: {1})",
+                            numElems, maxNumElementsToExtract));
+    }
+
+    Location loc = op.getLoc();
+    SmallVector<Value> extracted(numElems, nullptr);
+    for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+      extractedElem = rewriter.create<vector::ExtractOp>(
+          loc, op.getVector(), static_cast<int64_t>(idx));
+
+    Value res = extracted.front();
+    for (auto extractedElem : llvm::drop_begin(extracted))
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+                                       extractedElem, op.getFastmathAttr());
+    if (Value acc = op.getAcc())
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+                                       op.getFastmathAttr());
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+private:
+  unsigned maxNumElementsToExtract = 0;
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1702,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
                                     PatternBenefit(benefit.getBenefit() + 1));
 }
 
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+                                         maxNumElementsToExtract, benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..b7bc19594491bd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL:   func.func @reduce_2x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_fp32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+  %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+  %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+  %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+  %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+  return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+  %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+  %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+  %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+  %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+  %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+  %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+  %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     return %[[E0]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<1xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_1x_acc_fp32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT:     return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_acc_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK:          %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK:          %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK:          %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK:          %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT:     return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_fp32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+  return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_3x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT:     %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<3xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+  return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @masked_reduce_one_element_vector_addf
 //  CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
 //  CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = ...
[truncated]

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern itself is LGTM, but I'm not very familiar with this minnum/minimum buisiness.

@kuhar
Copy link
Member Author

kuhar commented Dec 17, 2023

BTW, I forgot to add to the PR description that I plan to use it to remove most of the vector reduction to SPIR-V lowering patterns, only leaving the specialized ones that we can expand to vector dot product.

@kuhar kuhar changed the title [mlir][vector] Add pattern to break down small reductions into arith ops [mlir][vector] Add pattern to break down reductions into arith ops Dec 18, 2023
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some comments, thanks!

mlir/lib/Dialect/Vector/IR/VectorOps.cpp Outdated Show resolved Hide resolved
kuhar added a commit to kuhar/llvm-project that referenced this pull request Dec 18, 2023
Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.

Required for future patterns in
llvm#75727.
@kuhar kuhar force-pushed the break-down-vector-reduction branch from 1434d6a to 0c2a497 Compare December 18, 2023 19:47
@kuhar
Copy link
Member Author

kuhar commented Dec 18, 2023

Split off the makeArithReduction change into #75846 and rebased onto it.

Added example as suggested by @banach-space.

kuhar added a commit that referenced this pull request Dec 18, 2023
Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.

Required for future patterns in
#75727.
The number of vector elements considered 'small' enough to extract is
parameterized.

This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.

Depends on llvm#75846.

 Please enter the commit message for your changes. Lines starting
@kuhar kuhar force-pushed the break-down-vector-reduction branch from 0c2a497 to 111ead9 Compare December 18, 2023 22:49
@kuhar kuhar merged commit 0767711 into llvm:main Dec 18, 2023
3 of 4 checks passed
qihangkong pushed a commit to rvgpu/rvgpu-llvm that referenced this pull request Apr 23, 2024
Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.

Required for future patterns in
llvm/llvm-project#75727.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants