diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 50537f6c9abed..e415061768fe4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ +#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -73,7 +74,8 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType); /// type. class LLVMArrayType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + SubElementTypeInterface::Trait> { public: /// Inherit base constructors. using Base::Base; @@ -111,6 +113,9 @@ class LLVMArrayType unsigned getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const; + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -120,9 +125,9 @@ class LLVMArrayType /// LLVM dialect function type. It consists of a single return type (unlike MLIR /// which can have multiple), a list of parameter types and can optionally be /// variadic. -class LLVMFunctionType - : public Type::TypeBase { +class LLVMFunctionType : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; @@ -150,11 +155,11 @@ class LLVMFunctionType LLVMFunctionType clone(TypeRange inputs, TypeRange results) const; /// Returns the result type of the function. - Type getReturnType(); + Type getReturnType() const; /// Returns the result type of the function as an ArrayRef, enabling better /// integration with generic MLIR utilities. - ArrayRef getReturnTypes(); + ArrayRef getReturnTypes() const; /// Returns the number of arguments to the function. unsigned getNumParams(); @@ -163,12 +168,15 @@ class LLVMFunctionType Type getParamType(unsigned i); /// Returns a list of argument types of the function. - ArrayRef getParams(); + ArrayRef getParams() const; ArrayRef params() { return getParams(); } /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type result, ArrayRef arguments, bool); + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -179,9 +187,10 @@ class LLVMFunctionType /// object in memory. Pointers may be opaque or parameterized by the element /// type. Both opaque and non-opaque pointers are additionally parameterized by /// the address space. -class LLVMPointerType : public Type::TypeBase { +class LLVMPointerType + : public Type::TypeBase< + LLVMPointerType, Type, detail::LLVMPointerTypeStorage, + DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> { public: /// Inherit base constructors. using Base::Base; @@ -232,6 +241,9 @@ class LLVMPointerType : public Type::TypeBase walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -265,6 +277,7 @@ class LLVMPointerType : public Type::TypeBase { public: /// Inherit base constructors. @@ -359,6 +372,9 @@ class LLVMStructType LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -369,7 +385,8 @@ class LLVMStructType /// length that can be processed as one. class LLVMFixedVectorType : public Type::TypeBase { + detail::LLVMTypeAndSizeStorage, + SubElementTypeInterface::Trait> { public: /// Inherit base constructor. using Base::Base; @@ -388,7 +405,7 @@ class LLVMFixedVectorType static bool isValidElementType(Type type); /// Returns the element type of the vector. - Type getElementType(); + Type getElementType() const; /// Returns the number of elements in the fixed vector. unsigned getNumElements(); @@ -396,6 +413,9 @@ class LLVMFixedVectorType /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type elementType, unsigned numElements); + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -407,7 +427,8 @@ class LLVMFixedVectorType /// elements can be processed as one in SIMD context. class LLVMScalableVectorType : public Type::TypeBase { + detail::LLVMTypeAndSizeStorage, + SubElementTypeInterface::Trait> { public: /// Inherit base constructor. using Base::Base; @@ -424,7 +445,7 @@ class LLVMScalableVectorType static bool isValidElementType(Type type); /// Returns the element type of the vector. - Type getElementType(); + Type getElementType() const; /// Returns the scaling factor of the number of elements in the vector. The /// vector contains at least the resulting number of elements, or any non-zero @@ -434,6 +455,9 @@ class LLVMScalableVectorType /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type elementType, unsigned minNumElements); + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index b02d53a2efae7..49d2d8d24963b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -86,6 +86,12 @@ LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout, return dataLayout.getTypePreferredAlignment(getElementType()); } +void LLVMArrayType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // Function type. //===----------------------------------------------------------------------===// @@ -119,8 +125,10 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs, return get(results[0], llvm::to_vector(inputs), isVarArg()); } -Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); } -ArrayRef LLVMFunctionType::getReturnTypes() { +Type LLVMFunctionType::getReturnType() const { + return getImpl()->getReturnType(); +} +ArrayRef LLVMFunctionType::getReturnTypes() const { return getImpl()->getReturnType(); } @@ -134,7 +142,7 @@ Type LLVMFunctionType::getParamType(unsigned i) { bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); } -ArrayRef LLVMFunctionType::getParams() { +ArrayRef LLVMFunctionType::getParams() const { return getImpl()->getArgumentTypes(); } @@ -151,6 +159,13 @@ LLVMFunctionType::verify(function_ref emitError, return success(); } +void LLVMFunctionType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : llvm::concat(getReturnTypes(), getParams())) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // Pointer type. //===----------------------------------------------------------------------===// @@ -353,6 +368,12 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries, return success(); } +void LLVMPointerType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // Struct type. //===----------------------------------------------------------------------===// @@ -589,6 +610,13 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, return mlir::success(); } +void LLVMStructType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : getBody()) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // Vector types. //===----------------------------------------------------------------------===// @@ -621,7 +649,7 @@ LLVMFixedVectorType::getChecked(function_ref emitError, numElements); } -Type LLVMFixedVectorType::getElementType() { +Type LLVMFixedVectorType::getElementType() const { return static_cast(impl)->elementType; } @@ -640,6 +668,12 @@ LLVMFixedVectorType::verify(function_ref emitError, emitError, elementType, numElements); } +void LLVMFixedVectorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // LLVMScalableVectorType. //===----------------------------------------------------------------------===// @@ -658,7 +692,7 @@ LLVMScalableVectorType::getChecked(function_ref emitError, minNumElements); } -Type LLVMScalableVectorType::getElementType() { +Type LLVMScalableVectorType::getElementType() const { return static_cast(impl)->elementType; } @@ -680,6 +714,12 @@ LLVMScalableVectorType::verify(function_ref emitError, emitError, elementType, numElements); } +void LLVMScalableVectorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp index 9c0ea4f14d766..75c6fd004e3d4 100644 --- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -18,3 +18,46 @@ TEST_F(LLVMIRTest, IsStructTypeMutable) { ASSERT_TRUE(bool(structTy)); ASSERT_TRUE(structTy.hasTrait()); } + +TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) { + auto fooStructTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(fooStructTy)); + auto barStructTy = LLVMStructType::getIdentified(&context, "bar"); + ASSERT_TRUE(bool(barStructTy)); + + // Created two structs that are referencing each other. + Type fooBody[] = {LLVMPointerType::get(barStructTy)}; + ASSERT_TRUE(succeeded(fooStructTy.setBody(fooBody, /*packed=*/false))); + Type barBody[] = {LLVMPointerType::get(fooStructTy)}; + ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*packed=*/false))); + + auto subElementInterface = fooStructTy.dyn_cast(); + ASSERT_TRUE(bool(subElementInterface)); + // Test if walkSubElements goes into infinite loops. + SmallVector subElementTypes; + subElementInterface.walkSubElements( + [](Attribute attr) {}, + [&](Type type) { subElementTypes.push_back(type); }); + // We don't record LLVMPointerType (because it's immutable), thus + // !llvm.ptr> will be visited twice. + ASSERT_EQ(subElementTypes.size(), 5U); + + // !llvm.ptr> + ASSERT_TRUE(subElementTypes[0].isa()); + + // !llvm.struct<"foo",...> + auto structType = subElementTypes[1].dyn_cast(); + ASSERT_TRUE(bool(structType)); + ASSERT_TRUE(structType.getName().equals("foo")); + + // !llvm.ptr> + ASSERT_TRUE(subElementTypes[2].isa()); + + // !llvm.struct<"bar",...> + structType = subElementTypes[3].dyn_cast(); + ASSERT_TRUE(bool(structType)); + ASSERT_TRUE(structType.getName().equals("bar")); + + // !llvm.ptr> + ASSERT_TRUE(subElementTypes[4].isa()); +}