Skip to content

Commit f1b9221

Browse files
sogartarsilvasean
authored andcommitted
[MLIR][Math] Add erf to math dialect
Add math.erf lowering to libm call. Add math.erf polynomial approximation. Reviewed By: silvas, ezhulenev Differential Revision: https://reviews.llvm.org/D112200
1 parent b283d55 commit f1b9221

File tree

9 files changed

+395
-2
lines changed

9 files changed

+395
-2
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,39 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
285285
}];
286286
}
287287

288+
//===----------------------------------------------------------------------===//
289+
// ErfOp
290+
//===----------------------------------------------------------------------===//
291+
292+
def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
293+
let summary = "error function of the specified value";
294+
let description = [{
295+
Syntax:
296+
297+
```
298+
operation ::= ssa-id `=` `math.erf` ssa-use `:` type
299+
```
300+
301+
The `erf` operation computes the error function. It takes one operand
302+
and returns one result of the same type. This type may be a float scalar
303+
type, a vector whose element type is float, or a tensor of floats. It has
304+
no standard attributes.
305+
306+
Example:
307+
308+
```mlir
309+
// Scalar error function value.
310+
%a = math.erf %b : f64
311+
312+
// SIMD vector element-wise error function value.
313+
%f = math.erf %g : vector<4xf32>
314+
315+
// Tensor element-wise error function value.
316+
%x = math.erf %y : tensor<4x?xf8>
317+
```
318+
}];
319+
}
320+
288321

289322
//===----------------------------------------------------------------------===//
290323
// ExpOp
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- Approximation.h - Math dialect -----------------------------*- C++-*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_
10+
#define MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_
11+
12+
#include "mlir/Dialect/Math/IR/Math.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
15+
namespace mlir {
16+
namespace math {
17+
18+
struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
19+
public:
20+
using OpRewritePattern::OpRewritePattern;
21+
22+
LogicalResult matchAndRewrite(math::ErfOp op,
23+
PatternRewriter &rewriter) const final;
24+
};
25+
26+
} // namespace math
27+
} // namespace mlir
28+
29+
#endif // MLIR_DIALECT_MATH_TRANSFORMATIONS_APPROXIMATION_H_

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
116116
VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
117117
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
118118
"atan2f", "atan2", benefit);
119+
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
120+
"erf", benefit);
119121
patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
120122
"expm1f", "expm1", benefit);
121123
patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1515
#include "mlir/Dialect/Math/IR/Math.h"
16+
#include "mlir/Dialect/Math/Transforms/Approximation.h"
1617
#include "mlir/Dialect/Math/Transforms/Passes.h"
1718
#include "mlir/Dialect/Vector/VectorOps.h"
1819
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -21,9 +22,12 @@
2122
#include "mlir/Transforms/Bufferize.h"
2223
#include "mlir/Transforms/DialectConversion.h"
2324
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25+
#include "llvm/ADT/ArrayRef.h"
2426
#include <climits>
27+
#include <cstddef>
2528

2629
using namespace mlir;
30+
using namespace mlir::math;
2731
using namespace mlir::vector;
2832

2933
using TypePredicate = llvm::function_ref<bool(Type)>;
@@ -183,6 +187,24 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
183187
return exp2ValueF32;
184188
}
185189

190+
namespace {
191+
Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
192+
llvm::ArrayRef<Value> coeffs, Value x) {
193+
auto width = vectorWidth(x.getType(), isF32);
194+
if (coeffs.size() == 0) {
195+
return broadcast(builder, f32Cst(builder, 0.0f), *width);
196+
} else if (coeffs.size() == 1) {
197+
return coeffs[0];
198+
}
199+
Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
200+
coeffs[coeffs.size() - 2]);
201+
for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
202+
res = builder.create<math::FmaOp>(x, res, coeffs[i]);
203+
}
204+
return res;
205+
}
206+
} // namespace
207+
186208
//----------------------------------------------------------------------------//
187209
// TanhOp approximation.
188210
//----------------------------------------------------------------------------//
@@ -465,6 +487,122 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
465487
return success();
466488
}
467489

