Skip to content

Commit

Permalink
Add a math.cbrt instruction and lowering to libm.
Browse files Browse the repository at this point in the history
There's currently no way to get accurate cube roots in the math dialect.
powf(x, 1/3.0) is too inaccurate in some cases.

Reviewed By: akuegel

Differential Revision: https://reviews.llvm.org/D140842
  • Loading branch information
jreiffers committed Jan 3, 2023
1 parent 367e618 commit 998a3a3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/Math/IR/MathOps.td
Expand Up @@ -196,6 +196,28 @@ def Math_Atan2Op : Math_FloatBinaryOp<"atan2">{
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// CbrtOp
//===----------------------------------------------------------------------===//

def Math_CbrtOp : Math_FloatUnaryOp<"cbrt"> {
let summary = "cube root of the specified value";
let description = [{
The `cbrt` operation computes the cube root. It takes one operand of
floating point type (i.e., scalar, tensor or vector) and returns one result
of the same type. It has no standard attributes.

Example:

```mlir
// Scalar cube root value.
%a = math.cbrt %b : f64
```

Note: This op is not equivalent to powf(..., 1/3.0).
}];
}

//===----------------------------------------------------------------------===//
// CeilOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Expand Up @@ -171,6 +171,8 @@ void mlir::populateMathToLibmConversionPatterns(
"atan", benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
"atan2f", "atan2", benefit);
patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(patterns.getContext(), "cbrtf",
"cbrt", benefit);
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
"erf", benefit);
patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
Expand Up @@ -8,6 +8,8 @@
// CHECK-DAG: @expm1f(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @atan2(f64, f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @atan2f(f32, f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @cbrt(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @cbrtf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @tan(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @tanf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @tanh(f64) -> f64 attributes {llvm.readnone}
Expand Down Expand Up @@ -241,6 +243,18 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}

// CHECK-LABEL: func @cbrt_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
%float_result = math.cbrt %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.cbrt %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}

// CHECK-LABEL: func @cos_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Math/ops.mlir
Expand Up @@ -26,6 +26,18 @@ func.func @atan2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}

// CHECK-LABEL: func @cbrt(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @cbrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
// CHECK: %{{.*}} = math.cbrt %[[F]] : f32
%0 = math.cbrt %f : f32
// CHECK: %{{.*}} = math.cbrt %[[V]] : vector<4xf32>
%1 = math.cbrt %v : vector<4xf32>
// CHECK: %{{.*}} = math.cbrt %[[T]] : tensor<4x4x?xf32>
%2 = math.cbrt %t : tensor<4x4x?xf32>
return
}

// CHECK-LABEL: func @cos(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @cos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
Expand Down

0 comments on commit 998a3a3

Please sign in to comment.