diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index f0f17c6adcb08..ea7a556297a76 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -48,6 +48,25 @@ struct MathPolynomialApproximationOptions { void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns); void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns); +// Adds patterns to convert to f32 around math functions for which `predicate` +// returns true. +void populateMathF32ExpansionPatterns( + RewritePatternSet &patterns, llvm::function_ref predicate); + +// Adds patterns to enable polynomial approximations for math functions for +// which `predicate` returns true. +void populateMathPolynomialApproximationPatterns( + RewritePatternSet &patterns, llvm::function_ref predicate); + +// Legacy. Calls both populateMathF32ExpansionPatterns and +// populateMathPolynomialApproximationPatterns with predicates enabling a +// certain set of math function rewrites, that probably can't be changed for +// compatibility reasons. Notice that unlike +// populateMathPolynomialApproximationPatterns(patterns, predicate), this +// overload also calls populateMathF32ExpansionPatterns. +// Prefer calling these functions directly: +// * populateMathF32ExpansionPatterns(patterns, predicate) +// * populateMathPolynomialApproximationPatterns(patterns, predicate) void populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options = {}); diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 24c892f68b503..777427de9465c 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1667,28 +1667,125 @@ void mlir::populatePolynomialApproximateErfPattern( patterns.add(patterns.getContext()); } +template +static void +populateMathF32ExpansionPattern(RewritePatternSet &patterns, + llvm::function_ref predicate) { + if (predicate(OpType::getOperationName())) { + patterns.add>(patterns.getContext()); + } +} + +void mlir::populateMathF32ExpansionPatterns( + RewritePatternSet &patterns, + llvm::function_ref predicate) { + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); +} + +template +static void populateMathPolynomialApproximationPattern( + RewritePatternSet &patterns, + llvm::function_ref predicate) { + if (predicate(OpType::getOperationName())) { + patterns.add(patterns.getContext()); + } +} + +void mlir::populateMathPolynomialApproximationPatterns( + RewritePatternSet &patterns, + llvm::function_ref predicate) { + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern< + CosOp, SinAndCosApproximation>(patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); + populateMathPolynomialApproximationPattern< + SinOp, SinAndCosApproximation>(patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); +} + void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { - // Patterns for leveraging existing f32 lowerings on other data types. - patterns - .add, ReuseF32Expansion, - ReuseF32Expansion, ReuseF32Expansion, - ReuseF32Expansion, ReuseF32Expansion, - ReuseF32Expansion, ReuseF32Expansion, - ReuseF32Expansion, ReuseF32Expansion, - ReuseF32Expansion, ReuseF32Expansion>( - patterns.getContext()); - - patterns - .add, - SinAndCosApproximation>(patterns.getContext()); + mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) -> bool { + return llvm::is_contained( + {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(), + math::TanhOp::getOperationName(), math::LogOp::getOperationName(), + math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(), + math::ErfOp::getOperationName(), math::ExpOp::getOperationName(), + math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(), + math::SinOp::getOperationName(), math::CosOp::getOperationName()}, + name); + }); + + populateMathPolynomialApproximationPatterns( + patterns, [](StringRef name) -> bool { + return llvm::is_contained( + {math::AtanOp::getOperationName(), + math::Atan2Op::getOperationName(), + math::TanhOp::getOperationName(), math::LogOp::getOperationName(), + math::Log2Op::getOperationName(), + math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(), + math::AsinOp::getOperationName(), math::AcosOp::getOperationName(), + math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(), + math::CbrtOp::getOperationName(), math::SinOp::getOperationName(), + math::CosOp::getOperationName()}, + name); + }); + if (options.enableAvx2) { - patterns.add>( - patterns.getContext()); + auto predicateRsqrt = [](StringRef name) { + return name == math::RsqrtOp::getOperationName(); + }; + mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt); + mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt); } }