Skip to content

Conversation

ianayl
Copy link
Contributor

@ianayl ianayl commented Sep 19, 2025

This PR introduces a MathToXeVM pass, which implements support for the afn fastmath flag for SPIRV/XeVM targets - It takes supported Math Ops with the afn flag, and converts them to function calls to OpenCL native_ intrinsics.

These intrinsic functions are supported by the SPIRV backend, and are automatically converted to OpExtInst calls to native_ ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM.

Note:

  • This pass also supports converting arith.divf to native equivalents. There is an option provided in the pass to turn this behavior off.
  • This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.

Copy link

github-actions bot commented Sep 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@mshahneo
Copy link
Contributor

@silee2 , @Jianhui-Li , @charithaintc

@ianayl
Copy link
Contributor Author

ianayl commented Sep 22, 2025

Another thing I'm realizing: I'm guessing I should probably make these calls such that I am preserving existing attributes, especially the existing fastmath flags to begin with? Opting to preserve fastmath flags in order to preserve more detail: Compilers that ingest SPIRV can use the fastmath flags for further optimization.

@ianayl ianayl marked this pull request as ready for review September 25, 2025 22:30
@llvmbot llvmbot added the mlir label Sep 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2025

@llvm/pr-subscribers-mlir

Author: Ian Li (ianayl)

Changes

This PR introduces a MathToXeVM pass, which implements support for the afn fastmath flag for SPIRV/XeVM targets - It takes supported Math Ops with the afn flag, and converts them to function calls to OpenCL native_ intrinsics.

These intrinsic functions are supported by the SPIRV backend, and are automatically converted to OpExtInst calls to native_ ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM.

Note:

  • This pass also supports converting arith.divf to native equivalents. There is an option provided in the pass to turn this behavior off.
  • This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.

Patch is 29.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159878.diff

