4646#include " stablehlo/dialect/BroadcastUtils.h"
4747#include " stablehlo/dialect/ChloOps.h"
4848#include " stablehlo/dialect/StablehloOps.h"
49+ #include " stablehlo/transforms/ChloDecompositionUtils.h"
4950#include " stablehlo/transforms/PassUtils.h"
5051#include " stablehlo/transforms/Passes.h"
5152
@@ -462,8 +463,7 @@ struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {
462463
463464template <typename FTy>
464465static Value materializeChebyshevPolynomialApproximation (
465- ConversionPatternRewriter &rewriter, Location loc, Value x,
466- ArrayRef<FTy> coefficients) {
466+ OpBuilder &rewriter, Location loc, Value x, ArrayRef<FTy> coefficients) {
467467 Value b0 = getConstantLike (rewriter, loc, 0.0 , x);
468468 Value b1 = getConstantLike (rewriter, loc, 0.0 , x);
469469 Value b2 = getConstantLike (rewriter, loc, 0.0 , x);
@@ -483,9 +483,10 @@ static Value materializeChebyshevPolynomialApproximation(
483483}
484484
485485template <typename FTy>
486- static Value materializeBesselI1eApproximation (
487- ConversionPatternRewriter &rewriter, Location loc, Value x,
488- ArrayRef<FTy> kI1eCoeffsA , ArrayRef<FTy> kI1eCoeffsB ) {
486+ static Value materializeBesselI1eApproximation (OpBuilder &rewriter,
487+ Location loc, Value x,
488+ ArrayRef<FTy> kI1eCoeffsA ,
489+ ArrayRef<FTy> kI1eCoeffsB ) {
489490 Value z = rewriter.create <mlir::stablehlo::AbsOp>(loc, x);
490491 Value half = getConstantLike (rewriter, loc, 0.5 , x);
491492 Value two = getConstantLike (rewriter, loc, 2.0 , x);
@@ -515,8 +516,8 @@ static Value materializeBesselI1eApproximation(
515516 loc, rewriter.create <mlir::stablehlo::SignOp>(loc, x), select);
516517}
517518
518- Value materializeBesselI1eApproximationF32 (ConversionPatternRewriter &rewriter,
519- Location loc, ValueRange args) {
519+ Value materializeBesselI1eApproximationF32 (OpBuilder &rewriter, Location loc ,
520+ ValueRange args) {
520521 Value x = args.front ();
521522 assert (cast<ShapedType>(x.getType ()).getElementType ().isF32 () &&
522523 " expect f32 element type" );
@@ -541,8 +542,9 @@ Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
541542 kI1eCoeffsB );
542543}
543544
544- static Value materializeBesselI1eApproximationF64 (
545- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
545+ static Value materializeBesselI1eApproximationF64 (OpBuilder &rewriter,
546+ Location loc,
547+ ValueRange args) {
546548 Value x = args.front ();
547549 assert (cast<ShapedType>(x.getType ()).getElementType ().isF64 () &&
548550 " expect f64 element type" );
@@ -586,8 +588,8 @@ static Value materializeBesselI1eApproximationF64(
586588static Value materializeWithUpcast (ConversionPatternRewriter &rewriter,
587589 Location loc, ValueRange args,
588590 FloatType minPrecisionTy,
589- Value callback (ConversionPatternRewriter & ,
590- Location, ValueRange)) {
591+ Value callback (OpBuilder &, Location ,
592+ ValueRange)) {
591593 Type originalTy = getElementTypeOrSelf (args.front ().getType ());
592594 auto floatOriginalTy = dyn_cast<FloatType>(originalTy);
593595 bool needsUpcast =
@@ -645,9 +647,9 @@ struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> {
645647};
646648
647649template <typename FTy>
648- static Value materializePolynomialApproximation (
649- ConversionPatternRewriter &rewriter, Location loc, Value x,
650- ArrayRef<FTy> coefficients) {
650+ static Value materializePolynomialApproximation (OpBuilder &rewriter,
651+ Location loc, Value x,
652+ ArrayRef<FTy> coefficients) {
651653 if (coefficients.empty ()) return getConstantLike (rewriter, loc, 0.0 , x);
652654
653655 Value poly = getConstantLike (rewriter, loc, coefficients[0 ], x);
@@ -836,7 +838,7 @@ static Value materializeErfcApproximationF64(
836838// argument and derive the final approximation for all |x| >= 1.
837839// This implementation is based on Cephes.
838840static Value materializeErfcApproximationF32ForMagnitudeGeOne (
839- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
841+ OpBuilder &rewriter, Location loc, ValueRange args) {
840842 Value x = args.front ();
841843 assert (cast<ShapedType>(x.getType ()).getElementType ().isF32 () &&
842844 " expect f32 element type" );
@@ -902,7 +904,7 @@ static Value materializeErfcApproximationF32ForMagnitudeGeOne(
902904// Precondition is |x| <= 1. Use erfc approximation, otherwise.
903905// This implementation is based on Cephes.
904906static Value materializeErfApproximationF32ForMagnitudeLeOne (
905- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
907+ OpBuilder &rewriter, Location loc, ValueRange args) {
906908 Value x = args.front ();
907909 assert (cast<ShapedType>(x.getType ()).getElementType ().isF32 () &&
908910 " expect f32 element type" );
@@ -921,8 +923,8 @@ static Value materializeErfApproximationF32ForMagnitudeLeOne(
921923}
922924
923925// This is the same approximation as used in Eigen.
924- static Value materializeErfApproximationF32 (ConversionPatternRewriter &rewriter,
925- Location loc, ValueRange args) {
926+ static Value materializeErfApproximationF32 (OpBuilder &rewriter, Location loc ,
927+ ValueRange args) {
926928 Value x = args.front ();
927929 assert (cast<ShapedType>(x.getType ()).getElementType ().isF32 () &&
928930 " expect f32 element type" );
@@ -958,8 +960,8 @@ static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
958960 erf, ubErf);
959961}
960962
961- static Value materializeErfcApproximationF32 (
962- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
963+ static Value materializeErfcApproximationF32 (OpBuilder &rewriter, Location loc,
964+ ValueRange args) {
963965 Value x = args.front ();
964966 assert (cast<ShapedType>(x.getType ()).getElementType ().isF32 () &&
965967 " expect f32 element type" );
@@ -1041,8 +1043,7 @@ struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> {
10411043 }
10421044};
10431045
1044- static Value erfInv32 (ConversionPatternRewriter &b, Location loc,
1045- ValueRange args) {
1046+ static Value erfInv32 (OpBuilder &b, Location loc, ValueRange args) {
10461047 constexpr int kDegree = 9 ;
10471048 constexpr std::array<float , 9 > wLessThan5Constants = {
10481049 2 .81022636e-08f , 3 .43273939e-07f , -3 .5233877e-06f ,
@@ -1248,6 +1249,8 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
12481249 12.507343278686904814458936853 , -0.13857109526572011689554707 ,
12491250 9.984369578019570859563e-6 , 1.50563273514931155834e-7 };
12501251
1252+ } // namespace
1253+
12511254// Compute the Lgamma function using Lanczos' approximation from "A Precision
12521255// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
12531256// series B. Vol. 1:
@@ -1257,8 +1260,7 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
12571260// with t(z) = z + kLanczosGamma + 1/2
12581261// a(z) = kBaseLanczosCoeff
12591262// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
1260- static Value materializeLgamma (ConversionPatternRewriter &rewriter,
1261- Location loc, ValueRange args) {
1263+ Value materializeLgamma (OpBuilder &rewriter, Location loc, ValueRange args) {
12621264 // If the input is less than 0.5 use Euler's reflection formula.
12631265 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
12641266 // Let z be
@@ -1393,6 +1395,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter,
13931395 getConstantLikeInfValue (rewriter, loc, x, /* negative=*/ false ), lgamma);
13941396}
13951397
1398+ namespace {
1399+
13961400// Express `cosh` as
13971401// cosh(x) = (e^x + e^-x) / 2
13981402// = e^(x + log(1/2)) + e^(-x + log(1/2))
@@ -1403,8 +1407,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter,
14031407// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
14041408// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
14051409// we deem this acceptable.
1406- static Value materializeCoshApproximation (ConversionPatternRewriter &rewriter,
1407- Location loc, ValueRange operands) {
1410+ static Value materializeCoshApproximation (OpBuilder &rewriter, Location loc ,
1411+ ValueRange operands) {
14081412 mlir::chlo::CoshOp::Adaptor transformed (operands);
14091413 Value x = transformed.getOperand ();
14101414
@@ -1431,6 +1435,8 @@ struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
14311435 }
14321436};
14331437
1438+ } // namespace
1439+
14341440// Compute the Digamma function using Lanczos' approximation from "A Precision
14351441// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
14361442// series B. Vol. 1:
@@ -1439,8 +1445,7 @@ struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
14391445// a(z) = kBaseLanczosCoeff
14401446// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
14411447// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
1442- static Value materializeDigamma (ConversionPatternRewriter &rewriter,
1443- Location loc, ValueRange args) {
1448+ Value materializeDigamma (OpBuilder &rewriter, Location loc, ValueRange args) {
14441449 // If the input is less than 0.5 use Euler's reflection formula.
14451450 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
14461451 // Let z be
@@ -1545,14 +1550,16 @@ static Value materializeDigamma(ConversionPatternRewriter &rewriter,
15451550 digamma);
15461551}
15471552
1553+ namespace {
1554+
15481555static Value getConstantLikeSmallestFiniteValue (OpBuilder &b, Location loc,
15491556 Value val) {
15501557 auto ty = cast<FloatType>(getElementTypeOrSelf (val.getType ()));
15511558 return getConstantLike (
15521559 b, loc, llvm::APFloat::getSmallest (ty.getFloatSemantics ()), val);
15531560}
15541561
1555- static Value materializeZeta (ConversionPatternRewriter &rewriter, Location loc,
1562+ static Value materializeZeta (OpBuilder &rewriter, Location loc,
15561563 ValueRange args) {
15571564 // Implementation ported from:
15581565 // https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917
@@ -1703,8 +1710,9 @@ static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
17031710 return output;
17041711}
17051712
1706- static Value materializePolygamma (ConversionPatternRewriter &rewriter,
1707- Location loc, ValueRange args) {
1713+ } // namespace
1714+
1715+ Value materializePolygamma (OpBuilder &rewriter, Location loc, ValueRange args) {
17081716 mlir::chlo::PolygammaOp::Adaptor transformed (args);
17091717 Value n = transformed.getN ();
17101718 Value x = transformed.getX ();
@@ -1747,6 +1755,8 @@ static Value materializePolygamma(ConversionPatternRewriter &rewriter,
17471755 result);
17481756}
17491757
1758+ namespace {
1759+
17501760struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
17511761 using OpConversionPattern::OpConversionPattern;
17521762
@@ -1901,8 +1911,9 @@ struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
19011911// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
19021912// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
19031913// we deem this acceptable.
1904- static Value materializeSinhApproximationForLargeX (
1905- ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) {
1914+ static Value materializeSinhApproximationForLargeX (OpBuilder &rewriter,
1915+ Location loc,
1916+ ValueRange operands) {
19061917 mlir::chlo::SinhOp::Adaptor transformed (operands);
19071918 Value x = transformed.getOperand ();
19081919
@@ -1918,8 +1929,8 @@ static Value materializeSinhApproximationForLargeX(
19181929// Express `sinh` as
19191930// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
19201931// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
1921- static Value materializeSinhApproximation (ConversionPatternRewriter &rewriter,
1922- Location loc, ValueRange operands) {
1932+ static Value materializeSinhApproximation (OpBuilder &rewriter, Location loc ,
1933+ ValueRange operands) {
19231934 Value largeSinhResult =
19241935 materializeSinhApproximationForLargeX (rewriter, loc, operands);
19251936
0 commit comments