-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add pattern to break down reductions into arith ops #75727
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesThe number of vector elements considered 'small' is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Also update the Patch is 22.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75727.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB,
bool testDynamicValueUsingBounds = false);
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value acc, Value mask = Value());
+ Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath = nullptr,
+ Value mask = nullptr);
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+ RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+ PatternBenefit benefit = 1);
+
/// Populate `patterns` with the following patterns.
///
/// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
zeroIdx);
}
- Value result = vector::makeArithReduction(
- rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+ Value result =
+ vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+ cast, /*fastmath=*/nullptr, mask);
rewriter.replaceOp(rootOp, result);
return success();
}
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
- result, acc, mask);
+ result, acc,
+ reductionOp.getFastmathAttr(), mask);
rewriter.replaceOp(rootOp, result);
return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath,
Value mask) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for ADD reduction");
break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
break;
case CombiningKind::MAXF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MAXIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MINF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MINIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for MUL reduction");
break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+ return makeArithReduction(rewriter, loc, kind, mul, acc,
+ /*fastmath=*/nullptr, mask);
}
/// Return the positions of the reductions in the given map.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..143360079916a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include <cassert>
#include <cstdint>
#include <functional>
#include <optional>
@@ -44,6 +45,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,50 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
}
};
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+ BreakDownVectorReduction(MLIRContext *context,
+ unsigned maxNumElementsToExtract,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit),
+ maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+ LogicalResult matchAndRewrite(vector::ReductionOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType type = op.getSourceVectorType();
+ if (type.isScalable() || op.isMasked())
+ return failure();
+ assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+ int64_t numElems = type.getNumElements();
+ if (numElems > maxNumElementsToExtract) {
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("has too many vector elements ({0}) to break down "
+ "(max allowed: {1})",
+ numElems, maxNumElementsToExtract));
+ }
+
+ Location loc = op.getLoc();
+ SmallVector<Value> extracted(numElems, nullptr);
+ for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+ extractedElem = rewriter.create<vector::ExtractOp>(
+ loc, op.getVector(), static_cast<int64_t>(idx));
+
+ Value res = extracted.front();
+ for (auto extractedElem : llvm::drop_begin(extracted))
+ res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+ extractedElem, op.getFastmathAttr());
+ if (Value acc = op.getAcc())
+ res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+ op.getFastmathAttr());
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned maxNumElementsToExtract = 0;
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1702,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
PatternBenefit(benefit.getBenefit() + 1));
}
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+ RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+ PatternBenefit benefit) {
+ patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+ maxNumElementsToExtract, benefit);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..b7bc19594491bd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL: func.func @reduce_2x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG: %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_fp32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+ %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+ %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+ %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+ %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+ %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+ %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+ return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG: %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+ %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+ %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+ %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+ %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+ %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+ %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+ %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+ %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+ %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: return %[[E0]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<1xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_acc_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_1x_acc_fp32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_acc_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT: %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT: return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_acc_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK: %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_fp32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+ %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+ return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @reduce_3x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<3xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+// CHECK: return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern itself is LGTM, but I'm not very familiar with this minnum/minimum buisiness.
BTW, I forgot to add to the PR description that I plan to use it to remove most of the vector reduction to SPIR-V lowering patterns, only leaving the specialized ones that we can expand to vector dot product. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some comments, thanks!
Propagate fast math flags. Distinguish `minf`/`maxf` and `minimumf`/`maximumf`. Required for future patterns in llvm#75727.
1434d6a
to
0c2a497
Compare
Split off the Added example as suggested by @banach-space. |
Propagate fast math flags. Distinguish `minf`/`maxf` and `minimumf`/`maximumf`. Required for future patterns in #75727.
The number of vector elements considered 'small' enough to extract is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Targets without dedicated reduction intrinsics can use that as an emulation path too. Depends on llvm#75846. Please enter the commit message for your changes. Lines starting
0c2a497
to
111ead9
Compare
Propagate fast math flags. Distinguish `minf`/`maxf` and `minimumf`/`maximumf`. Required for future patterns in llvm/llvm-project#75727.
The number of vector elements considered 'small' enough to extract is
parameterized.
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.
Depends on #75846.