8 Files Affected:

  • (added) mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h (+27)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+26)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/MathToXeVM/CMakeLists.txt (+24)
  • (added) mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp (+188)
  • (added) mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir (+158)
  • (added) mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir (+118)
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000000000..91d3c92fd6296
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,27 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+                                          bool convertArith);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..40d866ec7bf10 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..5817babf68ddb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -796,6 +796,32 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+  let summary =
+      "Convert (fast) math operations to native XeVM/SPIRV equivalents";
+  let description = [{
+    This pass converts supported math ops marked with the `afn` fastmath flag
+    to function calls for OpenCL `native_` math intrinsics: These intrinsics
+    are typically mapped directly to native device instructions, often resulting
+    in better performance. However, the precision/error of these intrinsics
+    are implementation-defined, and thus math ops are only converted when they
+    have the `afn` fastmath flag enabled.
+  }];
+  let options = [Option<
+      "convertArith", "convert-arith", "bool", /*default=*/"true",
+      "Convert supported Arith ops (e.g. arith.divf) as well.">];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "func::FuncDialect",
+    "xevm::XeVMDialect",
+    "vector::VectorDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToEmitC
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..bebf1b8fff3f9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..711c6876bb168
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+# TODO check if everything here is needed
+add_mlir_conversion_library(MLIRMathToXeVM
+  MathToXeVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000000000..46833735a79dd
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,188 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle
+// native functions/intrinsics that take vector operands.
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+                           PatternBenefit benefit = 1)
+      : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+      return failure();
+
+    arith::FastMathFlags fastFlags = op.getFastmath();
+    if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
+      return failure();
+
+    SmallVector<Type, 1> operandTypes;
+    for (auto operand : adaptor.getOperands()) {
+      // This pass only supports operations on vectors that are already in SPIRV
+      // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+      // supported vetor sizes are done in other blocking optimization passes.
+      if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+        return failure();
+      operandTypes.push_back(operand.getType());
+    }
+    LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+    auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+        op, funcOp, adaptor.getOperands());
+    arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+    mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+    callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+    return success();
+  }
+
+  inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+    if (type.isFloat()) {
+      return true;
+    } else if (auto vecType = dyn_cast<VectorType>(type)) {
+      if (!vecType.getElementType().isFloat())
+        return false;
+      // SPIRV distinguishes between vectors and matrices: OpenCL native math
+      // intrsinics are not compatible with matrices.
+      ArrayRef<int64_t> shape = vecType.getShape();
+      if (shape.size() != 1)
+        return false;
+      // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+      if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+          shape[0] == 16)
+        return true;
+    }
+    return false;
+  }
+
+  LLVM::LLVMFuncOp
+  appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+    // This function assumes op types have already been validated using
+    // isSPIRVCompatibleFloatOrVec.
+    using LLVM::LLVMFuncOp;
+
+    std::string mangledNativeFunc =
+        "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+    auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+      if (type.isF32())
+        mangledNativeFunc += "f";
+      else if (type.isF16())
+        mangledNativeFunc += "Dh";
+      else if (type.isF64())
+        mangledNativeFunc += "d";
+    };
+
+    for (auto type : operandTypes) {
+      if (auto vecType = dyn_cast<VectorType>(type)) {
+        mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+        appendFloatToMangledFunc(vecType.getElementType());
+      } else
+        appendFloatToMangledFunc(type);
+    }
+
+    auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
+    auto funcOp =
+        SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+    if (funcOp)
+      return funcOp;
+
+    auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
+    assert(parentFunc && "expected there to be a parent function");
+    OpBuilder b(parentFunc);
+
+    // Create a valid global location removing any metadata attached to the
+    // location, as debug info metadata inside of a function cannot be used
+    // outside of that function.
+    auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
+    auto globalloc =
+        op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+    return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+  }
+
+  const StringRef nativeFunc;
+};
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+                                                bool convertArith) {
+  patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_exp");
+  patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_cos");
+  patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+      patterns.getContext(), "__spirv_ocl_native_exp2");
+  patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_log");
+  patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+      patterns.getContext(), "__spirv_ocl_native_log2");
+  patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+      patterns.getContext(), "__spirv_ocl_native_log10");
+  patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+      patterns.getContext(), "__spirv_ocl_native_powr");
+  patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+      patterns.getContext(), "__spirv_ocl_native_rsqrt");
+  patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_sin");
+  patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+      patterns.getContext(), "__spirv_ocl_native_sqrt");
+  patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_tan");
+  if (convertArith)
+    patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+        patterns.getContext(), "__spirv_ocl_native_divide");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+    : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+  using Base::Base;
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+  auto m = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateMathToXeVMConversionPatterns(patterns, convertArith);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
new file mode 100644
index 0000000000000..ba5de228da411
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -0,0 +1,158 @@
+// RUN: mlir-opt %s -convert-math-to-xevm \
+// RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' 
+// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \
+// RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
+
+module @test_module {
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+  //
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+  //
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+  // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+  // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+  // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
+
+  // CHECK-LABEL: func @math_ops
+  func.func @math_ops() {
+
+    %c1_f16 = arith.constant 1. : f16
+    %c1_f32 = arith.constant 1. : f32
+    %c1_f64 = arith.constant 1. : f64
+
+    // CHECK: math.exp
+    %exp_normal_f16 = math.exp %c1_f16 : f16
+    // CHECK: math.exp
+    %exp_normal_f32 = math.exp %c1_f32 : f32
+    // CHECK: math.exp
+    %exp_normal_f64 = math.exp %c1_f64 : f64
+
+    // Check float operations are converted properly:
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+    %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+    %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+    %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
+    
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+    %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+    %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+    %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
+
+    // CHECK: math.exp
+    %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+    // CHECK: math.exp
+    %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+    // CHECK: math.exp
+    %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+    // Check vector operations:
+
+    %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
+    %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
+    %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64>
+    %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
+    %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64>
+    %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64>
+    %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64>
+    %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+    %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64>
+    %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+
+    %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
+    %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32>
+    %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16>
+    %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+    // Check unsupported vector sizes are not converted:
+
+    %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
+    %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
+
+    // CHECK: math.exp
+    %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+    // CHECK: math.exp
+    %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+
+    // Check fastmath flags propagate properly:
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+    %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32
+    %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32
+    %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32
+
+    // Check all other math operations:
+
+    // native_divide(gentype x, gentype y)
+    // TODO: convert arith.divf to arith/native_divide if option is enabled
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+    %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
+
+    // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+    %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32
+
+    // CHECK: l...
[truncated]

@mshahneo
Copy link
Contributor

@akroviakov

Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a few comments

// supported vector sizes: Distributing unsupported vector sizes to SPIRV
// supported vetor sizes are done in other blocking optimization passes.
if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above for notifyMatchFailure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, let me know if the failure reasons are detailed enough

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to print the actual type in the failure reason to accelerate debugging, something like llvm::formatv("incompatible operand type: '{0}'", srcType)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that was a thing, done!

@mshahneo mshahneo changed the title [MLIR][SPIRV][XeVM] Add support for fastmath afn option using native OpenCL intrinsics [MLIR][SPIRV][XeVM] Add MathToXeVM (math-to-xevm) pass Oct 1, 2025
@mshahneo mshahneo requested a review from jpienaar October 1, 2025 19:57
@ianayl
Copy link
Contributor Author

ianayl commented Oct 2, 2025

CI failure seems to be an infrastructure issue, I don't seem to have permissions to retrigger the CI. If needed I'll make a dummy commit to restart the CI.

  | # Removing /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci
  | 🚨 Error: Failed to remove "/scratch/powerllvm/cpap8006/llvm-project/libcxx-ci" (unlinkat /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci/.ci/all_requirements.txt: permission denied)
  | # Waiting 10 seconds
  | # Removing /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci
  | 🚨 Error: Failed to remove "/scratch/powerllvm/cpap8006/llvm-project/libcxx-ci" (unlinkat /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci/.ci/all_requirements.txt: permission denied)
  | # Waiting 10 seconds
  | # Removing /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci
  | 🚨 Error: Failed to remove "/scratch/powerllvm/cpap8006/llvm-project/libcxx-ci" (unlinkat /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci/.ci/all_requirements.txt: permission denied)
  | # Waiting 10 seconds
  | # Removing /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci
  | 🚨 Error: Failed to remove "/scratch/powerllvm/cpap8006/llvm-project/libcxx-ci" (unlinkat /scratch/powerllvm/cpap8006/llvm-project/libcxx-ci/.ci/all_requirements.txt: permission denied)
  | # Waiting 10 seconds
  | 🚨 Error: exit status 128

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants