diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index 56370388dea87..cfd8c4b8f11f7 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -352,6 +352,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// ClampFOp +//===----------------------------------------------------------------------===// + +def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> { + let summary = "floating point clamping operation"; + let description = [{ + The `clampf` operation takes three operands and returns one result, each of + these is required to be the same type. Operands must be of floating point type + (i.e., scalar, tensor or vector). + + The semantics of the operation are described by: + ``` + clampf(value, min, max) = maxf(minf(value, min), max) + ``` + + Example: + + ```mlir + %d = math.clampf %value to [%min, %max] : f64 + ``` + }]; + let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max, + DefaultValuedAttr:$fastmath); + let assemblyFormat = [{ + $value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)? + attr-dict `:` type($result) + }]; +} + //===----------------------------------------------------------------------===// // CopySignOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index c0fe5d3be448a..b3abbf728a3c6 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -23,22 +23,16 @@ class ConversionTarget; class RewritePatternSet; class TypeConverter; -void populateExpandCtlzPattern(RewritePatternSet &patterns); -void populateExpandTanPattern(RewritePatternSet &patterns); -void populateExpandSinhPattern(RewritePatternSet &patterns); -void populateExpandCoshPattern(RewritePatternSet &patterns); -void populateExpandTanhPattern(RewritePatternSet &patterns); -void populateExpandAsinhPattern(RewritePatternSet &patterns); -void populateExpandAcoshPattern(RewritePatternSet &patterns); -void populateExpandAtanhPattern(RewritePatternSet &patterns); -void populateExpandFmaFPattern(RewritePatternSet &patterns); -void populateExpandCeilFPattern(RewritePatternSet &patterns); -void populateExpandExp2FPattern(RewritePatternSet &patterns); -void populateExpandPowFPattern(RewritePatternSet &patterns); -void populateExpandFPowIPattern(RewritePatternSet &patterns); -void populateExpandRoundFPattern(RewritePatternSet &patterns); -void populateExpandRoundEvenPattern(RewritePatternSet &patterns); -void populateExpandRsqrtPattern(RewritePatternSet &patterns); +namespace math { +/// Adds patterns to expand math operations into other more fundamental +/// operations. For example, hyperbolic functions are expanded into expressions +/// using `exp`. If `opMnemonics` is empty then all available patterns will be +/// added, otherwise only the patterns corresponding to ops in `opMnemonics` +/// will be added to the set. +void populateExpansionPatterns(RewritePatternSet &patterns, + ArrayRef opMnemonics = {}); +} // namespace math + void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index a84c89020d4f3..4d415aeac8f58 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> { let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; } +def MathExpandOpsPass : Pass<"math-expand-ops"> { + let summary = "Expand math operations."; + let description = [{ + Expands some math operations into more fundamental operations, allowing them + to be subsequently lowered through these. For example, hyperbolic functions + are transformed into their expanded form containing only `exp` functions. + + The `ops` parameter can be used to apply only a subset of all the + available expansions, these must correspond to the operation mnemonic. + For example, `ops=sinh,acosh` will expand only `math.sinh` and + `math.acosh` operations. If the list is empty, then all expansions are + applied. + }]; + let dependentDialects = ["arith::ArithDialect"]; + let options = [ + ListOption<"opMnemonics", "ops", "std::string", + "Operations to expand."> + ]; +} + #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index e1c0c2410c126..d37a056e8e158 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp - ExpandPatterns.cpp + ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp similarity index 89% rename from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp rename to mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp index 4a40a3055ed62..cd68039d0d964 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp @@ -13,14 +13,18 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +namespace mlir::math { +#define GEN_PASS_DEF_MATHEXPANDOPSPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + /// Create a float constant. static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b) { @@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, return success(); } -void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { - patterns.add(convertCtlzOp); -} - -void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { - patterns.add(convertSinhOp); -} - -void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { - patterns.add(convertCoshOp); -} - -void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { - patterns.add(convertTanOp); -} - -void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { - patterns.add(convertTanhOp); -} - -void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { - patterns.add(convertAsinhOp); -} - -void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { - patterns.add(convertAcoshOp); -} - -void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { - patterns.add(convertAtanhOp); -} - -void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { - patterns.add(convertFmaFOp); -} - -void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { - patterns.add(convertCeilOp); -} - -void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { - patterns.add(convertExp2fOp); -} - -void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertPowfOp); -} - -void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { - patterns.add(convertFPowIOp); -} - -void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundOp); +// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf` +static LogicalResult convertClampfOp(math::ClampFOp op, + PatternRewriter &rewriter) { + auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(), + op.getMin(), op.getFastmath()); + rewriter.replaceOpWithNewOp(op, minOp, op.getMax(), + op.getFastmath()); + return success(); } -void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundEvenOp); +void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns, + ArrayRef opMnemonics) { + auto filter = [&](StringRef name) { + // This should be a static assert and `consume_front` take a twine, but none + // is currently possible. TODO: augment `StringRef::consume_front` and make + // `getDialectNamespace` use `std::string_view`. + assert("math" == MathDialect::getDialectNamespace()); + name.consume_front("math."); + return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0); + }; + if (filter(CountLeadingZerosOp::getOperationName())) + patterns.add(convertCtlzOp); + if (filter(SinhOp::getOperationName())) + patterns.add(convertSinhOp); + if (filter(CoshOp::getOperationName())) + patterns.add(convertCoshOp); + if (filter(TanOp::getOperationName())) + patterns.add(convertTanOp); + if (filter(TanhOp::getOperationName())) + patterns.add(convertTanhOp); + if (filter(AsinhOp::getOperationName())) + patterns.add(convertAsinhOp); + if (filter(AcoshOp::getOperationName())) + patterns.add(convertAcoshOp); + if (filter(AtanhOp::getOperationName())) + patterns.add(convertAtanhOp); + if (filter(FmaOp::getOperationName())) + patterns.add(convertFmaFOp); + if (filter(CeilOp::getOperationName())) + patterns.add(convertCeilOp); + if (filter(Exp2Op::getOperationName())) + patterns.add(convertExp2fOp); + if (filter(PowFOp::getOperationName())) + patterns.add(convertPowfOp); + if (filter(FPowIOp::getOperationName())) + patterns.add(convertFPowIOp); + if (filter(RoundOp::getOperationName())) + patterns.add(convertRoundOp); + if (filter(RoundEvenOp::getOperationName())) + patterns.add(convertRoundEvenOp); + if (filter(RsqrtOp::getOperationName())) + patterns.add(convertRsqrtOp); + if (filter(ClampFOp::getOperationName())) + patterns.add(convertClampfOp); } -void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { - patterns.add(convertRsqrtOp); -} +//===----------------------------------------------------------------------===// +// MathExpandOpsPass pass +//===----------------------------------------------------------------------===// +namespace { +struct MathExpandOpsPass final + : math::impl::MathExpandOpsPassBase { + using MathExpandOpsPassBase::MathExpandOpsPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SmallVector mnemonics = + llvm::to_vector_of(opMnemonics); + math::populateExpansionPatterns(patterns, mnemonics); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 1420acaa40d35..615c607efc3c3 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -1,7 +1,9 @@ -// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER // CHECK-LABEL: func @tanh func.func @tanh(%arg: f32) -> f32 { + // CHECK-FILTER-NOT: math.tanh %res = math.tanh %arg : f32 return %res : f32 } @@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 { // CHECK-LABEL: func @vector_tanh func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> { // CHECK-NOT: math.tanh + // CHECK-FILTER-NOT: math.tanh %res = math.tanh %arg : vector<4xf32> return %res : vector<4xf32> } @@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> { // CHECK-LABEL: func @tan func.func @tan(%arg: f32) -> f32 { + // CHECK-FILTER-NOT: math.tan %res = math.tan %arg : f32 return %res : f32 } @@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 { // CHECK-LABEL: func @vector_tan func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> { + // CHECK-FILTER-NOT: math.tan %res = math.tan %arg : vector<4xf32> return %res : vector<4xf32> } @@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> { // ----- func.func @ctlz(%arg: i32) -> i32 { + // CHECK-FILTER: math.ctlz %res = math.ctlz %arg : i32 return %res : i32 } @@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 { // ----- func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> { + // CHECK-FILTER: math.ctlz %res = math.ctlz %arg : vector<4xi32> return %res : vector<4xi32> } @@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 { // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]] // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]] // CHECK-NEXT: return [[ADDF]] + // CHECK-FILTER: math.ceil %ret = math.ceil %a : f64 return %ret : f64 } @@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 { // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]] // CHECK: [[EXP:%.+]] = math.exp [[MULF]] // CHECK: return [[EXP]] + // CHECK-FILTER: math.exp2 %ret = math.exp2 %a : f64 return %ret : f64 } @@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{ %a = math.rsqrt %arg : tensor<*xf32> return %a: tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func.func @clampf_scalar_op +// CHECK-SAME: (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16) +// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16 +// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16 +// CHECK: return %[[V1]] : f16 + +func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 { + %a = math.clampf %arg to [%min, %max] : f16 + return %a: f16 +} + +// CHECK-LABEL: func.func @clampf_vector_op +// CHECK-SAME: (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>) +// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath : vector<3x4xf32> +// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath : vector<3x4xf32> +// CHECK: return %[[V1]] : vector<3x4xf32> + +func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{ + %a = math.clampf %arg to [%min, %max] fastmath : vector<3x4xf32> + return %a: vector<3x4xf32> +} diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir index 8feadedd1860e..cb10fc4397ffc 100644 --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s // CHECK-LABEL: func @atan( @@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>) math.isnormal %t : tensor<4x?xf32> return } + +// CHECK-LABEL: func @clampf( +func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>, + %as: f32, %ms: f32, %Ms: f32, + %at: tensor, %mt: tensor, %Mt: tensor) { + // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath : vector<3x4xf32> + %rv = math.clampf %av to [%mv, %Mv] fastmath : vector<3x4xf32> + // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32 + %rs = math.clampf %as to [%ms, %Ms] fastmath : f32 + // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor + %rt = math.clampf %at to [%mt, %Mt] : tensor + return +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt index 91e70d1785369..900dff3b5e9f1 100644 --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMathTestPasses TestAlgebraicSimplification.cpp - TestExpandMath.cpp TestPolynomialApproximation.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp deleted file mode 100644 index efc1acf8bb6cd..0000000000000 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===- TestExpandMath.cpp - Test expand math op into exp form -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains test passes for expanding math operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -struct TestExpandMathPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass) - - void runOnOperation() override; - StringRef getArgument() const final { return "test-expand-math"; } - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - StringRef getDescription() const final { return "Test expanding math"; } -}; -} // namespace - -void TestExpandMathPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateExpandCtlzPattern(patterns); - populateExpandExp2FPattern(patterns); - populateExpandTanPattern(patterns); - populateExpandSinhPattern(patterns); - populateExpandCoshPattern(patterns); - populateExpandTanhPattern(patterns); - populateExpandAsinhPattern(patterns); - populateExpandAcoshPattern(patterns); - populateExpandAtanhPattern(patterns); - populateExpandFmaFPattern(patterns); - populateExpandCeilFPattern(patterns); - populateExpandPowFPattern(patterns); - populateExpandFPowIPattern(patterns); - populateExpandRoundFPattern(patterns); - populateExpandRoundEvenPattern(patterns); - populateExpandRsqrtPattern(patterns); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); -} - -namespace mlir { -namespace test { -void registerTestExpandMathPass() { PassRegistration(); } -} // namespace test -} // namespace mlir diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index b599c9d8435d4..3f9d3f2125e1a 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(math-expand-ops),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \ // RUN: | mlir-runner \ // RUN: -e main -entry-point-result=void -O0 \ // RUN: -shared-libs=%mlir_c_runner_utils \ diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 14714c452503a..7b992b4ee029b 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -98,7 +98,6 @@ void registerTestDiagnosticsMetadataPass(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestEmulateNarrowTypePass(); -void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); @@ -245,7 +244,6 @@ void registerTestPasses() { mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestEmulateNarrowTypePass(); - mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering();