diff --git a/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp index 9a0651a5445e6..77d7ce236e435 100644 --- a/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp +++ b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp @@ -64,6 +64,8 @@ void mlir::populateConvertMathToEmitCPatterns( languageTarget); patterns.insert>(context, "round", languageTarget); + patterns.insert>( + context, "roundeven", languageTarget); patterns.insert>(context, "exp", languageTarget); patterns.insert>(context, "cos", @@ -82,4 +84,6 @@ void mlir::populateConvertMathToEmitCPatterns( languageTarget); patterns.insert>(context, "pow", languageTarget); + patterns.insert>(context, "sqrt", + languageTarget); } diff --git a/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir index 111d93de1accb..85e9a269f3f2b 100644 --- a/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir +++ b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir @@ -81,6 +81,16 @@ func.func @ceil(%arg0: f32, %arg1: f64) { return } +func.func @sqrt(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "sqrtf" + // c99-NEXT: emitc.call_opaque "sqrt" + // cpp11: emitc.call_opaque "std::sqrt" + // cpp11-NEXT: emitc.call_opaque "std::sqrt" + %0 = math.sqrt %arg0 : f32 + %1 = math.sqrt %arg1 : f64 + return +} + func.func @exp(%arg0: f32, %arg1: f64) { // c99: emitc.call_opaque "expf" // c99-NEXT: emitc.call_opaque "exp" @@ -110,3 +120,13 @@ func.func @round(%arg0: f32, %arg1: f64) { %1 = math.round %arg1 : f64 return } + +func.func @roundeven(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "roundevenf" + // c99-NEXT: emitc.call_opaque "roundeven" + // cpp11: emitc.call_opaque "std::roundeven" + // cpp11-NEXT: emitc.call_opaque "std::roundeven" + %0 = math.roundeven %arg0 : f32 + %1 = math.roundeven %arg1 : f64 + return +} diff --git a/mlir/test/Dialect/EmitC/math/ops.mlir b/mlir/test/Dialect/EmitC/math/ops.mlir new file mode 100644 index 0000000000000..759e05b111edc --- /dev/null +++ b/mlir/test/Dialect/EmitC/math/ops.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -math-expand-ops=ops=rsqrt -convert-math-to-emitc \ +// RUN: -convert-arith-to-emitc | FileCheck %s + +/// This file checks the conversion of math ops whose EmitC lowering requires +/// expansion across multiple dialects, s.a. arith. +/// The FileCheck coverage is intentionally minimal, since the full MathToEmitC +/// lowering is already covered in +/// `test/Conversion/MathToEmitC/math-to-emitc.mlir`, their expansion in +/// `test/Dialect/Math/expand-math.mlir`, but not the combination of the two. + +/// Vector cases excluded: `math.rsqrt` expands through `arith.constant` to +/// materialize the numerator `1.0`, and ArithToEmitC does not convert +/// `VectorType` to an EmitC type. + +/// Tensor cases excluded: `math.rsqrt` expands through `arith.divf`, and the +/// resulting `emitc.div` does not accept tensor operands. + +// CHECK-LABEL: func.func @rsqrt32 +// CHECK-SAME: %[[SRC:.*]]: f32) -> f32 +func.func @rsqrt32(%float: f32) -> (f32) { +// CHECK-NOT: math.sqrt +// CHECK: %[[CONST:.*]] = "emitc.constant"() <{value = 1.000000e+00 : f32}> +// CHECK: %[[SQRT:.*]] = emitc.call_opaque "sqrtf"(%[[SRC]]) +// CHECK: %[[DIV:.*]] = emitc.div %[[CONST]], %[[SQRT]] + %float_result = math.rsqrt %float : f32 +// CHECK: return %[[DIV]] : f32 + return %float_result : f32 +} + +// CHECK-LABEL: func.func @rsqrt64 +// CHECK-SAME: %[[SRC:.*]]: f64) -> f64 +func.func @rsqrt64(%float: f64) -> (f64) { +// CHECK-NOT: math.sqrt +// CHECK: %[[CONST:.*]] = "emitc.constant"() <{value = 1.000000e+00 : f64}> +// CHECK: %[[SQRT:.*]] = emitc.call_opaque "sqrt"(%[[SRC]]) +// CHECK: %[[DIV:.*]] = emitc.div %[[CONST]], %[[SQRT]] + %float_result = math.rsqrt %float : f64 +// CHECK: return %[[DIV]] : f64 + return %float_result : f64 +} + +/// `math.sqrt` is only lowered for f32/f64. +// CHECK-LABEL: func.func @negative_rsqrt16 +func.func @negative_rsqrt16(%float: f16) -> (f16) { +// CHECK: math.sqrt + %float_result = math.rsqrt %float : f16 + return %float_result : f16 +}