-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[MLIR][ROCDL] Convert math::fpowi
to ROCDL call
#122640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
lialan
commented
Jan 12, 2025
- Have to relax static assert to allow reuse of existing template patterns for conversion.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: lialan (lialan) Changes
Full diff: https://github.com/llvm/llvm-project/pull/122640.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 3b94abd88f9ed2..caa3148dedff57 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Builders.h"
namespace mlir {
@@ -58,7 +59,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
- SourceOp>::value,
+ SourceOp>::value ||
+ std::is_same_v<SourceOp, math::FPowIOp>,
"expected op with same operand and result types");
if (!op->template getParentOfType<FunctionOpInterface>()) {
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index c17bfe4f71a98d..627bed011826a2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -57,7 +57,7 @@ void mlir::populateMathToROCDLConversionPatterns(
// Handled by mathToLLVM: math::FmaOp
// Handled by mathToLLVM: math::LogOp (32-bit only)
// FIXME: math::IPowIOp
- // FIXME: math::FPowIOp
+ // Handled by mathToLLVM: math::FPowIOp
// Handled by mathToLLVM: math::RoundEvenOp
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::SqrtOp
@@ -114,6 +114,8 @@ void mlir::populateMathToROCDLConversionPatterns(
"__ocml_tan_f64", "__ocml_tan_f16");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
"__ocml_erf_f64", "__ocml_erf_f16");
+ populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
+ "__ocml_pown_f64", "__ocml_pown_f16");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e0ea18d41f66da..e4b2f01d6544ab 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -484,6 +484,24 @@ module @test_module {
// -----
+module @test_module {
+ // CHECK: llvm.func @__ocml_pown_f16(f16, i32) -> f16
+ // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
+ // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
+ // CHECK-LABEL: func @math_fpowi
+ func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
+ // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
+ %0 = math.fpowi %arg0, %arg3 : f16, i32
+ // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+ %1 = math.fpowi %arg1, %arg3 : f32, i32
+ // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+ %2 = math.fpowi %arg2, %arg3 : f64, i32
+ return %0, %1, %2 : f16, f32, f64
+ }
+}
+
+// -----
+
// Math operation not inside function
// Ensure it not crash
|
|
019492b
to
faec837
Compare
Can you also take care of this #122640 (comment) before you land this PR? |
faec837
to
425b774
Compare