-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Expose type and attribute names in the MLIRContext and abstract type/attr classes #72189
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-core Author: Fehr Mathieu (math-fehr) ChangesThis patch expose the type and attribute names in C++ as methods in the It adds support in ODS to generate the Patch is 51.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72189.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index cf78b2312c24609..6dabbaa7ae1ad36 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -55,6 +55,10 @@ class DataLayoutEntryAttr
/// Prints this attribute.
void print(AsmPrinter &os) const;
+
+ static constexpr StringRef getAttrName() {
+ return "builtin.data_layout_entry";
+ }
};
//===----------------------------------------------------------------------===//
@@ -109,6 +113,10 @@ class DataLayoutSpecAttr
/// Prints this attribute.
void print(AsmPrinter &os) const;
+
+ static constexpr StringRef getAttrName() {
+ return "builtin.data_layout_spec";
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..29ba15f8fe073a6 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -45,6 +45,8 @@ class AsyncTokenType
public:
// Used for generic hooks in TypeBase.
using Base::Base;
+
+ static constexpr StringRef getTypeName() { return "gpu.async_token"; }
};
/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
@@ -128,6 +130,8 @@ class MMAMatrixType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "gpu.mma_matrix"; }
+
/// Get MMAMatrixType and verify construction Invariants.
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
@@ -168,18 +172,39 @@ void addAsyncDependency(Operation *op, Value token);
// Handle types for sparse.
enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
-template <SparseHandleKind K>
-class SparseHandleType
- : public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
+class SparseDnTensorHandleType
+ : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
+public:
+ using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
+ TypeStorage>::Base;
+ using Base::Base;
+
+ static constexpr StringRef getTypeName() {
+ return "gpu.sparse.dntensor_handle";
+ }
+};
+
+class SparseSpMatHandleType
+ : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
public:
using Base =
- typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
+ typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
using Base::Base;
+
+ static constexpr StringRef getTypeName() { return "gpu.sparse.spmat_handle"; }
};
-using SparseDnTensorHandleType = SparseHandleType<SparseHandleKind::DnTensor>;
-using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
-using SparseSpGEMMOpHandleType = SparseHandleType<SparseHandleKind::SpGEMMOp>;
+class SparseSpGEMMOpHandleType
+ : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
+public:
+ using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
+ TypeStorage>::Base;
+ using Base::Base;
+
+ static constexpr StringRef getTypeName() {
+ return "gpu.sparse.spgemmop_handle";
+ }
+};
} // namespace gpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index ba2f14f173aa0c3..66d9a11a1693bc7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -58,18 +58,19 @@ namespace LLVM {
//===----------------------------------------------------------------------===//
// Batch-define trivial types.
-#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, TypeName) \
class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> { \
public: \
using Base::Base; \
+ static constexpr StringRef getTypeName() { return TypeName; } \
}
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, "llvm.x86_mmx");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
#undef DEFINE_TRIVIAL_LLVM_TYPE
@@ -110,6 +111,8 @@ class LLVMStructType
/// Inherit base constructors.
using Base::Base;
+ static constexpr StringRef getTypeName() { return "llvm.struct"; }
+
/// Checks if the given type can be contained in a structure type.
static bool isValidElementType(Type type);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 0bd068c1be7c90a..96cdbf01b4bd91f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -162,6 +162,8 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
elements can be processed as one in SIMD context.
}];
+ let typeName = "llvm.fixed_vec";
+
let parameters = (ins "Type":$elementType, "unsigned":$numElements);
let assemblyFormat = [{
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
@@ -192,6 +194,8 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
elements can be processed as one in SIMD context.
}];
+ let typeName = "llvm.scalable_vec";
+
let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
let assemblyFormat = [{
`<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 2776b3e6e17ba50..f0300ea8cbb8ab2 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -198,6 +198,8 @@ class AnyQuantizedType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.any"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static AnyQuantizedType get(unsigned flags, Type storageType,
@@ -257,6 +259,8 @@ class UniformQuantizedType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.uniform"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
@@ -315,6 +319,8 @@ class UniformQuantizedPerAxisType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.uniform_per_axis"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
@@ -383,6 +389,8 @@ class CalibratedQuantizedType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.calibrated"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static CalibratedQuantizedType get(Type expressedType, double min,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 3b914dc4cc82f11..8de961295153b8d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -79,6 +79,8 @@ class InterfaceVarABIAttr
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);
+
+ static constexpr StringRef getAttrName() { return "spirv.interface_var_abi"; }
};
/// An attribute that specifies the SPIR-V (version, capabilities, extensions)
@@ -129,6 +131,8 @@ class VerCapExtAttr
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);
+
+ static constexpr StringRef getAttrName() { return "spirv.ver_cap_ext"; }
};
/// An attribute that specifies the target version, allowed extensions and
@@ -183,6 +187,8 @@ class TargetEnvAttr
/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits() const;
+
+ static constexpr StringRef getAttrName() { return "spirv.target_env"; }
};
} // namespace spirv
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4be2582f8fd68cc..e536adc8dce4abb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -132,6 +132,8 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.array"; }
+
static ArrayType get(Type elementType, unsigned elementCount);
/// Returns an array type with the given stride in bytes.
@@ -162,6 +164,8 @@ class ImageType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.image"; }
+
static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
@@ -201,6 +205,8 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.pointer"; }
+
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType() const;
@@ -220,6 +226,8 @@ class RuntimeArrayType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.rtarray"; }
+
static RuntimeArrayType get(Type elementType);
/// Returns a runtime array type with the given stride in bytes.
@@ -244,6 +252,8 @@ class SampledImageType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.sampled_image"; }
+
static SampledImageType get(Type imageType);
static SampledImageType
@@ -288,6 +298,8 @@ class StructType
// Type for specifying the offset of the struct members
using OffsetInfo = uint32_t;
+ static constexpr StringRef getTypeName() { return "spirv.struct"; }
+
// Type for specifying the decoration(s) on struct members
struct MemberDecorationInfo {
uint32_t memberIndex : 31;
@@ -387,6 +399,8 @@ class CooperativeMatrixType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.coopmatrix"; }
+
static CooperativeMatrixType get(Type elementType, uint32_t rows,
uint32_t columns, Scope scope,
CooperativeMatrixUseKHR use);
@@ -414,6 +428,8 @@ class CooperativeMatrixNVType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.NV.coopmatrix"; }
+
static CooperativeMatrixNVType get(Type elementType, Scope scope,
unsigned rows, unsigned columns);
Type getElementType() const;
@@ -438,6 +454,8 @@ class JointMatrixINTELType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.jointmatrix"; }
+
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
unsigned columns, MatrixLayout matrixLayout);
Type getElementType() const;
@@ -464,6 +482,8 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.matrix"; }
+
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 42a611ee8e42205..91c9283de8bd415 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -264,6 +264,9 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
// Make it possible to use such attributes as parameters for other attributes.
string cppType = dialect.cppNamespace # "::" # cppClassName;
+ // The unique attribute name.
+ string attrName = dialect.name # "." # mnemonic;
+
// The call expression to convert from the storage type to the return
// type. For example, an enum can be stored as an int but returned as an
// enum class.
@@ -289,6 +292,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;
+ // The unique type name.
+ string typeName = dialect.name # "." # mnemonic;
+
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
@@ -431,15 +437,15 @@ class AttributeSelfTypeParameter<string desc,
/// This class defines an attribute that contains an array of elements. The
/// elements can be any type, but if they are attributes, the nested elements
/// are parsed and printed using the custom attribute syntax.
-class ArrayOfAttr<Dialect dialect, string attrName, string attrMnemonic,
+class ArrayOfAttr<Dialect dialect, string name, string attrMnemonic,
string eltName, list<Trait> traits = []>
- : AttrDef<dialect, attrName, traits> {
+ : AttrDef<dialect, name, traits> {
let parameters = (ins OptionalArrayRefParameter<eltName>:$value);
let mnemonic = attrMnemonic;
let assemblyFormat = "`[` (`]`) : ($value^ `]`)?";
let returnType = "::llvm::ArrayRef<" # eltName # ">";
- let constBuilderCall = "$_builder.getAttr<" # attrName # "Attr>($0)";
+ let constBuilderCall = "$_builder.getAttr<" # name # "Attr>($0)";
let convertFromStorage = "$_self.getValue()";
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 75ea1ce24753c97..104f5570e3bfb31 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -45,7 +45,7 @@ class AbstractAttribute {
return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
T::getWalkImmediateSubElementsFn(),
T::getReplaceImmediateSubElementsFn(),
- T::getTypeID());
+ T::getTypeID(), T::getAttrName());
}
/// This method is used by Dialect objects to register attributes with
@@ -57,10 +57,10 @@ class AbstractAttribute {
HasTraitFn &&hasTrait,
WalkImmediateSubElementsFn walkImmediateSubElementsFn,
ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
- TypeID typeID) {
+ TypeID typeID, StringRef name) {
return AbstractAttribute(dialect, std::move(interfaceMap),
std::move(hasTrait), walkImmediateSubElementsFn,
- replaceImmediateSubElementsFn, typeID);
+ replaceImmediateSubElementsFn, typeID, name);
}
/// Return the dialect this attribute was registered to.
@@ -102,17 +102,20 @@ class AbstractAttribute {
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }
+ /// Return the unique name representing the type.
+ StringRef getName() const { return name; }
+
private:
AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
HasTraitFn &&hasTraitFn,
WalkImmediateSubElementsFn walkImmediateSubElementsFn,
ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
- TypeID typeID)
+ TypeID typeID, StringRef name)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
hasTraitFn(std::move(hasTraitFn)),
walkImmediateSubElementsFn(walkImmediateSubElementsFn),
replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
- typeID(typeID) {}
+ typeID(typeID), name(name) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -141,6 +144,9 @@ class AbstractAttribute {
/// The unique identifier of the derived Attribute class.
const TypeID typeID;
+
+ /// The unique name of this type.
+ const StringRef name;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c8161604aad3503..c45473ceccfe371 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -1044,6 +1044,8 @@ class DistinctAttr
/// Creates a distinct attribute that associates a referenced attribute with a
/// unique identifier.
static DistinctAttr create(Attribute referencedAttr);
+
+ static constexpr StringRef getAttrName() { return "builtin.distinct"; }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 390231da662e2d4..f4a42ccf736bfea 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -64,6 +64,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
AffineMap getAffineMap() const { return getValue(); }
}];
let skipDefaultBuilders = 1;
+ let attrName = "builtin.affine_map";
}
//===----------------------------------------------------------------------===//
@@ -134,6 +135,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
});
}
}];
+ let attrName = "builtin.array";
}
//===----------------------------------------------------------------------===//
@@ -211,6 +213,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
/// Return true if there are no elements in the dense array.
bool empty() const { return !size(); }
}];
+ let attrName = "builtin.dense_array";
}
//===----------------------------------------------------------------------===//
@@ -352,6 +355,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
let genAccessors = 0;
let genStorageClass = 0;
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dense_int_or_fp_elements";
}
//===----------------------------------------------------------------------===//
@@ -423,6 +427,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
let genAccessors = 0;
let genStorageClass = 0;
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dense_string_elements";
}
//===----------------------------------------------------------------------===//
@@ -481,6 +486,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
];
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dense_resource_elements";
}
//===----------------------------------------------------------------------===//
@@ -579,6 +585,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
public:
}];
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dictionary";
}
//===----------------------------------------------------------------------===//
@@ -642,6 +649,7 @@ def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> {
}];
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
+ let attrName = "builtin.float";
}
//===----------------------------------------------------------------------===//
@@ -730,6 +738,7 @@ def Builtin_IntegerAttr : Builtin_Attr...
[truncated]
|
@llvm/pr-subscribers-mlir-llvm Author: Fehr Mathieu (math-fehr) ChangesThis patch expose the type and attribute names in C++ as methods in the It adds support in ODS to generate the Patch is 51.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72189.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index cf78b2312c24609..6dabbaa7ae1ad36 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -55,6 +55,10 @@ class DataLayoutEntryAttr
/// Prints this attribute.
void print(AsmPrinter &os) const;
+
+ static constexpr StringRef getAttrName() {
+ return "builtin.data_layout_entry";
+ }
};
//===----------------------------------------------------------------------===//
@@ -109,6 +113,10 @@ class DataLayoutSpecAttr
/// Prints this attribute.
void print(AsmPrinter &os) const;
+
+ static constexpr StringRef getAttrName() {
+ return "builtin.data_layout_spec";
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..29ba15f8fe073a6 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -45,6 +45,8 @@ class AsyncTokenType
public:
// Used for generic hooks in TypeBase.
using Base::Base;
+
+ static constexpr StringRef getTypeName() { return "gpu.async_token"; }
};
/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
@@ -128,6 +130,8 @@ class MMAMatrixType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "gpu.mma_matrix"; }
+
/// Get MMAMatrixType and verify construction Invariants.
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
@@ -168,18 +172,39 @@ void addAsyncDependency(Operation *op, Value token);
// Handle types for sparse.
enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
-template <SparseHandleKind K>
-class SparseHandleType
- : public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
+class SparseDnTensorHandleType
+ : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
+public:
+ using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
+ TypeStorage>::Base;
+ using Base::Base;
+
+ static constexpr StringRef getTypeName() {
+ return "gpu.sparse.dntensor_handle";
+ }
+};
+
+class SparseSpMatHandleType
+ : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
public:
using Base =
- typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
+ typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
using Base::Base;
+
+ static constexpr StringRef getTypeName() { return "gpu.sparse.spmat_handle"; }
};
-using SparseDnTensorHandleType = SparseHandleType<SparseHandleKind::DnTensor>;
-using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
-using SparseSpGEMMOpHandleType = SparseHandleType<SparseHandleKind::SpGEMMOp>;
+class SparseSpGEMMOpHandleType
+ : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
+public:
+ using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
+ TypeStorage>::Base;
+ using Base::Base;
+
+ static constexpr StringRef getTypeName() {
+ return "gpu.sparse.spgemmop_handle";
+ }
+};
} // namespace gpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index ba2f14f173aa0c3..66d9a11a1693bc7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -58,18 +58,19 @@ namespace LLVM {
//===----------------------------------------------------------------------===//
// Batch-define trivial types.
-#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, TypeName) \
class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> { \
public: \
using Base::Base; \
+ static constexpr StringRef getTypeName() { return TypeName; } \
}
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, "llvm.x86_mmx");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
#undef DEFINE_TRIVIAL_LLVM_TYPE
@@ -110,6 +111,8 @@ class LLVMStructType
/// Inherit base constructors.
using Base::Base;
+ static constexpr StringRef getTypeName() { return "llvm.struct"; }
+
/// Checks if the given type can be contained in a structure type.
static bool isValidElementType(Type type);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 0bd068c1be7c90a..96cdbf01b4bd91f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -162,6 +162,8 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
elements can be processed as one in SIMD context.
}];
+ let typeName = "llvm.fixed_vec";
+
let parameters = (ins "Type":$elementType, "unsigned":$numElements);
let assemblyFormat = [{
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
@@ -192,6 +194,8 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
elements can be processed as one in SIMD context.
}];
+ let typeName = "llvm.scalable_vec";
+
let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
let assemblyFormat = [{
`<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 2776b3e6e17ba50..f0300ea8cbb8ab2 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -198,6 +198,8 @@ class AnyQuantizedType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.any"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static AnyQuantizedType get(unsigned flags, Type storageType,
@@ -257,6 +259,8 @@ class UniformQuantizedType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.uniform"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
@@ -315,6 +319,8 @@ class UniformQuantizedPerAxisType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.uniform_per_axis"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
@@ -383,6 +389,8 @@ class CalibratedQuantizedType
using Base::Base;
using Base::getChecked;
+ static constexpr StringRef getTypeName() { return "quant.calibrated"; }
+
/// Gets an instance of the type with all parameters specified but not
/// checked.
static CalibratedQuantizedType get(Type expressedType, double min,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 3b914dc4cc82f11..8de961295153b8d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -79,6 +79,8 @@ class InterfaceVarABIAttr
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);
+
+ static constexpr StringRef getAttrName() { return "spirv.interface_var_abi"; }
};
/// An attribute that specifies the SPIR-V (version, capabilities, extensions)
@@ -129,6 +131,8 @@ class VerCapExtAttr
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);
+
+ static constexpr StringRef getAttrName() { return "spirv.ver_cap_ext"; }
};
/// An attribute that specifies the target version, allowed extensions and
@@ -183,6 +187,8 @@ class TargetEnvAttr
/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits() const;
+
+ static constexpr StringRef getAttrName() { return "spirv.target_env"; }
};
} // namespace spirv
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4be2582f8fd68cc..e536adc8dce4abb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -132,6 +132,8 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.array"; }
+
static ArrayType get(Type elementType, unsigned elementCount);
/// Returns an array type with the given stride in bytes.
@@ -162,6 +164,8 @@ class ImageType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.image"; }
+
static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
@@ -201,6 +205,8 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.pointer"; }
+
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType() const;
@@ -220,6 +226,8 @@ class RuntimeArrayType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.rtarray"; }
+
static RuntimeArrayType get(Type elementType);
/// Returns a runtime array type with the given stride in bytes.
@@ -244,6 +252,8 @@ class SampledImageType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.sampled_image"; }
+
static SampledImageType get(Type imageType);
static SampledImageType
@@ -288,6 +298,8 @@ class StructType
// Type for specifying the offset of the struct members
using OffsetInfo = uint32_t;
+ static constexpr StringRef getTypeName() { return "spirv.struct"; }
+
// Type for specifying the decoration(s) on struct members
struct MemberDecorationInfo {
uint32_t memberIndex : 31;
@@ -387,6 +399,8 @@ class CooperativeMatrixType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.coopmatrix"; }
+
static CooperativeMatrixType get(Type elementType, uint32_t rows,
uint32_t columns, Scope scope,
CooperativeMatrixUseKHR use);
@@ -414,6 +428,8 @@ class CooperativeMatrixNVType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.NV.coopmatrix"; }
+
static CooperativeMatrixNVType get(Type elementType, Scope scope,
unsigned rows, unsigned columns);
Type getElementType() const;
@@ -438,6 +454,8 @@ class JointMatrixINTELType
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.jointmatrix"; }
+
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
unsigned columns, MatrixLayout matrixLayout);
Type getElementType() const;
@@ -464,6 +482,8 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
public:
using Base::Base;
+ static constexpr StringRef getTypeName() { return "spirv.matrix"; }
+
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 42a611ee8e42205..91c9283de8bd415 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -264,6 +264,9 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
// Make it possible to use such attributes as parameters for other attributes.
string cppType = dialect.cppNamespace # "::" # cppClassName;
+ // The unique attribute name.
+ string attrName = dialect.name # "." # mnemonic;
+
// The call expression to convert from the storage type to the return
// type. For example, an enum can be stored as an int but returned as an
// enum class.
@@ -289,6 +292,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;
+ // The unique type name.
+ string typeName = dialect.name # "." # mnemonic;
+
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
@@ -431,15 +437,15 @@ class AttributeSelfTypeParameter<string desc,
/// This class defines an attribute that contains an array of elements. The
/// elements can be any type, but if they are attributes, the nested elements
/// are parsed and printed using the custom attribute syntax.
-class ArrayOfAttr<Dialect dialect, string attrName, string attrMnemonic,
+class ArrayOfAttr<Dialect dialect, string name, string attrMnemonic,
string eltName, list<Trait> traits = []>
- : AttrDef<dialect, attrName, traits> {
+ : AttrDef<dialect, name, traits> {
let parameters = (ins OptionalArrayRefParameter<eltName>:$value);
let mnemonic = attrMnemonic;
let assemblyFormat = "`[` (`]`) : ($value^ `]`)?";
let returnType = "::llvm::ArrayRef<" # eltName # ">";
- let constBuilderCall = "$_builder.getAttr<" # attrName # "Attr>($0)";
+ let constBuilderCall = "$_builder.getAttr<" # name # "Attr>($0)";
let convertFromStorage = "$_self.getValue()";
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 75ea1ce24753c97..104f5570e3bfb31 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -45,7 +45,7 @@ class AbstractAttribute {
return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
T::getWalkImmediateSubElementsFn(),
T::getReplaceImmediateSubElementsFn(),
- T::getTypeID());
+ T::getTypeID(), T::getAttrName());
}
/// This method is used by Dialect objects to register attributes with
@@ -57,10 +57,10 @@ class AbstractAttribute {
HasTraitFn &&hasTrait,
WalkImmediateSubElementsFn walkImmediateSubElementsFn,
ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
- TypeID typeID) {
+ TypeID typeID, StringRef name) {
return AbstractAttribute(dialect, std::move(interfaceMap),
std::move(hasTrait), walkImmediateSubElementsFn,
- replaceImmediateSubElementsFn, typeID);
+ replaceImmediateSubElementsFn, typeID, name);
}
/// Return the dialect this attribute was registered to.
@@ -102,17 +102,20 @@ class AbstractAttribute {
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }
+ /// Return the unique name representing the type.
+ StringRef getName() const { return name; }
+
private:
AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
HasTraitFn &&hasTraitFn,
WalkImmediateSubElementsFn walkImmediateSubElementsFn,
ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
- TypeID typeID)
+ TypeID typeID, StringRef name)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
hasTraitFn(std::move(hasTraitFn)),
walkImmediateSubElementsFn(walkImmediateSubElementsFn),
replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
- typeID(typeID) {}
+ typeID(typeID), name(name) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -141,6 +144,9 @@ class AbstractAttribute {
/// The unique identifier of the derived Attribute class.
const TypeID typeID;
+
+ /// The unique name of this type.
+ const StringRef name;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c8161604aad3503..c45473ceccfe371 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -1044,6 +1044,8 @@ class DistinctAttr
/// Creates a distinct attribute that associates a referenced attribute with a
/// unique identifier.
static DistinctAttr create(Attribute referencedAttr);
+
+ static constexpr StringRef getAttrName() { return "builtin.distinct"; }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 390231da662e2d4..f4a42ccf736bfea 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -64,6 +64,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
AffineMap getAffineMap() const { return getValue(); }
}];
let skipDefaultBuilders = 1;
+ let attrName = "builtin.affine_map";
}
//===----------------------------------------------------------------------===//
@@ -134,6 +135,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
});
}
}];
+ let attrName = "builtin.array";
}
//===----------------------------------------------------------------------===//
@@ -211,6 +213,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
/// Return true if there are no elements in the dense array.
bool empty() const { return !size(); }
}];
+ let attrName = "builtin.dense_array";
}
//===----------------------------------------------------------------------===//
@@ -352,6 +355,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
let genAccessors = 0;
let genStorageClass = 0;
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dense_int_or_fp_elements";
}
//===----------------------------------------------------------------------===//
@@ -423,6 +427,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
let genAccessors = 0;
let genStorageClass = 0;
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dense_string_elements";
}
//===----------------------------------------------------------------------===//
@@ -481,6 +486,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
];
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dense_resource_elements";
}
//===----------------------------------------------------------------------===//
@@ -579,6 +585,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
public:
}];
let skipDefaultBuilders = 1;
+ let attrName = "builtin.dictionary";
}
//===----------------------------------------------------------------------===//
@@ -642,6 +649,7 @@ def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> {
}];
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
+ let attrName = "builtin.float";
}
//===----------------------------------------------------------------------===//
@@ -730,6 +738,7 @@ def Builtin_IntegerAttr : Builtin_Attr...
[truncated]
|
mlir/lib/IR/MLIRContext.cpp
Outdated
@@ -212,6 +212,9 @@ class MLIRContextImpl { | |||
DenseMap<TypeID, AbstractType *> registeredTypes; | |||
StorageUniquer typeUniquer; | |||
|
|||
/// This is a mapping from type name to the abstract type describing it. | |||
llvm::StringMap<AbstractType *> nameToType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use DenseMap<StringAttr, AbstractType *> nameToType;
(more efficient storage, and more efficient lookup)?
Also when is it useful? (please add to the comment: the "what" is less important than the "why").
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh my bad, it got erased again in a following PR, I'll add it back now!
Sorry about that!
I moved Also, I added the change in Currently, I have a |
So my understanding is that For it to work, we would need to create a So for now I removed it, please tell me if you see a way I could make this work! |
@joker-eph @kuhar, does this PR sounds good to you now? |
Oh great... We can't add the name for Pretty sad about these string duplication, but we can live with it. Please document this next to the StringMap.
It's likely possible, but will require some deeper look. I may try this after you land it! |
Another solution would be to use a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
We could if we're confident this is enough. The API should then take a "StringLiteral" (it's more specific than a StringRef) I believe. |
I'm confident |
Well it's a low-level API, so let's document the lifetime constraint on the StringRef then! |
Oh yes my bad! I added documentation for that (and for the |
96b7196
to
bd8b02e
Compare
Congrats @math-fehr on landing a solid patch. I'll be using/abusing this soon for the Python bindings :) |
This patch expose the type and attribute names in C++ as methods in the
AbstractType
andAbstractAttribute
classes, and keep a map of names toAbstractType
andAbstractAttribute
in theMLIRContext
. Type and attribute names should be unique.It adds support in ODS to generate the
getName
methods inAbstractType
andAbstractAttribute
, through the use of two new variables,typeName
andattrName
. It also adds names to C++-defined type and attributes.