Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Expose type and attribute names in the MLIRContext and abstract type/attr classes #72189

Merged
merged 17 commits into from
Nov 30, 2023

Conversation

math-fehr
Copy link
Contributor

This patch expose the type and attribute names in C++ as methods in the AbstractType and AbstractAttribute classes, and keep a map of names to AbstractType and AbstractAttribute in the MLIRContext. Type and attribute names should be unique.

It adds support in ODS to generate the getName methods in AbstractType and AbstractAttribute, through the use of two new variables, typeName and attrName. It also adds names to C++-defined type and attributes.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 14, 2023

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-dlti
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-quant
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-core

Author: Fehr Mathieu (math-fehr)

Changes

This patch expose the type and attribute names in C++ as methods in the AbstractType and AbstractAttribute classes, and keep a map of names to AbstractType and AbstractAttribute in the MLIRContext. Type and attribute names should be unique.

It adds support in ODS to generate the getName methods in AbstractType and AbstractAttribute, through the use of two new variables, typeName and attrName. It also adds names to C++-defined type and attributes.


Patch is 51.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72189.diff

26 Files Affected:

  • (modified) mlir/include/mlir/Dialect/DLTI/DLTI.h (+8)
  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+32-7)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (+10-7)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td (+4)
  • (modified) mlir/include/mlir/Dialect/Quant/QuantTypes.h (+8)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h (+6)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+20)
  • (modified) mlir/include/mlir/IR/AttrTypeBase.td (+9-3)
  • (modified) mlir/include/mlir/IR/AttributeSupport.h (+11-5)
  • (modified) mlir/include/mlir/IR/BuiltinAttributes.h (+2)
  • (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+17)
  • (modified) mlir/include/mlir/IR/BuiltinLocationAttributes.td (+6)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+28-27)
  • (modified) mlir/include/mlir/IR/TypeSupport.h (+17-5)
  • (modified) mlir/include/mlir/TableGen/AttrOrTypeDef.h (+8)
  • (modified) mlir/lib/IR/ExtensibleDialect.cpp (+14-2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+25-2)
  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+16)
  • (modified) mlir/test/lib/Dialect/Test/TestTypes.h (+2)
  • (modified) mlir/test/mlir-tblgen/attrdefs.td (+5)
  • (modified) mlir/test/mlir-tblgen/op-attribute.td (+1)
  • (modified) mlir/test/mlir-tblgen/op-decl-and-defs.td (+3-1)
  • (modified) mlir/test/mlir-tblgen/typedefs.td (+2)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+17)
  • (modified) mlir/unittests/IR/TypeTest.cpp (+5)
  • (modified) mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (+9)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index cf78b2312c24609..6dabbaa7ae1ad36 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -55,6 +55,10 @@ class DataLayoutEntryAttr
 
   /// Prints this attribute.
   void print(AsmPrinter &os) const;
+
+  static constexpr StringRef getAttrName() {
+    return "builtin.data_layout_entry";
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -109,6 +113,10 @@ class DataLayoutSpecAttr
 
   /// Prints this attribute.
   void print(AsmPrinter &os) const;
+
+  static constexpr StringRef getAttrName() {
+    return "builtin.data_layout_spec";
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..29ba15f8fe073a6 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -45,6 +45,8 @@ class AsyncTokenType
 public:
   // Used for generic hooks in TypeBase.
   using Base::Base;
+
+  static constexpr StringRef getTypeName() { return "gpu.async_token"; }
 };
 
 /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
@@ -128,6 +130,8 @@ class MMAMatrixType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "gpu.mma_matrix"; }
+
   /// Get MMAMatrixType and verify construction Invariants.
   static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
                            StringRef operand);
@@ -168,18 +172,39 @@ void addAsyncDependency(Operation *op, Value token);
 // Handle types for sparse.
 enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
 
-template <SparseHandleKind K>
-class SparseHandleType
-    : public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
+class SparseDnTensorHandleType
+    : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
+public:
+  using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
+                                       TypeStorage>::Base;
+  using Base::Base;
+
+  static constexpr StringRef getTypeName() {
+    return "gpu.sparse.dntensor_handle";
+  }
+};
+
+class SparseSpMatHandleType
+    : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
 public:
   using Base =
-      typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
+      typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
   using Base::Base;
+
+  static constexpr StringRef getTypeName() { return "gpu.sparse.spmat_handle"; }
 };
 
-using SparseDnTensorHandleType = SparseHandleType<SparseHandleKind::DnTensor>;
-using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
-using SparseSpGEMMOpHandleType = SparseHandleType<SparseHandleKind::SpGEMMOp>;
+class SparseSpGEMMOpHandleType
+    : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
+public:
+  using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
+                                       TypeStorage>::Base;
+  using Base::Base;
+
+  static constexpr StringRef getTypeName() {
+    return "gpu.sparse.spgemmop_handle";
+  }
+};
 
 } // namespace gpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index ba2f14f173aa0c3..66d9a11a1693bc7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -58,18 +58,19 @@ namespace LLVM {
 //===----------------------------------------------------------------------===//
 
 // Batch-define trivial types.
-#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName)                                    \
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, TypeName)                          \
   class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> {      \
   public:                                                                      \
     using Base::Base;                                                          \
+    static constexpr StringRef getTypeName() { return TypeName; }              \
   }
 
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, "llvm.x86_mmx");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
 
 #undef DEFINE_TRIVIAL_LLVM_TYPE
 
@@ -110,6 +111,8 @@ class LLVMStructType
   /// Inherit base constructors.
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "llvm.struct"; }
+
   /// Checks if the given type can be contained in a structure type.
   static bool isValidElementType(Type type);
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 0bd068c1be7c90a..96cdbf01b4bd91f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -162,6 +162,8 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
     elements can be processed as one in SIMD context.
   }];
 
+  let typeName = "llvm.fixed_vec";
+
   let parameters = (ins "Type":$elementType, "unsigned":$numElements);
   let assemblyFormat = [{
     `<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
@@ -192,6 +194,8 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
     elements can be processed as one in SIMD context.
   }];
 
+  let typeName = "llvm.scalable_vec";
+
   let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
   let assemblyFormat = [{
     `<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 2776b3e6e17ba50..f0300ea8cbb8ab2 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -198,6 +198,8 @@ class AnyQuantizedType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.any"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static AnyQuantizedType get(unsigned flags, Type storageType,
@@ -257,6 +259,8 @@ class UniformQuantizedType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.uniform"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static UniformQuantizedType get(unsigned flags, Type storageType,
@@ -315,6 +319,8 @@ class UniformQuantizedPerAxisType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.uniform_per_axis"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static UniformQuantizedPerAxisType
@@ -383,6 +389,8 @@ class CalibratedQuantizedType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.calibrated"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static CalibratedQuantizedType get(Type expressedType, double min,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 3b914dc4cc82f11..8de961295153b8d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -79,6 +79,8 @@ class InterfaceVarABIAttr
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               IntegerAttr descriptorSet, IntegerAttr binding,
                               IntegerAttr storageClass);
+
+  static constexpr StringRef getAttrName() { return "spirv.interface_var_abi"; }
 };
 
 /// An attribute that specifies the SPIR-V (version, capabilities, extensions)
@@ -129,6 +131,8 @@ class VerCapExtAttr
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               IntegerAttr version, ArrayAttr capabilities,
                               ArrayAttr extensions);
+
+  static constexpr StringRef getAttrName() { return "spirv.ver_cap_ext"; }
 };
 
 /// An attribute that specifies the target version, allowed extensions and
@@ -183,6 +187,8 @@ class TargetEnvAttr
 
   /// Returns the target resource limits.
   ResourceLimitsAttr getResourceLimits() const;
+
+  static constexpr StringRef getAttrName() { return "spirv.target_env"; }
 };
 } // namespace spirv
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4be2582f8fd68cc..e536adc8dce4abb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -132,6 +132,8 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.array"; }
+
   static ArrayType get(Type elementType, unsigned elementCount);
 
   /// Returns an array type with the given stride in bytes.
