Skip to content

[MLIR][ROCDL] Convert math::fpowi to ROCDL call #122640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

lialan
Copy link
Member

@lialan lialan commented Jan 12, 2025

  • Have to relax static assert to allow reuse of existing template patterns for conversion.

@llvmbot
Copy link
Member

llvmbot commented Jan 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: lialan (lialan)

Changes
  • Have to relax static assert to allow reuse of existing template patterns for conversion.

Full diff: https://github.com/llvm/llvm-project/pull/122640.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+3-1)
  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+3-1)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+18)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 3b94abd88f9ed2..caa3148dedff57 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/IR/Builders.h"
 
 namespace mlir {
@@ -58,7 +59,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         "expected single result op");
 
     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value,
+                                  SourceOp>::value ||
+                      std::is_same_v<SourceOp, math::FPowIOp>,
                   "expected op with same operand and result types");
 
     if (!op->template getParentOfType<FunctionOpInterface>()) {
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index c17bfe4f71a98d..627bed011826a2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -57,7 +57,7 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::FmaOp
   // Handled by mathToLLVM: math::LogOp (32-bit only)
   // FIXME: math::IPowIOp
-  // FIXME: math::FPowIOp
+  // Handled by mathToLLVM: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp
@@ -114,6 +114,8 @@ void mlir::populateMathToROCDLConversionPatterns(
                                   "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
                                   "__ocml_erf_f64", "__ocml_erf_f16");
+  populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
+                                    "__ocml_pown_f64", "__ocml_pown_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e0ea18d41f66da..e4b2f01d6544ab 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -484,6 +484,24 @@ module @test_module {
 
 // -----
 
+module @test_module {
+  // CHECK: llvm.func @__ocml_pown_f16(f16, i32) -> f16
+  // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
+  // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
+  // CHECK-LABEL: func @math_fpowi
+  func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
+    // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
+    %0 = math.fpowi %arg0, %arg3 : f16, i32
+    // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+    %1 = math.fpowi %arg1, %arg3 : f32, i32
+    // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+    %2 = math.fpowi %arg2, %arg3 : f64, i32
+    return %0, %1, %2 : f16, f32, f64
+  }
+}
+
+// -----
+
 // Math operation not inside function
 // Ensure it not crash
 

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@lialan lialan requested a review from kuhar January 13, 2025 16:08
@lialan lialan requested a review from kuhar January 13, 2025 18:08
@lialan lialan requested a review from kuhar January 14, 2025 02:27
@kuhar
Copy link
Member

kuhar commented Jan 14, 2025

Can you also take care of this #122640 (comment) before you land this PR?

@hanhanW hanhanW merged commit 9f114af into llvm:main Jan 14, 2025
8 checks passed
@lialan lialan deleted the lialan/rocdl_lib branch January 14, 2025 04:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants