diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 475e3f495e065..e46b576810316 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -88,8 +88,6 @@ class ScalarType : public SPIRVType { static bool isValid(FloatType); /// Returns true if the given float type is valid for the SPIR-V dialect. static bool isValid(IntegerType); - - std::optional getSizeInBytes(); }; // SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V @@ -112,8 +110,6 @@ class CompositeType : public SPIRVType { /// Return true if the number of elements is known at compile time and is not /// implementation dependent. bool hasCompileTimeKnownNumElements() const; - - std::optional getSizeInBytes(); }; // SPIR-V array type @@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase getSizeInBytes(); }; // SPIR-V image type diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 7c2f43bea9ddb..4733bfca8be21 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/ErrorHandling.h" #include +#include using namespace mlir; using namespace mlir::spirv; @@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; } unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } -std::optional ArrayType::getSizeInBytes() { - auto elementType = llvm::cast(getElementType()); - std::optional size = elementType.getSizeInBytes(); - if (!size) - return std::nullopt; - return (*size + getArrayStride()) * getNumElements(); -} - //===----------------------------------------------------------------------===// // CompositeType //===----------------------------------------------------------------------===// @@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) { } } -std::optional CompositeType::getSizeInBytes() { - if (auto arrayType = llvm::dyn_cast(*this)) - return arrayType.getSizeInBytes(); - if (auto structType = llvm::dyn_cast(*this)) - return structType.getSizeInBytes(); - if (auto vectorType = llvm::dyn_cast(*this)) { - std::optional elementSize = - llvm::cast(vectorType.getElementType()).getSizeInBytes(); - if (!elementSize) - return std::nullopt; - return *elementSize * vectorType.getNumElements(); - } - if (auto tensorArmType = llvm::dyn_cast(*this)) { - std::optional elementSize = - llvm::cast(tensorArmType.getElementType()).getSizeInBytes(); - if (!elementSize) - return std::nullopt; - return *elementSize * tensorArmType.getNumElements(); - } - return std::nullopt; -} - //===----------------------------------------------------------------------===// // CooperativeMatrixType //===----------------------------------------------------------------------===// @@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) { #undef WIDTH_CASE } -std::optional ScalarType::getSizeInBytes() { - auto bitWidth = getIntOrFloatBitWidth(); - // According to the SPIR-V spec: - // "There is no physical size or bit pattern defined for values with boolean - // type. If they are stored (in conjunction with OpVariable), they can only - // be used with logical addressing operations, not physical, and only with - // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, - // Private, Function, Input, and Output." - if (bitWidth == 1) - return std::nullopt; - return bitWidth / 8; -} - //===----------------------------------------------------------------------===// // SPIRVType //===----------------------------------------------------------------------===// @@ -760,11 +718,35 @@ void SPIRVType::getCapabilities( } std::optional SPIRVType::getSizeInBytes() { - if (auto scalarType = llvm::dyn_cast(*this)) - return scalarType.getSizeInBytes(); - if (auto compositeType = llvm::dyn_cast(*this)) - return compositeType.getSizeInBytes(); - return std::nullopt; + return TypeSwitch>(*this) + .Case([](ScalarType type) -> std::optional { + // According to the SPIR-V spec: + // "There is no physical size or bit pattern defined for values with + // boolean type. If they are stored (in conjunction with OpVariable), + // they can only be used with logical addressing operations, not + // physical, and only with non-externally visible shader Storage + // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and + // Output." + int64_t bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 1) + return std::nullopt; + return bitWidth / 8; + }) + .Case([](ArrayType type) -> std::optional { + // Since array type may have an explicit stride declaration (in bytes), + // we also include it in the calculation. + auto elementType = cast(type.getElementType()); + if (std::optional size = elementType.getSizeInBytes()) + return (*size + type.getArrayStride()) * type.getNumElements(); + return std::nullopt; + }) + .Case([](auto type) -> std::optional { + if (std::optional elementSize = + cast(type.getElementType()).getSizeInBytes()) + return *elementSize * type.getNumElements(); + return std::nullopt; + }) + .Default(std::optional()); } //===----------------------------------------------------------------------===//