-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Rework type size calculation #160162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Similar to `::getExtensions` and `::getCapabilities`, introduce a single entry point for type size calculation. Also fix potential infinite recursion with `StructType`s (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Jakub Kuderski (kuhar) ChangesSimilar to Also fix potential infinite recursion with Full diff: https://github.com/llvm/llvm-project/pull/160162.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 475e3f495e065..e46b576810316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -88,8 +88,6 @@ class ScalarType : public SPIRVType {
static bool isValid(FloatType);
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
-
- std::optional<int64_t> getSizeInBytes();
};
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
@@ -112,8 +110,6 @@ class CompositeType : public SPIRVType {
/// Return true if the number of elements is known at compile time and is not
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
-
- std::optional<int64_t> getSizeInBytes();
};
// SPIR-V array type
@@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
-
- /// Returns the array size in bytes. Since array type may have an explicit
- /// stride declaration (in bytes), we also include it in the calculation.
- std::optional<int64_t> getSizeInBytes();
};
// SPIR-V image type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7c2f43bea9ddb..5ed7652987859 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -20,6 +20,7 @@
#include "llvm/Support/ErrorHandling.h"
#include <cstdint>
+#include <optional>
using namespace mlir;
using namespace mlir::spirv;
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
-std::optional<int64_t> ArrayType::getSizeInBytes() {
- auto elementType = llvm::cast<SPIRVType>(getElementType());
- std::optional<int64_t> size = elementType.getSizeInBytes();
- if (!size)
- return std::nullopt;
- return (*size + getArrayStride()) * getNumElements();
-}
-
//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
}
}
-std::optional<int64_t> CompositeType::getSizeInBytes() {
- if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
- return arrayType.getSizeInBytes();
- if (auto structType = llvm::dyn_cast<StructType>(*this))
- return structType.getSizeInBytes();
- if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
- std::optional<int64_t> elementSize =
- llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
- if (!elementSize)
- return std::nullopt;
- return *elementSize * vectorType.getNumElements();
- }
- if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
- std::optional<int64_t> elementSize =
- llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
- if (!elementSize)
- return std::nullopt;
- return *elementSize * tensorArmType.getNumElements();
- }
- return std::nullopt;
-}
-
//===----------------------------------------------------------------------===//
// CooperativeMatrixType
//===----------------------------------------------------------------------===//
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
#undef WIDTH_CASE
}
-std::optional<int64_t> ScalarType::getSizeInBytes() {
- auto bitWidth = getIntOrFloatBitWidth();
- // According to the SPIR-V spec:
- // "There is no physical size or bit pattern defined for values with boolean
- // type. If they are stored (in conjunction with OpVariable), they can only
- // be used with logical addressing operations, not physical, and only with
- // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
- // Private, Function, Input, and Output."
- if (bitWidth == 1)
- return std::nullopt;
- return bitWidth / 8;
-}
-
//===----------------------------------------------------------------------===//
// SPIRVType
//===----------------------------------------------------------------------===//
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
- return scalarType.getSizeInBytes();
- if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
- return compositeType.getSizeInBytes();
- return std::nullopt;
+ return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
+ .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
+ // According to the SPIR-V spec:
+ // "There is no physical size or bit pattern defined for values with
+ // boolean type. If they are stored (in conjunction with OpVariable),
+ // they can only be used with logical addressing operations, not
+ // physical, and only with non-externally visible shader Storage
+ // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
+ // Output."
+ int64_t bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 1)
+ return std::nullopt;
+ return bitWidth / 8;
+ })
+ .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
+ // Since array type may have an explicit stride declaration (in bytes),
+ // we also include it in the calculation.
+ auto elementType = cast<SPIRVType>(type.getElementType());
+ if (std::optional<int64_t> size = elementType.getSizeInBytes())
+ return (*size + type.getArrayStride()) * type.getNumElements();
+ return std::nullopt;
+ })
+ .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
+ if(std::optional<int64_t> elementSize =
+ cast<ScalarType>(type.getElementType()).getSizeInBytes())
+ return *elementSize * type.getNumElements();
+ return std::nullopt;
+ })
+ .Default(std::optional<int64_t>());
}
//===----------------------------------------------------------------------===//
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Similar to
::getExtensions
and::getCapabilities
, introduce a single entry point for type size calculation.Also fix potential infinite recursion with
StructType
s (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.