Skip to content

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 22, 2025

  • Fix infinite recursion with nested structs.
  • Drop ::getCapbilities function from derived types, so that there's only one entry point that queries type extensions.
  • Move all capability logic to a new helper class -- this way the ::getCapabilities functions can't diverge across concrete types and 'convenience types' like CompositeType.

Fixes: #159963

* Fix infinite recursion with nested structs.
* Drop `::getCapbilities` function from derived types, so that there's only one entry point that queries type extensions.
* Move all capability logic to a new helper class -- this way the `::getCapabilities` functions can't diverge across concrete types and 'convenience types' like CompositeType.

Fixes: llvm#159963
@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes
  • Fix infinite recursion with nested structs.
  • Drop ::getCapbilities function from derived types, so that there's only one entry point that queries type extensions.
  • Move all capability logic to a new helper class -- this way the ::getCapabilities functions can't diverge across concrete types and 'convenience types' like CompositeType.

Fixes: #159963


Patch is 24.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160113.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (-34)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+110-130)
  • (modified) mlir/test/Conversion/SCFToSPIRV/unsupported.mlir (+1-12)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir (+12-1)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 6beffc17d6d58..475e3f495e065 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -89,9 +89,6 @@ class ScalarType : public SPIRVType {
   /// Returns true if the given float type is valid for the SPIR-V dialect.
   static bool isValid(IntegerType);
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   std::optional<int64_t> getSizeInBytes();
 };
 
@@ -116,9 +113,6 @@ class CompositeType : public SPIRVType {
   /// implementation dependent.
   bool hasCompileTimeKnownNumElements() const;
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   std::optional<int64_t> getSizeInBytes();
 };
 
@@ -144,9 +138,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
   /// type.
   unsigned getArrayStride() const;
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   /// Returns the array size in bytes. Since array type may have an explicit
   /// stride declaration (in bytes), we also include it in the calculation.
   std::optional<int64_t> getSizeInBytes();
@@ -186,9 +177,6 @@ class ImageType
   ImageSamplerUseInfo getSamplerUseInfo() const;
   ImageFormat getImageFormat() const;
   // TODO: Add support for Access qualifier
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 // SPIR-V pointer type
@@ -204,9 +192,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
   Type getPointeeType() const;
 
   StorageClass getStorageClass() const;
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 // SPIR-V run-time array type
@@ -228,9 +213,6 @@ class RuntimeArrayType
   /// Returns the array stride in bytes. 0 means no stride decorated on this
   /// type.
   unsigned getArrayStride() const;
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 // SPIR-V sampled image type
@@ -252,10 +234,6 @@ class SampledImageType
                    Type imageType);
 
   Type getImageType() const;
-
-  void
-  getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                  std::optional<spirv::StorageClass> storage = std::nullopt);
 };
 
 /// SPIR-V struct type. Two kinds of struct types are supported:
@@ -405,9 +383,6 @@ class StructType
   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
              ArrayRef<MemberDecorationInfo> memberDecorations = {},
              ArrayRef<StructDecorationInfo> structDecorations = {});
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 llvm::hash_code
@@ -440,9 +415,6 @@ class CooperativeMatrixType
   /// Returns the use parameter of the cooperative matrix.
   CooperativeMatrixUseKHR getUse() const;
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 
   ArrayRef<int64_t> getShape() const;
@@ -493,9 +465,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
 
   /// Returns the elements' type (i.e, single element type).
   Type getElementType() const;
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 /// SPIR-V TensorARM Type
@@ -531,9 +500,6 @@ class TensorArmType
   ArrayRef<int64_t> getShape() const;
   bool hasRank() const { return !getShape().empty(); }
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 } // namespace spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 8244e64abba12..7dd67bf56caa4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -45,17 +45,67 @@ class TypeExtensionVisitor {
       return;
 
     TypeSwitch<SPIRVType>(type)
-        .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
+        .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
             [this](auto concreteType) { addConcrete(concreteType); })
-        .Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
+        .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
             [this](auto concreteType) { add(concreteType.getElementType()); })
+        .Case<SampledImageType>([this](SampledImageType concreteType) {
+          add(concreteType.getImageType());
+        })
         .Case<StructType>([this](StructType concreteType) {
           for (Type elementType : concreteType.getElementTypes())
             add(elementType);
         })
+        .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+  }
+
+  void add(Type type) { add(cast<SPIRVType>(type)); }
+
+private:
+  // Types that add unique extensions.
+  void addConcrete(CooperativeMatrixType type);
+  void addConcrete(PointerType type);
+  void addConcrete(ScalarType type);
+  void addConcrete(TensorArmType type);
+
+  SPIRVType::ExtensionArrayRefVector &extensions;
+  std::optional<StorageClass> storage;
+  llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
+};
+
+// Helper function to collect capabilities implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type capability information. All capability
+// logic should be added to this class, while the
+// `SPIRVType::getCapabilities` function should not handle capability-related
+// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
+class TypeCapabilityVisitor {
+public:
+  TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
+                        std::optional<StorageClass> storage)
+      : capabilities(capabilities), storage(storage) {}
+
+  // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+  // as seen and dispatches to the right concrete `.add` function.
+  void add(SPIRVType type) {
+    if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
+      return;
+
+    TypeSwitch<SPIRVType>(type)
+        .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
+              RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
+            [this](auto concreteType) { addConcrete(concreteType); })
+        .Case<ArrayType>([this](ArrayType concreteType) {
+          add(concreteType.getElementType());
+        })
         .Case<SampledImageType>([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
         })
+        .Case<StructType>([this](StructType concreteType) {
+          for (Type elementType : concreteType.getElementTypes())
+            add(elementType);
+        })
         .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
   }
 
@@ -63,12 +113,16 @@ class TypeExtensionVisitor {
 
 private:
   // Types that add unique extensions.
-  void addConcrete(ScalarType type);
-  void addConcrete(PointerType type);
   void addConcrete(CooperativeMatrixType type);
+  void addConcrete(ImageType type);
+  void addConcrete(MatrixType type);
+  void addConcrete(PointerType type);
+  void addConcrete(RuntimeArrayType type);
+  void addConcrete(ScalarType type);
   void addConcrete(TensorArmType type);
+  void addConcrete(VectorType type);
 
-  SPIRVType::ExtensionArrayRefVector &extensions;
+  SPIRVType::CapabilityArrayRefVector &capabilities;
   std::optional<StorageClass> storage;
   llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
 };
@@ -118,13 +172,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 
-void ArrayType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-}
-
 std::optional<int64_t> ArrayType::getSizeInBytes() {
   auto elementType = llvm::cast<SPIRVType>(getElementType());
   std::optional<int64_t> size = elementType.getSizeInBytes();
@@ -188,30 +235,14 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
   return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
 }
 
-void CompositeType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
-            StructType>(
-          [&](auto type) { type.getCapabilities(capabilities, storage); })
-      .Case<VectorType>([&](VectorType type) {
-        auto vecSize = getNumElements();
-        if (vecSize == 8 || vecSize == 16) {
-          static const Capability caps[] = {Capability::Vector16};
-          ArrayRef<Capability> ref(caps, std::size(caps));
-          capabilities.push_back(ref);
-        }
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getCapabilities(capabilities, storage);
-      })
-      .Case<TensorArmType>([&](TensorArmType type) {
-        static constexpr Capability cap{Capability::TensorsARM};
-        capabilities.push_back(cap);
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getCapabilities(capabilities, storage);
-      })
-      .Default([](Type) { llvm_unreachable("invalid composite type"); });
+void TypeCapabilityVisitor::addConcrete(VectorType type) {
+  add(type.getElementType());
+
+  int64_t vecSize = type.getNumElements();
+  if (vecSize == 8 || vecSize == 16) {
+    static constexpr auto cap = Capability::Vector16;
+    capabilities.push_back(cap);
+  }
 }
 
 std::optional<int64_t> CompositeType::getSizeInBytes() {
@@ -317,12 +348,9 @@ void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
   extensions.push_back(ext);
 }
 
-void CooperativeMatrixType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
+void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
+  add(type.getElementType());
+  static constexpr auto caps = Capability::CooperativeMatrixKHR;
   capabilities.push_back(caps);
 }
 
@@ -428,14 +456,14 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
 
 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
 
-void ImageType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass>) {
-  if (auto dimCaps = spirv::getCapabilities(getDim()))
+void TypeCapabilityVisitor::addConcrete(ImageType type) {
+  if (auto dimCaps = spirv::getCapabilities(type.getDim()))
     capabilities.push_back(*dimCaps);
 
-  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
+  if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
     capabilities.push_back(*fmtCaps);
+
+  add(type.getElementType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -486,15 +514,15 @@ void TypeExtensionVisitor::addConcrete(PointerType type) {
     extensions.push_back(*scExts);
 }
 
-void PointerType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
+void TypeCapabilityVisitor::addConcrete(PointerType type) {
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
-  llvm::cast<SPIRVType>(getPointeeType())
-      .getCapabilities(capabilities, getStorageClass());
+  std::optional<StorageClass> oldStorageClass = storage;
+  storage = type.getStorageClass();
+  add(type.getPointeeType());
+  storage = oldStorageClass;
 
-  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
+  if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
     capabilities.push_back(*scCaps);
 }
 
@@ -534,16 +562,11 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
 
 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
 
-void RuntimeArrayType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  {
-    static const Capability caps[] = {Capability::Shader};
-    ArrayRef<Capability> ref(caps, std::size(caps));
-    capabilities.push_back(ref);
-  }
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
+void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
+  add(type.getElementType());
+
+  static constexpr auto cap = Capability::Shader;
+  capabilities.push_back(cap);
 }
 
 //===----------------------------------------------------------------------===//
@@ -601,10 +624,8 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
   }
 }
 
