diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 6ad631f1eec7a..fa8788de6c3d2 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -170,11 +170,6 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); } -// TODO: Move to some common place? -static std::string getDecorationString(spirv::Decoration decor) { - return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor)); -} - namespace { /// Converts elementwise unary, binary and ternary arith operations to SPIR-V diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h index 7425f4b5311ce..2ea54baaf8953 100644 --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -9,14 +9,27 @@ #ifndef MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H #define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" +#include namespace mlir { namespace spirv { +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +/// Converts a SPIR-V Decoration enum value to its snake_case string +/// representation for use in MLIR attributes. +inline std::string getDecorationString(spirv::Decoration decor) { + return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor)); +} + /// Converts elementwise unary, binary and ternary standard operations to SPIR-V /// operations. template diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index 399ccf3925f3a..2491c7cbd3d22 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "../SPIRVCommon/Pattern.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" @@ -45,14 +46,12 @@ static constexpr const char kSPIRVModule[] = "__spv__"; /// Returns the string name of the `DescriptorSet` decoration. static std::string descriptorSetName() { - return llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); + return spirv::getDecorationString(spirv::Decoration::DescriptorSet); } /// Returns the string name of the `Binding` decoration. static std::string bindingName() { - return llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::Binding)); + return spirv::getDecorationString(spirv::Decoration::Binding); } /// Calculates the index of the kernel's operand that is represented by the