From 1a7bb18eb1be9b55da8e4803d4e74fe2a58658ce Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 22 Sep 2025 10:00:34 -0400 Subject: [PATCH 1/2] [mlir][spirv] Rework type capability queries * Fix infinite recursion with nested structs. * Drop `::getCapbilities` function from derived types, so that there's only one entry point that queries type extensions. * Move all capability logic to a new helper class -- this way the `::getCapabilities` functions can't diverge across concrete types and 'convenience types' like CompositeType. Fixes: #159963 --- .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 34 --- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 240 ++++++++---------- .../Conversion/SCFToSPIRV/unsupported.mlir | 13 +- .../SPIRV/Transforms/vce-deduction.mlir | 13 +- 4 files changed, 123 insertions(+), 177 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 6beffc17d6d58..475e3f495e065 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -89,9 +89,6 @@ class ScalarType : public SPIRVType { /// Returns true if the given float type is valid for the SPIR-V dialect. static bool isValid(IntegerType); - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); - std::optional getSizeInBytes(); }; @@ -116,9 +113,6 @@ class CompositeType : public SPIRVType { /// implementation dependent. bool hasCompileTimeKnownNumElements() const; - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); - std::optional getSizeInBytes(); }; @@ -144,9 +138,6 @@ class ArrayType : public Type::TypeBase storage = std::nullopt); - /// Returns the array size in bytes. Since array type may have an explicit /// stride declaration (in bytes), we also include it in the calculation. std::optional getSizeInBytes(); @@ -186,9 +177,6 @@ class ImageType ImageSamplerUseInfo getSamplerUseInfo() const; ImageFormat getImageFormat() const; // TODO: Add support for Access qualifier - - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); }; // SPIR-V pointer type @@ -204,9 +192,6 @@ class PointerType : public Type::TypeBase storage = std::nullopt); }; // SPIR-V run-time array type @@ -228,9 +213,6 @@ class RuntimeArrayType /// Returns the array stride in bytes. 0 means no stride decorated on this /// type. unsigned getArrayStride() const; - - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); }; // SPIR-V sampled image type @@ -252,10 +234,6 @@ class SampledImageType Type imageType); Type getImageType() const; - - void - getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); }; /// SPIR-V struct type. Two kinds of struct types are supported: @@ -405,9 +383,6 @@ class StructType trySetBody(ArrayRef memberTypes, ArrayRef offsetInfo = {}, ArrayRef memberDecorations = {}, ArrayRef structDecorations = {}); - - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); }; llvm::hash_code @@ -440,9 +415,6 @@ class CooperativeMatrixType /// Returns the use parameter of the cooperative matrix. CooperativeMatrixUseKHR getUse() const; - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); - operator ShapedType() const { return llvm::cast(*this); } ArrayRef getShape() const; @@ -493,9 +465,6 @@ class MatrixType : public Type::TypeBase storage = std::nullopt); }; /// SPIR-V TensorARM Type @@ -531,9 +500,6 @@ class TensorArmType ArrayRef getShape() const; bool hasRank() const { return !getShape().empty(); } operator ShapedType() const { return llvm::cast(*this); } - - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); }; } // namespace spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 8244e64abba12..7dd67bf56caa4 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -45,17 +45,67 @@ class TypeExtensionVisitor { return; TypeSwitch(type) - .Case( + .Case( [this](auto concreteType) { addConcrete(concreteType); }) - .Case( + .Case( [this](auto concreteType) { add(concreteType.getElementType()); }) + .Case([this](SampledImageType concreteType) { + add(concreteType.getImageType()); + }) .Case([this](StructType concreteType) { for (Type elementType : concreteType.getElementTypes()) add(elementType); }) + .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); }); + } + + void add(Type type) { add(cast(type)); } + +private: + // Types that add unique extensions. + void addConcrete(CooperativeMatrixType type); + void addConcrete(PointerType type); + void addConcrete(ScalarType type); + void addConcrete(TensorArmType type); + + SPIRVType::ExtensionArrayRefVector &extensions; + std::optional storage; + llvm::SmallDenseSet>> seen; +}; + +// Helper function to collect capabilities implied by a type by visiting all its +// subtypes. Maintains a set of `seen` types to avoid recursion in structs. +// +// Serves as the source-of-truth for type capability information. All capability +// logic should be added to this class, while the +// `SPIRVType::getCapabilities` function should not handle capability-related +// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`. +class TypeCapabilityVisitor { +public: + TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities, + std::optional storage) + : capabilities(capabilities), storage(storage) {} + + // Main visitor entry point. Adds all extensions to the vector. Saves `type` + // as seen and dispatches to the right concrete `.add` function. + void add(SPIRVType type) { + if (auto [_it, inserted] = seen.insert({type, storage}); !inserted) + return; + + TypeSwitch(type) + .Case( + [this](auto concreteType) { addConcrete(concreteType); }) + .Case([this](ArrayType concreteType) { + add(concreteType.getElementType()); + }) .Case([this](SampledImageType concreteType) { add(concreteType.getImageType()); }) + .Case([this](StructType concreteType) { + for (Type elementType : concreteType.getElementTypes()) + add(elementType); + }) .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); }); } @@ -63,12 +113,16 @@ class TypeExtensionVisitor { private: // Types that add unique extensions. - void addConcrete(ScalarType type); - void addConcrete(PointerType type); void addConcrete(CooperativeMatrixType type); + void addConcrete(ImageType type); + void addConcrete(MatrixType type); + void addConcrete(PointerType type); + void addConcrete(RuntimeArrayType type); + void addConcrete(ScalarType type); void addConcrete(TensorArmType type); + void addConcrete(VectorType type); - SPIRVType::ExtensionArrayRefVector &extensions; + SPIRVType::CapabilityArrayRefVector &capabilities; std::optional storage; llvm::SmallDenseSet>> seen; }; @@ -118,13 +172,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; } unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } -void ArrayType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - llvm::cast(getElementType()) - .getCapabilities(capabilities, storage); -} - std::optional ArrayType::getSizeInBytes() { auto elementType = llvm::cast(getElementType()); std::optional size = elementType.getSizeInBytes(); @@ -188,30 +235,14 @@ bool CompositeType::hasCompileTimeKnownNumElements() const { return !llvm::isa(*this); } -void CompositeType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - TypeSwitch(*this) - .Case( - [&](auto type) { type.getCapabilities(capabilities, storage); }) - .Case([&](VectorType type) { - auto vecSize = getNumElements(); - if (vecSize == 8 || vecSize == 16) { - static const Capability caps[] = {Capability::Vector16}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); - } - return llvm::cast(type.getElementType()) - .getCapabilities(capabilities, storage); - }) - .Case([&](TensorArmType type) { - static constexpr Capability cap{Capability::TensorsARM}; - capabilities.push_back(cap); - return llvm::cast(type.getElementType()) - .getCapabilities(capabilities, storage); - }) - .Default([](Type) { llvm_unreachable("invalid composite type"); }); +void TypeCapabilityVisitor::addConcrete(VectorType type) { + add(type.getElementType()); + + int64_t vecSize = type.getNumElements(); + if (vecSize == 8 || vecSize == 16) { + static constexpr auto cap = Capability::Vector16; + capabilities.push_back(cap); + } } std::optional CompositeType::getSizeInBytes() { @@ -317,12 +348,9 @@ void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) { extensions.push_back(ext); } -void CooperativeMatrixType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - llvm::cast(getElementType()) - .getCapabilities(capabilities, storage); - static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR}; +void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) { + add(type.getElementType()); + static constexpr auto caps = Capability::CooperativeMatrixKHR; capabilities.push_back(caps); } @@ -428,14 +456,14 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { ImageFormat ImageType::getImageFormat() const { return getImpl()->format; } -void ImageType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional) { - if (auto dimCaps = spirv::getCapabilities(getDim())) +void TypeCapabilityVisitor::addConcrete(ImageType type) { + if (auto dimCaps = spirv::getCapabilities(type.getDim())) capabilities.push_back(*dimCaps); - if (auto fmtCaps = spirv::getCapabilities(getImageFormat())) + if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat())) capabilities.push_back(*fmtCaps); + + add(type.getElementType()); } //===----------------------------------------------------------------------===// @@ -486,15 +514,15 @@ void TypeExtensionVisitor::addConcrete(PointerType type) { extensions.push_back(*scExts); } -void PointerType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { +void TypeCapabilityVisitor::addConcrete(PointerType type) { // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. - llvm::cast(getPointeeType()) - .getCapabilities(capabilities, getStorageClass()); + std::optional oldStorageClass = storage; + storage = type.getStorageClass(); + add(type.getPointeeType()); + storage = oldStorageClass; - if (auto scCaps = spirv::getCapabilities(getStorageClass())) + if (auto scCaps = spirv::getCapabilities(type.getStorageClass())) capabilities.push_back(*scCaps); } @@ -534,16 +562,11 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } -void RuntimeArrayType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - { - static const Capability caps[] = {Capability::Shader}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); - } - llvm::cast(getElementType()) - .getCapabilities(capabilities, storage); +void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) { + add(type.getElementType()); + + static constexpr auto cap = Capability::Shader; + capabilities.push_back(cap); } //===----------------------------------------------------------------------===// @@ -601,10 +624,8 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) { } } -void ScalarType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - unsigned bitwidth = getIntOrFloatBitWidth(); +void TypeCapabilityVisitor::addConcrete(ScalarType type) { + unsigned bitwidth = type.getIntOrFloatBitWidth(); // 8- or 16-bit integer/floating-point numbers will require extra capabilities // to appear in interface storage classes. See SPV_KHR_16bit_storage and @@ -613,15 +634,13 @@ void ScalarType::getCapabilities( #define STORAGE_CASE(storage, cap8, cap16) \ case StorageClass::storage: { \ if (bitwidth == 8) { \ - static const Capability caps[] = {Capability::cap8}; \ - ArrayRef ref(caps, std::size(caps)); \ - capabilities.push_back(ref); \ + static constexpr auto cap = Capability::cap8; \ + capabilities.push_back(cap); \ return; \ } \ if (bitwidth == 16) { \ - static const Capability caps[] = {Capability::cap16}; \ - ArrayRef ref(caps, std::size(caps)); \ - capabilities.push_back(ref); \ + static constexpr auto cap = Capability::cap16; \ + capabilities.push_back(cap); \ return; \ } \ /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \ @@ -640,9 +659,8 @@ void ScalarType::getCapabilities( case StorageClass::Input: case StorageClass::Output: { if (bitwidth == 16) { - static const Capability caps[] = {Capability::StorageInputOutput16}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); + static constexpr auto cap = Capability::StorageInputOutput16; + capabilities.push_back(cap); return; } break; @@ -658,12 +676,11 @@ void ScalarType::getCapabilities( #define WIDTH_CASE(type, width) \ case width: { \ - static const Capability caps[] = {Capability::type##width}; \ - ArrayRef ref(caps, std::size(caps)); \ - capabilities.push_back(ref); \ + static constexpr auto cap = Capability::type##width; \ + capabilities.push_back(cap); \ } break - if (auto intType = llvm::dyn_cast(*this)) { + if (auto intType = dyn_cast(type)) { switch (bitwidth) { WIDTH_CASE(Int, 8); WIDTH_CASE(Int, 16); @@ -675,14 +692,14 @@ void ScalarType::getCapabilities( llvm_unreachable("invalid bitwidth to getCapabilities"); } } else { - assert(llvm::isa(*this)); + assert(isa(type)); switch (bitwidth) { case 16: { - if (isa(*this)) { - static const Capability cap = Capability::BFloat16TypeKHR; + if (isa(type)) { + static constexpr auto cap = Capability::BFloat16TypeKHR; capabilities.push_back(cap); } else { - static const Capability cap = Capability::Float16; + static constexpr auto cap = Capability::Float16; capabilities.push_back(cap); } break; @@ -740,23 +757,7 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, void SPIRVType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { - if (auto scalarType = llvm::dyn_cast(*this)) { - scalarType.getCapabilities(capabilities, storage); - } else if (auto compositeType = llvm::dyn_cast(*this)) { - compositeType.getCapabilities(capabilities, storage); - } else if (auto imageType = llvm::dyn_cast(*this)) { - imageType.getCapabilities(capabilities, storage); - } else if (auto sampledImageType = llvm::dyn_cast(*this)) { - sampledImageType.getCapabilities(capabilities, storage); - } else if (auto matrixType = llvm::dyn_cast(*this)) { - matrixType.getCapabilities(capabilities, storage); - } else if (auto ptrType = llvm::dyn_cast(*this)) { - ptrType.getCapabilities(capabilities, storage); - } else if (auto tensorArmType = llvm::dyn_cast(*this)) { - tensorArmType.getCapabilities(capabilities, storage); - } else { - llvm_unreachable("invalid SPIR-V Type to getCapabilities"); - } + TypeCapabilityVisitor{capabilities, storage}.add(*this); } std::optional SPIRVType::getSizeInBytes() { @@ -814,12 +815,6 @@ SampledImageType::verifyInvariants(function_ref emitError, return success(); } -void SampledImageType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - llvm::cast(getImageType()).getCapabilities(capabilities, storage); -} - //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// @@ -1172,13 +1167,6 @@ StructType::trySetBody(ArrayRef memberTypes, structDecorations); } -void StructType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - for (Type elementType : getElementTypes()) - llvm::cast(elementType).getCapabilities(capabilities, storage); -} - llvm::hash_code spirv::hash_value( const StructType::MemberDecorationInfo &memberDecorationInfo) { return llvm::hash_combine(memberDecorationInfo.memberIndex, @@ -1271,16 +1259,11 @@ unsigned MatrixType::getNumElements() const { return (getImpl()->columnCount) * getNumRows(); } -void MatrixType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - { - static const Capability caps[] = {Capability::Matrix}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); - } - // Add any capabilities associated with the underlying vectors (i.e., columns) - llvm::cast(getColumnType()).getCapabilities(capabilities, storage); +void TypeCapabilityVisitor::addConcrete(MatrixType type) { + add(type.getColumnType()); + + static constexpr auto cap = Capability::Matrix; + capabilities.push_back(cap); } //===----------------------------------------------------------------------===// @@ -1332,12 +1315,9 @@ void TypeExtensionVisitor::addConcrete(TensorArmType type) { extensions.push_back(ext); } -void TensorArmType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - llvm::cast(getElementType()) - .getCapabilities(capabilities, storage); - static constexpr Capability cap{Capability::TensorsARM}; +void TypeCapabilityVisitor::addConcrete(TensorArmType type) { + add(type.getElementType()); + static constexpr auto cap = Capability::TensorsARM; capabilities.push_back(cap); } diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir index d24f37b553bb5..1a1c24a09aa8c 100644 --- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s +// RUN: mlir-opt --convert-scf-to-spirv %s | FileCheck %s // `scf.parallel` conversion is not supported yet. // Make sure that we do not accidentally invalidate this function by removing @@ -19,14 +19,3 @@ func.func @func(%arg0: i64) { } return } - -// ----- - -// Make sure we don't crash on recursive structs. -// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase. - -// expected-error@below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}} -spirv.module Physical64 GLSL450 { - spirv.GlobalVariable @recursive: - !spirv.ptr, StorageBuffer>)>, StorageBuffer> -} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2d20ae0a13105..7dab87f8081ed 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -232,7 +232,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical Vulkan attributes { spirv.target_env = #spirv.target_env< #spirv.vce, @@ -242,3 +242,14 @@ spirv.module Logical Vulkan attributes { spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8> } } + +// Check that extension and capability queries handle recursive types. +// CHECK: requires #spirv.vce +spirv.module Physical64 GLSL450 attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits<>> +} { + spirv.GlobalVariable @recursive: + !spirv.ptr, StorageBuffer>)>, StorageBuffer> +} From f3c51fbf72e4af484c635e4864d4504f1d5ce6f5 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 22 Sep 2025 11:01:08 -0400 Subject: [PATCH 2/2] Fix nits --- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 7dd67bf56caa4..7c2f43bea9ddb 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -564,7 +564,6 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) { add(type.getElementType()); - static constexpr auto cap = Capability::Shader; capabilities.push_back(cap); } @@ -1261,7 +1260,6 @@ unsigned MatrixType::getNumElements() const { void TypeCapabilityVisitor::addConcrete(MatrixType type) { add(type.getColumnType()); - static constexpr auto cap = Capability::Matrix; capabilities.push_back(cap); }