@@ -162,6 +164,8 @@ class ImageType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.image"; }
+
   static ImageType
   get(Type elementType, Dim dim,
       ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
@@ -201,6 +205,8 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.pointer"; }
+
   static PointerType get(Type pointeeType, StorageClass storageClass);
 
   Type getPointeeType() const;
@@ -220,6 +226,8 @@ class RuntimeArrayType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.rtarray"; }
+
   static RuntimeArrayType get(Type elementType);
 
   /// Returns a runtime array type with the given stride in bytes.
@@ -244,6 +252,8 @@ class SampledImageType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.sampled_image"; }
+
   static SampledImageType get(Type imageType);
 
   static SampledImageType
@@ -288,6 +298,8 @@ class StructType
   // Type for specifying the offset of the struct members
   using OffsetInfo = uint32_t;
 
+  static constexpr StringRef getTypeName() { return "spirv.struct"; }
+
   // Type for specifying the decoration(s) on struct members
   struct MemberDecorationInfo {
     uint32_t memberIndex : 31;
@@ -387,6 +399,8 @@ class CooperativeMatrixType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.coopmatrix"; }
+
   static CooperativeMatrixType get(Type elementType, uint32_t rows,
                                    uint32_t columns, Scope scope,
                                    CooperativeMatrixUseKHR use);
@@ -414,6 +428,8 @@ class CooperativeMatrixNVType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.NV.coopmatrix"; }
+
   static CooperativeMatrixNVType get(Type elementType, Scope scope,
                                      unsigned rows, unsigned columns);
   Type getElementType() const;
@@ -438,6 +454,8 @@ class JointMatrixINTELType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.jointmatrix"; }
+
   static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
                                   unsigned columns, MatrixLayout matrixLayout);
   Type getElementType() const;
@@ -464,6 +482,8 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.matrix"; }
+
   static MatrixType get(Type columnType, uint32_t columnCount);
 
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 42a611ee8e42205..91c9283de8bd415 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -264,6 +264,9 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   // Make it possible to use such attributes as parameters for other attributes.
   string cppType = dialect.cppNamespace # "::" # cppClassName;
 
+  // The unique attribute name.
+  string attrName = dialect.name # "." # mnemonic;
+
   // The call expression to convert from the storage type to the return
   // type. For example, an enum can be stored as an int but returned as an
   // enum class.
@@ -289,6 +292,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
   // Make it possible to use such type as parameters for other types.
   string cppType = dialect.cppNamespace # "::" # cppClassName;
 
+  // The unique type name.
+  string typeName = dialect.name # "." # mnemonic;
+
   // A constant builder provided when the type has no parameters.
   let builderCall = !if(!empty(parameters),
                            "$_builder.getType<" # dialect.cppNamespace #
@@ -431,15 +437,15 @@ class AttributeSelfTypeParameter<string desc,
 /// This class defines an attribute that contains an array of elements. The
 /// elements can be any type, but if they are attributes, the nested elements
 /// are parsed and printed using the custom attribute syntax.
-class ArrayOfAttr<Dialect dialect, string attrName, string attrMnemonic,
+class ArrayOfAttr<Dialect dialect, string name, string attrMnemonic,
                   string eltName, list<Trait> traits = []>
-    : AttrDef<dialect, attrName, traits> {
+    : AttrDef<dialect, name, traits> {
   let parameters = (ins OptionalArrayRefParameter<eltName>:$value);
   let mnemonic = attrMnemonic;
   let assemblyFormat = "`[` (`]`) : ($value^ `]`)?";
 
   let returnType = "::llvm::ArrayRef<" # eltName # ">";
-  let constBuilderCall = "$_builder.getAttr<" # attrName # "Attr>($0)";
+  let constBuilderCall = "$_builder.getAttr<" # name # "Attr>($0)";
   let convertFromStorage = "$_self.getValue()";
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 75ea1ce24753c97..104f5570e3bfb31 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -45,7 +45,7 @@ class AbstractAttribute {
     return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
                              T::getWalkImmediateSubElementsFn(),
                              T::getReplaceImmediateSubElementsFn(),
-                             T::getTypeID());
+                             T::getTypeID(), T::getAttrName());
   }
 
   /// This method is used by Dialect objects to register attributes with
@@ -57,10 +57,10 @@ class AbstractAttribute {
       HasTraitFn &&hasTrait,
       WalkImmediateSubElementsFn walkImmediateSubElementsFn,
       ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
-      TypeID typeID) {
+      TypeID typeID, StringRef name) {
     return AbstractAttribute(dialect, std::move(interfaceMap),
                              std::move(hasTrait), walkImmediateSubElementsFn,
-                             replaceImmediateSubElementsFn, typeID);
+                             replaceImmediateSubElementsFn, typeID, name);
   }
 
   /// Return the dialect this attribute was registered to.
@@ -102,17 +102,20 @@ class AbstractAttribute {
   /// Return the unique identifier representing the concrete attribute class.
   TypeID getTypeID() const { return typeID; }
 
+  /// Return the unique name representing the type.
+  StringRef getName() const { return name; }
+
 private:
   AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
                     HasTraitFn &&hasTraitFn,
                     WalkImmediateSubElementsFn walkImmediateSubElementsFn,
                     ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
-                    TypeID typeID)
+                    TypeID typeID, StringRef name)
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
         hasTraitFn(std::move(hasTraitFn)),
         walkImmediateSubElementsFn(walkImmediateSubElementsFn),
         replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