490+
//----------------------------------------------------------------------------//
491+
// Erf approximation.
492+
//----------------------------------------------------------------------------//
493+
494+
// Approximates erf(x) with
495+
// a - P(x)/Q(x)
496+
// where P and Q are polynomials of degree 4.
497+
// Different coefficients are chosen based on the value of x.
498+
// The approximation error is ~2.5e-07.
499+
// Boost's minimax tool that utilizes the Remez method was used to find the
500+
// coefficients.
501+
LogicalResult
502+
ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
503+
PatternRewriter &rewriter) const {
504+
auto width = vectorWidth(op.operand().getType(), isF32);
505+
if (!width.hasValue())
506+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
507+
508+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
509+
auto bcast = [&](Value value) -> Value {
510+
return broadcast(builder, value, *width);
511+
};
512+
513+
const int intervalsCount = 3;
514+
const int polyDegree = 4;
515+
516+
Value zero = bcast(f32Cst(builder, 0));
517+
Value one = bcast(f32Cst(builder, 1));
518+
Value pp[intervalsCount][polyDegree + 1];
519+
pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00));
520+
pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00));
521+
pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01));
522+
pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01));
523+
pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02));
524+
pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00));
525+
pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00));
526+
pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01));
527+
pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01));
528+
pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02));
529+
pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03));
530+
pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03));
531+
pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03));
532+
pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04));
533+
pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05));
534+
535+
Value qq[intervalsCount][polyDegree + 1];
536+
qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00));
537+
qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01));
538+
qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01));
539+
qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01));
540+
qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02));
541+
qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00));
542+
qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01));
543+
qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01));
544+
qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02));
545+
qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02));
546+
qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00));
547+
qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00));
548+
qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00));
549+
qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01));
550+
qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02));
551+
552+
Value offsets[intervalsCount];
553+
offsets[0] = bcast(f32Cst(builder, 0));
554+
offsets[1] = bcast(f32Cst(builder, 0));
555+
offsets[2] = bcast(f32Cst(builder, 1));
556+
557+
Value bounds[intervalsCount];
558+
bounds[0] = bcast(f32Cst(builder, 0.8));
559+
bounds[1] = bcast(f32Cst(builder, 2));
560+
bounds[2] = bcast(f32Cst(builder, 3.75));
561+
562+
Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
563+
op.operand(), zero);
564+
Value negArg = builder.create<arith::NegFOp>(op.operand());
565+
Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand());
566+
567+
Value offset = offsets[0];
568+
Value p[polyDegree + 1];
569+
Value q[polyDegree + 1];
570+
for (int i = 0; i <= polyDegree; ++i) {
571+
p[i] = pp[0][i];
572+
q[i] = qq[0][i];
573+
}
574+
575+
// TODO: maybe use vector stacking to reduce the number of selects.
576+
Value isLessThanBound[intervalsCount];
577+
for (int j = 0; j < intervalsCount - 1; ++j) {
578+
isLessThanBound[j] =
579+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
580+
for (int i = 0; i <= polyDegree; ++i) {
581+
p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]);
582+
q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]);
583+
}
584+
offset =
585+
builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]);
586+
}
587+
isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
588+
arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
589+
590+
Value pPoly = makePolynomialCalculation(builder, p, x);
591+
Value qPoly = makePolynomialCalculation(builder, q, x);
592+
Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
593+
Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
594+
formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1],
595+
formula, one);
596+
597+
// erf is odd function: erf(x) = -erf(-x).
598+
Value negFormula = builder.create<arith::NegFOp>(formula);
599+
Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula);
600+
601+
rewriter.replaceOp(op, res);
602+
603+
return success();
604+
}
605+
468606
//----------------------------------------------------------------------------//
469607
// Exp approximation.
470608
//----------------------------------------------------------------------------//
@@ -848,8 +986,8 @@ void mlir::populateMathPolynomialApproximationPatterns(
848986
RewritePatternSet &patterns,
849987
const MathPolynomialApproximationOptions &options) {
850988
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
851-
Log1pApproximation, ExpApproximation, ExpM1Approximation,
852-
SinAndCosApproximation<true, math::SinOp>,
989+
Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
990+
ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>,
853991
SinAndCosApproximation<false, math::CosOp>>(
854992
patterns.getContext());
855993
if (options.enableAvx2)

mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
22

3+
// CHECK-DAG: @erf(f64) -> f64
4+
// CHECK-DAG: @erff(f32) -> f32
35
// CHECK-DAG: @expm1(f64) -> f64
46
// CHECK-DAG: @expm1f(f32) -> f32
57
// CHECK-DAG: @atan2(f64, f64) -> f64
@@ -32,6 +34,18 @@ func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
3234
return %float_result, %double_result : f32, f64
3335
}
3436

37+
// CHECK-LABEL: func @erf_caller
38+
// CHECK-SAME: %[[FLOAT:.*]]: f32
39+
// CHECK-SAME: %[[DOUBLE:.*]]: f64
40+
func @erf_caller(%float: f32, %double: f64) -> (f32, f64) {
41+
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @erff(%[[FLOAT]]) : (f32) -> f32
42+
%float_result = math.erf %float : f32
43+
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @erf(%[[DOUBLE]]) : (f64) -> f64
44+
%double_result = math.erf %double : f64
45+
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
46+
return %float_result, %double_result : f32, f64
47+
}
48+
3549
// CHECK-LABEL: func @expm1_caller
3650
// CHECK-SAME: %[[FLOAT:.*]]: f32
3751
// CHECK-SAME: %[[DOUBLE:.*]]: f64

mlir/test/Dialect/Math/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
5050
return
5151
}
5252

53+
// CHECK-LABEL: func @erf(
54+
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
55+
func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
56+
// CHECK: %{{.*}} = math.erf %[[F]] : f32
57+
%0 = math.erf %f : f32
58+
// CHECK: %{{.*}} = math.erf %[[V]] : vector<4xf32>
59+
%1 = math.erf %v : vector<4xf32>
60+
// CHECK: %{{.*}} = math.erf %[[T]] : tensor<4x4x?xf32>
61+
%2 = math.erf %t : tensor<4x4x?xf32>
62+
return
63+
}
64+
5365
// CHECK-LABEL: func @exp(
5466
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
5567
func @exp(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {

0 commit comments

Comments
 (0)