Skip to content

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 22, 2025

  • Fix infinite recursion with nested structs.
  • Drop ::getExtensions function from derived types, so that there's only one entry point that queries type extensions.
  • Move all extension logic to a new helper class -- this way the ::getExtensions functions can't diverge across concrete types and 'convenience types' like CompositeType.

We should also fix ::getCapabilities in a similar way and move the testcase to vce-deduction.mlir.

Issue: #159963

* Fix infinite recursion with nested structs.
* Move all extension logic to a new helper class -- this way
  `::getExtensions` functions can't diverge across concrete types and
  'convenience types' like `CompositeType`.

We should also fix `::getCapabilities` in a similar way and move the
testcase to `vce-deduction.mlir`.

Issue: llvm#159963
@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes
  • Fix infinite recursion with nested structs.
  • Move all extension logic to a new helper class -- this way ::getExtensions functions can't diverge across concrete types and 'convenience types' like CompositeType.

We should also fix ::getCapabilities in a similar way and move the testcase to vce-deduction.mlir.

Issue: #159963


Full diff: https://github.com/llvm/llvm-project/pull/160020.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+108-66)
  • (modified) mlir/test/Conversion/SCFToSPIRV/unsupported.mlir (+12-1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d890dac96b118..85250700b9bf9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -14,14 +14,73 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <cstdint>
 
 using namespace mlir;
 using namespace mlir::spirv;
 
+namespace {
+// Helper function to collect extensions implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type extension information. All extension
+// logic should be added to this class, while
+// `*Type::getExtensions` functions should not handle extension-related logic
+// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
+class TypeExtensionVisitor {
+  SPIRVType::ExtensionArrayRefVector &extensions;
+  std::optional<StorageClass> storage;
+  DenseSet<Type> seen;
+
+public:
+  TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
+                       std::optional<StorageClass> storage)
+      : extensions(extensions), storage(storage) {}
+
+  // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+  // as seen and dispatches to the right concrete `.add` function.
+  void add(SPIRVType type) {
+    if (auto [_it, inserted] = seen.insert(type); !inserted)
+      return;
+
+    TypeSwitch<SPIRVType>(type)
+        .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType,
+              VectorType, ArrayType, RuntimeArrayType, StructType, MatrixType,
+              ImageType, SampledImageType>(
+            [this](auto concreteType) { add(concreteType); })
+        .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+  }
+
+  // Convenience overloads for use in `T::getExtensions` functions.
+  void add(Type type) { add(cast<SPIRVType>(type)); }
+  void add(Type *type) { add(cast<SPIRVType>(*type)); }
+
+  // Types that add unique extensions.
+  void add(ScalarType type);
+  void add(PointerType type);
+  void add(CooperativeMatrixType type);
+  void add(TensorArmType type);
+
+  // Trivial passthrough without any new extensions.
+  void add(VectorType type) { add(type.getElementType()); }
+  void add(ArrayType type) { add(type.getElementType()); }
+  void add(RuntimeArrayType type) { add(type.getElementType()); }
+  void add(StructType type) {
+    for (Type elementType : type.getElementTypes())
+      add(elementType);
+  }
+  void add(MatrixType type) { add(type.getElementType()); }
+  void add(ImageType type) { add(type.getElementType()); }
+  void add(SampledImageType type) { add(type.getImageType()); }
+};
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ArrayType
 //===----------------------------------------------------------------------===//
@@ -67,7 +126,7 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 
 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                               std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void ArrayType::getCapabilities(
@@ -143,22 +202,7 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
-            StructType>(
-          [&](auto type) { type.getExtensions(extensions, storage); })
-      .Case<VectorType>([&](VectorType type) {
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getExtensions(extensions, storage);
-      })
-      .Case<TensorArmType>([&](TensorArmType type) {
-        static constexpr Extension ext{Extension::SPV_ARM_tensors};
-        extensions.push_back(ext);
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getExtensions(extensions, storage);
-      })
-
-      .Default([](Type) { llvm_unreachable("invalid composite type"); });
+  TypeExtensionVisitor{extensions, storage}.add(cast<SPIRVType>(*this));
 }
 
 void CompositeType::getCapabilities(
@@ -284,12 +328,16 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
   return getImpl()->use;
 }
 
+void TypeExtensionVisitor::add(CooperativeMatrixType type) {
+  add(type.getElementType());
+  static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
+  extensions.push_back(ext);
+}
+
 void CooperativeMatrixType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
-  extensions.push_back(exts);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void CooperativeMatrixType::getCapabilities(
@@ -403,9 +451,9 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
 
 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
 
-void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
-                              std::optional<StorageClass>) {
-  // Image types do not require extra extensions thus far.
+void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                              std::optional<StorageClass> storage) {
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void ImageType::getCapabilities(
@@ -454,17 +502,23 @@ StorageClass PointerType::getStorageClass() const {
   return getImpl()->storageClass;
 }
 
-void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
-                                std::optional<StorageClass> storage) {
+void TypeExtensionVisitor::add(PointerType type) {
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
-  llvm::cast<SPIRVType>(getPointeeType())
-      .getExtensions(extensions, getStorageClass());
+  std::optional<StorageClass> oldStorageClass = storage;
+  storage = type.getStorageClass();
+  add(type.getPointeeType());
+  storage = oldStorageClass;
 
-  if (auto scExts = spirv::getExtensions(getStorageClass()))
+  if (auto scExts = spirv::getExtensions(type.getStorageClass()))
     extensions.push_back(*scExts);
 }
 
+void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                                std::optional<StorageClass> storage) {
+  TypeExtensionVisitor{extensions, storage}.add(this);
+}
+
 void PointerType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
@@ -516,7 +570,7 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
 void RuntimeArrayType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void RuntimeArrayType::getCapabilities(
@@ -553,10 +607,9 @@ bool ScalarType::isValid(IntegerType type) {
   return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
 }
 
-void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
-                               std::optional<StorageClass> storage) {
-  if (isa<BFloat16Type>(*this)) {
-    static const Extension ext = Extension::SPV_KHR_bfloat16;
+void TypeExtensionVisitor::add(ScalarType type) {
+  if (isa<BFloat16Type>(type)) {
+    static constexpr auto ext = Extension::SPV_KHR_bfloat16;
     extensions.push_back(ext);
   }
 
@@ -570,18 +623,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
   case StorageClass::PushConstant:
   case StorageClass::StorageBuffer:
   case StorageClass::Uniform:
-    if (getIntOrFloatBitWidth() == 8) {
-      static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
-      ArrayRef<Extension> ref(exts, std::size(exts));
-      extensions.push_back(ref);
+    if (type.getIntOrFloatBitWidth() == 8) {
+      static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
+      extensions.push_back(ext);
     }
     [[fallthrough]];
   case StorageClass::Input:
   case StorageClass::Output:
-    if (getIntOrFloatBitWidth() == 16) {
-      static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
-      ArrayRef<Extension> ref(exts, std::size(exts));
-      extensions.push_back(ref);
+    if (type.getIntOrFloatBitWidth() == 16) {
+      static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
+      extensions.push_back(ext);
     }
     break;
   default:
@@ -589,6 +640,11 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
   }
 }
 
+void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                               std::optional<StorageClass> storage) {
+  TypeExtensionVisitor{extensions, storage}.add(this);
+}
+
 void ScalarType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
@@ -722,23 +778,7 @@ bool SPIRVType::isScalarOrVector() {
 
 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                               std::optional<StorageClass> storage) {
-  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
-    scalarType.getExtensions(extensions, storage);
-  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
-    compositeType.getExtensions(extensions, storage);
-  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
-    imageType.getExtensions(extensions, storage);
-  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
-    sampledImageType.getExtensions(extensions, storage);
-  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
-    matrixType.getExtensions(extensions, storage);
-  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
-    ptrType.getExtensions(extensions, storage);
-  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
-    tensorArmType.getExtensions(extensions, storage);
-  } else {
-    llvm_unreachable("invalid SPIR-V Type to getExtensions");
-  }
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void SPIRVType::getCapabilities(
@@ -821,7 +861,7 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 void SampledImageType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void SampledImageType::getCapabilities(
@@ -1184,8 +1224,7 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
 
 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
-  for (Type elementType : getElementTypes())
-    llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void StructType::getCapabilities(
@@ -1289,7 +1328,7 @@ unsigned MatrixType::getNumElements() const {
 
 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void MatrixType::getCapabilities(
@@ -1347,13 +1386,16 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
 ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
 
+void TypeExtensionVisitor::add(TensorArmType type) {
+  add(type.getElementType());
+  static constexpr auto ext = Extension::SPV_ARM_tensors;
+  extensions.push_back(ext);
+}
+
 void TensorArmType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static constexpr Extension ext{Extension::SPV_ARM_tensors};
-  extensions.push_back(ext);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void TensorArmType::getCapabilities(
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index 71bf2f3d918e8..d24f37b553bb5 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
 
 // `scf.parallel` conversion is not supported yet.
 // Make sure that we do not accidentally invalidate this function by removing
@@ -19,3 +19,14 @@ func.func @func(%arg0: i64) {
   }
   return
 }
+
+// -----
+
+// Make sure we don't crash on recursive structs.
+// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
+
+// expected-error@below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
+spirv.module Physical64 GLSL450 {
+  spirv.GlobalVariable @recursive:
+    !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+}

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a default implementation of ::getExtensions() in SPIRVType class. This way we avoid repeating the same implementation in each inherited type, as they can just use the parent default. I think this also would help enforce that ::getExtensions() is not abused as any custom implementation of ::getExtensions() would have to be explicit.

@kuhar
Copy link
Member Author

kuhar commented Sep 22, 2025

Could we have a default implementation of ::getExtensions() in SPIRVType class. This way we avoid repeating the same implementation in each inherited type, as they can just use the parent default. I think this also would help enforce that ::getExtensions() is not abused as any custom implementation of ::getExtensions() would have to be explicit.

Good idea, done. I wanted to preserve the existing API but didn't realize that all these ::getExtensions functions in derived classes only shadow the main definition in SPIRVType.

@kuhar kuhar requested a review from IgWod-IMG September 22, 2025 12:02
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar kuhar merged commit 32b1f16 into llvm:main Sep 22, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants