-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement acos operator in MLIR Math Dialect #74584
Conversation
Required for torch-mlir. Cf. llvm/torch-mlir#2604 "Implement torch.aten.acos".
@llvm/pr-subscribers-mlir-math @llvm/pr-subscribers-mlir Author: Frederik Harwath (frederik-h) ChangesRequired for torch-mlir. Full diff: https://github.com/llvm/llvm-project/pull/74584.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index f8e9fd601304b..9742d3d936dff 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -300,6 +300,35 @@ def Math_CosOp : Math_FloatUnaryOp<"cos"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// AcosOp
+//===----------------------------------------------------------------------===//
+
+def Math_AcosOp : Math_FloatUnaryOp<"acos"> {
+ let summary = "arcus cosine of the specified value";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `math.acos` ssa-use `:` type
+ ```
+
+ The `acos` operation computes the arcus cosine of a given value. It takes one
+ operand of floating point type (i.e., scalar, tensor or vector) and returns one
+ result of the same type. It has no standard attributes.
+
+ Example:
+
+ ```mlir
+ // Scalar arcus cosine value.
+ %a = math.acos %b : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
+
+
//===----------------------------------------------------------------------===//
// SinOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 103c1fb8c3822..27c2cb9352071 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -162,6 +162,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
+ populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 28d1c062f235e..066a21c76f7d1 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -41,6 +41,24 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
[](const APInt &a) { return a.abs(); });
}
+//===----------------------------------------------------------------------===//
+// AcosOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return APFloat(acos(a.convertToDouble()));
+ case 32:
+ return APFloat(acosf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// AtanOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 4837a0cce634a..f0c4512cbfdcc 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
+// CHECK-DAG: @acos(f64) -> f64 attributes {llvm.readnone}
+// CHECK-DAG: @acosf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @atan(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @atanf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @erf(f64) -> f64 attributes {llvm.readnone}
@@ -29,6 +31,43 @@
// CHECK-DAG: @ceil(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @ceilf(f32) -> f32 attributes {llvm.readnone}
+// CHECK-LABEL: func @acos_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @acos_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @acosf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.acos %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @acos(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.acos %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
+// CHECK-LABEL: func @acos_vec_caller(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
+// CHECK: %[[OUT0_F32:.*]] = call @acosf(%[[IN0_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
+// CHECK: %[[OUT1_F32:.*]] = call @acosf(%[[IN1_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
+// CHECK: %[[OUT0_F64:.*]] = call @acos(%[[IN0_F64]]) : (f64) -> f64
+// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
+// CHECK: %[[OUT1_F64:.*]] = call @acos(%[[IN1_F64]]) : (f64) -> f64
+// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: }
+func.func @acos_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+ %float_result = math.acos %float : vector<2xf32>
+ %double_result = math.acos %double : vector<2xf64>
+ return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+
// CHECK-LABEL: func @atan_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
|
@stellaraccident @vivekkhandelwal1 Could you review this small PR or recommend someone who could do it? |
I don't have the right to do so. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me. Thanks.
Thank you! |
Required for torch-mlir.
Cf. llvm/torch-mlir#2604 "Implement torch.aten.acos".