Skip to content

[mlir][Math][LLVMSPV] Add Math to LLVM-SPV conversion pass#198370

Draft
akhilgoe wants to merge 6 commits into
llvm:mainfrom
akhilgoe:akhil/llvmspv_math
Draft

[mlir][Math][LLVMSPV] Add Math to LLVM-SPV conversion pass#198370
akhilgoe wants to merge 6 commits into
llvm:mainfrom
akhilgoe:akhil/llvmspv_math

Conversation

@akhilgoe
Copy link
Copy Markdown
Contributor

This PR adds a conversion pass with patterns to convert supported math ops to SPIR-V builtin calls for the chosen extended instruction set. For OpenCL these lowerings correspond to OpExtInst calls into the OpenCL SPIR-V extended instruction set via mangled __spirv_ocl_ entry points for f32/f64 variants.

@llvmorg-github-actions
Copy link
Copy Markdown

@llvm/pr-subscribers-mlir

Author: Akhil Goel (akhilgoe)

Changes

This PR adds a conversion pass with patterns to convert supported math ops to SPIR-V builtin calls for the chosen extended instruction set. For OpenCL these lowerings correspond to OpExtInst calls into the OpenCL SPIR-V extended instruction set via mangled __spirv_ocl_ entry points for f32/f64 variants.


Patch is 31.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/198370.diff

8 Files Affected:

  • (added) mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h (+27)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+22)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+11-7)
  • (added) mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt (+24)
  • (added) mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp (+143)
  • (added) mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir (+414)
diff --git a/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
new file mode 100644
index 0000000000000..448b65600f930
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
@@ -0,0 +1,27 @@
+//===- MathToLLVMSPV.h - Utils for converting Math to LLVMSPV -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+#define MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to OCL LLVM-SPV
+/// builtin calls.
+void populateMathToOCLExtSetLLVMSPVConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit = 1);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a54b98004c3b6..8f6e080ad55c0 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -48,6 +48,7 @@
 #include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d401b56c7602d..548b1351fae02 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -838,6 +838,28 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToLLVMSPV
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLLVMSPV : Pass<"convert-math-to-llvm-spv", "ModuleOp"> {
+  let summary = "Convert Math dialect to LLVM SPV builtin calls";
+  let description = [{
+    This pass converts supported Math ops to function calls for SPIR-V
+    math intrinsics.
+
+    The extensionSetName option specifies the instruction set chosen for
+    math op lowerings.  
+  }];
+  let dependentDialects = [
+    "func::FuncDialect",
+    "LLVM::LLVMDialect",
+    "vector::VectorDialect"];
+  let options = [Option<"extensionSetName", "extension-set-name", "std::string",
+                        /*default=*/"\"\"",
+                        "SPIR-V Extension set to use for math lowering">];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToLibm
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e17988b12cade..7a2e745a3a64c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
+add_subdirectory(MathToLLVMSPV)
 add_subdirectory(MathToNVVM)
 add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 9f36e5c369d06..cb9b6da071839 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -54,14 +54,14 @@ using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
 template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
-  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
-                                StringRef f32Func, StringRef f64Func,
-                                StringRef f32ApproxFunc, StringRef f16Func,
-                                StringRef i32Func = "",
-                                PatternBenefit benefit = 1)
+  explicit OpToFuncCallLowering(
+      const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func,
+      StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func = "",
+      PatternBenefit benefit = 1,
+      LLVM::cconv::CConv cconv = LLVM::cconv::CConv::C)
       : ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
         f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
-        i32Func(i32Func) {}
+        i32Func(i32Func), cconv(cconv) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -104,6 +104,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
     auto callOp =
         LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
