|
13 | 13 |
|
14 | 14 | #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
15 | 15 | #include "mlir/Dialect/Math/IR/Math.h" |
| 16 | +#include "mlir/Dialect/Math/Transforms/Approximation.h" |
16 | 17 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
17 | 18 | #include "mlir/Dialect/Vector/VectorOps.h" |
18 | 19 | #include "mlir/Dialect/X86Vector/X86VectorDialect.h" |
|
21 | 22 | #include "mlir/Transforms/Bufferize.h" |
22 | 23 | #include "mlir/Transforms/DialectConversion.h" |
23 | 24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 25 | +#include "llvm/ADT/ArrayRef.h" |
24 | 26 | #include <climits> |
| 27 | +#include <cstddef> |
25 | 28 |
|
26 | 29 | using namespace mlir; |
| 30 | +using namespace mlir::math; |
27 | 31 | using namespace mlir::vector; |
28 | 32 |
|
29 | 33 | using TypePredicate = llvm::function_ref<bool(Type)>; |
@@ -183,6 +187,24 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { |
183 | 187 | return exp2ValueF32; |
184 | 188 | } |
185 | 189 |
|
| 190 | +namespace { |
| 191 | +Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, |
| 192 | + llvm::ArrayRef<Value> coeffs, Value x) { |
| 193 | + auto width = vectorWidth(x.getType(), isF32); |
| 194 | + if (coeffs.size() == 0) { |
| 195 | + return broadcast(builder, f32Cst(builder, 0.0f), *width); |
| 196 | + } else if (coeffs.size() == 1) { |
| 197 | + return coeffs[0]; |
| 198 | + } |
| 199 | + Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], |
| 200 | + coeffs[coeffs.size() - 2]); |
| 201 | + for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { |
| 202 | + res = builder.create<math::FmaOp>(x, res, coeffs[i]); |
| 203 | + } |
| 204 | + return res; |
| 205 | +} |
| 206 | +} // namespace |
| 207 | + |
186 | 208 | //----------------------------------------------------------------------------// |
187 | 209 | // TanhOp approximation. |
188 | 210 | //----------------------------------------------------------------------------// |
@@ -465,6 +487,122 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op, |
465 | 487 | return success(); |
466 | 488 | } |
467 | 489 |
|
| 490 | +//----------------------------------------------------------------------------// |
| 491 | +// Erf approximation. |
| 492 | +//----------------------------------------------------------------------------// |
| 493 | + |
| 494 | +// Approximates erf(x) with |
| 495 | +// a - P(x)/Q(x) |
| 496 | +// where P and Q are polynomials of degree 4. |
| 497 | +// Different coefficients are chosen based on the value of x. |
| 498 | +// The approximation error is ~2.5e-07. |
| 499 | +// Boost's minimax tool that utilizes the Remez method was used to find the |
| 500 | +// coefficients. |
| 501 | +LogicalResult |
| 502 | +ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, |
| 503 | + PatternRewriter &rewriter) const { |
| 504 | + auto width = vectorWidth(op.operand().getType(), isF32); |
| 505 | + if (!width.hasValue()) |
| 506 | + return rewriter.notifyMatchFailure(op, "unsupported operand type"); |
| 507 | + |
| 508 | + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
| 509 | + auto bcast = [&](Value value) -> Value { |
| 510 | + return broadcast(builder, value, *width); |
| 511 | + }; |
| 512 | + |
| 513 | + const int intervalsCount = 3; |
| 514 | + const int polyDegree = 4; |
| 515 | + |
| 516 | + Value zero = bcast(f32Cst(builder, 0)); |
| 517 | + Value one = bcast(f32Cst(builder, 1)); |
| 518 | + Value pp[intervalsCount][polyDegree + 1]; |
| 519 | + pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00)); |
| 520 | + pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00)); |
| 521 | + pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01)); |
| 522 | + pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01)); |
| 523 | + pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02)); |
| 524 | + pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00)); |
| 525 | + pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00)); |
| 526 | + pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01)); |
| 527 | + pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01)); |
| 528 | + pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02)); |
| 529 | + pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03)); |
| 530 | + pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03)); |
| 531 | + pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03)); |
| 532 | + pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04)); |
| 533 | + pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05)); |
| 534 | + |
| 535 | + Value qq[intervalsCount][polyDegree + 1]; |
| 536 | + qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00)); |
| 537 | + qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01)); |
| 538 | + qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01)); |
| 539 | + qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01)); |
| 540 | + qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02)); |
| 541 | + qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00)); |
| 542 | + qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01)); |
| 543 | + qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01)); |
| 544 | + qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02)); |
| 545 | + qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02)); |
| 546 | + qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00)); |
| 547 | + qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00)); |
| 548 | + qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00)); |
| 549 | + qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01)); |
| 550 | + qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02)); |
| 551 | + |
| 552 | + Value offsets[intervalsCount]; |
| 553 | + offsets[0] = bcast(f32Cst(builder, 0)); |
| 554 | + offsets[1] = bcast(f32Cst(builder, 0)); |
| 555 | + offsets[2] = bcast(f32Cst(builder, 1)); |
| 556 | + |
| 557 | + Value bounds[intervalsCount]; |
| 558 | + bounds[0] = bcast(f32Cst(builder, 0.8)); |
| 559 | + bounds[1] = bcast(f32Cst(builder, 2)); |
| 560 | + bounds[2] = bcast(f32Cst(builder, 3.75)); |
| 561 | + |
| 562 | + Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, |
| 563 | + op.operand(), zero); |
| 564 | + Value negArg = builder.create<arith::NegFOp>(op.operand()); |
| 565 | + Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand()); |
| 566 | + |
| 567 | + Value offset = offsets[0]; |
| 568 | + Value p[polyDegree + 1]; |
| 569 | + Value q[polyDegree + 1]; |
| 570 | + for (int i = 0; i <= polyDegree; ++i) { |
| 571 | + p[i] = pp[0][i]; |
| 572 | + q[i] = qq[0][i]; |
| 573 | + } |
| 574 | + |
| 575 | + // TODO: maybe use vector stacking to reduce the number of selects. |
| 576 | + Value isLessThanBound[intervalsCount]; |
| 577 | + for (int j = 0; j < intervalsCount - 1; ++j) { |
| 578 | + isLessThanBound[j] = |
| 579 | + builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); |
| 580 | + for (int i = 0; i <= polyDegree; ++i) { |
| 581 | + p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); |
| 582 | + q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); |
| 583 | + } |
| 584 | + offset = |
| 585 | + builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); |
| 586 | + } |
| 587 | + isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( |
| 588 | + arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); |
| 589 | + |
| 590 | + Value pPoly = makePolynomialCalculation(builder, p, x); |
| 591 | + Value qPoly = makePolynomialCalculation(builder, q, x); |
| 592 | + Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); |
| 593 | + Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); |
| 594 | + formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], |
| 595 | + formula, one); |
| 596 | + |
| 597 | + // erf is odd function: erf(x) = -erf(-x). |
| 598 | + Value negFormula = builder.create<arith::NegFOp>(formula); |
| 599 | + Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); |
| 600 | + |
| 601 | + rewriter.replaceOp(op, res); |
| 602 | + |
| 603 | + return success(); |
| 604 | +} |
| 605 | + |
468 | 606 | //----------------------------------------------------------------------------// |
469 | 607 | // Exp approximation. |
470 | 608 | //----------------------------------------------------------------------------// |
@@ -848,8 +986,8 @@ void mlir::populateMathPolynomialApproximationPatterns( |
848 | 986 | RewritePatternSet &patterns, |
849 | 987 | const MathPolynomialApproximationOptions &options) { |
850 | 988 | patterns.add<TanhApproximation, LogApproximation, Log2Approximation, |
851 | | - Log1pApproximation, ExpApproximation, ExpM1Approximation, |
852 | | - SinAndCosApproximation<true, math::SinOp>, |
| 989 | + Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, |
| 990 | + ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>, |
853 | 991 | SinAndCosApproximation<false, math::CosOp>>( |
854 | 992 | patterns.getContext()); |
855 | 993 | if (options.enableAvx2) |
|
0 commit comments