-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Rework type capability queries #160113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* Fix infinite recursion with nested structs. * Drop `::getCapbilities` function from derived types, so that there's only one entry point that queries type extensions. * Move all capability logic to a new helper class -- this way the `::getCapabilities` functions can't diverge across concrete types and 'convenience types' like CompositeType. Fixes: llvm#159963
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) Changes
Fixes: #159963 Patch is 24.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160113.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 6beffc17d6d58..475e3f495e065 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -89,9 +89,6 @@ class ScalarType : public SPIRVType {
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
std::optional<int64_t> getSizeInBytes();
};
@@ -116,9 +113,6 @@ class CompositeType : public SPIRVType {
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
std::optional<int64_t> getSizeInBytes();
};
@@ -144,9 +138,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// type.
unsigned getArrayStride() const;
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
/// Returns the array size in bytes. Since array type may have an explicit
/// stride declaration (in bytes), we also include it in the calculation.
std::optional<int64_t> getSizeInBytes();
@@ -186,9 +177,6 @@ class ImageType
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO: Add support for Access qualifier
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V pointer type
@@ -204,9 +192,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
Type getPointeeType() const;
StorageClass getStorageClass() const;
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V run-time array type
@@ -228,9 +213,6 @@ class RuntimeArrayType
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V sampled image type
@@ -252,10 +234,6 @@ class SampledImageType
Type imageType);
Type getImageType() const;
-
- void
- getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<spirv::StorageClass> storage = std::nullopt);
};
/// SPIR-V struct type. Two kinds of struct types are supported:
@@ -405,9 +383,6 @@ class StructType
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {},
ArrayRef<StructDecorationInfo> structDecorations = {});
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
llvm::hash_code
@@ -440,9 +415,6 @@ class CooperativeMatrixType
/// Returns the use parameter of the cooperative matrix.
CooperativeMatrixUseKHR getUse() const;
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
ArrayRef<int64_t> getShape() const;
@@ -493,9 +465,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
/// SPIR-V TensorARM Type
@@ -531,9 +500,6 @@ class TensorArmType
ArrayRef<int64_t> getShape() const;
bool hasRank() const { return !getShape().empty(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
} // namespace spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 8244e64abba12..7dd67bf56caa4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -45,17 +45,67 @@ class TypeExtensionVisitor {
return;
TypeSwitch<SPIRVType>(type)
- .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
+ .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
[this](auto concreteType) { addConcrete(concreteType); })
- .Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
+ .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
[this](auto concreteType) { add(concreteType.getElementType()); })
+ .Case<SampledImageType>([this](SampledImageType concreteType) {
+ add(concreteType.getImageType());
+ })
.Case<StructType>([this](StructType concreteType) {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
+ .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ }
+
+ void add(Type type) { add(cast<SPIRVType>(type)); }
+
+private:
+ // Types that add unique extensions.
+ void addConcrete(CooperativeMatrixType type);
+ void addConcrete(PointerType type);
+ void addConcrete(ScalarType type);
+ void addConcrete(TensorArmType type);
+
+ SPIRVType::ExtensionArrayRefVector &extensions;
+ std::optional<StorageClass> storage;
+ llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
+};
+
+// Helper function to collect capabilities implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type capability information. All capability
+// logic should be added to this class, while the
+// `SPIRVType::getCapabilities` function should not handle capability-related
+// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
+class TypeCapabilityVisitor {
+public:
+ TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage)
+ : capabilities(capabilities), storage(storage) {}
+
+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+ // as seen and dispatches to the right concrete `.add` function.
+ void add(SPIRVType type) {
+ if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
+ return;
+
+ TypeSwitch<SPIRVType>(type)
+ .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
+ RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
+ [this](auto concreteType) { addConcrete(concreteType); })
+ .Case<ArrayType>([this](ArrayType concreteType) {
+ add(concreteType.getElementType());
+ })
.Case<SampledImageType>([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
+ .Case<StructType>([this](StructType concreteType) {
+ for (Type elementType : concreteType.getElementTypes())
+ add(elementType);
+ })
.Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
}
@@ -63,12 +113,16 @@ class TypeExtensionVisitor {
private:
// Types that add unique extensions.
- void addConcrete(ScalarType type);
- void addConcrete(PointerType type);
void addConcrete(CooperativeMatrixType type);
+ void addConcrete(ImageType type);
+ void addConcrete(MatrixType type);
+ void addConcrete(PointerType type);
+ void addConcrete(RuntimeArrayType type);
+ void addConcrete(ScalarType type);
void addConcrete(TensorArmType type);
+ void addConcrete(VectorType type);
- SPIRVType::ExtensionArrayRefVector &extensions;
+ SPIRVType::CapabilityArrayRefVector &capabilities;
std::optional<StorageClass> storage;
llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
};
@@ -118,13 +172,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
-void ArrayType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
-}
-
std::optional<int64_t> ArrayType::getSizeInBytes() {
auto elementType = llvm::cast<SPIRVType>(getElementType());
std::optional<int64_t> size = elementType.getSizeInBytes();
@@ -188,30 +235,14 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
}
-void CompositeType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
- StructType>(
- [&](auto type) { type.getCapabilities(capabilities, storage); })
- .Case<VectorType>([&](VectorType type) {
- auto vecSize = getNumElements();
- if (vecSize == 8 || vecSize == 16) {
- static const Capability caps[] = {Capability::Vector16};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
- }
- return llvm::cast<ScalarType>(type.getElementType())
- .getCapabilities(capabilities, storage);
- })
- .Case<TensorArmType>([&](TensorArmType type) {
- static constexpr Capability cap{Capability::TensorsARM};
- capabilities.push_back(cap);
- return llvm::cast<ScalarType>(type.getElementType())
- .getCapabilities(capabilities, storage);
- })
- .Default([](Type) { llvm_unreachable("invalid composite type"); });
+void TypeCapabilityVisitor::addConcrete(VectorType type) {
+ add(type.getElementType());
+
+ int64_t vecSize = type.getNumElements();
+ if (vecSize == 8 || vecSize == 16) {
+ static constexpr auto cap = Capability::Vector16;
+ capabilities.push_back(cap);
+ }
}
std::optional<int64_t> CompositeType::getSizeInBytes() {
@@ -317,12 +348,9 @@ void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
extensions.push_back(ext);
}
-void CooperativeMatrixType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
- static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
+void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
+ add(type.getElementType());
+ static constexpr auto caps = Capability::CooperativeMatrixKHR;
capabilities.push_back(caps);
}
@@ -428,14 +456,14 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
-void ImageType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass>) {
- if (auto dimCaps = spirv::getCapabilities(getDim()))
+void TypeCapabilityVisitor::addConcrete(ImageType type) {
+ if (auto dimCaps = spirv::getCapabilities(type.getDim()))
capabilities.push_back(*dimCaps);
- if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
+ if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
capabilities.push_back(*fmtCaps);
+
+ add(type.getElementType());
}
//===----------------------------------------------------------------------===//
@@ -486,15 +514,15 @@ void TypeExtensionVisitor::addConcrete(PointerType type) {
extensions.push_back(*scExts);
}
-void PointerType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
+void TypeCapabilityVisitor::addConcrete(PointerType type) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
- llvm::cast<SPIRVType>(getPointeeType())
- .getCapabilities(capabilities, getStorageClass());
+ std::optional<StorageClass> oldStorageClass = storage;
+ storage = type.getStorageClass();
+ add(type.getPointeeType());
+ storage = oldStorageClass;
- if (auto scCaps = spirv::getCapabilities(getStorageClass()))
+ if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
capabilities.push_back(*scCaps);
}
@@ -534,16 +562,11 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
-void RuntimeArrayType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- {
- static const Capability caps[] = {Capability::Shader};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
- }
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
+void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
+ add(type.getElementType());
+
+ static constexpr auto cap = Capability::Shader;
+ capabilities.push_back(cap);
}
//===----------------------------------------------------------------------===//
@@ -601,10 +624,8 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
}
}
-void ScalarType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- unsigned bitwidth = getIntOrFloatBitWidth();
+void TypeCapabilityVisitor::addConcrete(ScalarType type) {
+ unsigned bitwidth = type.getIntOrFloatBitWidth();
// 8- or 16-bit integer/floating-point numbers will require extra capabilities
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
@@ -613,15 +634,13 @@ void ScalarType::getCapabilities(
#define STORAGE_CASE(storage, cap8, cap16) \
case StorageClass::storage: { \
if (bitwidth == 8) { \
- static const Capability caps[] = {Capability::cap8}; \
- ArrayRef<Capability> ref(caps, std::size(caps)); \
- capabilities.push_back(ref); \
+ static constexpr auto cap = Capability::cap8; \
+ capabilities.push_back(cap); \
return; \
} \
if (bitwidth == 16) { \
- static const Capability caps[] = {Capability::cap16}; \
- ArrayRef<Capability> ref(caps, std::size(caps)); \
- capabilities.push_back(ref); \
+ static constexpr auto cap = Capability::cap16; \
+ capabilities.push_back(cap); \
return; \
} \
/* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
@@ -640,9 +659,8 @@ void ScalarType::getCapabilities(
case StorageClass::Input:
case StorageClass::Output: {
if (bitwidth == 16) {
- static const Capability caps[] = {Capability::StorageInputOutput16};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
+ static constexpr auto cap = Capability::StorageInputOutput16;
+ capabilities.push_back(cap);
return;
}
break;
@@ -658,12 +676,11 @@ void ScalarType::getCapabilities(
#define WIDTH_CASE(type, width) \
case width: { \
- static const Capability caps[] = {Capability::type##width}; \
- ArrayRef<Capability> ref(caps, std::size(caps)); \
- capabilities.push_back(ref); \
+ static constexpr auto cap = Capability::type##width; \
+ capabilities.push_back(cap); \
} break
- if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
switch (bitwidth) {
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
@@ -675,14 +692,14 @@ void ScalarType::getCapabilities(
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
- assert(llvm::isa<FloatType>(*this));
+ assert(isa<FloatType>(type));
switch (bitwidth) {
case 16: {
- if (isa<BFloat16Type>(*this)) {
- static const Capability cap = Capability::BFloat16TypeKHR;
+ if (isa<BFloat16Type>(type)) {
+ static constexpr auto cap = Capability::BFloat16TypeKHR;
capabilities.push_back(cap);
} else {
- static const Capability cap = Capability::Float16;
+ static constexpr auto cap = Capability::Float16;
capabilities.push_back(cap);
}
break;
@@ -740,23 +757,7 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
void SPIRVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
- scalarType.getCapabilities(capabilities, storage);
- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
- compositeType.getCapabilities(capabilities, storage);
- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
- imageType.getCapabilities(capabilities, storage);
- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
- sampledImageType.getCapabilities(capabilities, storage);
- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
- matrixType.getCapabilities(capabilities, storage);
- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
- ptrType.getCapabilities(capabilities, storage);
- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
- tensorArmType.getCapabilities(capabilities, storage);
- } else {
- llvm_unreachable("invalid SPIR-V Type to getCapabilities");
- }
+ TypeCapabilityVisitor{capabilities, storage}.add(*this);
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
@@ -814,12 +815,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-void SampledImageType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
-}
-
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
@@ -1172,13 +1167,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
structDecorations);
}
-void StructType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- for (Type elementType : getElementTypes())
- llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
-}
-
llvm::hash_code spirv::hash_value(
const StructType::MemberDecorationInfo &memberDeco...
[truncated]
|
IgWod-IMG
approved these changes
Sep 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % small inconsistency in formatting
SeongjaeP
pushed a commit
to SeongjaeP/llvm-project
that referenced
this pull request
Sep 23, 2025
* Fix infinite recursion with nested structs. * Drop `::getCapbilities` function from derived types, so that there's only one entry point that queries type extensions. * Move all capability logic to a new helper class -- this way the `::getCapabilities` functions can't diverge across concrete types and 'convenience types' like CompositeType. Fixes: llvm#159963
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
::getCapbilities
function from derived types, so that there's only one entry point that queries type extensions.::getCapabilities
functions can't diverge across concrete types and 'convenience types' like CompositeType.Fixes: #159963