-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][SPIRV][XeVM] Add MathToXeVM (math-to-xevm
) pass
#159878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ianayl
wants to merge
13
commits into
llvm:main
Choose a base branch
from
ianayl:spirv-fastmath-nativeops
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
8af8375
Initial mockup of the MathToXeVM pass
ianayl 59160d6
clang-format
ianayl 89f94ea
fix cmake
ianayl bf4ed92
update tests
ianayl d7ae5f2
Add support for all other native opts
ianayl 9d3a540
finish testing for native-spirv-builtins
ianayl 0232c26
Accomodate for arith.divf
ianayl 3887fe5
clang-format
ianayl 31c911e
remove todos
ianayl 20ac595
Address reviewer comments
ianayl 7b8d029
clang-format
ianayl 17ad71c
improve comment
ianayl 203c1f0
Improve logging
ianayl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ | ||
#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ | ||
|
||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Dialect/LLVMIR/XeVMDialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include <memory> | ||
|
||
namespace mlir { | ||
class Pass; | ||
|
||
#define GEN_PASS_DECL_CONVERTMATHTOXEVM | ||
#include "mlir/Conversion/Passes.h.inc" | ||
|
||
/// Populate the given list with patterns that convert from Math to XeVM calls. | ||
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, | ||
bool convertArith); | ||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
add_mlir_conversion_library(MLIRMathToXeVM | ||
MathToXeVM.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRMathDialect | ||
MLIRLLVMCommonConversion | ||
MLIRPass | ||
MLIRTransformUtils | ||
MLIRVectorDialect | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" | ||
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" | ||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" | ||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
|
||
#include "../GPUCommon/GPUOpsLowering.h" | ||
#include "../GPUCommon/OpToFuncCallLowering.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_CONVERTMATHTOXEVM | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
|
||
#define DEBUG_TYPE "math-to-xevm" | ||
|
||
// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle | ||
// native functions/intrinsics that take vector operands. | ||
|
||
/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. | ||
template <typename Op> | ||
struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> { | ||
|
||
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, | ||
PatternBenefit benefit = 1) | ||
: OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(Op op, typename Op::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
if (!isSPIRVCompatibleFloatOrVec(op.getType())) | ||
return failure(); | ||
|
||
arith::FastMathFlags fastFlags = op.getFastmath(); | ||
if (!(static_cast<uint32_t>(fastFlags) & | ||
static_cast<uint32_t>(arith::FastMathFlags::afn))) | ||
return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); | ||
|
||
SmallVector<Type, 1> operandTypes; | ||
for (auto operand : adaptor.getOperands()) { | ||
Type opTy = operand.getType(); | ||
// This pass only supports operations on vectors that are already in SPIRV | ||
// supported vector sizes: Distributing unsupported vector sizes to SPIRV | ||
// supported vector sizes are done in other blocking optimization passes. | ||
if (!isSPIRVCompatibleFloatOrVec(opTy)) | ||
return rewriter.notifyMatchFailure( | ||
op, llvm::formatv("incompatible operand type: '{0}'", opTy)); | ||
operandTypes.push_back(opTy); | ||
} | ||
|
||
auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>(); | ||
auto funcOpRes = LLVM::lookupOrCreateFn( | ||
rewriter, moduleOp, getMangledNativeFuncName(operandTypes), | ||
operandTypes, op.getType()); | ||
assert(!failed(funcOpRes)); | ||
LLVM::LLVMFuncOp funcOp = funcOpRes.value(); | ||
|
||
auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( | ||
op, funcOp, adaptor.getOperands()); | ||
// Preserve fastmath flags in our MLIR op when converting to llvm function | ||
// calls, in order to allow further fastmath optimizations: We thus need to | ||
// convert arith fastmath attrs into attrs recognized by llvm. | ||
arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op); | ||
ianayl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; | ||
callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); | ||
return success(); | ||
} | ||
|
||
inline bool isSPIRVCompatibleFloatOrVec(Type type) const { | ||
if (type.isFloat()) { | ||
return true; | ||
} else if (auto vecType = dyn_cast<VectorType>(type)) { | ||
if (!vecType.getElementType().isFloat()) | ||
return false; | ||
// SPIRV distinguishes between vectors and matrices: OpenCL native math | ||
// intrsinics are not compatible with matrices. | ||
ArrayRef<int64_t> shape = vecType.getShape(); | ||
if (shape.size() != 1) | ||
return false; | ||
// SPIRV only allows vectors of size 2, 3, 4, 8, 16. | ||
if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || | ||
shape[0] == 16) | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
inline std::string | ||
getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const { | ||
std::string mangledFuncName = | ||
"_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); | ||
|
||
auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { | ||
if (type.isF32()) | ||
mangledFuncName += "f"; | ||
else if (type.isF16()) | ||
mangledFuncName += "Dh"; | ||
else if (type.isF64()) | ||
mangledFuncName += "d"; | ||
}; | ||
|
||
for (auto type : operandTypes) { | ||
if (auto vecType = dyn_cast<VectorType>(type)) { | ||
mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; | ||
appendFloatToMangledFunc(vecType.getElementType()); | ||
} else | ||
appendFloatToMangledFunc(type); | ||
} | ||
|
||
return mangledFuncName; | ||
} | ||
|
||
const StringRef nativeFunc; | ||
}; | ||
|
||
void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, | ||
bool convertArith) { | ||
patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(), | ||
"__spirv_ocl_native_exp"); | ||
patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(), | ||
"__spirv_ocl_native_cos"); | ||
patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>( | ||
patterns.getContext(), "__spirv_ocl_native_exp2"); | ||
patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(), | ||
"__spirv_ocl_native_log"); | ||
patterns.add<ConvertNativeFuncPattern<math::Log2Op>>( | ||
patterns.getContext(), "__spirv_ocl_native_log2"); | ||
patterns.add<ConvertNativeFuncPattern<math::Log10Op>>( | ||
patterns.getContext(), "__spirv_ocl_native_log10"); | ||
patterns.add<ConvertNativeFuncPattern<math::PowFOp>>( | ||
patterns.getContext(), "__spirv_ocl_native_powr"); | ||
patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>( | ||
patterns.getContext(), "__spirv_ocl_native_rsqrt"); | ||
patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(), | ||
"__spirv_ocl_native_sin"); | ||
patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>( | ||
patterns.getContext(), "__spirv_ocl_native_sqrt"); | ||
patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(), | ||
"__spirv_ocl_native_tan"); | ||
if (convertArith) | ||
patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>( | ||
patterns.getContext(), "__spirv_ocl_native_divide"); | ||
} | ||
|
||
namespace { | ||
struct ConvertMathToXeVMPass | ||
: public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> { | ||
using Base::Base; | ||
void runOnOperation() override; | ||
}; | ||
} // namespace | ||
|
||
void ConvertMathToXeVMPass::runOnOperation() { | ||
RewritePatternSet patterns(&getContext()); | ||
populateMathToXeVMConversionPatterns(patterns, convertArith); | ||
ConversionTarget target(getContext()); | ||
target.addLegalDialect<BuiltinDialect, func::FuncDialect, | ||
vector::VectorDialect, LLVM::LLVMDialect>(); | ||
if (failed( | ||
applyPartialConversion(getOperation(), target, std::move(patterns)))) | ||
signalPassFailure(); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.