diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 26d8f1401c32eb..b6715dc9fcd7ad 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -52,16 +52,6 @@ def SPIRV_Dialect : Dialect { let hasRegionResultAttrVerify = 1; let extraClassDeclaration = [{ - //===------------------------------------------------------------------===// - // Type - //===------------------------------------------------------------------===// - - /// Checks if the given `type` is valid in SPIR-V dialect. - static bool isValidType(Type type); - - /// Checks if the given `scalar type` is valid in SPIR-V dialect. - static bool isValidScalarType(Type type); - //===------------------------------------------------------------------===// // Attribute //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 385e79a0445eb4..85b35f73f82c5e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -78,6 +78,8 @@ class SPIRVType : public Type { static bool classof(Type type); + bool isScalarOrVector(); + /// The extension requirements for each type are following the /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) /// convention. @@ -109,6 +111,11 @@ class ScalarType : public SPIRVType { static bool classof(Type type); + /// Returns true if the given integer type is valid for the SPIR-V dialect. + static bool isValid(FloatType); + /// Returns true if the given float type is valid for the SPIR-V dialect. + static bool isValid(IntegerType); + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, @@ -122,6 +129,9 @@ class CompositeType : public SPIRVType { static bool classof(Type type); + /// Returns true if the given vector type is valid for the SPIR-V dialect. + static bool isValid(VectorType); + unsigned getNumElements() const; Type getElementType(unsigned) const; diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp index 44930b91e0ffd4..d4ce17c93706d8 100644 --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -59,7 +59,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType, Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, VulkanLayoutUtils::Size &alignment) { - if (spirv::SPIRVDialect::isValidScalarType(type)) { + if (type.isa()) { alignment = VulkanLayoutUtils::getScalarTypeAlignment(type); // Vulkan spec does not specify any padding for a scalar type. size = alignment; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp index f378047f36eaca..953d95b449d153 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/Functional.h" @@ -358,15 +359,6 @@ struct ConvertSelectionOpToSelect rhs.getOperation()->getAttrList().getDictionary(); } - // Checks that given type is valid for `spv.SelectOp`. - // According to SPIR-V spec: - // "Before version 1.4, Result Type must be a pointer, scalar, or vector. - // Starting with version 1.4, Result Type can additionally be a composite type - // other than a vector." - bool isValidType(Type type) const { - return spirv::SPIRVDialect::isValidScalarType(type) || - type.isa(); - } // Returns a source value for the given block. Value getSrcValue(Block *block) const { @@ -401,11 +393,20 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection( return failure(); } + // Checks that given type is valid for `spv.SelectOp`. + // According to SPIR-V spec: + // "Before version 1.4, Result Type must be a pointer, scalar, or vector. + // Starting with version 1.4, Result Type can additionally be a composite type + // other than a vector." + bool isScalarOrVector = trueBrStoreOp.value() + .getType() + .cast() + .isScalarOrVector(); + // Check that each `spv.Store` uses the same pointer, memory access // attributes and a valid type of the value. if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || - !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || - !isValidType(trueBrStoreOp.value().getType())) { + !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) { return failure(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index f2868a34f07652..8ed417cad58d86 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -152,42 +152,6 @@ template <> Optional parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser); -static bool isValidSPIRVIntType(IntegerType type) { - return llvm::is_contained(ArrayRef({1, 8, 16, 32, 64}), - type.getWidth()); -} - -bool SPIRVDialect::isValidScalarType(Type type) { - if (type.isa()) { - return !type.isBF16(); - } - if (auto intType = type.dyn_cast()) { - return isValidSPIRVIntType(intType); - } - return false; -} - -static bool isValidSPIRVVectorType(VectorType type) { - return type.getRank() == 1 && - SPIRVDialect::isValidScalarType(type.getElementType()) && - type.getNumElements() >= 2 && type.getNumElements() <= 4; -} - -bool SPIRVDialect::isValidType(Type type) { - // Allow SPIR-V dialect types - if (type.getKind() >= Type::FIRST_SPIRV_TYPE && - type.getKind() <= TypeKind::LAST_SPIRV_TYPE) { - return true; - } - if (SPIRVDialect::isValidScalarType(type)) { - return true; - } - if (auto vectorType = type.dyn_cast()) { - return isValidSPIRVVectorType(vectorType); - } - return false; -} - static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser) { Type type; @@ -206,7 +170,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, return Type(); } } else if (auto t = type.dyn_cast()) { - if (!isValidSPIRVIntType(t)) { + if (!ScalarType::isValid(t)) { parser.emitError(typeLoc, "only 1/8/16/32/64-bit integer type allowed but found ") << type; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index e5b630b82fb198..5451048aabe005 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -99,7 +99,7 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(Type t) { - if (spirv::SPIRVDialect::isValidScalarType(t)) { + if (t.isa()) { auto bitWidth = t.getIntOrFloatBitWidth(); // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with boolean @@ -163,7 +163,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) : targetEnv(targetAttr) { addConversion([](Type type) -> Optional { // If the type is already valid in SPIR-V, directly return. - return spirv::SPIRVDialect::isValidType(type) ? type : Optional(); + return type.isa() ? type : Optional(); }); addConversion([](IndexType indexType) { return SPIRVTypeConverter::getIndexType(indexType.getContext()); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 377242482b2a7d..f6b862156c49ef 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1373,7 +1373,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) { bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. - if (!SPIRVDialect::isValidType(type)) + if (!type.isa()) return false; if (type.getKind() >= Type::FIRST_SPIRV_TYPE && @@ -2460,7 +2460,7 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) { case StandardAttributes::Integer: case StandardAttributes::Float: { // Make sure bitwidth is allowed. - if (!spirv::SPIRVDialect::isValidType(value.getType())) + if (!value.getType().isa()) return constOp.emitOpError("default value bitwidth disallowed"); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 92dc5b82bb8af8..3f963bd1d8a87b 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -163,13 +163,19 @@ bool CompositeType::classof(Type type) { case TypeKind::Array: case TypeKind::RuntimeArray: case TypeKind::Struct: - case StandardTypes::Vector: return true; + case StandardTypes::Vector: + return isValid(type.cast()); default: return false; } } +bool CompositeType::isValid(VectorType type) { + return type.getRank() == 1 && type.getElementType().isa() && + type.getNumElements() >= 2 && type.getNumElements() <= 4; +} + Type CompositeType::getElementType(unsigned index) const { switch (getKind()) { case spirv::TypeKind::Array: @@ -560,7 +566,30 @@ void RuntimeArrayType::getCapabilities( // ScalarType //===----------------------------------------------------------------------===// -bool ScalarType::classof(Type type) { return type.isIntOrFloat(); } +bool ScalarType::classof(Type type) { + if (auto floatType = type.dyn_cast()) { + return isValid(floatType); + } + if (auto intType = type.dyn_cast()) { + return isValid(intType); + } + return false; +} + +bool ScalarType::isValid(FloatType type) { return !type.isBF16(); } + +bool ScalarType::isValid(IntegerType type) { + switch (type.getWidth()) { + case 1: + case 8: + case 16: + case 32: + case 64: + return true; + default: + return false; + } +} void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { @@ -678,9 +707,19 @@ void ScalarType::getCapabilities( //===----------------------------------------------------------------------===// bool SPIRVType::classof(Type type) { - return type.isa() || type.isa() || - (type.getKind() >= Type::FIRST_SPIRV_TYPE && - type.getKind() <= TypeKind::LAST_SPIRV_TYPE); + // Allow SPIR-V dialect types + if (type.getKind() >= Type::FIRST_SPIRV_TYPE && + type.getKind() <= TypeKind::LAST_SPIRV_TYPE) + return true; + if (type.isa()) + return true; + if (auto vectorType = type.dyn_cast()) + return CompositeType::isValid(vectorType); + return false; +} + +bool SPIRVType::isScalarOrVector() { + return isIntOrFloat() || isa(); } void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 516b9eca854403..1ca9cad977af08 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -21,12 +21,6 @@ using namespace mlir; -/// Checks if the `type` is a scalar or vector type. It is assumed that they are -/// valid for SPIR-V dialect already. -static bool isScalarOrVectorType(Type type) { - return spirv::SPIRVDialect::isValidScalarType(type) || type.isa(); -} - /// Creates a global variable for an argument based on the ABI info. static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, @@ -45,7 +39,7 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, // info create a variable of type !spv.ptr>. If not // it must already be a !spv.ptr>. auto varType = funcOp.getType().getInput(argIndex); - if (isScalarOrVectorType(varType)) { + if (varType.cast().isScalarOrVector()) { auto storageClass = static_cast(abiInfo.storage_class().getInt()); varType = @@ -198,7 +192,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( // at the start of the function. It is probably better to do the load just // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. - if (isScalarOrVectorType(argType.value())) { + if (argType.value().cast().isScalarOrVector()) { auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext()); auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);