-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Rework type extension queries #160020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* Fix infinite recursion with nested structs. * Move all extension logic to a new helper class -- this way `::getExtensions` functions can't diverge across concrete types and 'convenience types' like `CompositeType`. We should also fix `::getCapabilities` in a similar way and move the testcase to `vce-deduction.mlir`. Issue: llvm#159963
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) Changes
We should also fix Issue: #159963 Full diff: https://github.com/llvm/llvm-project/pull/160020.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d890dac96b118..85250700b9bf9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -14,14 +14,73 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
#include <cstdint>
using namespace mlir;
using namespace mlir::spirv;
+namespace {
+// Helper function to collect extensions implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type extension information. All extension
+// logic should be added to this class, while
+// `*Type::getExtensions` functions should not handle extension-related logic
+// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
+class TypeExtensionVisitor {
+ SPIRVType::ExtensionArrayRefVector &extensions;
+ std::optional<StorageClass> storage;
+ DenseSet<Type> seen;
+
+public:
+ TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage)
+ : extensions(extensions), storage(storage) {}
+
+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+ // as seen and dispatches to the right concrete `.add` function.
+ void add(SPIRVType type) {
+ if (auto [_it, inserted] = seen.insert(type); !inserted)
+ return;
+
+ TypeSwitch<SPIRVType>(type)
+ .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType,
+ VectorType, ArrayType, RuntimeArrayType, StructType, MatrixType,
+ ImageType, SampledImageType>(
+ [this](auto concreteType) { add(concreteType); })
+ .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ }
+
+ // Convenience overloads for use in `T::getExtensions` functions.
+ void add(Type type) { add(cast<SPIRVType>(type)); }
+ void add(Type *type) { add(cast<SPIRVType>(*type)); }
+
+ // Types that add unique extensions.
+ void add(ScalarType type);
+ void add(PointerType type);
+ void add(CooperativeMatrixType type);
+ void add(TensorArmType type);
+
+ // Trivial passthrough without any new extensions.
+ void add(VectorType type) { add(type.getElementType()); }
+ void add(ArrayType type) { add(type.getElementType()); }
+ void add(RuntimeArrayType type) { add(type.getElementType()); }
+ void add(StructType type) {
+ for (Type elementType : type.getElementTypes())
+ add(elementType);
+ }
+ void add(MatrixType type) { add(type.getElementType()); }
+ void add(ImageType type) { add(type.getElementType()); }
+ void add(SampledImageType type) { add(type.getImageType()); }
+};
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
@@ -67,7 +126,7 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void ArrayType::getCapabilities(
@@ -143,22 +202,7 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
- StructType>(
- [&](auto type) { type.getExtensions(extensions, storage); })
- .Case<VectorType>([&](VectorType type) {
- return llvm::cast<ScalarType>(type.getElementType())
- .getExtensions(extensions, storage);
- })
- .Case<TensorArmType>([&](TensorArmType type) {
- static constexpr Extension ext{Extension::SPV_ARM_tensors};
- extensions.push_back(ext);
- return llvm::cast<ScalarType>(type.getElementType())
- .getExtensions(extensions, storage);
- })
-
- .Default([](Type) { llvm_unreachable("invalid composite type"); });
+ TypeExtensionVisitor{extensions, storage}.add(cast<SPIRVType>(*this));
}
void CompositeType::getCapabilities(
@@ -284,12 +328,16 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
return getImpl()->use;
}
+void TypeExtensionVisitor::add(CooperativeMatrixType type) {
+ add(type.getElementType());
+ static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
+ extensions.push_back(ext);
+}
+
void CooperativeMatrixType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
- extensions.push_back(exts);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void CooperativeMatrixType::getCapabilities(
@@ -403,9 +451,9 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
-void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
- std::optional<StorageClass>) {
- // Image types do not require extra extensions thus far.
+void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void ImageType::getCapabilities(
@@ -454,17 +502,23 @@ StorageClass PointerType::getStorageClass() const {
return getImpl()->storageClass;
}
-void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
+void TypeExtensionVisitor::add(PointerType type) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
- llvm::cast<SPIRVType>(getPointeeType())
- .getExtensions(extensions, getStorageClass());
+ std::optional<StorageClass> oldStorageClass = storage;
+ storage = type.getStorageClass();
+ add(type.getPointeeType());
+ storage = oldStorageClass;
- if (auto scExts = spirv::getExtensions(getStorageClass()))
+ if (auto scExts = spirv::getExtensions(type.getStorageClass()))
extensions.push_back(*scExts);
}
+void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+ TypeExtensionVisitor{extensions, storage}.add(this);
+}
+
void PointerType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -516,7 +570,7 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
void RuntimeArrayType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void RuntimeArrayType::getCapabilities(
@@ -553,10 +607,9 @@ bool ScalarType::isValid(IntegerType type) {
return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
}
-void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- if (isa<BFloat16Type>(*this)) {
- static const Extension ext = Extension::SPV_KHR_bfloat16;
+void TypeExtensionVisitor::add(ScalarType type) {
+ if (isa<BFloat16Type>(type)) {
+ static constexpr auto ext = Extension::SPV_KHR_bfloat16;
extensions.push_back(ext);
}
@@ -570,18 +623,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
case StorageClass::PushConstant:
case StorageClass::StorageBuffer:
case StorageClass::Uniform:
- if (getIntOrFloatBitWidth() == 8) {
- static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
+ if (type.getIntOrFloatBitWidth() == 8) {
+ static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
+ extensions.push_back(ext);
}
[[fallthrough]];
case StorageClass::Input:
case StorageClass::Output:
- if (getIntOrFloatBitWidth() == 16) {
- static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
+ if (type.getIntOrFloatBitWidth() == 16) {
+ static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
+ extensions.push_back(ext);
}
break;
default:
@@ -589,6 +640,11 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
}
}
+void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+ TypeExtensionVisitor{extensions, storage}.add(this);
+}
+
void ScalarType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -722,23 +778,7 @@ bool SPIRVType::isScalarOrVector() {
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
- scalarType.getExtensions(extensions, storage);
- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
- compositeType.getExtensions(extensions, storage);
- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
- imageType.getExtensions(extensions, storage);
- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
- sampledImageType.getExtensions(extensions, storage);
- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
- matrixType.getExtensions(extensions, storage);
- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
- ptrType.getExtensions(extensions, storage);
- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
- tensorArmType.getExtensions(extensions, storage);
- } else {
- llvm_unreachable("invalid SPIR-V Type to getExtensions");
- }
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void SPIRVType::getCapabilities(
@@ -821,7 +861,7 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
void SampledImageType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void SampledImageType::getCapabilities(
@@ -1184,8 +1224,7 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- for (Type elementType : getElementTypes())
- llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void StructType::getCapabilities(
@@ -1289,7 +1328,7 @@ unsigned MatrixType::getNumElements() const {
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void MatrixType::getCapabilities(
@@ -1347,13 +1386,16 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type TensorArmType::getElementType() const { return getImpl()->elementType; }
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+void TypeExtensionVisitor::add(TensorArmType type) {
+ add(type.getElementType());
+ static constexpr auto ext = Extension::SPV_ARM_tensors;
+ extensions.push_back(ext);
+}
+
void TensorArmType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
-
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static constexpr Extension ext{Extension::SPV_ARM_tensors};
- extensions.push_back(ext);
+ TypeExtensionVisitor{extensions, storage}.add(this);
}
void TensorArmType::getCapabilities(
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index 71bf2f3d918e8..d24f37b553bb5 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
// `scf.parallel` conversion is not supported yet.
// Make sure that we do not accidentally invalidate this function by removing
@@ -19,3 +19,14 @@ func.func @func(%arg0: i64) {
}
return
}
+
+// -----
+
+// Make sure we don't crash on recursive structs.
+// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
+
+// expected-error@below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
+spirv.module Physical64 GLSL450 {
+ spirv.GlobalVariable @recursive:
+ !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have a default implementation of ::getExtensions()
in SPIRVType
class. This way we avoid repeating the same implementation in each inherited type, as they can just use the parent default. I think this also would help enforce that ::getExtensions()
is not abused as any custom implementation of ::getExtensions()
would have to be explicit.
Good idea, done. I wanted to preserve the existing API but didn't realize that all these |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
::getExtensions
function from derived types, so that there's only one entry point that queries type extensions.::getExtensions
functions can't diverge across concrete types and 'convenience types' likeCompositeType
.We should also fix
::getCapabilities
in a similar way and move the testcase tovce-deduction.mlir
.Issue: #159963