-        typeID(typeID) {}
+        typeID(typeID), name(name) {}
 
   /// Give StorageUserBase access to the mutable lookup.
   template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -141,6 +144,9 @@ class AbstractAttribute {
 
   /// The unique identifier of the derived Attribute class.
   const TypeID typeID;
+
+  /// The unique name of this type.
+  const StringRef name;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c8161604aad3503..c45473ceccfe371 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -1044,6 +1044,8 @@ class DistinctAttr
   /// Creates a distinct attribute that associates a referenced attribute with a
   /// unique identifier.
   static DistinctAttr create(Attribute referencedAttr);
+
+  static constexpr StringRef getAttrName() { return "builtin.distinct"; }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 390231da662e2d4..f4a42ccf736bfea 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -64,6 +64,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
     AffineMap getAffineMap() const { return getValue(); }
   }];
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.affine_map";
 }
 
 //===----------------------------------------------------------------------===//
@@ -134,6 +135,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
       });
     }
   }];
+  let attrName = "builtin.array";
 }
 
 //===----------------------------------------------------------------------===//
@@ -211,6 +213,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
     /// Return true if there are no elements in the dense array.
     bool empty() const { return !size(); }
   }];
+  let attrName = "builtin.dense_array";
 }
 
 //===----------------------------------------------------------------------===//
@@ -352,6 +355,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
   let genAccessors = 0;
   let genStorageClass = 0;
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dense_int_or_fp_elements";
 }
 
 //===----------------------------------------------------------------------===//
@@ -423,6 +427,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
   let genAccessors = 0;
   let genStorageClass = 0;
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dense_string_elements";
 }
 
 //===----------------------------------------------------------------------===//
@@ -481,6 +486,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
   ];
 
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dense_resource_elements";
 }
 
 //===----------------------------------------------------------------------===//
@@ -579,6 +585,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
   public:
   }];
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dictionary";
 }
 
 //===----------------------------------------------------------------------===//
@@ -642,6 +649,7 @@ def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> {
   }];
   let genVerifyDecl = 1;
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.float";
 }
 
 //===----------------------------------------------------------------------===//
@@ -730,6 +738,7 @@ def Builtin_IntegerAttr : Builtin_Attr...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 14, 2023

@llvm/pr-subscribers-mlir-llvm

Author: Fehr Mathieu (math-fehr)

Changes

This patch expose the type and attribute names in C++ as methods in the AbstractType and AbstractAttribute classes, and keep a map of names to AbstractType and AbstractAttribute in the MLIRContext. Type and attribute names should be unique.

It adds support in ODS to generate the getName methods in AbstractType and AbstractAttribute, through the use of two new variables, typeName and attrName. It also adds names to C++-defined type and attributes.


Patch is 51.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72189.diff

26 Files Affected:

  • (modified) mlir/include/mlir/Dialect/DLTI/DLTI.h (+8)
  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+32-7)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (+10-7)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td (+4)
  • (modified) mlir/include/mlir/Dialect/Quant/QuantTypes.h (+8)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h (+6)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+20)
  • (modified) mlir/include/mlir/IR/AttrTypeBase.td (+9-3)
  • (modified) mlir/include/mlir/IR/AttributeSupport.h (+11-5)
  • (modified) mlir/include/mlir/IR/BuiltinAttributes.h (+2)
  • (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+17)
  • (modified) mlir/include/mlir/IR/BuiltinLocationAttributes.td (+6)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+28-27)
  • (modified) mlir/include/mlir/IR/TypeSupport.h (+17-5)
  • (modified) mlir/include/mlir/TableGen/AttrOrTypeDef.h (+8)
  • (modified) mlir/lib/IR/ExtensibleDialect.cpp (+14-2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+25-2)
  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+16)
  • (modified) mlir/test/lib/Dialect/Test/TestTypes.h (+2)
  • (modified) mlir/test/mlir-tblgen/attrdefs.td (+5)
  • (modified) mlir/test/mlir-tblgen/op-attribute.td (+1)
  • (modified) mlir/test/mlir-tblgen/op-decl-and-defs.td (+3-1)
  • (modified) mlir/test/mlir-tblgen/typedefs.td (+2)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+17)
  • (modified) mlir/unittests/IR/TypeTest.cpp (+5)
  • (modified) mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (+9)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index cf78b2312c24609..6dabbaa7ae1ad36 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -55,6 +55,10 @@ class DataLayoutEntryAttr
 
   /// Prints this attribute.
   void print(AsmPrinter &os) const;
