diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 531feccccb032..6beffc17d6d58 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -89,8 +89,6 @@ class ScalarType : public SPIRVType { /// Returns true if the given float type is valid for the SPIR-V dialect. static bool isValid(IntegerType); - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); @@ -118,8 +116,6 @@ class CompositeType : public SPIRVType { /// implementation dependent. bool hasCompileTimeKnownNumElements() const; - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); @@ -148,8 +144,6 @@ class ArrayType : public Type::TypeBase storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); @@ -193,8 +187,6 @@ class ImageType ImageFormat getImageFormat() const; // TODO: Add support for Access qualifier - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); }; @@ -213,8 +205,6 @@ class PointerType : public Type::TypeBase storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); }; @@ -239,8 +229,6 @@ class RuntimeArrayType /// type. unsigned getArrayStride() const; - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); }; @@ -265,8 +253,6 @@ class SampledImageType Type getImageType() const; - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); @@ -420,8 +406,6 @@ class StructType ArrayRef memberDecorations = {}, ArrayRef structDecorations = {}); - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); }; @@ -456,8 +440,6 @@ class CooperativeMatrixType /// Returns the use parameter of the cooperative matrix. CooperativeMatrixUseKHR getUse() const; - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); @@ -512,8 +494,6 @@ class MatrixType : public Type::TypeBase storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); }; @@ -552,8 +532,6 @@ class TensorArmType bool hasRank() const { return !getShape().empty(); } operator ShapedType() const { return llvm::cast(*this); } - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); }; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index d890dac96b118..8244e64abba12 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -14,14 +14,67 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace mlir; using namespace mlir::spirv; +namespace { +// Helper function to collect extensions implied by a type by visiting all its +// subtypes. Maintains a set of `seen` types to avoid recursion in structs. +// +// Serves as the source-of-truth for type extension information. All extension +// logic should be added to this class, while the +// `SPIRVType::getExtensions` function should not handle extension-related logic +// directly and only invoke `TypeExtensionVisitor::add(Type *)`. +class TypeExtensionVisitor { +public: + TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage) + : extensions(extensions), storage(storage) {} + + // Main visitor entry point. Adds all extensions to the vector. Saves `type` + // as seen and dispatches to the right concrete `.add` function. + void add(SPIRVType type) { + if (auto [_it, inserted] = seen.insert({type, storage}); !inserted) + return; + + TypeSwitch(type) + .Case( + [this](auto concreteType) { addConcrete(concreteType); }) + .Case( + [this](auto concreteType) { add(concreteType.getElementType()); }) + .Case([this](StructType concreteType) { + for (Type elementType : concreteType.getElementTypes()) + add(elementType); + }) + .Case([this](SampledImageType concreteType) { + add(concreteType.getImageType()); + }) + .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); }); + } + + void add(Type type) { add(cast(type)); } + +private: + // Types that add unique extensions. + void addConcrete(ScalarType type); + void addConcrete(PointerType type); + void addConcrete(CooperativeMatrixType type); + void addConcrete(TensorArmType type); + + SPIRVType::ExtensionArrayRefVector &extensions; + std::optional storage; + llvm::SmallDenseSet>> seen; +}; + +} // namespace + //===----------------------------------------------------------------------===// // ArrayType //===----------------------------------------------------------------------===// @@ -65,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; } unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } -void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - llvm::cast(getElementType()).getExtensions(extensions, storage); -} - void ArrayType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { @@ -140,27 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const { return !llvm::isa(*this); } -void CompositeType::getExtensions( - SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - TypeSwitch(*this) - .Case( - [&](auto type) { type.getExtensions(extensions, storage); }) - .Case([&](VectorType type) { - return llvm::cast(type.getElementType()) - .getExtensions(extensions, storage); - }) - .Case([&](TensorArmType type) { - static constexpr Extension ext{Extension::SPV_ARM_tensors}; - extensions.push_back(ext); - return llvm::cast(type.getElementType()) - .getExtensions(extensions, storage); - }) - - .Default([](Type) { llvm_unreachable("invalid composite type"); }); -} - void CompositeType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { @@ -284,12 +311,10 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const { return getImpl()->use; } -void CooperativeMatrixType::getExtensions( - SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - llvm::cast(getElementType()).getExtensions(extensions, storage); - static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix}; - extensions.push_back(exts); +void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) { + add(type.getElementType()); + static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix; + extensions.push_back(ext); } void CooperativeMatrixType::getCapabilities( @@ -403,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { ImageFormat ImageType::getImageFormat() const { return getImpl()->format; } -void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &, - std::optional) { - // Image types do not require extra extensions thus far. -} - void ImageType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional) { @@ -454,14 +474,15 @@ StorageClass PointerType::getStorageClass() const { return getImpl()->storageClass; } -void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { +void TypeExtensionVisitor::addConcrete(PointerType type) { // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. - llvm::cast(getPointeeType()) - .getExtensions(extensions, getStorageClass()); + std::optional oldStorageClass = storage; + storage = type.getStorageClass(); + add(type.getPointeeType()); + storage = oldStorageClass; - if (auto scExts = spirv::getExtensions(getStorageClass())) + if (auto scExts = spirv::getExtensions(type.getStorageClass())) extensions.push_back(*scExts); } @@ -513,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } -void RuntimeArrayType::getExtensions( - SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - llvm::cast(getElementType()).getExtensions(extensions, storage); -} - void RuntimeArrayType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { @@ -553,10 +568,9 @@ bool ScalarType::isValid(IntegerType type) { return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth()); } -void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - if (isa(*this)) { - static const Extension ext = Extension::SPV_KHR_bfloat16; +void TypeExtensionVisitor::addConcrete(ScalarType type) { + if (isa(type)) { + static constexpr auto ext = Extension::SPV_KHR_bfloat16; extensions.push_back(ext); } @@ -570,18 +584,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, case StorageClass::PushConstant: case StorageClass::StorageBuffer: case StorageClass::Uniform: - if (getIntOrFloatBitWidth() == 8) { - static const Extension exts[] = {Extension::SPV_KHR_8bit_storage}; - ArrayRef ref(exts, std::size(exts)); - extensions.push_back(ref); + if (type.getIntOrFloatBitWidth() == 8) { + static constexpr auto ext = Extension::SPV_KHR_8bit_storage; + extensions.push_back(ext); } [[fallthrough]]; case StorageClass::Input: case StorageClass::Output: - if (getIntOrFloatBitWidth() == 16) { - static const Extension exts[] = {Extension::SPV_KHR_16bit_storage}; - ArrayRef ref(exts, std::size(exts)); - extensions.push_back(ref); + if (type.getIntOrFloatBitWidth() == 16) { + static constexpr auto ext = Extension::SPV_KHR_16bit_storage; + extensions.push_back(ext); } break; default: @@ -722,23 +734,7 @@ bool SPIRVType::isScalarOrVector() { void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { - if (auto scalarType = llvm::dyn_cast(*this)) { - scalarType.getExtensions(extensions, storage); - } else if (auto compositeType = llvm::dyn_cast(*this)) { - compositeType.getExtensions(extensions, storage); - } else if (auto imageType = llvm::dyn_cast(*this)) { - imageType.getExtensions(extensions, storage); - } else if (auto sampledImageType = llvm::dyn_cast(*this)) { - sampledImageType.getExtensions(extensions, storage); - } else if (auto matrixType = llvm::dyn_cast(*this)) { - matrixType.getExtensions(extensions, storage); - } else if (auto ptrType = llvm::dyn_cast(*this)) { - ptrType.getExtensions(extensions, storage); - } else if (auto tensorArmType = llvm::dyn_cast(*this)) { - tensorArmType.getExtensions(extensions, storage); - } else { - llvm_unreachable("invalid SPIR-V Type to getExtensions"); - } + TypeExtensionVisitor{extensions, storage}.add(*this); } void SPIRVType::getCapabilities( @@ -818,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref emitError, return success(); } -void SampledImageType::getExtensions( - SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - llvm::cast(getImageType()).getExtensions(extensions, storage); -} - void SampledImageType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { @@ -1182,12 +1172,6 @@ StructType::trySetBody(ArrayRef memberTypes, structDecorations); } -void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - for (Type elementType : getElementTypes()) - llvm::cast(elementType).getExtensions(extensions, storage); -} - void StructType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { @@ -1287,11 +1271,6 @@ unsigned MatrixType::getNumElements() const { return (getImpl()->columnCount) * getNumRows(); } -void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - llvm::cast(getColumnType()).getExtensions(extensions, storage); -} - void MatrixType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { @@ -1347,12 +1326,9 @@ TensorArmType TensorArmType::cloneWith(std::optional> shape, Type TensorArmType::getElementType() const { return getImpl()->elementType; } ArrayRef TensorArmType::getShape() const { return getImpl()->shape; } -void TensorArmType::getExtensions( - SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - - llvm::cast(getElementType()).getExtensions(extensions, storage); - static constexpr Extension ext{Extension::SPV_ARM_tensors}; +void TypeExtensionVisitor::addConcrete(TensorArmType type) { + add(type.getElementType()); + static constexpr auto ext = Extension::SPV_ARM_tensors; extensions.push_back(ext); } diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir index 71bf2f3d918e8..d24f37b553bb5 100644 --- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s // `scf.parallel` conversion is not supported yet. // Make sure that we do not accidentally invalidate this function by removing @@ -19,3 +19,14 @@ func.func @func(%arg0: i64) { } return } + +// ----- + +// Make sure we don't crash on recursive structs. +// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase. + +// expected-error@below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}} +spirv.module Physical64 GLSL450 { + spirv.GlobalVariable @recursive: + !spirv.ptr, StorageBuffer>)>, StorageBuffer> +}