diff --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.h b/flang/include/flang/Optimizer/Dialect/FIRDialect.h index fb82d520dbc265..4bafb4ab7fb68a 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRDialect.h +++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.h @@ -32,6 +32,12 @@ class FIROpsDialect final : public mlir::Dialect { mlir::Type type) const override; void printAttribute(mlir::Attribute attr, mlir::DialectAsmPrinter &p) const override; + +private: + // Register the Attributes of this dialect. + void registerAttributes(); + // Register the Types of this dialect. + void registerTypes(); }; } // namespace fir diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp index 035245dbe935cd..a2fdf7cd43d045 100644 --- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp +++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp @@ -243,3 +243,12 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, os << "<(unknown attribute)>"; } } + +//===----------------------------------------------------------------------===// +// FIROpsDialect +//===----------------------------------------------------------------------===// + +void FIROpsDialect::registerAttributes() { + addAttributes(); +} diff --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp index 889b5ef5536687..f80aa7d3380e8e 100644 --- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp +++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp @@ -19,13 +19,8 @@ using namespace fir; fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx) : mlir::Dialect("fir", ctx, mlir::TypeID::get()) { - addTypes(); - addAttributes(); + registerTypes(); + registerAttributes(); addOperations< #define GET_OP_LIST #include "flang/Optimizer/Dialect/FIROps.cpp.inc" diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index 873f589e8a4b02..beab54e4f1f88c 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -866,3 +866,15 @@ mlir::LogicalResult fir::VectorType::verify( bool fir::VectorType::isValidElementType(mlir::Type t) { return isa_real(t) || isa_integer(t); } + +//===----------------------------------------------------------------------===// +// FIROpsDialect +//===----------------------------------------------------------------------===// + +void FIROpsDialect::registerTypes() { + addTypes(); +} diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md index 6a261da8a6c27a..7942fa2e18689b 100644 --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -319,7 +319,9 @@ public: Once the dialect types have been defined, they must then be registered with a `Dialect`. This is done via a similar mechanism to -[operations](LangRef.md#operations), with the `addTypes` method. +[operations](LangRef.md#operations), with the `addTypes` method. The one +distinct difference with operations, is that when a type is registered the +definition of its storage class must be visible. ```c++ struct MyDialect : public Dialect { diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md index 7074521eb5f4d5..315cf3237b407e 100644 --- a/mlir/docs/Tutorials/Toy/Ch-7.md +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -187,6 +187,9 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) } ``` +(An important note here is that when registering a type, the definition of the +storage class must be visible.) + With this we can now use our `StructType` when generating MLIR from Toy. See examples/toy/Ch7/mlir/MLIRGen.cpp for more details. diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 8170b5b579cb8b..cbcf53f313c1f7 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -76,33 +76,6 @@ struct ToyInlinerInterface : public DialectInlinerInterface { } }; -//===----------------------------------------------------------------------===// -// ToyDialect -//===----------------------------------------------------------------------===// - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) - : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { - addOperations< -#define GET_OP_LIST -#include "toy/Ops.cpp.inc" - >(); - addInterfaces(); - addTypes(); -} - -mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, - mlir::Attribute value, - mlir::Type type, - mlir::Location loc) { - if (type.isa()) - return builder.create(loc, type, - value.cast()); - return builder.create(loc, type, - value.cast()); -} - //===----------------------------------------------------------------------===// // Toy Operations //===----------------------------------------------------------------------===// @@ -566,3 +539,30 @@ void ToyDialect::printType(mlir::Type type, #define GET_OP_CLASSES #include "toy/Ops.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); + addTypes(); +} + +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa()) + return builder.create(loc, type, + value.cast()); + return builder.create(loc, type, + value.cast()); +} diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td index c6d76be4849403..afdf50673ed4e7 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td @@ -64,6 +64,9 @@ def PDL_Dialect : Dialect { let name = "pdl"; let cppNamespace = "::mlir::pdl"; + let extraClassDeclaration = [{ + void registerTypes(); + }]; } #endif // MLIR_DIALECT_PDL_IR_PDLDIALECT diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index f18dcef1997a82..d293a6a88afd4d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -52,6 +52,9 @@ def SPIRV_Dialect : Dialect { let hasRegionResultAttrVerify = 1; let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + //===------------------------------------------------------------------===// // Attribute //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td index 383f87bd5d6089..a52257c9d1f48f 100644 --- a/mlir/include/mlir/IR/BuiltinDialect.td +++ b/mlir/include/mlir/IR/BuiltinDialect.td @@ -22,6 +22,17 @@ def Builtin_Dialect : Dialect { let name = ""; let cppNamespace = "::mlir"; + let extraClassDeclaration = [{ + private: + // Register the builtin Attributes. + void registerAttributes(); + // Register the builtin Location Attributes. + void registerLocationAttributes(); + // Register the builtin Types. + void registerTypes(); + + public: + }]; } #endif // BUILTIN_BASE diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h index fc7ffa74f3b5e8..2b66edb51ac99e 100644 --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -135,7 +135,13 @@ class StorageUniquer { /// instances of this class type. `id` is the type identifier that will be /// used to identify this type when creating instances of it via 'get'. template void registerParametricStorageType(TypeID id) { - registerParametricStorageTypeImpl(id); + // If the storage is trivially destructible, we don't need a destructor + // function. + if (std::is_trivially_destructible::value) + return registerParametricStorageTypeImpl(id, nullptr); + registerParametricStorageTypeImpl(id, [](BaseStorage *storage) { + static_cast(storage)->~Storage(); + }); } /// Utility override when the storage type represents the type id. template void registerParametricStorageType() { @@ -244,8 +250,10 @@ class StorageUniquer { function_ref ctorFn); /// Implementation for registering an instance of a derived type with - /// parametric storage. - void registerParametricStorageTypeImpl(TypeID id); + /// parametric storage. This method takes an optional destructor function that + /// destructs storage instances when necessary. + void registerParametricStorageTypeImpl( + TypeID id, function_ref destructorFn); /// Implementation for getting an instance of a derived type with default /// storage. diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp index a796b0725d3671..895c3b0f4734ec 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -20,6 +20,12 @@ using namespace mlir; +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" + void arm_sve::ArmSVEDialect::initialize() { addOperations< #define GET_OP_LIST @@ -31,12 +37,6 @@ void arm_sve::ArmSVEDialect::initialize() { >(); } -#define GET_OP_CLASSES -#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" - //===----------------------------------------------------------------------===// // ScalableVectorType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 02b4687b35a4ab..bfc90b2ac674bf 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "TypeDetail.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index e82c4ab6fb1661..beb43d7072f2ec 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -25,10 +25,7 @@ void PDLDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" - >(); + registerTypes(); } /// Returns true if the given operation is used by a "binding" pdl operation diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp index b16fade224fc1f..20f013af246fce 100644 --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -26,6 +26,13 @@ using namespace mlir::pdl; // PDLDialect //===----------------------------------------------------------------------===// +void PDLDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" + >(); +} + static Type parsePDLType(DialectAsmParser &parser) { StringRef typeTag; if (parser.parseKeyword(&typeTag)) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp index c74c34d88dd779..a514b44a899176 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -350,3 +351,11 @@ spirv::TargetEnvAttr::verify(function_ref emitError, return success(); } + +//===----------------------------------------------------------------------===// +// SPIR-V Dialect +//===----------------------------------------------------------------------===// + +void spirv::SPIRVDialect::registerAttributes() { + addAttributes(); +} diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index a3b639fa4e0579..81b3ee5e525f54 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -115,10 +115,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { //===----------------------------------------------------------------------===// void SPIRVDialect::initialize() { - addTypes(); - - addAttributes(); + registerAttributes(); + registerTypes(); // Add SPIR-V ops. addOperations< diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 2bfd9b8f084ff6..17ee2dfb0ec050 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -1154,3 +1154,12 @@ void MatrixType::getCapabilities( // Add any capabilities associated with the underlying vectors (i.e., columns) getColumnType().cast().getCapabilities(capabilities, storage); } + +//===----------------------------------------------------------------------===// +// SPIR-V Dialect +//===----------------------------------------------------------------------===// + +void SPIRVDialect::registerTypes() { + addTypes(); +} diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 8945253644d822..39d016b8d0fee9 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -100,36 +100,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind, return false; } -//===----------------------------------------------------------------------===// -// VectorDialect -//===----------------------------------------------------------------------===// - -void VectorDialect::initialize() { - addAttributes(); - - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/Vector/VectorOps.cpp.inc" - >(); -} - -/// Materialize a single constant operation from a given attribute value with -/// the desired resultant type. -Operation *VectorDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { - return builder.create(loc, type, value); -} - -IntegerType vector::getVectorSubscriptType(Builder &builder) { - return builder.getIntegerType(64); -} - -ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, - ArrayRef values) { - return builder.getI64ArrayAttr(values); -} - //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -230,6 +200,36 @@ void VectorDialect::printAttribute(Attribute attr, llvm_unreachable("Unknown attribute type"); } +//===----------------------------------------------------------------------===// +// VectorDialect +//===----------------------------------------------------------------------===// + +void VectorDialect::initialize() { + addAttributes(); + + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Vector/VectorOps.cpp.inc" + >(); +} + +/// Materialize a single constant operation from a given attribute value with +/// the desired resultant type. +Operation *VectorDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create(loc, type, value); +} + +IntegerType vector::getVectorSubscriptType(Builder &builder) { + return builder.getIntegerType(64); +} + +ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, + ArrayRef values) { + return builder.getI64ArrayAttr(values); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 1d76122996de46..5efb8f7c70fff7 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" @@ -28,6 +29,18 @@ using namespace mlir::detail; #define GET_ATTRDEF_CLASSES #include "mlir/IR/BuiltinAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialect +//===----------------------------------------------------------------------===// + +void BuiltinDialect::registerAttributes() { + addAttributes(); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index b19d541e50452e..28aef1500a00c4 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -60,17 +60,9 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface { } // end anonymous namespace. void BuiltinDialect::initialize() { - addTypes(); - addAttributes(); - addAttributes(); + registerTypes(); + registerAttributes(); + registerLocationAttributes(); addOperations< #define GET_OP_LIST #include "mlir/IR/BuiltinOps.cpp.inc" diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 652883f745e32f..758e16bf199993 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -30,6 +30,17 @@ using namespace mlir::detail; #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialect +//===----------------------------------------------------------------------===// + +void BuiltinDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/IR/BuiltinTypes.cpp.inc" + >(); +} + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// @@ -514,7 +525,7 @@ LogicalResult MemRefType::verify(function_ref emitError, if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; - // Negative sizes are not allowed except for `-1` that means dynamic size. + // Negative sizes are not allowed except for `-1` that means dynamic size. for (int64_t s : shape) if (s < -1) return emitError() << "invalid memref size"; diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp index 93a9d265209de5..cf730199e69399 100644 --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Location.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Identifier.h" #include "llvm/ADT/SetVector.h" @@ -20,6 +21,17 @@ using namespace mlir::detail; #define GET_ATTRDEF_CLASSES #include "mlir/IR/BuiltinLocationAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialect +//===----------------------------------------------------------------------===// + +void BuiltinDialect::registerLocationAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/IR/BuiltinLocationAttributes.cpp.inc" + >(); +} + //===----------------------------------------------------------------------===// // LocationAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp index 7a802430f010cd..e7805150e37d82 100644 --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -100,12 +100,23 @@ class ParametricStorageUniquer { return storage; } + /// Destroy all of the storage instances within the given shard. + void destroyShardInstances(Shard &shard) { + if (!destructorFn) + return; + for (HashedStorage &instance : shard.instances) + destructorFn(instance.storage); + } + public: #if LLVM_ENABLE_THREADS != 0 /// Initialize the storage uniquer with a given number of storage shards to - /// use. The provided shard number is required to be a valid power of 2. - ParametricStorageUniquer(size_t numShards = 8) - : shards(new std::atomic[numShards]), numShards(numShards) { + /// use. The provided shard number is required to be a valid power of 2. The + /// destructor function is used to destroy any allocated storage instances. + ParametricStorageUniquer(function_ref destructorFn, + size_t numShards = 8) + : shards(new std::atomic[numShards]), numShards(numShards), + destructorFn(destructorFn) { assert(llvm::isPowerOf2_64(numShards) && "the number of shards is required to be a power of 2"); for (size_t i = 0; i < numShards; i++) @@ -113,9 +124,12 @@ class ParametricStorageUniquer { } ~ParametricStorageUniquer() { // Free all of the allocated shards. - for (size_t i = 0; i != numShards; ++i) - if (Shard *shard = shards[i].load()) + for (size_t i = 0; i != numShards; ++i) { + if (Shard *shard = shards[i].load()) { + destroyShardInstances(*shard); delete shard; + } + } } /// Get or create an instance of a parametric type. BaseStorage * @@ -204,10 +218,17 @@ class ParametricStorageUniquer { /// The number of available shards. size_t numShards; + /// Function to used to destruct any allocated storage instances. + function_ref destructorFn; + #else /// If multi-threading is disabled, ignore the shard parameter as we will - /// always use one shard. - ParametricStorageUniquer(size_t numShards = 0) {} + /// always use one shard. The destructor function is used to destroy any + /// allocated storage instances. + ParametricStorageUniquer(function_ref destructorFn, + size_t numShards = 0) + : destructorFn(destructorFn) {} + ~ParametricStorageUniquer() { destroyShardInstances(shard); } /// Get or create an instance of a parametric type. BaseStorage * @@ -228,6 +249,9 @@ class ParametricStorageUniquer { private: /// The main uniquer shard that is used for allocating storage instances. Shard shard; + + /// Function to used to destruct any allocated storage instances. + function_ref destructorFn; #endif }; } // end anonymous namespace @@ -323,9 +347,10 @@ auto StorageUniquer::getParametricStorageTypeImpl( /// Implementation for registering an instance of a derived type with /// parametric storage. -void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) { +void StorageUniquer::registerParametricStorageTypeImpl( + TypeID id, function_ref destructorFn) { impl->parametricUniquers.try_emplace( - id, std::make_unique()); + id, std::make_unique(destructorFn)); } /// Implementation for getting an instance of a derived type with default diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index 72921a22e475e3..8dcc3498c96470 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -100,6 +100,13 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const { // TestDialect //===----------------------------------------------------------------------===// +void TestDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "TestAttrDefs.cpp.inc" + >(); +} + Attribute TestDialect::parseAttribute(DialectAsmParser &parser, Type type) const { StringRef attrTag; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 0244d10736234f..991094d3b0b06e 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -166,20 +166,14 @@ struct TestInlinerInterface : public DialectInlinerInterface { //===----------------------------------------------------------------------===// void TestDialect::initialize() { + registerAttributes(); + registerTypes(); addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "TestAttrDefs.cpp.inc" - >(); addInterfaces(); - addTypes(); allowUnknownOperations(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 39b055691db23a..1968ebd46f6ff5 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -32,6 +32,9 @@ def Test_Dialect : Dialect { let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; void printAttribute(Attribute attr, diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index f8d0c6a83f0739..38ab9c819974db 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -164,6 +164,13 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params, // TestDialect //===----------------------------------------------------------------------===// +void TestDialect::registerTypes() { + addTypes(); +} + static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt index 7ea17583bc3eef..6616a793ec12fb 100644 --- a/mlir/unittests/Support/CMakeLists.txt +++ b/mlir/unittests/Support/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_unittest(MLIRSupportTests DebugCounterTest.cpp IndentedOstreamTest.cpp MathExtrasTest.cpp + StorageUniquerTest.cpp ) target_link_libraries(MLIRSupportTests diff --git a/mlir/unittests/Support/StorageUniquerTest.cpp b/mlir/unittests/Support/StorageUniquerTest.cpp new file mode 100644 index 00000000000000..6db6783bb89f93 --- /dev/null +++ b/mlir/unittests/Support/StorageUniquerTest.cpp @@ -0,0 +1,60 @@ +//===- StorageUniquerTest.cpp - StorageUniquer Tests ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/StorageUniquer.h" +#include "gmock/gmock.h" + +using namespace mlir; + +namespace { +/// Simple storage class used for testing. +template +struct SimpleStorage : public StorageUniquer::BaseStorage { + using Base = SimpleStorage; + using KeyTy = std::tuple; + + SimpleStorage(KeyTy key) : key(key) {} + + /// Get an instance of this storage instance. + template + static ConcreteT *get(StorageUniquer &uniquer, ParamsT &&...params) { + return uniquer.get( + /*initFn=*/{}, std::make_tuple(std::forward(params)...)); + } + + /// Construct an instance with the given storage allocator. + static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, + KeyTy key) { + return new (alloc.allocate()) + ConcreteT(std::forward(key)); + } + bool operator==(const KeyTy &key) const { return this->key == key; } + + KeyTy key; +}; +} // namespace + +TEST(StorageUniquerTest, NonTrivialDestructor) { + struct NonTrivialStorage : public SimpleStorage { + using Base::Base; + ~NonTrivialStorage() { + bool *wasDestructed = std::get<0>(key); + *wasDestructed = true; + } + }; + + // Verify that the storage instance destructor was properly called. + bool wasDestructed = false; + { + StorageUniquer uniquer; + uniquer.registerParametricStorageType(); + NonTrivialStorage::get(uniquer, &wasDestructed); + } + + EXPECT_TRUE(wasDestructed); +}