+
+  static constexpr StringRef getAttrName() {
+    return "builtin.data_layout_entry";
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -109,6 +113,10 @@ class DataLayoutSpecAttr
 
   /// Prints this attribute.
   void print(AsmPrinter &os) const;
+
+  static constexpr StringRef getAttrName() {
+    return "builtin.data_layout_spec";
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..29ba15f8fe073a6 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -45,6 +45,8 @@ class AsyncTokenType
 public:
   // Used for generic hooks in TypeBase.
   using Base::Base;
+
+  static constexpr StringRef getTypeName() { return "gpu.async_token"; }
 };
 
 /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
@@ -128,6 +130,8 @@ class MMAMatrixType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "gpu.mma_matrix"; }
+
   /// Get MMAMatrixType and verify construction Invariants.
   static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
                            StringRef operand);
@@ -168,18 +172,39 @@ void addAsyncDependency(Operation *op, Value token);
 // Handle types for sparse.
 enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
 
-template <SparseHandleKind K>
-class SparseHandleType
-    : public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
+class SparseDnTensorHandleType
+    : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
+public:
+  using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
+                                       TypeStorage>::Base;
+  using Base::Base;
+
+  static constexpr StringRef getTypeName() {
+    return "gpu.sparse.dntensor_handle";
+  }
+};
+
+class SparseSpMatHandleType
+    : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
 public:
   using Base =
-      typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
+      typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
   using Base::Base;
+
+  static constexpr StringRef getTypeName() { return "gpu.sparse.spmat_handle"; }
 };
 
-using SparseDnTensorHandleType = SparseHandleType<SparseHandleKind::DnTensor>;
-using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
-using SparseSpGEMMOpHandleType = SparseHandleType<SparseHandleKind::SpGEMMOp>;
+class SparseSpGEMMOpHandleType
+    : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
+public:
+  using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
+                                       TypeStorage>::Base;
+  using Base::Base;
+
+  static constexpr StringRef getTypeName() {
+    return "gpu.sparse.spgemmop_handle";
+  }
+};
 
 } // namespace gpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index ba2f14f173aa0c3..66d9a11a1693bc7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -58,18 +58,19 @@ namespace LLVM {
 //===----------------------------------------------------------------------===//
 
 // Batch-define trivial types.
-#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName)                                    \
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, TypeName)                          \
   class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> {      \
   public:                                                                      \
     using Base::Base;                                                          \
+    static constexpr StringRef getTypeName() { return TypeName; }              \
   }
 
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, "llvm.x86_mmx");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
 
 #undef DEFINE_TRIVIAL_LLVM_TYPE
 
@@ -110,6 +111,8 @@ class LLVMStructType
   /// Inherit base constructors.
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "llvm.struct"; }
+
   /// Checks if the given type can be contained in a structure type.
   static bool isValidElementType(Type type);
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 0bd068c1be7c90a..96cdbf01b4bd91f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -162,6 +162,8 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
     elements can be processed as one in SIMD context.
   }];
 
+  let typeName = "llvm.fixed_vec";
+
   let parameters = (ins "Type":$elementType, "unsigned":$numElements);
   let assemblyFormat = [{
     `<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
@@ -192,6 +194,8 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
     elements can be processed as one in SIMD context.
   }];
 
+  let typeName = "llvm.scalable_vec";
+
   let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
   let assemblyFormat = [{
     `<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 2776b3e6e17ba50..f0300ea8cbb8ab2 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -198,6 +198,8 @@ class AnyQuantizedType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.any"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static AnyQuantizedType get(unsigned flags, Type storageType,
@@ -257,6 +259,8 @@ class UniformQuantizedType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.uniform"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static UniformQuantizedType get(unsigned flags, Type storageType,
@@ -315,6 +319,8 @@ class UniformQuantizedPerAxisType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.uniform_per_axis"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static UniformQuantizedPerAxisType
@@ -383,6 +389,8 @@ class CalibratedQuantizedType
   using Base::Base;
   using Base::getChecked;
 
+  static constexpr StringRef getTypeName() { return "quant.calibrated"; }
+
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
   static CalibratedQuantizedType get(Type expressedType, double min,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 3b914dc4cc82f11..8de961295153b8d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -79,6 +79,8 @@ class InterfaceVarABIAttr
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               IntegerAttr descriptorSet, IntegerAttr binding,
                               IntegerAttr storageClass);
+
+  static constexpr StringRef getAttrName() { return "spirv.interface_var_abi"; }
 };
 
 /// An attribute that specifies the SPIR-V (version, capabilities, extensions)
@@ -129,6 +131,8 @@ class VerCapExtAttr
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               IntegerAttr version, ArrayAttr capabilities,
                               ArrayAttr extensions);
+
+  static constexpr StringRef getAttrName() { return "spirv.ver_cap_ext"; }
 };
 
 /// An attribute that specifies the target version, allowed extensions and
@@ -183,6 +187,8 @@ class TargetEnvAttr
 
   /// Returns the target resource limits.
   ResourceLimitsAttr getResourceLimits() const;
+
+  static constexpr StringRef getAttrName() { return "spirv.target_env"; }
 };
 } // namespace spirv
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4be2582f8fd68cc..e536adc8dce4abb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -132,6 +132,8 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.array"; }
+
   static ArrayType get(Type elementType, unsigned elementCount);
 
   /// Returns an array type with the given stride in bytes.
@@ -162,6 +164,8 @@ class ImageType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.image"; }
+
   static ImageType
   get(Type elementType, Dim dim,
       ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
@@ -201,6 +205,8 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.pointer"; }
+
   static PointerType get(Type pointeeType, StorageClass storageClass);
 
   Type getPointeeType() const;
@@ -220,6 +226,8 @@ class RuntimeArrayType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.rtarray"; }
+
   static RuntimeArrayType get(Type elementType);
 
   /// Returns a runtime array type with the given stride in bytes.
@@ -244,6 +252,8 @@ class SampledImageType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.sampled_image"; }
+
   static SampledImageType get(Type imageType);
 
   static SampledImageType
@@ -288,6 +298,8 @@ class StructType
   // Type for specifying the offset of the struct members
   using OffsetInfo = uint32_t;
 
+  static constexpr StringRef getTypeName() { return "spirv.struct"; }
+
   // Type for specifying the decoration(s) on struct members
   struct MemberDecorationInfo {
     uint32_t memberIndex : 31;
@@ -387,6 +399,8 @@ class CooperativeMatrixType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.coopmatrix"; }
+
   static CooperativeMatrixType get(Type elementType, uint32_t rows,
                                    uint32_t columns, Scope scope,
                                    CooperativeMatrixUseKHR use);
@@ -414,6 +428,8 @@ class CooperativeMatrixNVType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.NV.coopmatrix"; }
+
   static CooperativeMatrixNVType get(Type elementType, Scope scope,
                                      unsigned rows, unsigned columns);
   Type getElementType() const;
@@ -438,6 +454,8 @@ class JointMatrixINTELType
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.jointmatrix"; }
+
   static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
                                   unsigned columns, MatrixLayout matrixLayout);
   Type getElementType() const;
@@ -464,6 +482,8 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
 public:
   using Base::Base;
 
+  static constexpr StringRef getTypeName() { return "spirv.matrix"; }
+
   static MatrixType get(Type columnType, uint32_t columnCount);
 
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 42a611ee8e42205..91c9283de8bd415 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -264,6 +264,9 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   // Make it possible to use such attributes as parameters for other attributes.
   string cppType = dialect.cppNamespace # "::" # cppClassName;
 
+  // The unique attribute name.
+  string attrName = dialect.name # "." # mnemonic;
+
   // The call expression to convert from the storage type to the return
   // type. For example, an enum can be stored as an int but returned as an
   // enum class.
@@ -289,6 +292,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
   // Make it possible to use such type as parameters for other types.
   string cppType = dialect.cppNamespace # "::" # cppClassName;
 
+  // The unique type name.
+  string typeName = dialect.name # "." # mnemonic;
+
   // A constant builder provided when the type has no parameters.
   let builderCall = !if(!empty(parameters),
                            "$_builder.getType<" # dialect.cppNamespace #
@@ -431,15 +437,15 @@ class AttributeSelfTypeParameter<string desc,
 /// This class defines an attribute that contains an array of elements. The
 /// elements can be any type, but if they are attributes, the nested elements
 /// are parsed and printed using the custom attribute syntax.
-class ArrayOfAttr<Dialect dialect, string attrName, string attrMnemonic,
+class ArrayOfAttr<Dialect dialect, string name, string attrMnemonic,
                   string eltName, list<Trait> traits = []>
-    : AttrDef<dialect, attrName, traits> {
+    : AttrDef<dialect, name, traits> {
   let parameters = (ins OptionalArrayRefParameter<eltName>:$value);
   let mnemonic = attrMnemonic;
   let assemblyFormat = "`[` (`]`) : ($value^ `]`)?";
 
   let returnType = "::llvm::ArrayRef<" # eltName # ">";
-  let constBuilderCall = "$_builder.getAttr<" # attrName # "Attr>($0)";
+  let constBuilderCall = "$_builder.getAttr<" # name # "Attr>($0)";
   let convertFromStorage = "$_self.getValue()";
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 75ea1ce24753c97..104f5570e3bfb31 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -45,7 +45,7 @@ class AbstractAttribute {
     return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
                              T::getWalkImmediateSubElementsFn(),
                              T::getReplaceImmediateSubElementsFn(),
-                             T::getTypeID());
+                             T::getTypeID(), T::getAttrName());
   }
 
   /// This method is used by Dialect objects to register attributes with
@@ -57,10 +57,10 @@ class AbstractAttribute {
       HasTraitFn &&hasTrait,
       WalkImmediateSubElementsFn walkImmediateSubElementsFn,
       ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
-      TypeID typeID) {
+      TypeID typeID, StringRef name) {
     return AbstractAttribute(dialect, std::move(interfaceMap),
                              std::move(hasTrait), walkImmediateSubElementsFn,
-                             replaceImmediateSubElementsFn, typeID);
+                             replaceImmediateSubElementsFn, typeID, name);
   }
 
   /// Return the dialect this attribute was registered to.
@@ -102,17 +102,20 @@ class AbstractAttribute {
   /// Return the unique identifier representing the concrete attribute class.
   TypeID getTypeID() const { return typeID; }
 
+  /// Return the unique name representing the type.
+  StringRef getName() const { return name; }
+
 private:
   AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
                     HasTraitFn &&hasTraitFn,
                     WalkImmediateSubElementsFn walkImmediateSubElementsFn,
                     ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
-                    TypeID typeID)
+                    TypeID typeID, StringRef name)
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
         hasTraitFn(std::move(hasTraitFn)),
         walkImmediateSubElementsFn(walkImmediateSubElementsFn),
         replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
-        typeID(typeID) {}
+        typeID(typeID), name(name) {}
 
   /// Give StorageUserBase access to the mutable lookup.
   template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -141,6 +144,9 @@ class AbstractAttribute {
 
   /// The unique identifier of the derived Attribute class.
   const TypeID typeID;
+
+  /// The unique name of this type.
+  const StringRef name;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c8161604aad3503..c45473ceccfe371 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -1044,6 +1044,8 @@ class DistinctAttr
   /// Creates a distinct attribute that associates a referenced attribute with a
   /// unique identifier.
   static DistinctAttr create(Attribute referencedAttr);
+
+  static constexpr StringRef getAttrName() { return "builtin.distinct"; }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 390231da662e2d4..f4a42ccf736bfea 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -64,6 +64,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
     AffineMap getAffineMap() const { return getValue(); }
   }];
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.affine_map";
 }
 
 //===----------------------------------------------------------------------===//
@@ -134,6 +135,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
       });
     }
   }];
+  let attrName = "builtin.array";
 }
 
 //===----------------------------------------------------------------------===//
@@ -211,6 +213,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
     /// Return true if there are no elements in the dense array.
     bool empty() const { return !size(); }
   }];
+  let attrName = "builtin.dense_array";
 }
 
 //===----------------------------------------------------------------------===//
@@ -352,6 +355,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
   let genAccessors = 0;
   let genStorageClass = 0;
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dense_int_or_fp_elements";
 }
 
 //===----------------------------------------------------------------------===//
@@ -423,6 +427,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
   let genAccessors = 0;
   let genStorageClass = 0;
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dense_string_elements";
 }
 
 //===----------------------------------------------------------------------===//
@@ -481,6 +486,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
   ];
 
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dense_resource_elements";
 }
 
 //===----------------------------------------------------------------------===//
@@ -579,6 +585,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
   public:
   }];
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.dictionary";
 }
 
 //===----------------------------------------------------------------------===//
@@ -642,6 +649,7 @@ def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> {
   }];
   let genVerifyDecl = 1;
   let skipDefaultBuilders = 1;
+  let attrName = "builtin.float";
 }
 
 //===----------------------------------------------------------------------===//
@@ -730,6 +738,7 @@ def Builtin_IntegerAttr : Builtin_Attr...
[truncated]

@math-fehr
Copy link
Contributor Author

@@ -212,6 +212,9 @@ class MLIRContextImpl {
DenseMap<TypeID, AbstractType *> registeredTypes;
StorageUniquer typeUniquer;

/// This is a mapping from type name to the abstract type describing it.
llvm::StringMap<AbstractType *> nameToType;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use DenseMap<StringAttr, AbstractType *> nameToType; (more efficient storage, and more efficient lookup)?

Also when is it useful? (please add to the comment: the "what" is less important than the "why").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh my bad, it got erased again in a following PR, I'll add it back now!
Sorry about that!

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 20, 2023
@math-fehr
Copy link
Contributor Author

I moved getTypeName and getAttrName to just a name field, so that it will be the same for both attributes and types, and also would not clash with any name parameter that could be define in ODS.
Do you think that make sense, or do you think that this is too much different than operations (that have a getName static method). Or do you think this should also be changed in operations?

Also, I added the change in MLIRContext so that we store a DensMap of StringAttr rather than a StringMap.

Currently, I have a Address boundary error when build/NATIVE/bin/mlir-linalg-ods-yaml-gen is being used in the build. I'm trying to figure out if it's because of the ::name change, or the StringAttr.

@math-fehr
Copy link
Contributor Author

So my understanding is that DenseMap<StringAttr, AbstractType*> was what was creating the problem, but I'm not sure if this is a mistake on my end or if this is not possible.

For it to work, we would need to create a StringAttr before loading the builtin dialect, is that possible? This is because when we are loading the builtin dialect, we need to create those StringAttr for the type names when we load each type/attribute.

So for now I removed it, please tell me if you see a way I could make this work!

@math-fehr
Copy link
Contributor Author

@joker-eph @kuhar, does this PR sounds good to you now?
I'll wait for your approval to merge it, but otherwise I think enough time passed that anyone interested commented?

@joker-eph
Copy link
Collaborator

So my understanding is that DenseMap<StringAttr, AbstractType*> was what was creating the problem, but I'm not sure if this is a mistake on my end or if this is not possible.

For it to work, we would need to create a StringAttr before loading the builtin dialect, is that possible? This is because when we are loading the builtin dialect, we need to create those StringAttr for the type names when we load each type/attribute.

Oh great... We can't add the name for StringAttr before registering the storage is that it?

Pretty sad about these string duplication, but we can live with it. Please document this next to the StringMap.

So for now I removed it, please tell me if you see a way I could make this work!

It's likely possible, but will require some deeper look. I may try this after you land it!

@math-fehr
Copy link
Contributor Author

So my understanding is that DenseMap<StringAttr, AbstractType*> was what was creating the problem, but I'm not sure if this is a mistake on my end or if this is not possible.
For it to work, we would need to create a StringAttr before loading the builtin dialect, is that possible? This is because when we are loading the builtin dialect, we need to create those StringAttr for the type names when we load each type/attribute.

Oh great... We can't add the name for StringAttr before registering the storage is that it?

Pretty sad about these string duplication, but we can live with it. Please document this next to the StringMap.

So for now I removed it, please tell me if you see a way I could make this work!

It's likely possible, but will require some deeper look. I may try this after you land it!

Another solution would be to use a DenseMap of StringRef instead, as all names are basically StringLiteral for C++ defined attributes, and StringRef for dynamically defined attributes (as they already store their names).
Do you think that makes more sense for this PR?

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@joker-eph
Copy link
Collaborator

Do you think that makes more sense for this PR?

We could if we're confident this is enough. The API should then take a "StringLiteral" (it's more specific than a StringRef) I believe.

@math-fehr
Copy link
Contributor Author

Do you think that makes more sense for this PR?

We could if we're confident this is enough. The API should then take a "StringLiteral" (it's more specific than a StringRef) I believe.

I'm confident StringRef will work, since it works for both statically-defined and dynamically defined types and attributes.
For StringLiteral, it wouldn't work for dynamically defined attributes and types.

@joker-eph
Copy link
Collaborator

Well it's a low-level API, so let's document the lifetime constraint on the StringRef then!

@math-fehr
Copy link
Contributor Author

Oh yes my bad! I added documentation for that (and for the StringRef in the AbstractType/AbstractAttribute)

@math-fehr math-fehr merged commit 3dbac2c into llvm:main Nov 30, 2023
3 checks passed
@makslevental
Copy link
Contributor

Congrats @math-fehr on landing a solid patch. I'll be using/abusing this soon for the Python bindings :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants