From 8af837573e9325d5f9f1e38bd0e4c510a58d37a1 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Fri, 19 Sep 2025 16:44:38 -0700 Subject: [PATCH 01/13] Initial mockup of the MathToXeVM pass --- .../mlir/Conversion/MathToXeVM/MathToXeVM.h | 26 +++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 17 ++ mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 24 +++ mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 159 ++++++++++++++++++ .../Conversion/MathToXeVM/math-to-xevm.mlir | 22 +++ .../MathToXeVM/native-spirv-builtins.mlir | 33 ++++ 8 files changed, 283 insertions(+) create mode 100644 mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h create mode 100644 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt create mode 100644 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp create mode 100644 mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir create mode 100644 mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h new file mode 100644 index 0000000000000..7982aa3769e84 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h @@ -0,0 +1,26 @@ +//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ +#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Math to XeVM calls. +void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index da061b269daf7..ead4d5c50046d 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -84,6 +84,7 @@ #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 1a37d057776e2..20e0b95cc5c78 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -796,6 +796,23 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// MathToXeVM +//===----------------------------------------------------------------------===// + +def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { + let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this? + let description = [{ + This pass converts supported Math ops to XeVM. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "func::FuncDialect", + "xevm::XeVMDialect", + "vector::VectorDialect", + ]; +} + //===----------------------------------------------------------------------===// // MathToEmitC //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f83c4870..bebf1b8fff3f9 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) +add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt new file mode 100644 index 0000000000000..3f389359a6a2c --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -0,0 +1,24 @@ +// TODO check if everything here is needed +add_mlir_conversion_library(MLIRMathToXeVM + MathToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRFuncDialect + MLIRGPUToGPURuntimeTransforms + MLIRMathDialect + MLIRLLVMCommonConversion + MLIRPass + MLIRTransformUtils + MLIRVectorDialect + MLIRVectorUtils + ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp new file mode 100644 index 0000000000000..e18350219ffe8 --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -0,0 +1,159 @@ +//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "../GPUCommon/GPUOpsLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-xevm" + +// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle +// native functions/intrinsics that take vector operands. + +/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. +template +struct ConvertNativeFuncPattern final : public OpConversionPattern { + + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), nativeFunc(nativeFunc) {} + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: OCL doesn't provide native int intrinsics, but check what happens + // when IGC receives a native_exp on ints anyway + // TODO: what about vectorization? + if (!isSPIRVCompatibleFloatOrVec(op.getType())) + return failure(); + + arith::FastMathFlags fastFlags = op.getFastmath(); + if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn)) + return failure(); + + // FIXME: Implement handling for vector sizes/dimensions that are not + // supported by SPIRV + SmallVector operandTypes; + for (auto operand : adaptor.getOperands()) { + if (!isSPIRVCompatibleFloatOrVec(operand.getType())) + return failure(); + operandTypes.push_back(operand.getType()); + } + LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes); + rewriter.replaceOpWithNewOp(op, funcOp, adaptor.getOperands()); + return success(); + } + + inline bool isSPIRVCompatibleFloatOrVec(Type type) const { + if (type.isFloat()) { + return true; + } else if (auto vecType = dyn_cast(type)) { + if (!vecType.getElementType().isFloat()) + return false; + // SPIRV distinguishes between vectors and matrices: OpenCL native math + // intrsinics are not compatible with matrices. + ArrayRef shape = vecType.getShape(); + if (shape.size() != 1) + return false; + // SPIRV only allows vectors of size 2, 3, 4, 8, 16. + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16) + return true; + } + return false; + } + + LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector &operandTypes) const { + // This function assumes op types have already been validated using + // isSPIRVCompatibleFloatOrVec. + using LLVM::LLVMFuncOp; + + std::string mangledNativeFunc = + "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); + + auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) { + if (type.isF32()) + mangledNativeFunc += "f"; + else if (type.isF16()) + mangledNativeFunc += "Dh"; + else if (type.isF64()) + mangledNativeFunc += "d"; + }; + + for (auto type : operandTypes) { + if (auto vecType = dyn_cast(type)) { + mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + appendFloatToMangledFunc(vecType.getElementType()); + } else + appendFloatToMangledFunc(type); + } + + auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc); + auto funcOp = + SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return funcOp; + + auto parentFunc = op->template getParentOfType(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); + + // Create a valid global location removing any metadata attached to the + // location as debug info metadata inside of a function cannot be used + // outside of that function. + auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes); + auto globalloc = op->getLoc()->template findInstanceOfOrUnknown(); + return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType); + } + + const StringRef nativeFunc; +}; + + +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_exp"); +} + +namespace { +struct ConvertMathToXeVMPass + : public impl::ConvertMathToXeVMBase { + ConvertMathToXeVMPass() = default; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToXeVMPass::runOnOperation() { + auto m = getOperation(); + //MLIRContext *ctx = m.getContext(); + + RewritePatternSet patterns(&getContext()); + populateMathToXeVMConversionPatterns(patterns); + ConversionTarget target(getContext()); + target.addLegalDialect(); + if (failed(applyPartialConversion(m, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir new file mode 100644 index 0000000000000..436d0e0941b9e --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s + +module @test_module { + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + // CHECK-LABEL: func @math_ops + func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16 + %result16 = math.exp %arg_f16 fastmath : f16 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64 + %result64 = math.exp %arg_f64 fastmath : f64 + + // CHECK: math.exp + %result_no_fast = math.exp %arg_f64 : f64 + + // TODO check fastmath + + func.return %result16, %result64 : f16, f64 + } +} \ No newline at end of file diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir new file mode 100644 index 0000000000000..f762f4b60f818 --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \ +// RUN: -debug-only=serialize-to-isa 2> %t +// RUN: FileCheck --input-file=%t %s +// +// MathToXeVM pass generates OpenCL intrinsics function calls when converting +// Math ops with `fastmath` attr to native function calls. It is assumed that +// the SPIRV backend would correctly convert these intrinsics calls to OpenCL +// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp). +// +// To ensure this assumption holds, this test verifies that the SPIRV backend +// behaves as expected. + +module @test_ocl_intrinsics attributes {gpu.container_module} { + gpu.module @kernel [#xevm.target] { + llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} { + // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16 + // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]] + %c0_f16 = llvm.mlir.constant(0. : f16) : f16 + // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64 + // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]] + %c0_f64 = llvm.mlir.constant(0. : f64) : f64 + + // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] + %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16 + // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] + %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64 + + llvm.return + } + llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + } +} From 59160d682d89a019178cf1e74433e29a5e3a75ad Mon Sep 17 00:00:00 2001 From: Ian Li Date: Mon, 22 Sep 2025 08:10:29 -0700 Subject: [PATCH 02/13] clang-format --- mlir/include/mlir/Conversion/Passes.h | 2 +- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 28 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index ead4d5c50046d..40d866ec7bf10 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -49,6 +49,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" @@ -84,7 +85,6 @@ #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" -#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" namespace mlir { diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index e18350219ffe8..e1f1205d6efaa 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -38,8 +38,9 @@ using namespace mlir; template struct ConvertNativeFuncPattern final : public OpConversionPattern { - ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), nativeFunc(nativeFunc) {} + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), nativeFunc(nativeFunc) {} LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, @@ -51,7 +52,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { return failure(); arith::FastMathFlags fastFlags = op.getFastmath(); - if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn)) + if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn)) return failure(); // FIXME: Implement handling for vector sizes/dimensions that are not @@ -63,7 +64,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { operandTypes.push_back(operand.getType()); } LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes); - rewriter.replaceOpWithNewOp(op, funcOp, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, funcOp, + adaptor.getOperands()); return success(); } @@ -79,13 +81,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { if (shape.size() != 1) return false; // SPIRV only allows vectors of size 2, 3, 4, 8, 16. - if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16) + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || + shape[0] == 16) return true; } return false; } - LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector &operandTypes) const { + LLVM::LLVMFuncOp + appendOrGetFuncOp(Op &op, const SmallVector &operandTypes) const { // This function assumes op types have already been validated using // isSPIRVCompatibleFloatOrVec. using LLVM::LLVMFuncOp; @@ -112,7 +116,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc); auto funcOp = - SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (funcOp) return funcOp; @@ -124,17 +128,17 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { // location as debug info metadata inside of a function cannot be used // outside of that function. auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes); - auto globalloc = op->getLoc()->template findInstanceOfOrUnknown(); + auto globalloc = + op->getLoc()->template findInstanceOfOrUnknown(); return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType); } const StringRef nativeFunc; }; - void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) { - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_exp"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_exp"); } namespace { @@ -147,7 +151,7 @@ struct ConvertMathToXeVMPass void ConvertMathToXeVMPass::runOnOperation() { auto m = getOperation(); - //MLIRContext *ctx = m.getContext(); + // MLIRContext *ctx = m.getContext(); RewritePatternSet patterns(&getContext()); populateMathToXeVMConversionPatterns(patterns); From 89f94eadf57b3aeb6fb416ed48f1fc7ba1cf1476 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Mon, 22 Sep 2025 08:18:45 -0700 Subject: [PATCH 03/13] fix cmake --- mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt index 3f389359a6a2c..711c6876bb168 100644 --- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -1,4 +1,4 @@ -// TODO check if everything here is needed +# TODO check if everything here is needed add_mlir_conversion_library(MLIRMathToXeVM MathToXeVM.cpp From bf4ed92c054490aa46d1273c0020c8345ce5c320 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Tue, 23 Sep 2025 16:03:43 -0700 Subject: [PATCH 04/13] update tests --- .../Conversion/MathToXeVM/math-to-xevm.mlir | 73 +++++++++++++++++-- .../MathToXeVM/native-spirv-builtins.mlir | 40 ++++++++++ 2 files changed, 107 insertions(+), 6 deletions(-) diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir index 436d0e0941b9e..00719e9b881d2 100644 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -2,21 +2,82 @@ module @test_module { // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> // CHECK-LABEL: func @math_ops - func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + func.func @math_ops() { + + %c1_f16 = arith.constant 1. : f16 + %c1_f32 = arith.constant 1. : f32 + %c1_f64 = arith.constant 1. : f64 + + // CHECK: math.exp + %res_normal_f16 = math.exp %c1_f16 : f16 + // CHECK: math.exp + %res_normal_f32 = math.exp %c1_f32 : f32 + // CHECK: math.exp + %res_normal_f64 = math.exp %c1_f64 : f64 // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16 - %result16 = math.exp %arg_f16 fastmath : f16 + %res_fast_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32 + %res_fast_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64 + %res_fast_f64 = math.exp %c1_f64 fastmath : f64 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16 + %res_afn_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32 + %res_afn_f32 = math.exp %c1_f32 fastmath : f32 // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64 - %result64 = math.exp %arg_f64 fastmath : f64 + %res_afn_f64 = math.exp %c1_f64 fastmath : f64 // CHECK: math.exp - %result_no_fast = math.exp %arg_f64 : f64 + %res_none_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: math.exp + %res_none_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: math.exp + %res_none_f64 = math.exp %c1_f64 fastmath : f64 + + %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64> + %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64> + %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64> + %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64> + %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64> + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) : (vector<2xf64>) -> vector<2xf64> + %res_v2_f64 = math.exp %v2_c1_f64 fastmath : vector<2xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) : (vector<3xf64>) -> vector<3xf64> + %res_v3_f64 = math.exp %v3_c1_f64 fastmath : vector<3xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) : (vector<4xf64>) -> vector<4xf64> + %res_v4_f64 = math.exp %v4_c1_f64 fastmath : vector<4xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) : (vector<8xf64>) -> vector<8xf64> + %res_v8_f64 = math.exp %v8_c1_f64 fastmath : vector<8xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) : (vector<16xf64>) -> vector<16xf64> + %res_v16_f64 = math.exp %v16_c1_f64 fastmath : vector<16xf64> + + %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32> + %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16> - // TODO check fastmath + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) : (vector<16xf32>) -> vector<16xf32> + %res_v16_f32 = math.exp %v16_c1_f32 fastmath : vector<16xf32> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) : (vector<4xf16>) -> vector<4xf16> + %res_v4_f16 = math.exp %v4_c1_f16 fastmath : vector<4xf16> + + %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64> + %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64> + + // CHECK: math.exp + %res_v5_f64 = math.exp %v5_c1_f64 fastmath : vector<5xf64> + // CHECK: math.exp + %res_v32_f64 = math.exp %v32_c1_f64 fastmath : vector<32xf64> - func.return %result16, %result64 : f16, f64 + return } } \ No newline at end of file diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir index f762f4b60f818..92744c9e165da 100644 --- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -16,18 +16,58 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16 // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]] %c0_f16 = llvm.mlir.constant(0. : f16) : f16 + // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32 + // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]] + %c0_f32 = llvm.mlir.constant(0. : f32) : f32 // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64 // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]] %c0_f64 = llvm.mlir.constant(0. : f64) : f64 + // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2 + // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]] + %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64> + // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3 + // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]] + %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32> + // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4 + // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]] + %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64> + // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8 + // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]] + %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64> + // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16 + // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]] + %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16> + // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16 + // CHECK: %{{.+}} = OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]] + %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32 // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64 + // CHECK: %{{.+}} = OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]] + %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64> + // CHECK: %{{.+}} = OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]] + %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32> + // CHECK: %{{.+}} = OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]] + %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64> + // CHECK: %{{.+}} = OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]] + %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64> + // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] + %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16> + llvm.return } llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64> + llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32> + llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64> + llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64> + llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16> + + } } From d7ae5f2a35f3d089fdc997b7374c473c227b9182 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Wed, 24 Sep 2025 16:40:30 -0700 Subject: [PATCH 05/13] Add support for all other native opts --- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 29 +++- .../Conversion/MathToXeVM/math-to-xevm.mlir | 136 +++++++++++++----- .../MathToXeVM/native-spirv-builtins.mlir | 13 ++ 3 files changed, 139 insertions(+), 39 deletions(-) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index e1f1205d6efaa..055cfdf064e4e 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/MathToXeVM/MathToXeVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" @@ -47,7 +48,6 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // TODO: OCL doesn't provide native int intrinsics, but check what happens // when IGC receives a native_exp on ints anyway - // TODO: what about vectorization? if (!isSPIRVCompatibleFloatOrVec(op.getType())) return failure(); @@ -56,7 +56,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { return failure(); // FIXME: Implement handling for vector sizes/dimensions that are not - // supported by SPIRV + // supported by SPIRV. SmallVector operandTypes; for (auto operand : adaptor.getOperands()) { if (!isSPIRVCompatibleFloatOrVec(operand.getType())) @@ -64,8 +64,11 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { operandTypes.push_back(operand.getType()); } LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes); - rewriter.replaceOpWithNewOp(op, funcOp, + auto callOp = rewriter.replaceOpWithNewOp(op, funcOp, adaptor.getOperands()); + arith::AttrConvertFastMathToLLVM fastAttrConverter(op); + mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; + callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); return success(); } @@ -139,6 +142,26 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) { patterns.add>(patterns.getContext(), "__spirv_ocl_native_exp"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_cos"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_exp2"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_log"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_log2"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_log10"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_powr"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_rsqrt"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_sin"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_sqrt"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_tan"); } namespace { diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir index 00719e9b881d2..8e1d20dc94d78 100644 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -4,12 +4,26 @@ module @test_module { // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 - + // // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32> + // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16> + // + // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64 + // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16 + // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + // CHECK-LABEL: func @math_ops func.func @math_ops() { @@ -18,32 +32,36 @@ module @test_module { %c1_f64 = arith.constant 1. : f64 // CHECK: math.exp - %res_normal_f16 = math.exp %c1_f16 : f16 + %exp_normal_f16 = math.exp %c1_f16 : f16 // CHECK: math.exp - %res_normal_f32 = math.exp %c1_f32 : f32 + %exp_normal_f32 = math.exp %c1_f32 : f32 // CHECK: math.exp - %res_normal_f64 = math.exp %c1_f64 : f64 - - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16 - %res_fast_f16 = math.exp %c1_f16 fastmath : f16 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32 - %res_fast_f32 = math.exp %c1_f32 fastmath : f32 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64 - %res_fast_f64 = math.exp %c1_f64 fastmath : f64 + %exp_normal_f64 = math.exp %c1_f64 : f64 + + // Check float operations are converted properly: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %exp_fast_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_fast_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %exp_fast_f64 = math.exp %c1_f64 fastmath : f64 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16 - %res_afn_f16 = math.exp %c1_f16 fastmath : f16 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32 - %res_afn_f32 = math.exp %c1_f32 fastmath : f32 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64 - %res_afn_f64 = math.exp %c1_f64 fastmath : f64 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %exp_afn_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_afn_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %exp_afn_f64 = math.exp %c1_f64 fastmath : f64 // CHECK: math.exp - %res_none_f16 = math.exp %c1_f16 fastmath : f16 + %exp_none_f16 = math.exp %c1_f16 fastmath : f16 // CHECK: math.exp - %res_none_f32 = math.exp %c1_f32 fastmath : f32 + %exp_none_f32 = math.exp %c1_f32 fastmath : f32 // CHECK: math.exp - %res_none_f64 = math.exp %c1_f64 fastmath : f64 + %exp_none_f64 = math.exp %c1_f64 fastmath : f64 + + // Check vector operations: %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64> %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64> @@ -51,32 +69,78 @@ module @test_module { %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64> %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) : (vector<2xf64>) -> vector<2xf64> - %res_v2_f64 = math.exp %v2_c1_f64 fastmath : vector<2xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) : (vector<3xf64>) -> vector<3xf64> - %res_v3_f64 = math.exp %v3_c1_f64 fastmath : vector<3xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) : (vector<4xf64>) -> vector<4xf64> - %res_v4_f64 = math.exp %v4_c1_f64 fastmath : vector<4xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) : (vector<8xf64>) -> vector<8xf64> - %res_v8_f64 = math.exp %v8_c1_f64 fastmath : vector<8xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) : (vector<16xf64>) -> vector<16xf64> - %res_v16_f64 = math.exp %v16_c1_f64 fastmath : vector<16xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<2xf64>) -> vector<2xf64> + %exp_v2_f64 = math.exp %v2_c1_f64 fastmath : vector<2xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<3xf64>) -> vector<3xf64> + %exp_v3_f64 = math.exp %v3_c1_f64 fastmath : vector<3xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<4xf64>) -> vector<4xf64> + %exp_v4_f64 = math.exp %v4_c1_f64 fastmath : vector<4xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<8xf64>) -> vector<8xf64> + %exp_v8_f64 = math.exp %v8_c1_f64 fastmath : vector<8xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<16xf64>) -> vector<16xf64> + %exp_v16_f64 = math.exp %v16_c1_f64 fastmath : vector<16xf64> %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32> %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) : (vector<16xf32>) -> vector<16xf32> - %res_v16_f32 = math.exp %v16_c1_f32 fastmath : vector<16xf32> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) : (vector<4xf16>) -> vector<4xf16> - %res_v4_f16 = math.exp %v4_c1_f16 fastmath : vector<4xf16> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<16xf32>) -> vector<16xf32> + %exp_v16_f32 = math.exp %v16_c1_f32 fastmath : vector<16xf32> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<4xf16>) -> vector<4xf16> + %exp_v4_f16 = math.exp %v4_c1_f16 fastmath : vector<4xf16> + + // Check unsupported vector sizes are not converted: %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64> %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64> // CHECK: math.exp - %res_v5_f64 = math.exp %v5_c1_f64 fastmath : vector<5xf64> + %exp_v5_f64 = math.exp %v5_c1_f64 fastmath : vector<5xf64> // CHECK: math.exp - %res_v32_f64 = math.exp %v32_c1_f64 fastmath : vector<32xf64> + %exp_v32_f64 = math.exp %v32_c1_f64 fastmath : vector<32xf64> + + // Check fastmath flags propagate properly: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath : f32 + + // Check all other math operations: + + // native_divide(gentype x, gentype y) + // TODO: convert arith.divf to arith/native_divide if option is enabled + + // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %cos_afn_f16 = math.cos %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp2_afn_f32 = math.exp2 %c1_f32 fastmath : f32 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %log_afn_f16 = math.log %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %log2_afn_f32 = math.log2 %c1_f32 fastmath : f32 + + // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %log10_afn_f64 = math.log10 %c1_f64 fastmath : f64 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16, f16) -> f16 + %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath : f64 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %sin_afn_f16 = math.sin %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath : f32 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %tan_afn_f64 = math.tan %c1_f64 fastmath : f64 return } diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir index 92744c9e165da..b83288c7ec99e 100644 --- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -57,6 +57,19 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16> + + // SPIRV backend does not currently handle fastmath flags: The SPIRV + // backend would need to generate OpDecorate calls to decorate math ops + // with FPFastMathMode/FPFastMathModeINTEL decorations. + // + // FIXME: When support for fastmath flags in the SPIRV backend is added, + // add tests here to ensure fastmath flags are converted to the correct + // OpDecorate calls. + // + // See: + // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions + // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate + llvm.return } llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 From 9d3a540934345628510df63b1ac69312d76cea8a Mon Sep 17 00:00:00 2001 From: Ian Li Date: Thu, 25 Sep 2025 12:16:24 -0700 Subject: [PATCH 06/13] finish testing for native-spirv-builtins --- .../MathToXeVM/native-spirv-builtins.mlir | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir index b83288c7ec99e..6bc90e34060b4 100644 --- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -39,25 +39,24 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]] %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16> - // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16 - // CHECK: %{{.+}} = OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]] + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]] %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32 - // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64 - // CHECK: %{{.+}} = OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]] + // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]] %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64> - // CHECK: %{{.+}} = OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]] + // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]] %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32> - // CHECK: %{{.+}} = OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]] + // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]] %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64> - // CHECK: %{{.+}} = OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]] + // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]] %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64> - // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] + // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16> - // SPIRV backend does not currently handle fastmath flags: The SPIRV // backend would need to generate OpDecorate calls to decorate math ops // with FPFastMathMode/FPFastMathModeINTEL decorations. @@ -70,8 +69,30 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]] + %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]] + %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]] + %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]] + %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]] + %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath} : (vector<8xf64>) -> vector<8xf64> + // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]] + %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16> + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]] + %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]] + %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]] + %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]] + %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + llvm.return } + llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 @@ -80,7 +101,15 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64> llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64> llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16> - - + llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64> + llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16> + llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 } } From 0232c265a210231da53081548392247ab27017f3 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Thu, 25 Sep 2025 14:59:40 -0700 Subject: [PATCH 07/13] Accomodate for arith.divf --- .../mlir/Conversion/MathToXeVM/MathToXeVM.h | 2 +- mlir/include/mlir/Conversion/Passes.td | 13 +++++++++++-- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 19 ++++++++++--------- .../Conversion/MathToXeVM/math-to-xevm.mlir | 13 ++++++++++++- .../MathToXeVM/native-spirv-builtins.mlir | 5 ++++- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h index 7982aa3769e84..6bb69361dcb6d 100644 --- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h +++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h @@ -20,7 +20,7 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" /// Populate the given list with patterns that convert from Math to XeVM calls. -void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns); +void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 20e0b95cc5c78..976e1b6b183e1 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -801,10 +801,19 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { //===----------------------------------------------------------------------===// def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { - let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this? + let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents"; let description = [{ - This pass converts supported Math ops to XeVM. + This pass converts supported math ops marked with the `afn` fastmath flag + to function calls for OpenCL `native_` math intrinsics: These intrinsics + are typically mapped directly to native device instructions, often resulting + in better performance. However, the precision/error of these intrinsics + are implementation-defined, and thus math ops are only converted when they + have the `afn` fastmath flag enabled. }]; + let options = [ + Option<"convertArith", "convert-arith", "bool", /*default=*/"true", + "Convert supported Arith ops (e.g. arith.divf) as well."> + ]; let dependentDialects = [ "arith::ArithDialect", "func::FuncDialect", diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index 055cfdf064e4e..b75f8d3640a41 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -46,8 +46,6 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // TODO: OCL doesn't provide native int intrinsics, but check what happens - // when IGC receives a native_exp on ints anyway if (!isSPIRVCompatibleFloatOrVec(op.getType())) return failure(); @@ -55,10 +53,11 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn)) return failure(); - // FIXME: Implement handling for vector sizes/dimensions that are not - // supported by SPIRV. SmallVector operandTypes; for (auto operand : adaptor.getOperands()) { + // This pass only supports operations on vectors that are already in SPIRV + // supported vector sizes: Distributing unsupported vector sizes to SPIRV + // supported vetor sizes are done in other blocking optimization passes. if (!isSPIRVCompatibleFloatOrVec(operand.getType())) return failure(); operandTypes.push_back(operand.getType()); @@ -128,7 +127,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { OpBuilder b(parentFunc); // Create a valid global location removing any metadata attached to the - // location as debug info metadata inside of a function cannot be used + // location, as debug info metadata inside of a function cannot be used // outside of that function. auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes); auto globalloc = @@ -139,7 +138,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { const StringRef nativeFunc; }; -void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) { +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith) { patterns.add>(patterns.getContext(), "__spirv_ocl_native_exp"); patterns.add>(patterns.getContext(), @@ -162,22 +161,24 @@ void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) { "__spirv_ocl_native_sqrt"); patterns.add>(patterns.getContext(), "__spirv_ocl_native_tan"); + if (convertArith) + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_divide"); } namespace { struct ConvertMathToXeVMPass : public impl::ConvertMathToXeVMBase { - ConvertMathToXeVMPass() = default; + using Base::Base; void runOnOperation() override; }; } // namespace void ConvertMathToXeVMPass::runOnOperation() { auto m = getOperation(); - // MLIRContext *ctx = m.getContext(); RewritePatternSet patterns(&getContext()); - populateMathToXeVMConversionPatterns(patterns); + populateMathToXeVMConversionPatterns(patterns, convertArith); ConversionTarget target(getContext()); target.addLegalDialect(); diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir index 8e1d20dc94d78..ba5de228da411 100644 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s +// RUN: mlir-opt %s -convert-math-to-xevm \ +// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' +// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \ +// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH' module @test_module { // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 @@ -23,6 +26,7 @@ module @test_module { // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 // CHECK-LABEL: func @math_ops func.func @math_ops() { @@ -142,6 +146,13 @@ module @test_module { // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 %tan_afn_f64 = math.tan %c1_f64 fastmath : f64 + %c6_9_f32 = arith.constant 6.9 : f32 + %c7_f32 = arith.constant 7. : f32 + + // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + // CHECK-NO-ARITH: arith.divf + %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath : f32 + return } } \ No newline at end of file diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir index 6bc90e34060b4..2492adafd6a50 100644 --- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -89,10 +89,12 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]] %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]] + %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 llvm.return } - + llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 @@ -111,5 +113,6 @@ module @test_ocl_intrinsics attributes {gpu.container_module} { llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 } } From 3887fe5fe68d0f3caa4d1c2a850684ae97543140 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Thu, 25 Sep 2025 15:12:54 -0700 Subject: [PATCH 08/13] clang-format --- .../mlir/Conversion/MathToXeVM/MathToXeVM.h | 3 +- mlir/include/mlir/Conversion/Passes.td | 10 ++--- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 37 ++++++++++--------- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h index 6bb69361dcb6d..91d3c92fd6296 100644 --- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h +++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h @@ -20,7 +20,8 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" /// Populate the given list with patterns that convert from Math to XeVM calls. -void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith); +void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 976e1b6b183e1..5817babf68ddb 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -801,7 +801,8 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { //===----------------------------------------------------------------------===// def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { - let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents"; + let summary = + "Convert (fast) math operations to native XeVM/SPIRV equivalents"; let description = [{ This pass converts supported math ops marked with the `afn` fastmath flag to function calls for OpenCL `native_` math intrinsics: These intrinsics @@ -810,10 +811,9 @@ def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { are implementation-defined, and thus math ops are only converted when they have the `afn` fastmath flag enabled. }]; - let options = [ - Option<"convertArith", "convert-arith", "bool", /*default=*/"true", - "Convert supported Arith ops (e.g. arith.divf) as well."> - ]; + let options = [Option< + "convertArith", "convert-arith", "bool", /*default=*/"true", + "Convert supported Arith ops (e.g. arith.divf) as well.">]; let dependentDialects = [ "arith::ArithDialect", "func::FuncDialect", diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index b75f8d3640a41..46833735a79dd 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -63,8 +63,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { operandTypes.push_back(operand.getType()); } LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes); - auto callOp = rewriter.replaceOpWithNewOp(op, funcOp, - adaptor.getOperands()); + auto callOp = rewriter.replaceOpWithNewOp( + op, funcOp, adaptor.getOperands()); arith::AttrConvertFastMathToLLVM fastAttrConverter(op); mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); @@ -138,32 +138,33 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { const StringRef nativeFunc; }; -void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith) { +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith) { patterns.add>(patterns.getContext(), "__spirv_ocl_native_exp"); patterns.add>(patterns.getContext(), "__spirv_ocl_native_cos"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_exp2"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_exp2"); patterns.add>(patterns.getContext(), "__spirv_ocl_native_log"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_log2"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_log10"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_powr"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_rsqrt"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_log2"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_log10"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_powr"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_rsqrt"); patterns.add>(patterns.getContext(), "__spirv_ocl_native_sin"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_sqrt"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_sqrt"); patterns.add>(patterns.getContext(), "__spirv_ocl_native_tan"); if (convertArith) - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_divide"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_divide"); } namespace { From 31c911e12f99d15879d1b33cccb6a5c1c622a54c Mon Sep 17 00:00:00 2001 From: Ian Li Date: Fri, 26 Sep 2025 14:19:38 -0700 Subject: [PATCH 09/13] remove todos --- mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 5 ----- mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir | 3 --- 2 files changed, 8 deletions(-) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt index 711c6876bb168..95aaba31a993e 100644 --- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -1,4 +1,3 @@ -# TODO check if everything here is needed add_mlir_conversion_library(MLIRMathToXeVM MathToXeVM.cpp @@ -12,13 +11,9 @@ add_mlir_conversion_library(MLIRMathToXeVM Core LINK_LIBS PUBLIC - MLIRDialectUtils - MLIRFuncDialect - MLIRGPUToGPURuntimeTransforms MLIRMathDialect MLIRLLVMCommonConversion MLIRPass MLIRTransformUtils MLIRVectorDialect - MLIRVectorUtils ) diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir index ba5de228da411..e1d3b2615e121 100644 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -113,9 +113,6 @@ module @test_module { // Check all other math operations: - // native_divide(gentype x, gentype y) - // TODO: convert arith.divf to arith/native_divide if option is enabled - // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 %cos_afn_f16 = math.cos %c1_f16 fastmath : f16 From 20ac59571ab178cd669028a9dd0a1b8f23df598b Mon Sep 17 00:00:00 2001 From: Ian Li Date: Wed, 1 Oct 2025 12:25:59 -0700 Subject: [PATCH 10/13] Address reviewer comments --- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 59 ++++++++----------- .../Conversion/MathToXeVM/math-to-xevm.mlir | 42 ++++++------- 2 files changed, 44 insertions(+), 57 deletions(-) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index 46833735a79dd..0c1f9d39e72a2 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -50,21 +51,29 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { return failure(); arith::FastMathFlags fastFlags = op.getFastmath(); - if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn)) - return failure(); + if (!(static_cast(fastFlags) & static_cast(arith::FastMathFlags::afn))) + return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); SmallVector operandTypes; for (auto operand : adaptor.getOperands()) { // This pass only supports operations on vectors that are already in SPIRV // supported vector sizes: Distributing unsupported vector sizes to SPIRV - // supported vetor sizes are done in other blocking optimization passes. + // supported vector sizes are done in other blocking optimization passes. if (!isSPIRVCompatibleFloatOrVec(operand.getType())) - return failure(); + return rewriter.notifyMatchFailure(op, "no equivalent native operation for operand type"); operandTypes.push_back(operand.getType()); } - LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes); + + auto moduleOp = op->template getParentWithTrait(); + auto funcOpRes = + LLVM::lookupOrCreateFn(rewriter, moduleOp, getMangledNativeFuncName(operandTypes), operandTypes, op.getType()); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + auto callOp = rewriter.replaceOpWithNewOp( op, funcOp, adaptor.getOperands()); + // Preserve the fastmath flags in our MLIR op for later use: We need to + // convert our MLIR fastmath attrs into something compatible with llvm. arith::AttrConvertFastMathToLLVM fastAttrConverter(op); mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); @@ -90,49 +99,29 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { return false; } - LLVM::LLVMFuncOp - appendOrGetFuncOp(Op &op, const SmallVector &operandTypes) const { - // This function assumes op types have already been validated using - // isSPIRVCompatibleFloatOrVec. - using LLVM::LLVMFuncOp; - std::string mangledNativeFunc = + inline std::string getMangledNativeFuncName(const ArrayRef operandTypes) const { + std::string mangledFuncName = "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); - auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) { + auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { if (type.isF32()) - mangledNativeFunc += "f"; + mangledFuncName += "f"; else if (type.isF16()) - mangledNativeFunc += "Dh"; + mangledFuncName += "Dh"; else if (type.isF64()) - mangledNativeFunc += "d"; + mangledFuncName += "d"; }; for (auto type : operandTypes) { if (auto vecType = dyn_cast(type)) { - mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; appendFloatToMangledFunc(vecType.getElementType()); } else appendFloatToMangledFunc(type); } - auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc); - auto funcOp = - SymbolTable::lookupNearestSymbolFrom(op, funcAttr); - if (funcOp) - return funcOp; - - auto parentFunc = op->template getParentOfType(); - assert(parentFunc && "expected there to be a parent function"); - OpBuilder b(parentFunc); - - // Create a valid global location removing any metadata attached to the - // location, as debug info metadata inside of a function cannot be used - // outside of that function. - auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes); - auto globalloc = - op->getLoc()->template findInstanceOfOrUnknown(); - return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType); + return mangledFuncName; } const StringRef nativeFunc; @@ -176,13 +165,11 @@ struct ConvertMathToXeVMPass } // namespace void ConvertMathToXeVMPass::runOnOperation() { - auto m = getOperation(); - RewritePatternSet patterns(&getContext()); populateMathToXeVMConversionPatterns(patterns, convertArith); ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(m, target, std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir index e1d3b2615e121..04b5906489d00 100644 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -4,29 +4,29 @@ // RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH' module @test_module { - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 - // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 - // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 // - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32> - // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16> // - // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 - // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 - // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 - // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 - // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64 - // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16 - // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 - // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 - // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 - // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 - // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16 + // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 // CHECK-LABEL: func @math_ops func.func @math_ops() { From 7b8d0297f8c973d55c445b243295204497ef0d46 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Wed, 1 Oct 2025 12:27:58 -0700 Subject: [PATCH 11/13] clang-format --- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index 0c1f9d39e72a2..825a7bb79242a 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -51,7 +51,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { return failure(); arith::FastMathFlags fastFlags = op.getFastmath(); - if (!(static_cast(fastFlags) & static_cast(arith::FastMathFlags::afn))) + if (!(static_cast(fastFlags) & + static_cast(arith::FastMathFlags::afn))) return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); SmallVector operandTypes; @@ -60,13 +61,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { // supported vector sizes: Distributing unsupported vector sizes to SPIRV // supported vector sizes are done in other blocking optimization passes. if (!isSPIRVCompatibleFloatOrVec(operand.getType())) - return rewriter.notifyMatchFailure(op, "no equivalent native operation for operand type"); + return rewriter.notifyMatchFailure( + op, "no equivalent native operation for operand type"); operandTypes.push_back(operand.getType()); } auto moduleOp = op->template getParentWithTrait(); - auto funcOpRes = - LLVM::lookupOrCreateFn(rewriter, moduleOp, getMangledNativeFuncName(operandTypes), operandTypes, op.getType()); + auto funcOpRes = LLVM::lookupOrCreateFn( + rewriter, moduleOp, getMangledNativeFuncName(operandTypes), + operandTypes, op.getType()); assert(!failed(funcOpRes)); LLVM::LLVMFuncOp funcOp = funcOpRes.value(); @@ -99,8 +102,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { return false; } - - inline std::string getMangledNativeFuncName(const ArrayRef operandTypes) const { + inline std::string + getMangledNativeFuncName(const ArrayRef operandTypes) const { std::string mangledFuncName = "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); @@ -170,6 +173,7 @@ void ConvertMathToXeVMPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } From 17ad71c766c02550b934d6ca9fa817e7787713c8 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Wed, 1 Oct 2025 12:42:48 -0700 Subject: [PATCH 12/13] improve comment --- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index 825a7bb79242a..bbac0af19f0b4 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -75,8 +75,9 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { auto callOp = rewriter.replaceOpWithNewOp( op, funcOp, adaptor.getOperands()); - // Preserve the fastmath flags in our MLIR op for later use: We need to - // convert our MLIR fastmath attrs into something compatible with llvm. + // Preserve fastmath flags in our MLIR op when converting to llvm function + // calls, in order to allow further fastmath optimizations: We thus need to + // convert arith fastmath attrs into attrs recognized by llvm. arith::AttrConvertFastMathToLLVM fastAttrConverter(op); mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); From 203c1f08a1c2e54c3ae20ce3c18c6fe71e910239 Mon Sep 17 00:00:00 2001 From: Ian Li Date: Thu, 2 Oct 2025 08:54:38 -0700 Subject: [PATCH 13/13] Improve logging --- mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp index bbac0af19f0b4..156b9a38d07eb 100644 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" @@ -57,13 +58,14 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern { SmallVector operandTypes; for (auto operand : adaptor.getOperands()) { + Type opTy = operand.getType(); // This pass only supports operations on vectors that are already in SPIRV // supported vector sizes: Distributing unsupported vector sizes to SPIRV // supported vector sizes are done in other blocking optimization passes. - if (!isSPIRVCompatibleFloatOrVec(operand.getType())) + if (!isSPIRVCompatibleFloatOrVec(opTy)) return rewriter.notifyMatchFailure( - op, "no equivalent native operation for operand type"); - operandTypes.push_back(operand.getType()); + op, llvm::formatv("incompatible operand type: '{0}'", opTy)); + operandTypes.push_back(opTy); } auto moduleOp = op->template getParentWithTrait();