-void ScalarType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  unsigned bitwidth = getIntOrFloatBitWidth();
+void TypeCapabilityVisitor::addConcrete(ScalarType type) {
+  unsigned bitwidth = type.getIntOrFloatBitWidth();
 
   // 8- or 16-bit integer/floating-point numbers will require extra capabilities
   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
@@ -613,15 +634,13 @@ void ScalarType::getCapabilities(
 #define STORAGE_CASE(storage, cap8, cap16)                                     \
   case StorageClass::storage: {                                                \
     if (bitwidth == 8) {                                                       \
-      static const Capability caps[] = {Capability::cap8};                     \
-      ArrayRef<Capability> ref(caps, std::size(caps));                         \
-      capabilities.push_back(ref);                                             \
+      static constexpr auto cap = Capability::cap8;                            \
+      capabilities.push_back(cap);                                             \
       return;                                                                  \
     }                                                                          \
     if (bitwidth == 16) {                                                      \
-      static const Capability caps[] = {Capability::cap16};                    \
-      ArrayRef<Capability> ref(caps, std::size(caps));                         \
-      capabilities.push_back(ref);                                             \
+      static constexpr auto cap = Capability::cap16;                           \
+      capabilities.push_back(cap);                                             \
       return;                                                                  \
     }                                                                          \
     /* For 64-bit integers/floats, Int64/Float64 enables support for all */    \
@@ -640,9 +659,8 @@ void ScalarType::getCapabilities(
     case StorageClass::Input:
     case StorageClass::Output: {
       if (bitwidth == 16) {
-        static const Capability caps[] = {Capability::StorageInputOutput16};
-        ArrayRef<Capability> ref(caps, std::size(caps));
-        capabilities.push_back(ref);
+        static constexpr auto cap = Capability::StorageInputOutput16;
+        capabilities.push_back(cap);
         return;
       }
       break;
@@ -658,12 +676,11 @@ void ScalarType::getCapabilities(
 
 #define WIDTH_CASE(type, width)                                                \
   case width: {                                                                \
-    static const Capability caps[] = {Capability::type##width};                \
-    ArrayRef<Capability> ref(caps, std::size(caps));                           \
-    capabilities.push_back(ref);                                               \
+    static constexpr auto cap = Capability::type##width;                       \
+    capabilities.push_back(cap);                                               \
   } break
 
-  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
+  if (auto intType = dyn_cast<IntegerType>(type)) {
     switch (bitwidth) {
       WIDTH_CASE(Int, 8);
       WIDTH_CASE(Int, 16);
@@ -675,14 +692,14 @@ void ScalarType::getCapabilities(
       llvm_unreachable("invalid bitwidth to getCapabilities");
     }
   } else {
-    assert(llvm::isa<FloatType>(*this));
+    assert(isa<FloatType>(type));
     switch (bitwidth) {
     case 16: {
-      if (isa<BFloat16Type>(*this)) {
-        static const Capability cap = Capability::BFloat16TypeKHR;
+      if (isa<BFloat16Type>(type)) {
+        static constexpr auto cap = Capability::BFloat16TypeKHR;
         capabilities.push_back(cap);
       } else {
-        static const Capability cap = Capability::Float16;
+        static constexpr auto cap = Capability::Float16;
         capabilities.push_back(cap);
       }
       break;
@@ -740,23 +757,7 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
 void SPIRVType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
-  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
-    scalarType.getCapabilities(capabilities, storage);
-  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
-    compositeType.getCapabilities(capabilities, storage);
-  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
-    imageType.getCapabilities(capabilities, storage);
-  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
-    sampledImageType.getCapabilities(capabilities, storage);
-  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
-    matrixType.getCapabilities(capabilities, storage);
-  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
-    ptrType.getCapabilities(capabilities, storage);
-  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
-    tensorArmType.getCapabilities(capabilities, storage);
-  } else {
-    llvm_unreachable("invalid SPIR-V Type to getCapabilities");
-  }
+  TypeCapabilityVisitor{capabilities, storage}.add(*this);
 }
 
 std::optional<int64_t> SPIRVType::getSizeInBytes() {
@@ -814,12 +815,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-void SampledImageType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
-}
-
 //===----------------------------------------------------------------------===//
 // StructType
 //===----------------------------------------------------------------------===//
@@ -1172,13 +1167,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
                       structDecorations);
 }
 
-void StructType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  for (Type elementType : getElementTypes())
-    llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
-}
-
 llvm::hash_code spirv::hash_value(
     const StructType::MemberDecorationInfo &memberDeco...
[truncated]

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % small inconsistency in formatting

@kuhar kuhar enabled auto-merge (squash) September 22, 2025 15:01
@kuhar kuhar merged commit ca7c058 into llvm:main Sep 22, 2025
10 checks passed
SeongjaeP pushed a commit to SeongjaeP/llvm-project that referenced this pull request Sep 23, 2025
* Fix infinite recursion with nested structs.
* Drop `::getCapbilities` function from derived types, so that there's
only one entry point that queries type extensions.
* Move all capability logic to a new helper class -- this way the
`::getCapabilities` functions can't diverge across concrete types and
'convenience types' like CompositeType.

Fixes: llvm#159963
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] crashes with -convert-scf-to-spirv on SPIRV recursive struct
3 participants