[mlir][Math][LLVMSPV] Add Math to LLVM-SPV conversion pass#198370
[mlir][Math][LLVMSPV] Add Math to LLVM-SPV conversion pass#198370akhilgoe wants to merge 6 commits into
Conversation
|
@llvm/pr-subscribers-mlir Author: Akhil Goel (akhilgoe) ChangesThis PR adds a conversion pass with patterns to convert supported math ops to SPIR-V builtin calls for the chosen extended instruction set. For OpenCL these lowerings correspond to Patch is 31.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/198370.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
new file mode 100644
index 0000000000000..448b65600f930
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
@@ -0,0 +1,27 @@
+//===- MathToLLVMSPV.h - Utils for converting Math to LLVMSPV -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+#define MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to OCL LLVM-SPV
+/// builtin calls.
+void populateMathToOCLExtSetLLVMSPVConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a54b98004c3b6..8f6e080ad55c0 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -48,6 +48,7 @@
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d401b56c7602d..548b1351fae02 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -838,6 +838,28 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToLLVMSPV
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLLVMSPV : Pass<"convert-math-to-llvm-spv", "ModuleOp"> {
+ let summary = "Convert Math dialect to LLVM SPV builtin calls";
+ let description = [{
+ This pass converts supported Math ops to function calls for SPIR-V
+ math intrinsics.
+
+ The extensionSetName option specifies the instruction set chosen for
+ math op lowerings.
+ }];
+ let dependentDialects = [
+ "func::FuncDialect",
+ "LLVM::LLVMDialect",
+ "vector::VectorDialect"];
+ let options = [Option<"extensionSetName", "extension-set-name", "std::string",
+ /*default=*/"\"\"",
+ "SPIR-V Extension set to use for math lowering">];
+}
+
//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e17988b12cade..7a2e745a3a64c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
+add_subdirectory(MathToLLVMSPV)
add_subdirectory(MathToNVVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 9f36e5c369d06..cb9b6da071839 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -54,14 +54,14 @@ using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
- explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
- StringRef f32Func, StringRef f64Func,
- StringRef f32ApproxFunc, StringRef f16Func,
- StringRef i32Func = "",
- PatternBenefit benefit = 1)
+ explicit OpToFuncCallLowering(
+ const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func,
+ StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func = "",
+ PatternBenefit benefit = 1,
+ LLVM::cconv::CConv cconv = LLVM::cconv::CConv::C)
: ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
- i32Func(i32Func) {}
+ i32Func(i32Func), cconv(cconv) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -104,6 +104,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp =
LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
+ callOp.setCConv(cconv);
if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult()});
@@ -171,7 +172,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
- return LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ auto newFuncOp = LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ newFuncOp.setCConv(cconv);
+ return newFuncOp;
}
StringRef getFunctionName(Type type, SourceOp op) const {
@@ -202,6 +205,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32ApproxFunc;
const std::string f16Func;
const std::string i32Func;
+ const LLVM::cconv::CConv cconv;
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
new file mode 100644
index 0000000000000..34279187b1c21
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_conversion_library(MLIRMathToLLVMSPV
+ MathToLLVMSPV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVMSPV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
new file mode 100644
index 0000000000000..82a32b3966981
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
@@ -0,0 +1,143 @@
+//===-- MathToLLVMSPV.cpp - conversion from Math to SPIR-V builtin calls --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-llvm-spv"
+
+static bool isExtensionSetSupported(StringRef name) {
+ return name == "OpenCL.std";
+}
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ /*f32ApproxFunc=*/"", /*f16Func=*/"",
+ /*i32Func=*/"", benefit,
+ LLVM::cconv::CConv::SPIR_FUNC);
+}
+
+template <typename OpTy>
+static void populateOCLExtSetOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit,
+ StringRef opName) {
+ std::string mangledName =
+ "_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str();
+ populateOpPatterns<OpTy>(converter, patterns, benefit, mangledName + "f",
+ mangledName + "d");
+}
+
+void mlir::populateMathToOCLExtSetLLVMSPVConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ populateOCLExtSetOpPatterns<math::AcosOp>(converter, patterns, benefit,
+ "acos");
+ populateOCLExtSetOpPatterns<math::AcoshOp>(converter, patterns, benefit,
+ "acosh");
+ populateOCLExtSetOpPatterns<math::AsinOp>(converter, patterns, benefit,
+ "asin");
+ populateOCLExtSetOpPatterns<math::AsinhOp>(converter, patterns, benefit,
+ "asinh");
+ populateOCLExtSetOpPatterns<math::AtanOp>(converter, patterns, benefit,
+ "atan");
+ populateOCLExtSetOpPatterns<math::Atan2Op>(converter, patterns, benefit,
+ "atan2");
+ populateOCLExtSetOpPatterns<math::AtanhOp>(converter, patterns, benefit,
+ "atanh");
+ populateOCLExtSetOpPatterns<math::CbrtOp>(converter, patterns, benefit,
+ "cbrt");
+ populateOCLExtSetOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+ "copysign");
+ populateOCLExtSetOpPatterns<math::CosOp>(converter, patterns, benefit, "cos");
+ populateOCLExtSetOpPatterns<math::CoshOp>(converter, patterns, benefit,
+ "cosh");
+ populateOCLExtSetOpPatterns<math::ErfOp>(converter, patterns, benefit, "erf");
+ populateOCLExtSetOpPatterns<math::ErfcOp>(converter, patterns, benefit,
+ "erfc");
+ populateOCLExtSetOpPatterns<math::ExpOp>(converter, patterns, benefit, "exp");
+ populateOCLExtSetOpPatterns<math::Exp2Op>(converter, patterns, benefit,
+ "exp2");
+ populateOCLExtSetOpPatterns<math::ExpM1Op>(converter, patterns, benefit,
+ "expm1");
+ populateOCLExtSetOpPatterns<math::LogOp>(converter, patterns, benefit, "log");
+ populateOCLExtSetOpPatterns<math::Log10Op>(converter, patterns, benefit,
+ "log10");
+ populateOCLExtSetOpPatterns<math::Log1pOp>(converter, patterns, benefit,
+ "log1p");
+ populateOCLExtSetOpPatterns<math::Log2Op>(converter, patterns, benefit,
+ "log2");
+ populateOCLExtSetOpPatterns<math::PowFOp>(converter, patterns, benefit,
+ "pow");
+ populateOCLExtSetOpPatterns<math::RsqrtOp>(converter, patterns, benefit,
+ "rsqrt");
+ populateOCLExtSetOpPatterns<math::SinOp>(converter, patterns, benefit, "sin");
+ populateOCLExtSetOpPatterns<math::SinhOp>(converter, patterns, benefit,
+ "sinh");
+ populateOCLExtSetOpPatterns<math::SqrtOp>(converter, patterns, benefit,
+ "sqrt");
+ populateOCLExtSetOpPatterns<math::TanOp>(converter, patterns, benefit, "tan");
+ populateOCLExtSetOpPatterns<math::TanhOp>(converter, patterns, benefit,
+ "tanh");
+}
+
+namespace {
+struct ConvertMathToLLVMSPVPass final
+ : impl::ConvertMathToLLVMSPVBase<ConvertMathToLLVMSPVPass> {
+ using impl::ConvertMathToLLVMSPVBase<
+ ConvertMathToLLVMSPVPass>::ConvertMathToLLVMSPVBase;
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLLVMSPVPass::runOnOperation() {
+ auto m = getOperation();
+ MLIRContext *ctx = m.getContext();
+
+ if (!isExtensionSetSupported(extensionSetName)) {
+ m.emitError() << "Unsupported extension set '" << extensionSetName << "'!";
+ return signalPassFailure();
+ }
+
+ RewritePatternSet patterns(&getContext());
+ LowerToLLVMOptions options(ctx, DataLayout(m));
+ LLVMTypeConverter converter(ctx, options);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ if (extensionSetName == "OpenCL.std") {
+ populateMathToOCLExtSetLLVMSPVConversionPatterns(converter, patterns,
+ /*benefit=*/1);
+ target
+ .addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>();
+ }
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
new file mode 100644
index 0000000000000..d47927aa2b5ad
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
@@ -0,0 +1,414 @@
+// RUN: mlir-opt %s -split-input-file -convert-math-to-llvm-spv='extension-set-name=OpenCL.std' -gpu-module-to-binary | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(f32, f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(f32, f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powf(f32, f32) -> f32
+ // CHECK-LABEL: func @math_bin_f32
+ func.func @math_bin_f32(%arg_f32_1 : f32, %arg_f32_2 : f32) -> (f32, f32, f32) {
+ %result1 = math.copysign %arg_f32_1, %arg_f32_2 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result2 = math.atan2 %arg_f32_1, %arg_f32_2 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result3 = math.powf %arg_f32_1, %arg_f32_2 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ func.return %result1, %result2, %result3 : f32, f32, f32
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(f64, f64) -> f64
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(f64, f64) -> f64
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powd(f64, f64) -> f64
+ // CHECK-LABEL: func @math_bin_f64
+ func.func @math_bin_f64(%arg_f64_1 : f64, %arg_f64_2 : f64) -> (f64, f64, f64) {
+ %result1 = math.copysign %arg_f64_1, %arg_f64_2 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ %result2 = math.atan2 %arg_f64_1, %arg_f64_2 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ %result3 = math.powf %arg_f64_1, %arg_f64_2 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result1, %result2, %result3 : f64, f64, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosd(f64) -> f64
+ // CHECK-LABEL: func @math_acos
+ func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(f64) -> f64
+ // CHECK-LABEL: func @math_acosh
+ func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asind(f64) -> f64
+ // CHECK-LABEL: func @math_asin
+ func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asind(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(f64) -> f64
+ // CHECK-LABEL: func @math_asinh
+ func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atand(f64) -> f64
+ // CHECK-LABEL: func @math_atan
+ func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atan %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanf(%{{.*}}) : (f32) -> f32
+ %result64 = math.atan %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atand(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(f64) -> f64
+ // CHECK-LABEL: func @math_atanh
+ func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atanh %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.atanh %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_cbrtf(f32...
[truncated]
|
|
@llvm/pr-subscribers-mlir-gpu Author: Akhil Goel (akhilgoe) ChangesThis PR adds a conversion pass with patterns to convert supported math ops to SPIR-V builtin calls for the chosen extended instruction set. For OpenCL these lowerings correspond to Patch is 31.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/198370.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
new file mode 100644
index 0000000000000..448b65600f930
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h
@@ -0,0 +1,27 @@
+//===- MathToLLVMSPV.h - Utils for converting Math to LLVMSPV -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+#define MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to OCL LLVM-SPV
+/// builtin calls.
+void populateMathToOCLExtSetLLVMSPVConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLLVMSPV_MATHTOLLVMSPV_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a54b98004c3b6..8f6e080ad55c0 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -48,6 +48,7 @@
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d401b56c7602d..548b1351fae02 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -838,6 +838,28 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToLLVMSPV
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLLVMSPV : Pass<"convert-math-to-llvm-spv", "ModuleOp"> {
+ let summary = "Convert Math dialect to LLVM SPV builtin calls";
+ let description = [{
+ This pass converts supported Math ops to function calls for SPIR-V
+ math intrinsics.
+
+ The extensionSetName option specifies the instruction set chosen for
+ math op lowerings.
+ }];
+ let dependentDialects = [
+ "func::FuncDialect",
+ "LLVM::LLVMDialect",
+ "vector::VectorDialect"];
+ let options = [Option<"extensionSetName", "extension-set-name", "std::string",
+ /*default=*/"\"\"",
+ "SPIR-V Extension set to use for math lowering">];
+}
+
//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e17988b12cade..7a2e745a3a64c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
+add_subdirectory(MathToLLVMSPV)
add_subdirectory(MathToNVVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 9f36e5c369d06..cb9b6da071839 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -54,14 +54,14 @@ using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
- explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
- StringRef f32Func, StringRef f64Func,
- StringRef f32ApproxFunc, StringRef f16Func,
- StringRef i32Func = "",
- PatternBenefit benefit = 1)
+ explicit OpToFuncCallLowering(
+ const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func,
+ StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func = "",
+ PatternBenefit benefit = 1,
+ LLVM::cconv::CConv cconv = LLVM::cconv::CConv::C)
: ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
- i32Func(i32Func) {}
+ i32Func(i32Func), cconv(cconv) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -104,6 +104,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp =
LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
+ callOp.setCConv(cconv);
if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult()});
@@ -171,7 +172,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
- return LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ auto newFuncOp = LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ newFuncOp.setCConv(cconv);
+ return newFuncOp;
}
StringRef getFunctionName(Type type, SourceOp op) const {
@@ -202,6 +205,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32ApproxFunc;
const std::string f16Func;
const std::string i32Func;
+ const LLVM::cconv::CConv cconv;
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
new file mode 100644
index 0000000000000..34279187b1c21
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_conversion_library(MLIRMathToLLVMSPV
+ MathToLLVMSPV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVMSPV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
new file mode 100644
index 0000000000000..82a32b3966981
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVMSPV/MathToLLVMSPV.cpp
@@ -0,0 +1,143 @@
+//===-- MathToLLVMSPV.cpp - conversion from Math to SPIR-V builtin calls --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLLVMSPV/MathToLLVMSPV.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOLLVMSPV
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-llvm-spv"
+
+static bool isExtensionSetSupported(StringRef name) {
+ return name == "OpenCL.std";
+}
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ /*f32ApproxFunc=*/"", /*f16Func=*/"",
+ /*i32Func=*/"", benefit,
+ LLVM::cconv::CConv::SPIR_FUNC);
+}
+
+template <typename OpTy>
+static void populateOCLExtSetOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit,
+ StringRef opName) {
+ std::string mangledName =
+ "_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str();
+ populateOpPatterns<OpTy>(converter, patterns, benefit, mangledName + "f",
+ mangledName + "d");
+}
+
+void mlir::populateMathToOCLExtSetLLVMSPVConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ populateOCLExtSetOpPatterns<math::AcosOp>(converter, patterns, benefit,
+ "acos");
+ populateOCLExtSetOpPatterns<math::AcoshOp>(converter, patterns, benefit,
+ "acosh");
+ populateOCLExtSetOpPatterns<math::AsinOp>(converter, patterns, benefit,
+ "asin");
+ populateOCLExtSetOpPatterns<math::AsinhOp>(converter, patterns, benefit,
+ "asinh");
+ populateOCLExtSetOpPatterns<math::AtanOp>(converter, patterns, benefit,
+ "atan");
+ populateOCLExtSetOpPatterns<math::Atan2Op>(converter, patterns, benefit,
+ "atan2");
+ populateOCLExtSetOpPatterns<math::AtanhOp>(converter, patterns, benefit,
+ "atanh");
+ populateOCLExtSetOpPatterns<math::CbrtOp>(converter, patterns, benefit,
+ "cbrt");
+ populateOCLExtSetOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+ "copysign");
+ populateOCLExtSetOpPatterns<math::CosOp>(converter, patterns, benefit, "cos");
+ populateOCLExtSetOpPatterns<math::CoshOp>(converter, patterns, benefit,
+ "cosh");
+ populateOCLExtSetOpPatterns<math::ErfOp>(converter, patterns, benefit, "erf");
+ populateOCLExtSetOpPatterns<math::ErfcOp>(converter, patterns, benefit,
+ "erfc");
+ populateOCLExtSetOpPatterns<math::ExpOp>(converter, patterns, benefit, "exp");
+ populateOCLExtSetOpPatterns<math::Exp2Op>(converter, patterns, benefit,
+ "exp2");
+ populateOCLExtSetOpPatterns<math::ExpM1Op>(converter, patterns, benefit,
+ "expm1");
+ populateOCLExtSetOpPatterns<math::LogOp>(converter, patterns, benefit, "log");
+ populateOCLExtSetOpPatterns<math::Log10Op>(converter, patterns, benefit,
+ "log10");
+ populateOCLExtSetOpPatterns<math::Log1pOp>(converter, patterns, benefit,
+ "log1p");
+ populateOCLExtSetOpPatterns<math::Log2Op>(converter, patterns, benefit,
+ "log2");
+ populateOCLExtSetOpPatterns<math::PowFOp>(converter, patterns, benefit,
+ "pow");
+ populateOCLExtSetOpPatterns<math::RsqrtOp>(converter, patterns, benefit,
+ "rsqrt");
+ populateOCLExtSetOpPatterns<math::SinOp>(converter, patterns, benefit, "sin");
+ populateOCLExtSetOpPatterns<math::SinhOp>(converter, patterns, benefit,
+ "sinh");
+ populateOCLExtSetOpPatterns<math::SqrtOp>(converter, patterns, benefit,
+ "sqrt");
+ populateOCLExtSetOpPatterns<math::TanOp>(converter, patterns, benefit, "tan");
+ populateOCLExtSetOpPatterns<math::TanhOp>(converter, patterns, benefit,
+ "tanh");
+}
+
+namespace {
+struct ConvertMathToLLVMSPVPass final
+ : impl::ConvertMathToLLVMSPVBase<ConvertMathToLLVMSPVPass> {
+ using impl::ConvertMathToLLVMSPVBase<
+ ConvertMathToLLVMSPVPass>::ConvertMathToLLVMSPVBase;
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLLVMSPVPass::runOnOperation() {
+ auto m = getOperation();
+ MLIRContext *ctx = m.getContext();
+
+ if (!isExtensionSetSupported(extensionSetName)) {
+ m.emitError() << "Unsupported extension set '" << extensionSetName << "'!";
+ return signalPassFailure();
+ }
+
+ RewritePatternSet patterns(&getContext());
+ LowerToLLVMOptions options(ctx, DataLayout(m));
+ LLVMTypeConverter converter(ctx, options);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ if (extensionSetName == "OpenCL.std") {
+ populateMathToOCLExtSetLLVMSPVConversionPatterns(converter, patterns,
+ /*benefit=*/1);
+ target
+ .addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::SinOp, LLVM::SqrtOp>();
+ }
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
new file mode 100644
index 0000000000000..d47927aa2b5ad
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVMSPV/math-to-llvm-spv.mlir
@@ -0,0 +1,414 @@
+// RUN: mlir-opt %s -split-input-file -convert-math-to-llvm-spv='extension-set-name=OpenCL.std' -gpu-module-to-binary | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(f32, f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(f32, f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powf(f32, f32) -> f32
+ // CHECK-LABEL: func @math_bin_f32
+ func.func @math_bin_f32(%arg_f32_1 : f32, %arg_f32_2 : f32) -> (f32, f32, f32) {
+ %result1 = math.copysign %arg_f32_1, %arg_f32_2 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result2 = math.atan2 %arg_f32_1, %arg_f32_2 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2f(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result3 = math.powf %arg_f32_1, %arg_f32_2 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ func.return %result1, %result2, %result3 : f32, f32, f32
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(f64, f64) -> f64
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(f64, f64) -> f64
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_powd(f64, f64) -> f64
+ // CHECK-LABEL: func @math_bin_f64
+ func.func @math_bin_f64(%arg_f64_1 : f64, %arg_f64_2 : f64) -> (f64, f64, f64) {
+ %result1 = math.copysign %arg_f64_1, %arg_f64_2 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_copysignd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ %result2 = math.atan2 %arg_f64_1, %arg_f64_2 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atan2d(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ %result3 = math.powf %arg_f64_1, %arg_f64_2 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_powd(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result1, %result2, %result3 : f64, f64, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acosd(f64) -> f64
+ // CHECK-LABEL: func @math_acos
+ func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acosd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(f64) -> f64
+ // CHECK-LABEL: func @math_acosh
+ func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_acoshd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asind(f64) -> f64
+ // CHECK-LABEL: func @math_asin
+ func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asind(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(f64) -> f64
+ // CHECK-LABEL: func @math_asinh
+ func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_asinhd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atand(f64) -> f64
+ // CHECK-LABEL: func @math_atan
+ func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atan %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanf(%{{.*}}) : (f32) -> f32
+ %result64 = math.atan %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atand(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(f32) -> f32
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(f64) -> f64
+ // CHECK-LABEL: func @math_atanh
+ func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atanh %arg_f32 : f32
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.atanh %arg_f64 : f64
+ // CHECK: llvm.call spir_funccc @_Z{{.*}}__spirv_ocl_atanhd(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func spir_funccc @_Z{{.*}}__spirv_ocl_cbrtf(f32...
[truncated]
|
|
Moving to draft for now. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Does it? I can't spot it... |
|
@joker-eph The initial commit added a conversion pass. This morning I moved the patterns to MathToXeVM pass pipeline and restricted to OpenCL transforms. |
| std::string mangledName = | ||
| "_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str(); |
There was a problem hiding this comment.
| std::string mangledName = | |
| "_Z" + std::to_string(12 + opName.size()) + "__spirv_ocl_" + opName.str(); | |
| std::string prefix = "__spirv_ocl_"; | |
| std::string mangledName = | |
| "_Z" + std::to_string(prefix.size() + opName.size()) + prefix + opName.str(); |
Just a nit to avoid "magic numbers"
There was a problem hiding this comment.
That's a great suggestion, thanks!
|
Thanks for clarifying: looks like these lowering are specific to XeVM runtime right? Seems right then. |
|
Yes, this is relevant for XeVM/SPIR-V targets. XeVM already lowers fastmath ops to native OpenCL intrinsics, so feels natural to move this to the existing pipeline. |
This PR adds a conversion pass with patterns to convert supported math ops to SPIR-V builtin calls for the chosen extended instruction set. For OpenCL these lowerings correspond to
OpExtInstcalls into the OpenCL SPIR-V extended instruction set via mangled__spirv_ocl_entry points for f32/f64 variants.