diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h new file mode 100644 index 0000000000000..91d3c92fd6296 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h @@ -0,0 +1,27 @@ +//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ +#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Math to XeVM calls. +void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index da061b269daf7..40d866ec7bf10 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -49,6 +49,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 1a37d057776e2..5817babf68ddb 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -796,6 +796,32 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// MathToXeVM +//===----------------------------------------------------------------------===// + +def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { + let summary = + "Convert (fast) math operations to native XeVM/SPIRV equivalents"; + let description = [{ + This pass converts supported math ops marked with the `afn` fastmath flag + to function calls for OpenCL `native_` math intrinsics: These intrinsics + are typically mapped directly to native device instructions, often resulting + in better performance. However, the precision/error of these intrinsics + are implementation-defined, and thus math ops are only converted when they + have the `afn` fastmath flag enabled. + }]; + let options = [Option< + "convertArith", "convert-arith", "bool", /*default=*/"true", + "Convert supported Arith ops (e.g. arith.divf) as well.">]; + let dependentDialects = [ + "arith::ArithDialect", + "func::FuncDialect", + "xevm::XeVMDialect", + "vector::VectorDialect", + ]; +} + //===----------------------------------------------------------------------===// // MathToEmitC //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f83c4870..bebf1b8fff3f9 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) +add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt new file mode 100644 index 0000000000000..95aaba31a993e --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRMathToXeVM + MathToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRMathDialect + MLIRLLVMCommonConversion + MLIRPass + MLIRTransformUtils + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp new file mode 100644 index 0000000000000..156b9a38d07eb --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -0,0 +1,182 @@ +//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + +#include "../GPUCommon/GPUOpsLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-xevm" + +// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle +// native functions/intrinsics that take vector operands. + +/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. +template +struct ConvertNativeFuncPattern final : public OpConversionPattern { + + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), nativeFunc(nativeFunc) {} + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isSPIRVCompatibleFloatOrVec(op.getType())) + return failure(); + + arith::FastMathFlags fastFlags = op.getFastmath(); + if (!(static_cast(fastFlags) & + static_cast(arith::FastMathFlags::afn))) + return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); + + SmallVector operandTypes; + for (auto operand : adaptor.getOperands()) { + Type opTy = operand.getType(); + // This pass only supports operations on vectors that are already in SPIRV + // supported vector sizes: Distributing unsupported vector sizes to SPIRV + // supported vector sizes are done in other blocking optimization passes. + if (!isSPIRVCompatibleFloatOrVec(opTy)) + return rewriter.notifyMatchFailure( + op, llvm::formatv("incompatible operand type: '{0}'", opTy)); + operandTypes.push_back(opTy); + } + + auto moduleOp = op->template getParentWithTrait(); + auto funcOpRes = LLVM::lookupOrCreateFn( + rewriter, moduleOp, getMangledNativeFuncName(operandTypes), + operandTypes, op.getType()); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + + auto callOp = rewriter.replaceOpWithNewOp( + op, funcOp, adaptor.getOperands()); + // Preserve fastmath flags in our MLIR op when converting to llvm function + // calls, in order to allow further fastmath optimizations: We thus need to + // convert arith fastmath attrs into attrs recognized by llvm. + arith::AttrConvertFastMathToLLVM fastAttrConverter(op); + mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; + callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); + return success(); + } + + inline bool isSPIRVCompatibleFloatOrVec(Type type) const { + if (type.isFloat()) { + return true; + } else if (auto vecType = dyn_cast(type)) { + if (!vecType.getElementType().isFloat()) + return false; + // SPIRV distinguishes between vectors and matrices: OpenCL native math + // intrsinics are not compatible with matrices. + ArrayRef shape = vecType.getShape(); + if (shape.size() != 1) + return false; + // SPIRV only allows vectors of size 2, 3, 4, 8, 16. + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || + shape[0] == 16) + return true; + } + return false; + } + + inline std::string + getMangledNativeFuncName(const ArrayRef operandTypes) const { + std::string mangledFuncName = + "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); + + auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { + if (type.isF32()) + mangledFuncName += "f"; + else if (type.isF16()) + mangledFuncName += "Dh"; + else if (type.isF64()) + mangledFuncName += "d"; + }; + + for (auto type : operandTypes) { + if (auto vecType = dyn_cast(type)) { + mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + appendFloatToMangledFunc(vecType.getElementType()); + } else + appendFloatToMangledFunc(type); + } + + return mangledFuncName; + } + + const StringRef nativeFunc; +}; + +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith) { + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_exp"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_cos"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_exp2"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_log"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_log2"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_log10"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_powr"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_rsqrt"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_sin"); + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_sqrt"); + patterns.add>(patterns.getContext(), + "__spirv_ocl_native_tan"); + if (convertArith) + patterns.add>( + patterns.getContext(), "__spirv_ocl_native_divide"); +} + +namespace { +struct ConvertMathToXeVMPass + : public impl::ConvertMathToXeVMBase { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToXeVMPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateMathToXeVMConversionPatterns(patterns, convertArith); + ConversionTarget target(getContext()); + target.addLegalDialect(); + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir new file mode 100644 index 0000000000000..04b5906489d00 --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -0,0 +1,155 @@ +// RUN: mlir-opt %s -convert-math-to-xevm \ +// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' +// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \ +// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH' + +module @test_module { + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + // + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16> + // + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16 + // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 + + // CHECK-LABEL: func @math_ops + func.func @math_ops() { + + %c1_f16 = arith.constant 1. : f16 + %c1_f32 = arith.constant 1. : f32 + %c1_f64 = arith.constant 1. : f64 + + // CHECK: math.exp + %exp_normal_f16 = math.exp %c1_f16 : f16 + // CHECK: math.exp + %exp_normal_f32 = math.exp %c1_f32 : f32 + // CHECK: math.exp + %exp_normal_f64 = math.exp %c1_f64 : f64 + + // Check float operations are converted properly: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %exp_fast_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_fast_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %exp_fast_f64 = math.exp %c1_f64 fastmath : f64 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %exp_afn_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_afn_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %exp_afn_f64 = math.exp %c1_f64 fastmath : f64 + + // CHECK: math.exp + %exp_none_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: math.exp + %exp_none_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: math.exp + %exp_none_f64 = math.exp %c1_f64 fastmath : f64 + + // Check vector operations: + + %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64> + %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64> + %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64> + %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64> + %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64> + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<2xf64>) -> vector<2xf64> + %exp_v2_f64 = math.exp %v2_c1_f64 fastmath : vector<2xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<3xf64>) -> vector<3xf64> + %exp_v3_f64 = math.exp %v3_c1_f64 fastmath : vector<3xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<4xf64>) -> vector<4xf64> + %exp_v4_f64 = math.exp %v4_c1_f64 fastmath : vector<4xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<8xf64>) -> vector<8xf64> + %exp_v8_f64 = math.exp %v8_c1_f64 fastmath : vector<8xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<16xf64>) -> vector<16xf64> + %exp_v16_f64 = math.exp %v16_c1_f64 fastmath : vector<16xf64> + + %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32> + %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16> + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<16xf32>) -> vector<16xf32> + %exp_v16_f32 = math.exp %v16_c1_f32 fastmath : vector<16xf32> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<4xf16>) -> vector<4xf16> + %exp_v4_f16 = math.exp %v4_c1_f16 fastmath : vector<4xf16> + + // Check unsupported vector sizes are not converted: + + %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64> + %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64> + + // CHECK: math.exp + %exp_v5_f64 = math.exp %v5_c1_f64 fastmath : vector<5xf64> + // CHECK: math.exp + %exp_v32_f64 = math.exp %v32_c1_f64 fastmath : vector<32xf64> + + // Check fastmath flags propagate properly: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath : f32 + + // Check all other math operations: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %cos_afn_f16 = math.cos %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %exp2_afn_f32 = math.exp2 %c1_f32 fastmath : f32 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %log_afn_f16 = math.log %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %log2_afn_f32 = math.log2 %c1_f32 fastmath : f32 + + // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %log10_afn_f64 = math.log10 %c1_f64 fastmath : f64 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16, f16) -> f16 + %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath : f64 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + %sin_afn_f16 = math.sin %c1_f16 fastmath : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath : f32 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + %tan_afn_f64 = math.tan %c1_f64 fastmath : f64 + + %c6_9_f32 = arith.constant 6.9 : f32 + %c7_f32 = arith.constant 7. : f32 + + // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + // CHECK-NO-ARITH: arith.divf + %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath : f32 + + return + } +} \ No newline at end of file diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir new file mode 100644 index 0000000000000..2492adafd6a50 --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \ +// RUN: -debug-only=serialize-to-isa 2> %t +// RUN: FileCheck --input-file=%t %s +// +// MathToXeVM pass generates OpenCL intrinsics function calls when converting +// Math ops with `fastmath` attr to native function calls. It is assumed that +// the SPIRV backend would correctly convert these intrinsics calls to OpenCL +// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp). +// +// To ensure this assumption holds, this test verifies that the SPIRV backend +// behaves as expected. + +module @test_ocl_intrinsics attributes {gpu.container_module} { + gpu.module @kernel [#xevm.target] { + llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} { + // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16 + // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]] + %c0_f16 = llvm.mlir.constant(0. : f16) : f16 + // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32 + // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]] + %c0_f32 = llvm.mlir.constant(0. : f32) : f32 + // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64 + // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]] + %c0_f64 = llvm.mlir.constant(0. : f64) : f64 + + // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2 + // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]] + %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64> + // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3 + // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]] + %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32> + // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4 + // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]] + %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64> + // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8 + // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]] + %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64> + // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16 + // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]] + %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16> + + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] + %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]] + %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32 + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] + %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64 + + // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]] + %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64> + // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]] + %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32> + // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]] + %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64> + // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]] + %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64> + // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] + %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16> + + // SPIRV backend does not currently handle fastmath flags: The SPIRV + // backend would need to generate OpDecorate calls to decorate math ops + // with FPFastMathMode/FPFastMathModeINTEL decorations. + // + // FIXME: When support for fastmath flags in the SPIRV backend is added, + // add tests here to ensure fastmath flags are converted to the correct + // OpDecorate calls. + // + // See: + // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions + // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate + + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]] + %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]] + %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]] + %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]] + %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]] + %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath} : (vector<8xf64>) -> vector<8xf64> + // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]] + %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16> + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]] + %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]] + %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]] + %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]] + %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]] + %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + + llvm.return + } + + llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64> + llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32> + llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64> + llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64> + llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16> + llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64> + llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16> + llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 + } +}