From 998a3a38948c9d220ddc759b8a6eee987e3ad320 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 2 Jan 2023 15:23:12 +0100 Subject: [PATCH] Add a math.cbrt instruction and lowering to libm. There's currently no way to get accurate cube roots in the math dialect. powf(x, 1/3.0) is too inaccurate in some cases. Reviewed By: akuegel Differential Revision: https://reviews.llvm.org/D140842 --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 22 +++++++++++++++++++ mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 2 ++ .../MathToLibm/convert-to-libm.mlir | 14 ++++++++++++ mlir/test/Dialect/Math/ops.mlir | 12 ++++++++++ 4 files changed, 50 insertions(+) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index 3f2a8d7cb4647..f8e9fd601304b 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -196,6 +196,28 @@ def Math_Atan2Op : Math_FloatBinaryOp<"atan2">{ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// CbrtOp +//===----------------------------------------------------------------------===// + +def Math_CbrtOp : Math_FloatUnaryOp<"cbrt"> { + let summary = "cube root of the specified value"; + let description = [{ + The `cbrt` operation computes the cube root. It takes one operand of + floating point type (i.e., scalar, tensor or vector) and returns one result + of the same type. It has no standard attributes. + + Example: + + ```mlir + // Scalar cube root value. + %a = math.cbrt %b : f64 + ``` + + Note: This op is not equivalent to powf(..., 1/3.0). + }]; +} + //===----------------------------------------------------------------------===// // CeilOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index d40666d6608c5..8a8adb5924666 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -171,6 +171,8 @@ void mlir::populateMathToLibmConversionPatterns( "atan", benefit); patterns.add>(patterns.getContext(), "atan2f", "atan2", benefit); + patterns.add>(patterns.getContext(), "cbrtf", + "cbrt", benefit); patterns.add>(patterns.getContext(), "erff", "erf", benefit); patterns.add>(patterns.getContext(), diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir index d911f8b1b8fbe..b0459d8bfcead 100644 --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -8,6 +8,8 @@ // CHECK-DAG: @expm1f(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @atan2(f64, f64) -> f64 attributes {llvm.readnone} // CHECK-DAG: @atan2f(f32, f32) -> f32 attributes {llvm.readnone} +// CHECK-DAG: @cbrt(f64) -> f64 attributes {llvm.readnone} +// CHECK-DAG: @cbrtf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @tan(f64) -> f64 attributes {llvm.readnone} // CHECK-DAG: @tanf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @tanh(f64) -> f64 attributes {llvm.readnone} @@ -241,6 +243,18 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) { return %float_result, %double_result : f32, f64 } +// CHECK-LABEL: func @cbrt_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.cbrt %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.cbrt %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + // CHECK-LABEL: func @cos_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir index 0f744f52d1c96..7e45d9bc6f74a 100644 --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -26,6 +26,18 @@ func.func @atan2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { return } +// CHECK-LABEL: func @cbrt( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @cbrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.cbrt %[[F]] : f32 + %0 = math.cbrt %f : f32 + // CHECK: %{{.*}} = math.cbrt %[[V]] : vector<4xf32> + %1 = math.cbrt %v : vector<4xf32> + // CHECK: %{{.*}} = math.cbrt %[[T]] : tensor<4x4x?xf32> + %2 = math.cbrt %t : tensor<4x4x?xf32> + return +} + // CHECK-LABEL: func @cos( // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func.func @cos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {