diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 798d3c84f9618..dced379d1f979 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -239,29 +239,48 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " - "integer or floating-point values"; + "values"; let description = [{ - Syntax: - - ``` - tensor-literal ::= integer-literal | float-literal | bool-literal | [] | [tensor-literal (, tensor-literal)* ] - dense-intorfloat-elements-attribute ::= `dense` `<` tensor-literal `>` `:` - ( tensor-type | vector-type ) - ``` - - A dense int-or-float elements attribute is an elements attribute containing - a densely packed vector or tensor of integer or floating-point values. The - element type of this attribute is required to be either an `IntegerType` or - a `FloatType`. + A dense elements attribute stores one or multiple elements of the same type. + The term "dense" refers to the fact that elements are not stored as + individual MLIR attributes, but in a raw buffer. The attribute provides a + covenience API to access elements in the form of MLIR attributes, but users + should avoid that API in performance-critical code and utilize APIs that + operate on raw bytes instead. + + The number of elements is determined by the `type` shaped type. (Unranked + shaped types are not supported.) The element type of the shaped type must + implement the `DenseElementType` interface. This type interface defines the + bitwidth of an element and provides a serializer/deserializer to/from MLIR + attributes. + + Storage format: Given an element bitwidth "w", element "i" starts at byte + offset "i * ceildiv(w, 8)". In other words, each element starts at a full + byte offset. + + TODO: The name `DenseIntOrFPElements` is no longer accurate. The attribute + will be renamed in the future. Examples: ``` - // A splat tensor of integer values. + // Literal-first syntax: A splat tensor of integer values. dense<10> : tensor<2xi32> - // A tensor of 2 float32 elements. + + // Literal-first syntax: A tensor of 2 float32 elements. dense<[10.0, 11.0]> : tensor<2xf32> + + // Type-first syntax: A splat tensor of integer values. + dense : 10 : i32> + + // Type-first syntax: A tensor of 2 float32 elements. + dense : [10.0, 11.0]> ``` + + Note: The literal-first syntax is supported only for complex, float, index, + int element types. The parser/print have special casing for these types. + Dense element attributes with other element types must use the type-first + syntax. }]; let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, "ArrayRef":$rawData); diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h index 5f14517d8dd71..9425d554b427c 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h @@ -19,6 +19,29 @@ struct fltSemantics; namespace mlir { class FloatType; class MLIRContext; + +namespace detail { +/// Float type implementation of +/// DenseElementTypeInterface::getDenseElementBitSize. +size_t getFloatTypeDenseElementBitSize(Type type); + +/// Float type implementation of DenseElementTypeInterface::convertToAttribute. +Attribute convertFloatTypeToAttribute(Type type, llvm::ArrayRef rawData); + +/// Float type implementation of +/// DenseElementTypeInterface::convertFromAttribute. +LogicalResult +convertFloatTypeFromAttribute(Type type, Attribute attr, + llvm::SmallVectorImpl &result); + +/// Read `bitWidth` bits from byte-aligned position in `rawData` and return as +/// an APInt. Handles endianness correctly. +llvm::APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth); + +/// Write `value` to byte-aligned position `bitPos` in `rawData`. Handles +/// endianness correctly. +void writeBits(char *rawData, size_t bitPos, llvm::APInt value); +} // namespace detail } // namespace mlir #include "mlir/IR/BuiltinTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 9ef08b7020b99..93c8c0694b467 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -41,12 +41,70 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> { }]; } +//===----------------------------------------------------------------------===// +// DenseElementTypeInterface +//===----------------------------------------------------------------------===// + +def DenseElementTypeInterface : TypeInterface<"DenseElementType"> { + let cppNamespace = "::mlir"; + let description = [{ + This interface allows custom types to be used as element types in + DenseElementsAttr. Types implementing this interface define: + + 1. The bit size for element storage. + 2. Helper methods for converting from/to Attribute. This assumes that there + is a corresponding attribute for each type that implements this + interface. + + The helper methods for converting from/to Attribute are utilized when + parsing/printing IR or iterating over the elements via Attribute. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the number of bits required to store one element in dense + storage. + + Note: The DenseElementsAttr infrastructure will automatically align + every element to a full byte in storage. This limitation could be lifted + in the future to support dense packing of non-byte-sized elements. + }], + /*retTy=*/"size_t", + /*methodName=*/"getDenseElementBitSize", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Attribute deserialization / attribute factory: Convert raw storage bytes + into an MLIR attribute. The size of `rawData` is + "ceilDiv(getDenseElementBitSize(), 8)". + }], + /*retTy=*/"::mlir::Attribute", + /*methodName=*/"convertToAttribute", + /*args=*/(ins "::llvm::ArrayRef":$rawData) + >, + InterfaceMethod< + /*desc=*/[{ + Attribute serialization: Convert an MLIR attribute into raw bytes. + Implementations must append "getDenseElementBitSize() / 8" values to + `result`. Return "failure" if the attribute is incompatible with this + element type. + }], + /*retTy=*/"::llvm::LogicalResult", + /*methodName=*/"convertFromAttribute", + /*args=*/(ins "::mlir::Attribute":$attr, + "::llvm::SmallVectorImpl&":$result) + >, + ]; +} + //===----------------------------------------------------------------------===// // FloatTypeInterface //===----------------------------------------------------------------------===// def FloatTypeInterface : TypeInterface<"FloatType", - [VectorElementTypeInterface]> { + [DenseElementTypeInterface, VectorElementTypeInterface]> { let cppNamespace = "::mlir"; let description = [{ This type interface should be implemented by all floating-point types. It @@ -83,6 +141,21 @@ def FloatTypeInterface : TypeInterface<"FloatType", /// The width includes the integer bit. unsigned getFPMantissaWidth(); }]; + + let extraTraitClassDeclaration = [{ + /// DenseElementTypeInterface implementations for float types. + size_t getDenseElementBitSize() const { + return ::mlir::detail::getFloatTypeDenseElementBitSize($_type); + } + ::mlir::Attribute convertToAttribute(::llvm::ArrayRef rawData) const { + return ::mlir::detail::convertFloatTypeToAttribute($_type, rawData); + } + ::llvm::LogicalResult + convertFromAttribute(::mlir::Attribute attr, + ::llvm::SmallVectorImpl &result) const { + return ::mlir::detail::convertFloatTypeFromAttribute($_type, attr, result); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 806064faeda00..e7d0a03a85e7d 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -45,7 +45,10 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> { // ComplexType //===----------------------------------------------------------------------===// -def Builtin_Complex : Builtin_Type<"Complex", "complex"> { +def Builtin_Complex : Builtin_Type<"Complex", "complex", + [DeclareTypeInterfaceMethods + ]> { let summary = "Complex number with a parameterized element type"; let description = [{ Syntax: @@ -560,7 +563,9 @@ def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">; //===----------------------------------------------------------------------===// def Builtin_Index : Builtin_Type<"Index", "index", - [VectorElementTypeInterface]> { + [DeclareTypeInterfaceMethods, + VectorElementTypeInterface]> { let summary = "Integer-like type with unknown platform-dependent bit width"; let description = [{ Syntax: @@ -591,7 +596,10 @@ def Builtin_Index : Builtin_Type<"Index", "index", //===----------------------------------------------------------------------===// def Builtin_Integer : Builtin_Type<"Integer", "integer", - [VectorElementTypeInterface, QuantStorageTypeInterface]> { + [VectorElementTypeInterface, QuantStorageTypeInterface, + DeclareTypeInterfaceMethods]> { let summary = "Integer type with arbitrary precision up to a fixed limit"; let description = [{ Syntax: diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h index 20e2dfb453ef7..e24dd2676cfa3 100644 --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -192,12 +192,13 @@ class InterfaceMap { template static InterfaceMap get() { constexpr size_t numInterfaces = num_interface_types_t::value; - if constexpr (numInterfaces == 0) + if constexpr (numInterfaces == 0) { return InterfaceMap(); - - InterfaceMap map; - (map.insertPotentialInterface(), ...); - return map; + } else { + InterfaceMap map; + map.insertPotentialInterfaces(); + return map; + } } /// Returns an instance of the concept object for the given interface if it @@ -217,6 +218,18 @@ class InterfaceMap { } private: + /// Insert the given interface types into the map (recursive expansion to + /// guarantee sequential, left-to-right evaluation across all compilers). + template + void insertPotentialInterfaces() { + insertPotentialInterface(); + } + template + void insertPotentialInterfaces() { + insertPotentialInterface(); + insertPotentialInterfaces(); + } + /// Insert the given interface type into the map, ignoring it if it doesn't /// actually represent an interface. template diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 5978a11d06bc9..dc9744a42b730 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" @@ -953,6 +954,119 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) { return eltParser.getAttr(); } +/// Try to parse a dense elements attribute with the type-first syntax. +/// Syntax: dense +/// This syntax is used for types other than int, float, index and complex. +/// +/// Returns: +/// - "null" attribute if this is not the type-first syntax. +/// - "failure" in case of a parse error. +/// - A valid Attribute otherwise. +static FailureOr parseDenseElementsAttrTyped(Parser &p, SMLoc loc) { + // Skip l_paren because "parseType" would try to parse it as a tuple/function + // type, but '(' starts a complex literal like in the literal-first syntax. + if (p.getToken().is(Token::l_paren)) + return Attribute(); + + // Parse type and valdiate that it's a shaped type. + auto typeLoc = p.getToken().getLoc(); + Type type; + OptionalParseResult typeResult = p.parseOptionalType(type); + if (!typeResult.has_value()) + return Attribute(); // Not type-first syntax. + if (failed(*typeResult)) + return failure(); // Type parse error. + + auto shapedType = dyn_cast(type); + if (!shapedType) { + p.emitError(typeLoc, "expected a shaped type for dense elements"); + return failure(); + } + if (!shapedType.hasStaticShape()) { + p.emitError(typeLoc, "dense elements type must have static shape"); + return failure(); + } + + // Check that the element type implements DenseElementTypeInterface. + auto denseEltType = dyn_cast(shapedType.getElementType()); + if (!denseEltType) { + p.emitError(typeLoc, + "element type must implement DenseElementTypeInterface " + "for type-first dense syntax"); + return failure(); + } + + // Parse colon. + if (p.parseToken(Token::colon, "expected ':' after type in dense attribute")) + return failure(); + + // Parse the element attributes and convert to raw bytes. + SmallVector rawData; + + // Helper to parse a single element. + auto parseSingleElement = [&]() -> ParseResult { + Attribute elemAttr = p.parseAttribute(); + if (!elemAttr) + return failure(); + if (failed(denseEltType.convertFromAttribute(elemAttr, rawData))) { + p.emitError("incompatible attribute for element type"); + return failure(); + } + return success(); + }; + + // Recursively parse elements matching the expected shape. + std::function)> parseElements; + parseElements = [&](ArrayRef remainingShape) -> ParseResult { + // Leaf: parse a single element. + if (remainingShape.empty()) + return parseSingleElement(); + + // Non-leaf: expect a list with the correct number of elements. + int64_t expectedCount = remainingShape.front(); + ArrayRef innerShape = remainingShape.drop_front(); + int64_t actualCount = 0; + + auto parseOne = [&]() -> ParseResult { + if (parseElements(innerShape)) + return failure(); + ++actualCount; + return success(); + }; + + if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOne)) + return failure(); + + if (actualCount != expectedCount) { + p.emitError() << "expected " << expectedCount + << " elements in dimension, got " << actualCount; + return failure(); + } + return success(); + }; + + // Parse elements. + if (!p.getToken().is(Token::l_square)) { + // Single element - parse as splat. + if (parseSingleElement()) + return failure(); + } else if (shapedType.getShape().empty()) { + // Scalar type shouldn't have a list. + p.emitError(loc, "expected single element for scalar type, got list"); + return failure(); + } else { + // Parse structured literal matching the shape. + if (parseElements(shapedType.getShape())) + return failure(); + } + + if (p.parseToken(Token::greater, "expected '>' to close dense attribute")) + return failure(); + + // Create the attribute from raw buffer. + return DenseElementsAttr::getFromRawBuffer(shapedType, rawData); +} + /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { auto attribLoc = getToken().getLoc(); @@ -960,7 +1074,16 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) { if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - // Parse the literal data if necessary. + // Try to parse the type-first syntax: dense + FailureOr typedResult = + parseDenseElementsAttrTyped(*this, attribLoc); + if (failed(typedResult)) + return nullptr; + if (*typedResult) + return *typedResult; + + // Try to parse the literal-first syntax, which is the default format for + // int, float, index and complex element types. TensorLiteralParser literalParser(*this); if (!consumeIf(Token::greater)) { if (literalParser.parse(/*allowHex=*/true) || diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 81455699421cc..b3242f838fc1d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -507,11 +507,18 @@ class AsmPrinter::Impl { /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); - /// Print a dense elements attribute. If 'allowHex' is true, a hex string is - /// used instead of individual elements when the elements attr is large. + /// Print a dense elements attribute in the literal-first syntax. If + /// 'allowHex' is true, a hex string is used instead of individual elements + /// when the elements attr is large. void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, bool allowHex); + /// Print a dense elements attribute using the type-first syntax and the + /// DenseElementTypeInterface, which provides the attribute printer for each + /// element. + void printTypeFirstDenseElementsAttr(DenseElementsAttr attr, + DenseElementType denseEltType); + /// Print a dense array attribute. void printDenseArrayAttr(DenseArrayAttr attr); @@ -2507,7 +2514,17 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr, printElidedElementsAttr(os); } else { os << "dense<"; - printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); + // Check if the element type implements DenseElementTypeInterface and is + // not a built-in type. Built-in types (int, float, index, complex) use + // the existing printing format for backwards compatibility. + Type eltType = intOrFpEltAttr.getElementType(); + if (isa(eltType)) { + printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); + } else { + printTypeFirstDenseElementsAttr(intOrFpEltAttr, + cast(eltType)); + typeElision = AttrTypeElision::Must; + } os << '>'; } @@ -2705,6 +2722,27 @@ void AsmPrinter::Impl::printDenseStringElementsAttr( printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } +void AsmPrinter::Impl::printTypeFirstDenseElementsAttr( + DenseElementsAttr attr, DenseElementType denseEltType) { + // Print the type first: dense + printType(attr.getType()); + os << " : "; + + ArrayRef rawData = attr.getRawData(); + // Storage is byte-aligned: align bit size up to next byte boundary. + size_t bitSize = denseEltType.getDenseElementBitSize(); + size_t byteSize = llvm::divideCeil(bitSize, static_cast(CHAR_BIT)); + + // Print elements: convert raw bytes to attribute, then print attribute. + printDenseElementsAttrImpl( + attr.isSplat(), attr.getType(), os, [&](unsigned index) { + size_t offset = attr.isSplat() ? 0 : index * byteSize; + ArrayRef elemData = rawData.slice(offset, byteSize); + Attribute elemAttr = denseEltType.convertToAttribute(elemData); + printAttributeImpl(elemAttr); + }); +} + void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { Type type = attr.getElementType(); unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 1f268603cf37f..8505149afdd9c 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -16,6 +16,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/AttributeSupport.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" @@ -32,12 +33,9 @@ namespace detail { /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { - // Align the width for complex to 8 to make storage and interpretation easier. - if (ComplexType comp = llvm::dyn_cast(eltType)) - return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; - if (eltType.isIndex()) - return IndexType::kInternalStorageBitWidth; - return eltType.getIntOrFloatBitWidth(); + if (auto denseEltType = llvm::dyn_cast(eltType)) + return denseEltType.getDenseElementBitSize(); + llvm_unreachable("unsupported element type"); } /// An attribute representing a reference to a dense vector or tensor object. diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 1a29fc534b40f..bbbc9198a68ab 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -10,6 +10,7 @@ #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" @@ -527,7 +528,7 @@ static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, } /// Writes value to the bit position `bitPos` in array `rawData`. -static void writeBits(char *rawData, size_t bitPos, APInt value) { +void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) { size_t bitWidth = value.getBitWidth(); // The bit position is guaranteed to be byte aligned. @@ -549,7 +550,8 @@ static void writeBits(char *rawData, size_t bitPos, APInt value) { /// Reads the next `bitWidth` bits from the bit position `bitPos` in array /// `rawData`. -static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { +APInt mlir::detail::readBits(const char *rawData, size_t bitPos, + size_t bitWidth) { // The bit position is guaranteed to be byte aligned. assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); APInt result(bitWidth, 0); @@ -595,39 +597,21 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = llvm::cast(getFromOpaquePointer(base)); Type eltTy = owner.getElementType(); - if (llvm::dyn_cast(eltTy)) - return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (llvm::isa(eltTy)) - return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (auto floatEltTy = llvm::dyn_cast(eltTy)) { - IntElementIterator intIt(owner, index); - FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); - return FloatAttr::get(eltTy, *floatIt); - } - if (auto complexTy = llvm::dyn_cast(eltTy)) { - auto complexEltTy = complexTy.getElementType(); - ComplexIntElementIterator complexIntIt(owner, index); - if (llvm::isa(complexEltTy)) { - auto value = *complexIntIt; - auto real = IntegerAttr::get(complexEltTy, value.real()); - auto imag = IntegerAttr::get(complexEltTy, value.imag()); - return ArrayAttr::get(complexTy.getContext(), - ArrayRef{real, imag}); - } - ComplexFloatElementIterator complexFloatIt( - llvm::cast(complexEltTy).getFloatSemantics(), complexIntIt); - auto value = *complexFloatIt; - auto real = FloatAttr::get(complexEltTy, value.real()); - auto imag = FloatAttr::get(complexEltTy, value.imag()); - return ArrayAttr::get(complexTy.getContext(), - ArrayRef{real, imag}); - } + // Handle strings specially. if (llvm::isa(owner)) { ArrayRef vals = owner.getRawStringData(); return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); } - llvm_unreachable("unexpected element type"); + + // All other types should implement DenseElementTypeInterface. + auto denseEltTy = llvm::cast(eltTy); + ArrayRef rawData = owner.getRawData(); + // Storage is byte-aligned: align bit size up to next byte boundary. + size_t bitSize = denseEltTy.getDenseElementBitSize(); + size_t byteSize = llvm::divideCeil(bitSize, CHAR_BIT); + size_t offset = owner.isSplat() ? 0 : index * byteSize; + return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize)); } //===----------------------------------------------------------------------===// @@ -888,79 +872,28 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, assert(hasSameNumElementsOrSplat(type, values)); Type eltType = type.getElementType(); - // Take care complex type case first. - if (auto complexType = llvm::dyn_cast(eltType)) { - if (complexType.getElementType().isIntOrIndex()) { - SmallVector> complexValues; - complexValues.reserve(values.size()); - for (Attribute attr : values) { - assert(llvm::isa(attr) && "expected ArrayAttr for complex"); - auto arrayAttr = llvm::cast(attr); - assert(arrayAttr.size() == 2 && "expected 2 element for complex"); - auto attr0 = arrayAttr[0]; - auto attr1 = arrayAttr[1]; - complexValues.push_back( - std::complex(llvm::cast(attr0).getValue(), - llvm::cast(attr1).getValue())); - } - return DenseElementsAttr::get(type, complexValues); - } - // Must be float. - SmallVector> complexValues; - complexValues.reserve(values.size()); - for (Attribute attr : values) { - assert(llvm::isa(attr) && "expected ArrayAttr for complex"); - auto arrayAttr = llvm::cast(attr); - assert(arrayAttr.size() == 2 && "expected 2 element for complex"); - auto attr0 = arrayAttr[0]; - auto attr1 = arrayAttr[1]; - complexValues.push_back( - std::complex(llvm::cast(attr0).getValue(), - llvm::cast(attr1).getValue())); - } - return DenseElementsAttr::get(type, complexValues); - } - - // If the element type is not based on int/float/index, assume it is a string - // type. - if (!eltType.isIntOrIndexOrFloat()) { + // Handle strings specially. + if (!llvm::isa(eltType)) { SmallVector stringValues; stringValues.reserve(values.size()); for (Attribute attr : values) { assert(llvm::isa(attr) && - "expected string value for non integer/index/float element"); + "expected string value for non-DenseElementType element"); stringValues.push_back(llvm::cast(attr).getValue()); } return get(type, stringValues); } - // Otherwise, get the raw storage width to use for the allocation. - size_t bitWidth = getDenseElementBitWidth(eltType); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - - // Compress the attribute values into a character buffer. - SmallVector data( - llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); - APInt intVal; - for (unsigned i = 0, e = values.size(); i < e; ++i) { - if (auto floatAttr = llvm::dyn_cast(values[i])) { - assert(floatAttr.getType() == eltType && - "expected float attribute type to equal element type"); - intVal = floatAttr.getValue().bitcastToAPInt(); - } else if (auto intAttr = llvm::dyn_cast(values[i])) { - assert(intAttr.getType() == eltType && - "expected integer attribute type to equal element type"); - intVal = intAttr.getValue(); - } else { - // Unsupported attribute type. + // All other types go through DenseElementTypeInterface. + auto denseEltType = llvm::dyn_cast(eltType); + assert(denseEltType && + "attempted to get DenseElementsAttr with unsupported element type"); + SmallVector data; + for (Attribute attr : values) { + LogicalResult result = denseEltType.convertFromAttribute(attr, data); + if (failed(result)) return {}; - } - - assert(intVal.getBitWidth() == bitWidth && - "expected value to have same bitwidth as element type"); - writeBits(data.data(), i * storageBitWidth, intVal); } - return DenseIntOrFPElementsAttr::getRaw(type, data); } diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp index 2f063be3e7cd0..29303d95eb003 100644 --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -6,9 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/CheckedArithmetic.h" +#include "llvm/Support/MathExtras.h" +#include using namespace mlir; using namespace mlir::detail; @@ -19,6 +22,37 @@ using namespace mlir::detail; #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" +//===----------------------------------------------------------------------===// +// DenseElementTypeInterface implementations for float types +//===----------------------------------------------------------------------===// + +size_t mlir::detail::getFloatTypeDenseElementBitSize(Type type) { + return cast(type).getWidth(); +} + +Attribute mlir::detail::convertFloatTypeToAttribute(Type type, + ArrayRef rawData) { + auto floatType = cast(type); + APInt intVal = readBits(rawData.data(), /*bitPos=*/0, floatType.getWidth()); + APFloat floatVal(floatType.getFloatSemantics(), intVal); + return FloatAttr::get(type, floatVal); +} + +LogicalResult +mlir::detail::convertFloatTypeFromAttribute(Type type, Attribute attr, + SmallVectorImpl &result) { + auto floatType = cast(type); + auto floatAttr = dyn_cast(attr); + if (!floatAttr || floatAttr.getType() != type) + return failure(); + size_t byteSize = + llvm::divideCeil(floatType.getWidth(), static_cast(CHAR_BIT)); + size_t bitPos = result.size() * CHAR_BIT; + result.resize(result.size() + byteSize); + writeBits(result.data(), bitPos, floatAttr.getValue().bitcastToAPInt()); + return success(); +} + //===----------------------------------------------------------------------===// // FloatType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 1e198043c590a..786c30851a071 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -12,14 +12,17 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TensorEncoding.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CheckedArithmetic.h" +#include using namespace mlir; using namespace mlir::detail; @@ -58,6 +61,39 @@ LogicalResult ComplexType::verify(function_ref emitError, return success(); } +size_t ComplexType::getDenseElementBitSize() const { + auto elemTy = cast(getElementType()); + return llvm::alignTo<8>(elemTy.getDenseElementBitSize()) * 2; +} + +Attribute ComplexType::convertToAttribute(ArrayRef rawData) const { + auto elemTy = cast(getElementType()); + size_t singleElementBytes = + llvm::alignTo<8>(elemTy.getDenseElementBitSize()) / 8; + Attribute real = + elemTy.convertToAttribute(rawData.take_front(singleElementBytes)); + Attribute imag = + elemTy.convertToAttribute(rawData.take_back(singleElementBytes)); + return ArrayAttr::get(getContext(), {real, imag}); +} + +LogicalResult +ComplexType::convertFromAttribute(Attribute attr, + SmallVectorImpl &result) const { + auto arrayAttr = dyn_cast(attr); + if (!arrayAttr || arrayAttr.size() != 2) + return failure(); + auto elemTy = cast(getElementType()); + SmallVector realData, imagData; + if (failed(elemTy.convertFromAttribute(arrayAttr[0], realData))) + return failure(); + if (failed(elemTy.convertFromAttribute(arrayAttr[1], imagData))) + return failure(); + result.append(realData); + result.append(imagData); + return success(); +} + //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// @@ -85,6 +121,57 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); } +size_t IntegerType::getDenseElementBitSize() const { + // Return the actual bit width. Storage alignment is handled separately. + return getWidth(); +} + +Attribute IntegerType::convertToAttribute(ArrayRef rawData) const { + APInt value = detail::readBits(rawData.data(), /*bitPos=*/0, getWidth()); + return IntegerAttr::get(*this, value); +} + +static void writeAPIntToVector(APInt apInt, SmallVectorImpl &result) { + size_t byteSize = llvm::divideCeil(apInt.getBitWidth(), CHAR_BIT); + size_t bitPos = result.size() * CHAR_BIT; + result.resize(result.size() + byteSize); + detail::writeBits(result.data(), bitPos, apInt); +} + +LogicalResult +IntegerType::convertFromAttribute(Attribute attr, + SmallVectorImpl &result) const { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getType() != *this) + return failure(); + writeAPIntToVector(intAttr.getValue(), result); + return success(); +} + +//===----------------------------------------------------------------------===// +// Index Type +//===----------------------------------------------------------------------===// + +size_t IndexType::getDenseElementBitSize() const { + return kInternalStorageBitWidth; +} + +Attribute IndexType::convertToAttribute(ArrayRef rawData) const { + APInt value = + detail::readBits(rawData.data(), /*bitPos=*/0, kInternalStorageBitWidth); + return IntegerAttr::get(*this, value); +} + +LogicalResult +IndexType::convertFromAttribute(Attribute attr, + SmallVectorImpl &result) const { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getType() != *this) + return failure(); + writeAPIntToVector(intAttr.getValue(), result); + return success(); +} + //===----------------------------------------------------------------------===// // Float Types //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir new file mode 100644 index 0000000000000..8749e562087c2 --- /dev/null +++ b/mlir/test/IR/dense-elements-type-interface.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s + +// Test dense elements attribute with custom element type using DenseElementTypeInterface. +// Uses the new type-first syntax: dense +// Note: The type is embedded in the attribute, so it's not printed again at the end. + +// CHECK-LABEL: func @dense_custom_element_type +func.func @dense_custom_element_type() { + // CHECK: "test.dummy"() {attr = dense : [1 : i32, 2 : i32, 3 : i32]>} + "test.dummy"() {attr = dense : [1 : i32, 2 : i32, 3 : i32]>} : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @dense_custom_element_type_2d +func.func @dense_custom_element_type_2d() { + // CHECK: "test.dummy"() {attr = dense : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} + "test.dummy"() {attr = dense : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @dense_custom_element_splat +func.func @dense_custom_element_splat() { + // CHECK: "test.dummy"() {attr = dense : 42 : i32>} + "test.dummy"() {attr = dense : 42 : i32>} : () -> () + return +} + +// ----- + +// CHECK-LABEL func @dense_i32_1d +func.func @dense_i32_1d() { + // The default assembly format for int, index, float, complex element types is + // the literal-first syntax. Such a dense elements attribute can be parsed + // with the type-first syntax, but it will come back with the literal-first + // syntax. + // CHECK: "test.dummy"() {attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> () + "test.dummy"() {attr = dense : [1 : i32, 2 : i32, 3 : i32]>} : () -> () + return +} + +// ----- + +func.func @invalid_element() { + // expected-error @+1 {{expected attribute value}} + "test.dummy"() {attr = dense : [foo]>} : () -> () + return +} + +// ----- + +func.func @incompatible_attribute() { + // expected-error @+1 {{incompatible attribute for element type}} + "test.dummy"() {attr = dense : ["foo"]>} : () -> () + return +} + +// ----- + +func.func @shape_mismatch() { + // expected-error @+1 {{expected 3 elements in dimension, got 2}} + "test.dummy"() {attr = dense : [1 : i32, 2 : i32]>} : () -> () + return +} + +// ----- + +func.func @dynamic_shape() { + // expected-error @+1 {{dense elements type must have static shape}} + "test.dummy"() {attr = dense : [1 : i32, 2 : i32, 3 : i32]>} : () -> () + return +} + +// ----- + +func.func @invalid_type() { + // expected-error @+1 {{expected a shaped type for dense elements}} + "test.dummy"() {attr = dense} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 964792ceebc07..08600ce713a17 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -18,6 +18,7 @@ include "TestDialect.td" include "TestAttrDefs.td" include "TestInterfaces.td" include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" @@ -512,4 +513,15 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> { let hasCustomAssemblyFormat = 1; } +def TestTypeDenseElement : Test_Type<"TestDenseElement", + [DeclareTypeInterfaceMethods + ]> { + let mnemonic = "dense_element"; + let description = [{ + A test type that implements DenseElementTypeInterface to test dense + elements with custom element types. Elements are stored as 32-bit integers. + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 71dd25b0093e0..ef3396fc4f610 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -15,6 +15,7 @@ #include "TestDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/Types.h" @@ -22,6 +23,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/TypeSize.h" +#include #include using namespace mlir; @@ -605,3 +607,29 @@ void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const { printer.printNewline(); printer << ">"; } + +//===----------------------------------------------------------------------===// +// TestDenseElementType - DenseElementTypeInterface Implementation +//===----------------------------------------------------------------------===// + +// Elements are stored as 32-bit integers. +size_t TestDenseElementType::getDenseElementBitSize() const { return 32; } + +Attribute +TestDenseElementType::convertToAttribute(ArrayRef rawData) const { + assert(rawData.size() == 4 && "expected 4 bytes for TestDenseElement"); + int32_t value; + std::memcpy(&value, rawData.data(), sizeof(value)); + return IntegerAttr::get(IntegerType::get(getContext(), 32), value); +} + +LogicalResult TestDenseElementType::convertFromAttribute( + Attribute attr, SmallVectorImpl &result) const { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getType().getIntOrFloatBitWidth() != 32) + return failure(); + int32_t value = intAttr.getValue().getSExtValue(); + result.append(reinterpret_cast(&value), + reinterpret_cast(&value) + sizeof(value)); + return success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index 6499a96f495d0..705fb86e9e9b3 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -19,6 +19,7 @@ #include "TestTraits.h" #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h"