+    callOp.setCConv(cconv);
 
     if (resultType == adaptor.getOperands().front().getType()) {
       rewriter.replaceOp(op, {callOp.getResult()});
@@ -171,7 +172,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     // location as debug info metadata inside of a function cannot be used
     // outside of that function.
     auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
-    return LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    auto newFuncOp = LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    newFuncOp.setCConv(cconv);
+    return newFuncOp;
   }
 
   StringRef getFunctionName(Type type, SourceOp op) const {
@@ -202,6 +205,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   const std::string f32ApproxFunc;
   const std::string f16Func;
   const std::string i32Func;
+  const LLVM::cconv::CConv cconv;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
new file mode 100644
index 0000000000000..34279187b1c21
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_conversion_library(MLIRMathToLLVMSPV
+  MathToLLVMSPV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVMSPV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
new file mode 100644
index 0000000000000..82a32b3966981
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
@@ -0,0 +1,143 @@
+//===-- MathToLLVMSPV.cpp - conversion from Math to SPIR-V builtin calls --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-llvm-spv"
+
+static bool isExtensionSetSupported(StringRef name) {
+  return name == "OpenCL.std";
+}
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns,
+                               PatternBenefit benefit, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                           /*f32ApproxFunc=*/"", /*f16Func=*/"",
+                                           /*i32Func=*/"", benefit,
+                                           LLVM::cconv::CConv::SPIR_FUNC);
+}
+
+template <typename OpTy>
+static void populateOCLExtSetOpPatterns(const LLVMTypeConverter &converter,
+                                        RewritePatternSet &patterns,
+                                        PatternBenefit benefit,
+                                        StringRef opName) {
+  std::string mangledName =
+      "_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str();
+  populateOpPatterns<OpTy>(converter, patterns, benefit, mangledName + "f",
+                           mangledName + "d");
+}
+
+void mlir::populateMathToOCLExtSetLLVMSPVConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  populateOCLExtSetOpPatterns<math::AcosOp>(converter, patterns, benefit,
+                                            "acos");
+  populateOCLExtSetOpPatterns<math::AcoshOp>(converter, patterns, benefit,
+                                             "acosh");
+  populateOCLExtSetOpPatterns<math::AsinOp>(converter, patterns, benefit,
+                                            "asin");
+  populateOCLExtSetOpPatterns<math::AsinhOp>(converter, patterns, benefit,
+                                             "asinh");
+  populateOCLExtSetOpPatterns<math::AtanOp>(converter, patterns, benefit,
+                                            "atan");
+  populateOCLExtSetOpPatterns<math::Atan2Op>(converter, patterns, benefit,
+                                             "atan2");
+  populateOCLExtSetOpPatterns<math::AtanhOp>(converter, patterns, benefit,
+                                             "atanh");
+  populateOCLExtSetOpPatterns<math::CbrtOp>(converter, patterns, benefit,
+                                            "cbrt");
+  populateOCLExtSetOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+                                                "copysign");
+  populateOCLExtSetOpPatterns<math::CosOp>(converter, patterns, benefit, "cos");
+  populateOCLExtSetOpPatterns<math::CoshOp>(converter, patterns, benefit,
+                                            "cosh");
+  populateOCLExtSetOpPatterns<math::ErfOp>(converter, patterns, benefit, "erf");
+  populateOCLExtSetOpPatterns<math::ErfcOp>(converter, patterns, benefit,
+                                            "erfc");
+  populateOCLExtSetOpPatterns<math::ExpOp>(converter, patterns, benefit, "exp");
+  populateOCLExtSetOpPatterns<math::Exp2Op>(converter, patterns, benefit,
+                                            "exp2");
+  populateOCLExtSetOpPatterns<math::ExpM1Op>(converter, patterns, benefit,
+                                             "expm1");
+  populateOCLExtSetOpPatterns<math::LogOp>(converter, patterns, benefit, "log");
+  populateOCLExtSetOpPatterns<math::Log10Op>(converter, patterns, benefit,
+                                             "log10");
+  populateOCLExtSetOpPatterns<math::Log1pOp>(converter, patterns, benefit,
+                                             "log1p");
+  populateOCLExtSetOpPatterns<math::Log2Op>(converter, patterns, benefit,
+                                            "log2");
+  populateOCLExtSetOpPatterns<math::PowFOp>(converter, patterns, benefit,
+                                            "pow");
+  populateOCLExtSetOpPatterns<math::RsqrtOp>(converter, patterns, benefit,
+                                             "rsqrt");
+  populateOCLExtSetOpPatterns<math::SinOp>(converter, patterns, benefit, "sin");
+  populateOCLExtSetOpPatterns<math::SinhOp>(converter, patterns, benefit,
+                                            "sinh");
+  populateOCLExtSetOpPatterns<math::SqrtOp>(converter, patterns, benefit,
+                                            "sqrt");
+  populateOCLExtSetOpPatterns<math::TanOp>(converter, patterns, benefit, "tan");
+  populateOCLExtSetOpPatterns<math::TanhOp>(converter, patterns, benefit,
+                                            "tanh");
+}
+
+namespace {
+struct ConvertMathToLLVMSPVPass final
+    : impl::ConvertMathToLLVMSPVBase<ConvertMathToLLVMSPVPass> {
+  using impl::ConvertMathToLLVMSPVBase<
+      ConvertMathToLLVMSPVPass>::ConvertMathToLLVMSPVBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLLVMSPVPass::runOnOperation() {
+  auto m = getOperation();
+  MLIRContext *ctx = m.getContext();
+
+  if (!isExtensionSetSupported(extensionSetName)) {
+    m.emitError() << "Unsupported extension set '" << extensionSetName << "'!";
+    return signalPassFailure();
+  }
+
+  RewritePatternSet patterns(&getContext());
+  LowerToLLVMOptions options(ctx, DataLayout(m));
+  LLVMTypeConverter converter(ctx, options);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  if (extensionSetName == "OpenCL.std") {
+    populateMathToOCLExtSetLLVMSPVConversionPatterns(converter, patterns,
+                                                     /*benefit=*/1);
+    target
+        .addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>();
+  }
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
new file mode 100644
index 0000000000000..d47927aa2b5ad
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
@@ -0,0 +1,414 @@
+// RUN: mlir-opt %s -split-input-file -convert-math-to-llvm-spv='extension-set-name=OpenCL.std' -gpu-module-to-binary | FileCheck %s 
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(f32, f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(f32, f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powf(f32, f32) -> f32
+  // CHECK-LABEL: func @math_bin_f32
+  func.func @math_bin_f32(%arg_f32_1 : f32, %arg_f32_2 : f32) -> (f32, f32, f32) {
+    %result1 = math.copysign %arg_f32_1, %arg_f32_2 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result2 = math.atan2 %arg_f32_1, %arg_f32_2 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result3 = math.powf %arg_f32_1, %arg_f32_2 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    func.return %result1, %result2, %result3 : f32, f32, f32
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(f64, f64) -> f64
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(f64, f64) -> f64
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powd(f64, f64) -> f64
+  // CHECK-LABEL: func @math_bin_f64
+  func.func @math_bin_f64(%arg_f64_1 : f64, %arg_f64_2 : f64) -> (f64, f64, f64) {
+    %result1 = math.copysign %arg_f64_1, %arg_f64_2 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    %result2 = math.atan2 %arg_f64_1, %arg_f64_2 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    %result3 = math.powf %arg_f64_1, %arg_f64_2 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result1, %result2, %result3 : f64, f64, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosd(f64) -> f64
+  // CHECK-LABEL: func @math_acos
+  func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acos %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosf(%{{.*}}) : (f32) -> f32
+    %result64 = math.acos %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(f64) -> f64
+  // CHECK-LABEL: func @math_acosh
+  func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acosh %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(%{{.*}}) : (f32) -> f32
+    %result64 = math.acosh %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asind(f64) -> f64
+  // CHECK-LABEL: func @math_asin
+  func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asin %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinf(%{{.*}}) : (f32) -> f32
+    %result64 = math.asin %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asind(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(f64) -> f64
+  // CHECK-LABEL: func @math_asinh
+  func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asinh %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(%{{.*}}) : (f32) -> f32
+    %result64 = math.asinh %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atand(f64) -> f64
+  // CHECK-LABEL: func @math_atan
+  func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atan %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanf(%{{.*}}) : (f32) -> f32
+    %result64 = math.atan %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atand(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(f64) -> f64
+  // CHECK-LABEL: func @math_atanh
+  func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atanh %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(%{{.*}}) : (f32) -> f32
+    %result64 = math.atanh %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_cbrtf(f32...
[truncated]

@llvmorg-github-actions
Copy link
Copy Markdown

@llvm/pr-subscribers-mlir-gpu

Author: Akhil Goel (akhilgoe)

Changes

This PR adds a conversion pass with patterns to convert supported math ops to SPIR-V builtin calls for the chosen extended instruction set. For OpenCL these lowerings correspond to OpExtInst calls into the OpenCL SPIR-V extended instruction set via mangled __spirv_ocl_ entry points for f32/f64 variants.


Patch is 31.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/198370.diff

8 Files Affected:

  • (added) mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h (+27)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+22)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+11-7)
  • (added) mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt (+24)
  • (added) mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp (+143)
  • (added) mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir (+414)
diff --git a/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
new file mode 100644
index 0000000000000..448b65600f930
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
@@ -0,0 +1,27 @@
+//===- MathToLLVMSPV.h - Utils for converting Math to LLVMSPV -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+#define MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to OCL LLVM-SPV
+/// builtin calls.
+void populateMathToOCLExtSetLLVMSPVConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit = 1);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a54b98004c3b6..8f6e080ad55c0 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -48,6 +48,7 @@
 #include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d401b56c7602d..548b1351fae02 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -838,6 +838,28 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToLLVMSPV
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLLVMSPV : Pass<"convert-math-to-llvm-spv", "ModuleOp"> {
+  let summary = "Convert Math dialect to LLVM SPV builtin calls";
+  let description = [{
+    This pass converts supported Math ops to function calls for SPIR-V
+    math intrinsics.
+
+    The extensionSetName option specifies the instruction set chosen for
+    math op lowerings.  
+  }];
+  let dependentDialects = [
+    "func::FuncDialect",
+    "LLVM::LLVMDialect",
+    "vector::VectorDialect"];
+  let options = [Option<"extensionSetName", "extension-set-name", "std::string",
+                        /*default=*/"\"\"",
+                        "SPIR-V Extension set to use for math lowering">];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToLibm
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e17988b12cade..7a2e745a3a64c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
+add_subdirectory(MathToLLVMSPV)
 add_subdirectory(MathToNVVM)
 add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 9f36e5c369d06..cb9b6da071839 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -54,14 +54,14 @@ using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
 template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
-  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
-                                StringRef f32Func, StringRef f64Func,
-                                StringRef f32ApproxFunc, StringRef f16Func,
-                                StringRef i32Func = "",
-                                PatternBenefit benefit = 1)
+  explicit OpToFuncCallLowering(
+      const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func,
+      StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func = "",
+      PatternBenefit benefit = 1,
+      LLVM::cconv::CConv cconv = LLVM::cconv::CConv::C)
       : ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
         f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
-        i32Func(i32Func) {}
+        i32Func(i32Func), cconv(cconv) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -104,6 +104,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
     auto callOp =
         LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
+    callOp.setCConv(cconv);
 
     if (resultType == adaptor.getOperands().front().getType()) {
       rewriter.replaceOp(op, {callOp.getResult()});
@@ -171,7 +172,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     // location as debug info metadata inside of a function cannot be used
     // outside of that function.
     auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
-    return LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    auto newFuncOp = LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    newFuncOp.setCConv(cconv);
+    return newFuncOp;
   }
 
   StringRef getFunctionName(Type type, SourceOp op) const {
@@ -202,6 +205,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   const std::string f32ApproxFunc;
   const std::string f16Func;
   const std::string i32Func;
+  const LLVM::cconv::CConv cconv;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
new file mode 100644
index 0000000000000..34279187b1c21
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_conversion_library(MLIRMathToLLVMSPV
+  MathToLLVMSPV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVMSPV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
new file mode 100644
index 0000000000000..82a32b3966981
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
@@ -0,0 +1,143 @@
+//===-- MathToLLVMSPV.cpp - conversion from Math to SPIR-V builtin calls --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-llvm-spv"
+
+static bool isExtensionSetSupported(StringRef name) {
+  return name == "OpenCL.std";
+}
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns,
+                               PatternBenefit benefit, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                           /*f32ApproxFunc=*/"", /*f16Func=*/"",
+                                           /*i32Func=*/"", benefit,
+                                           LLVM::cconv::CConv::SPIR_FUNC);
+}
+
+template <typename OpTy>
+static void populateOCLExtSetOpPatterns(const LLVMTypeConverter &converter,
+                                        RewritePatternSet &patterns,
+                                        PatternBenefit benefit,
+                                        StringRef opName) {
+  std::string mangledName =
+      "_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str();
+  populateOpPatterns<OpTy>(converter, patterns, benefit, mangledName + "f",
+                           mangledName + "d");
+}
+
+void mlir::populateMathToOCLExtSetLLVMSPVConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  populateOCLExtSetOpPatterns<math::AcosOp>(converter, patterns, benefit,
+                                            "acos");
+  populateOCLExtSetOpPatterns<math::AcoshOp>(converter, patterns, benefit,
+                                             "acosh");
+  populateOCLExtSetOpPatterns<math::AsinOp>(converter, patterns, benefit,
+                                            "asin");
+  populateOCLExtSetOpPatterns<math::AsinhOp>(converter, patterns, benefit,
+                                             "asinh");
+  populateOCLExtSetOpPatterns<math::AtanOp>(converter, patterns, benefit,
+                                            "atan");
+  populateOCLExtSetOpPatterns<math::Atan2Op>(converter, patterns, benefit,
+                                             "atan2");
+  populateOCLExtSetOpPatterns<math::AtanhOp>(converter, patterns, benefit,
+                                             "atanh");
+  populateOCLExtSetOpPatterns<math::CbrtOp>(converter, patterns, benefit,
+                                            "cbrt");
+  populateOCLExtSetOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+                                                "copysign");
+  populateOCLExtSetOpPatterns<math::CosOp>(converter, patterns, benefit, "cos");
+  populateOCLExtSetOpPatterns<math::CoshOp>(converter, patterns, benefit,
+                                            "cosh");
+  populateOCLExtSetOpPatterns<math::ErfOp>(converter, patterns, benefit, "erf");
+  populateOCLExtSetOpPatterns<math::ErfcOp>(converter, patterns, benefit,
+                                            "erfc");
+  populateOCLExtSetOpPatterns<math::ExpOp>(converter, patterns, benefit, "exp");
+  populateOCLExtSetOpPatterns<math::Exp2Op>(converter, patterns, benefit,
+                                            "exp2");
+  populateOCLExtSetOpPatterns<math::ExpM1Op>(converter, patterns, benefit,
+                                             "expm1");
+  populateOCLExtSetOpPatterns<math::LogOp>(converter, patterns, benefit, "log");
+  populateOCLExtSetOpPatterns<math::Log10Op>(converter, patterns, benefit,
+                                             "log10");
+  populateOCLExtSetOpPatterns<math::Log1pOp>(converter, patterns, benefit,
+                                             "log1p");
+  populateOCLExtSetOpPatterns<math::Log2Op>(converter, patterns, benefit,
+                                            "log2");
+  populateOCLExtSetOpPatterns<math::PowFOp>(converter, patterns, benefit,
+                                            "pow");
+  populateOCLExtSetOpPatterns<math::RsqrtOp>(converter, patterns, benefit,
+                                             "rsqrt");
+  populateOCLExtSetOpPatterns<math::SinOp>(converter, patterns, benefit, "sin");
+  populateOCLExtSetOpPatterns<math::SinhOp>(converter, patterns, benefit,
+                                            "sinh");
+  populateOCLExtSetOpPatterns<math::SqrtOp>(converter, patterns, benefit,
+                                            "sqrt");
+  populateOCLExtSetOpPatterns<math::TanOp>(converter, patterns, benefit, "tan");
+  populateOCLExtSetOpPatterns<math::TanhOp>(converter, patterns, benefit,
+                                            "tanh");
+}
+
+namespace {
+struct ConvertMathToLLVMSPVPass final
+    : impl::ConvertMathToLLVMSPVBase<ConvertMathToLLVMSPVPass> {
+  using impl::ConvertMathToLLVMSPVBase<
+      ConvertMathToLLVMSPVPass>::ConvertMathToLLVMSPVBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLLVMSPVPass::runOnOperation() {
+  auto m = getOperation();
+  MLIRContext *ctx = m.getContext();
+
+  if (!isExtensionSetSupported(extensionSetName)) {
+    m.emitError() << "Unsupported extension set '" << extensionSetName << "'!";
+    return signalPassFailure();
+  }
+
+  RewritePatternSet patterns(&getContext());
+  LowerToLLVMOptions options(ctx, DataLayout(m));
+  LLVMTypeConverter converter(ctx, options);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  if (extensionSetName == "OpenCL.std") {
+    populateMathToOCLExtSetLLVMSPVConversionPatterns(converter, patterns,
+                                                     /*benefit=*/1);
+    target
+        .addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>();
+  }
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
new file mode 100644
index 0000000000000..d47927aa2b5ad
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
@@ -0,0 +1,414 @@
+// RUN: mlir-opt %s -split-input-file -convert-math-to-llvm-spv='extension-set-name=OpenCL.std' -gpu-module-to-binary | FileCheck %s 
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(f32, f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(f32, f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powf(f32, f32) -> f32
+  // CHECK-LABEL: func @math_bin_f32
+  func.func @math_bin_f32(%arg_f32_1 : f32, %arg_f32_2 : f32) -> (f32, f32, f32) {
+    %result1 = math.copysign %arg_f32_1, %arg_f32_2 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result2 = math.atan2 %arg_f32_1, %arg_f32_2 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result3 = math.powf %arg_f32_1, %arg_f32_2 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    func.return %result1, %result2, %result3 : f32, f32, f32
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(f64, f64) -> f64
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(f64, f64) -> f64
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powd(f64, f64) -> f64
+  // CHECK-LABEL: func @math_bin_f64
+  func.func @math_bin_f64(%arg_f64_1 : f64, %arg_f64_2 : f64) -> (f64, f64, f64) {
+    %result1 = math.copysign %arg_f64_1, %arg_f64_2 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    %result2 = math.atan2 %arg_f64_1, %arg_f64_2 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    %result3 = math.powf %arg_f64_1, %arg_f64_2 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result1, %result2, %result3 : f64, f64, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosd(f64) -> f64
+  // CHECK-LABEL: func @math_acos
+  func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acos %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosf(%{{.*}}) : (f32) -> f32
+    %result64 = math.acos %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(f64) -> f64
+  // CHECK-LABEL: func @math_acosh
+  func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acosh %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(%{{.*}}) : (f32) -> f32
+    %result64 = math.acosh %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asind(f64) -> f64
+  // CHECK-LABEL: func @math_asin
+  func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asin %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinf(%{{.*}}) : (f32) -> f32
+    %result64 = math.asin %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asind(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(f64) -> f64
+  // CHECK-LABEL: func @math_asinh
+  func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asinh %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(%{{.*}}) : (f32) -> f32
+    %result64 = math.asinh %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atand(f64) -> f64
+  // CHECK-LABEL: func @math_atan
+  func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atan %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanf(%{{.*}}) : (f32) -> f32
+    %result64 = math.atan %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atand(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(f32) -> f32
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(f64) -> f64
+  // CHECK-LABEL: func @math_atanh
+  func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atanh %arg_f32 : f32
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(%{{.*}}) : (f32) -> f32
+    %result64 = math.atanh %arg_f64 : f64
+    // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_cbrtf(f32...
[truncated]

@akhilgoe
Copy link
Copy Markdown
Contributor Author

Moving to draft for now.

@akhilgoe akhilgoe marked this pull request as draft May 18, 2026 19:51
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 28, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

@joker-eph
Copy link
Copy Markdown
Contributor

This PR adds a conversion pass

Does it? I can't spot it...

@akhilgoe
Copy link
Copy Markdown
Contributor Author

akhilgoe commented May 28, 2026

@joker-eph The initial commit added a conversion pass. This morning I moved the patterns to MathToXeVM pass pipeline and restricted to OpenCL transforms.

Comment on lines +131 to +132
std::string mangledName =
"_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::string mangledName =
"_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str();
std::string prefix = "__spirv_ocl_";
std::string mangledName =
"_Z" + std::to_string(prefix.size() + opName.size()) + prefix + opName.str();

Just a nit to avoid "magic numbers"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great suggestion, thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@joker-eph
Copy link
Copy Markdown
Contributor

Thanks for clarifying: looks like these lowering are specific to XeVM runtime right? Seems right then.

@akhilgoe
Copy link
Copy Markdown
Contributor Author

Yes, this is relevant for XeVM/SPIR-V targets. XeVM already lowers fastmath ops to native OpenCL intrinsics, so feels natural to move this to the existing pipeline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants