Skip to content

Commit 38fe0f4

Browse files
authored
Expose materialize functions in Chlo to Stablehlo lowering (#2665)
We want to perform constant propogation through `chlo.lgamma` in Enzyme-JaX [Kevin](EnzymeAD/Enzyme-JAX#182 (comment)) mentioned he was open to exposing some materialize functions (which are currently static, and not callable from [our pass](https://github.com/EnzymeAD/Enzyme-JAX/blob/main/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp) atm) @wsmoses @GleasonK
1 parent 2f6be83 commit 38fe0f4

File tree

3 files changed

+83
-35
lines changed

3 files changed

+83
-35
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,7 @@ cc_library(
11331133
"stablehlo/transforms/VhloToVersion.cpp",
11341134
],
11351135
hdrs = [
1136+
"stablehlo/transforms/ChloDecompositionUtils.h",
11361137
"stablehlo/transforms/MapStablehloToVhlo.h",
11371138
"stablehlo/transforms/PassUtils.h",
11381139
"stablehlo/transforms/Passes.h",
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.
11+
==============================================================================*/
12+
13+
#ifndef STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_
14+
#define STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_
15+
16+
#include "mlir/IR/Value.h"
17+
#include "mlir/IR/ValueRange.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
namespace mlir {
21+
namespace stablehlo {
22+
23+
// Utility functions used in the Chlo to stablehlo legalization.
24+
25+
Value materializeLgamma(OpBuilder &rewriter, Location loc, ValueRange args);
26+
27+
Value materializeDigamma(OpBuilder &rewriter, Location loc, ValueRange args);
28+
29+
Value materializeZeta(OpBuilder &rewriter, Location loc, ValueRange args);
30+
31+
Value materializePolygamma(OpBuilder &rewriter, Location loc, ValueRange args);
32+
33+
} // namespace stablehlo
34+
} // namespace mlir
35+
36+
#endif // STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_

stablehlo/transforms/ChloLegalizeToStablehlo.cpp

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "stablehlo/dialect/BroadcastUtils.h"
4747
#include "stablehlo/dialect/ChloOps.h"
4848
#include "stablehlo/dialect/StablehloOps.h"
49+
#include "stablehlo/transforms/ChloDecompositionUtils.h"
4950
#include "stablehlo/transforms/PassUtils.h"
5051
#include "stablehlo/transforms/Passes.h"
5152

@@ -462,8 +463,7 @@ struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {
462463

463464
template <typename FTy>
464465
static Value materializeChebyshevPolynomialApproximation(
465-
ConversionPatternRewriter &rewriter, Location loc, Value x,
466-
ArrayRef<FTy> coefficients) {
466+
OpBuilder &rewriter, Location loc, Value x, ArrayRef<FTy> coefficients) {
467467
Value b0 = getConstantLike(rewriter, loc, 0.0, x);
468468
Value b1 = getConstantLike(rewriter, loc, 0.0, x);
469469
Value b2 = getConstantLike(rewriter, loc, 0.0, x);
@@ -483,9 +483,10 @@ static Value materializeChebyshevPolynomialApproximation(
483483
}
484484

485485
template <typename FTy>
486-
static Value materializeBesselI1eApproximation(
487-
ConversionPatternRewriter &rewriter, Location loc, Value x,
488-
ArrayRef<FTy> kI1eCoeffsA, ArrayRef<FTy> kI1eCoeffsB) {
486+
static Value materializeBesselI1eApproximation(OpBuilder &rewriter,
487+
Location loc, Value x,
488+
ArrayRef<FTy> kI1eCoeffsA,
489+
ArrayRef<FTy> kI1eCoeffsB) {
489490
Value z = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
490491
Value half = getConstantLike(rewriter, loc, 0.5, x);
491492
Value two = getConstantLike(rewriter, loc, 2.0, x);
@@ -515,8 +516,8 @@ static Value materializeBesselI1eApproximation(
515516
loc, rewriter.create<mlir::stablehlo::SignOp>(loc, x), select);
516517
}
517518

518-
Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
519-
Location loc, ValueRange args) {
519+
Value materializeBesselI1eApproximationF32(OpBuilder &rewriter, Location loc,
520+
ValueRange args) {
520521
Value x = args.front();
521522
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
522523
"expect f32 element type");
@@ -541,8 +542,9 @@ Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
541542
kI1eCoeffsB);
542543
}
543544

544-
static Value materializeBesselI1eApproximationF64(
545-
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
545+
static Value materializeBesselI1eApproximationF64(OpBuilder &rewriter,
546+
Location loc,
547+
ValueRange args) {
546548
Value x = args.front();
547549
assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
548550
"expect f64 element type");
@@ -586,8 +588,8 @@ static Value materializeBesselI1eApproximationF64(
586588
static Value materializeWithUpcast(ConversionPatternRewriter &rewriter,
587589
Location loc, ValueRange args,
588590
FloatType minPrecisionTy,
589-
Value callback(ConversionPatternRewriter &,
590-
Location, ValueRange)) {
591+
Value callback(OpBuilder &, Location,
592+
ValueRange)) {
591593
Type originalTy = getElementTypeOrSelf(args.front().getType());
592594
auto floatOriginalTy = dyn_cast<FloatType>(originalTy);
593595
bool needsUpcast =
@@ -645,9 +647,9 @@ struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> {
645647
};
646648

647649
template <typename FTy>
648-
static Value materializePolynomialApproximation(
649-
ConversionPatternRewriter &rewriter, Location loc, Value x,
650-
ArrayRef<FTy> coefficients) {
650+
static Value materializePolynomialApproximation(OpBuilder &rewriter,
651+
Location loc, Value x,
652+
ArrayRef<FTy> coefficients) {
651653
if (coefficients.empty()) return getConstantLike(rewriter, loc, 0.0, x);
652654

653655
Value poly = getConstantLike(rewriter, loc, coefficients[0], x);
@@ -836,7 +838,7 @@ static Value materializeErfcApproximationF64(
836838
// argument and derive the final approximation for all |x| >= 1.
837839
// This implementation is based on Cephes.
838840
static Value materializeErfcApproximationF32ForMagnitudeGeOne(
839-
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
841+
OpBuilder &rewriter, Location loc, ValueRange args) {
840842
Value x = args.front();
841843
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
842844
"expect f32 element type");
@@ -902,7 +904,7 @@ static Value materializeErfcApproximationF32ForMagnitudeGeOne(
902904
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
903905
// This implementation is based on Cephes.
904906
static Value materializeErfApproximationF32ForMagnitudeLeOne(
905-
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
907+
OpBuilder &rewriter, Location loc, ValueRange args) {
906908
Value x = args.front();
907909
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
908910
"expect f32 element type");
@@ -921,8 +923,8 @@ static Value materializeErfApproximationF32ForMagnitudeLeOne(
921923
}
922924

923925
// This is the same approximation as used in Eigen.
924-
static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
925-
Location loc, ValueRange args) {
926+
static Value materializeErfApproximationF32(OpBuilder &rewriter, Location loc,
927+
ValueRange args) {
926928
Value x = args.front();
927929
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
928930
"expect f32 element type");
@@ -958,8 +960,8 @@ static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
958960
erf, ubErf);
959961
}
960962

961-
static Value materializeErfcApproximationF32(
962-
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
963+
static Value materializeErfcApproximationF32(OpBuilder &rewriter, Location loc,
964+
ValueRange args) {
963965
Value x = args.front();
964966
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
965967
"expect f32 element type");
@@ -1041,8 +1043,7 @@ struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> {
10411043
}
10421044
};
10431045

1044-
static Value erfInv32(ConversionPatternRewriter &b, Location loc,
1045-
ValueRange args) {
1046+
static Value erfInv32(OpBuilder &b, Location loc, ValueRange args) {
10461047
constexpr int kDegree = 9;
10471048
constexpr std::array<float, 9> wLessThan5Constants = {
10481049
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
@@ -1248,6 +1249,8 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
12481249
12.507343278686904814458936853, -0.13857109526572011689554707,
12491250
9.984369578019570859563e-6, 1.50563273514931155834e-7};
12501251

1252+
} // namespace
1253+
12511254
// Compute the Lgamma function using Lanczos' approximation from "A Precision
12521255
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
12531256
// series B. Vol. 1:
@@ -1257,8 +1260,7 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
12571260
// with t(z) = z + kLanczosGamma + 1/2
12581261
// a(z) = kBaseLanczosCoeff
12591262
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
1260-
static Value materializeLgamma(ConversionPatternRewriter &rewriter,
1261-
Location loc, ValueRange args) {
1263+
Value materializeLgamma(OpBuilder &rewriter, Location loc, ValueRange args) {
12621264
// If the input is less than 0.5 use Euler's reflection formula.
12631265
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
12641266
// Let z be
@@ -1393,6 +1395,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter,
13931395
getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
13941396
}
13951397

1398+
namespace {
1399+
13961400
// Express `cosh` as
13971401
// cosh(x) = (e^x + e^-x) / 2
13981402
// = e^(x + log(1/2)) + e^(-x + log(1/2))
@@ -1403,8 +1407,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter,
14031407
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
14041408
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
14051409
// we deem this acceptable.
1406-
static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
1407-
Location loc, ValueRange operands) {
1410+
static Value materializeCoshApproximation(OpBuilder &rewriter, Location loc,
1411+
ValueRange operands) {
14081412
mlir::chlo::CoshOp::Adaptor transformed(operands);
14091413
Value x = transformed.getOperand();
14101414

@@ -1431,6 +1435,8 @@ struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
14311435
}
14321436
};
14331437

1438+
} // namespace
1439+
14341440
// Compute the Digamma function using Lanczos' approximation from "A Precision
14351441
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
14361442
// series B. Vol. 1:
@@ -1439,8 +1445,7 @@ struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
14391445
// a(z) = kBaseLanczosCoeff
14401446
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
14411447
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
1442-
static Value materializeDigamma(ConversionPatternRewriter &rewriter,
1443-
Location loc, ValueRange args) {
1448+
Value materializeDigamma(OpBuilder &rewriter, Location loc, ValueRange args) {
14441449
// If the input is less than 0.5 use Euler's reflection formula.
14451450
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
14461451
// Let z be
@@ -1545,14 +1550,16 @@ static Value materializeDigamma(ConversionPatternRewriter &rewriter,
15451550
digamma);
15461551
}
15471552

1553+
namespace {
1554+
15481555
static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc,
15491556
Value val) {
15501557
auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
15511558
return getConstantLike(
15521559
b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
15531560
}
15541561

1555-
static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
1562+
static Value materializeZeta(OpBuilder &rewriter, Location loc,
15561563
ValueRange args) {
15571564
// Implementation ported from:
15581565
// https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917
@@ -1703,8 +1710,9 @@ static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
17031710
return output;
17041711
}
17051712

1706-
static Value materializePolygamma(ConversionPatternRewriter &rewriter,
1707-
Location loc, ValueRange args) {
1713+
} // namespace
1714+
1715+
Value materializePolygamma(OpBuilder &rewriter, Location loc, ValueRange args) {
17081716
mlir::chlo::PolygammaOp::Adaptor transformed(args);
17091717
Value n = transformed.getN();
17101718
Value x = transformed.getX();
@@ -1747,6 +1755,8 @@ static Value materializePolygamma(ConversionPatternRewriter &rewriter,
17471755
result);
17481756
}
17491757

1758+
namespace {
1759+
17501760
struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
17511761
using OpConversionPattern::OpConversionPattern;
17521762

@@ -1901,8 +1911,9 @@ struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
19011911
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
19021912
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
19031913
// we deem this acceptable.
1904-
static Value materializeSinhApproximationForLargeX(
1905-
ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) {
1914+
static Value materializeSinhApproximationForLargeX(OpBuilder &rewriter,
1915+
Location loc,
1916+
ValueRange operands) {
19061917
mlir::chlo::SinhOp::Adaptor transformed(operands);
19071918
Value x = transformed.getOperand();
19081919

@@ -1918,8 +1929,8 @@ static Value materializeSinhApproximationForLargeX(
19181929
// Express `sinh` as
19191930
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
19201931
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
1921-
static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
1922-
Location loc, ValueRange operands) {
1932+
static Value materializeSinhApproximation(OpBuilder &rewriter, Location loc,
1933+
ValueRange operands) {
19231934
Value largeSinhResult =
19241935
materializeSinhApproximationForLargeX(rewriter, loc, operands);
19251936

0 commit comments

Comments
 (0)