60 changes: 60 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SubElementInterfaces.td"

// TODO: Currently the attributes defined in this file are prefixed with
Expand Down Expand Up @@ -424,6 +425,65 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//

def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
ElementsAttrInterface, TypedAttrInterface
]> {
let summary = "An Attribute containing a dense multi-dimensional array "
"backed by a resource";
let description = [{
Syntax:

```
dense-resource-elements-attribute ::=
`dense_resource` `<` resource-handle `>` `:` shaped-type
```

A dense resource elements attribute is an elements attribute backed by a
handle to a builtin dialect resource containing a densely packed array of
values. This class provides the low-level attribute, which should only be
interacted with in very generic terms, actual access to the underlying
resource data is intended to be managed through one of the subclasses, such
as; `DenseBoolResourceElementsAttr`, `DenseUI64ResourceElementsAttr`,
`DenseI32ResourceElementsAttr`, `DenseF32ResourceElementsAttr`,
`DenseF64ResourceElementsAttr`, etc.

Examples:

```mlir
// A tensor referencing a builtin dialect resource, `resource_1`, with two
// unsigned i32 elements.
dense_resource<resource_1> : tensor<2xui32>
```
}];
let parameters = (ins
AttributeSelfTypeParameter<"", "ShapedType">:$type,
ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle
);
let builders = [
AttrBuilderWithInferredContext<(ins
"ShapedType":$type, "DenseResourceElementsHandle":$handle
)>
];
let extraClassDeclaration = [{
protected:
/// A builder that inserts a new resource into the builtin dialect's blob
/// manager using the provided blob. The handle of the inserted blob is used
/// when building the attribute. The provided `blobName` is used as a hint
/// for the key of the new handle for the `blob` resource, but may be
/// changed if necessary to ensure uniqueness during insertion.
static DenseResourceElementsAttr get(
ShapedType type, StringRef blobName, AsmResourceBlob blob
);

public:
}];
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 12 additions & 10 deletions mlir/include/mlir/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ class Dialect {

/// Lookup an interface for the given ID if one is registered, otherwise
/// nullptr.
const DialectInterface *getRegisteredInterface(TypeID interfaceID) {
DialectInterface *getRegisteredInterface(TypeID interfaceID) {
auto it = registeredInterfaces.find(interfaceID);
return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
}
template <typename InterfaceT>
const InterfaceT *getRegisteredInterface() {
return static_cast<const InterfaceT *>(
InterfaceT *getRegisteredInterface() {
return static_cast<InterfaceT *>(
getRegisteredInterface(InterfaceT::getInterfaceID()));
}

Expand All @@ -189,6 +189,12 @@ class Dialect {
(void)std::initializer_list<int>{
0, (addInterface(std::make_unique<Args>(this)), 0)...};
}
template <typename InterfaceT, typename... Args>
InterfaceT &addInterface(Args &&...args) {
InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
addInterface(std::unique_ptr<DialectInterface>(interface));
return *interface;
}

protected:
/// The constructor takes a unique namespace for this dialect as well as the
Expand Down Expand Up @@ -305,15 +311,11 @@ struct isa_impl<
};
template <typename T>
struct cast_retty_impl<T, ::mlir::Dialect *> {
using ret_type =
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
const T *>;
using ret_type = T *;
};
template <typename T>
struct cast_retty_impl<T, ::mlir::Dialect> {
using ret_type =
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
const T &>;
using ret_type = T &;
};

template <typename T>
Expand All @@ -325,7 +327,7 @@ struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
}
template <typename To>
static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
const To &>
To &>
doitImpl(::mlir::Dialect &dialect) {
return *dialect.getRegisteredInterface<To>();
}
Expand Down
215 changes: 215 additions & 0 deletions mlir/include/mlir/IR/DialectResourceBlobManager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
//===- DialectResourceBlobManager.h - Dialect Blob Management ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines utility classes for referencing and managing asm resource
// blobs. These classes are intended to more easily facilitate the sharing of
// large blobs, and their definition.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H
#define MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H

#include "mlir/IR/AsmState.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/SMLoc.h"

namespace mlir {
//===----------------------------------------------------------------------===//
// DialectResourceBlobManager
//===---------------------------------------------------------------------===//

/// This class defines a manager for dialect resource blobs. Blobs are uniqued
/// by a given key, and represented using AsmResourceBlobs.
class DialectResourceBlobManager {
public:
/// The class represents an individual entry of a blob.
class BlobEntry {
public:
/// Return the key used to reference this blob.
StringRef getKey() const { return key; }

/// Return the blob owned by this entry if one has been initialized. Returns
/// nullptr otherwise.
const AsmResourceBlob *getBlob() const { return blob ? &*blob : nullptr; }
AsmResourceBlob *getBlob() { return blob ? &*blob : nullptr; }

/// Set the blob owned by this entry.
void setBlob(AsmResourceBlob &&newBlob) { blob = std::move(newBlob); }

private:
BlobEntry() = default;
BlobEntry(BlobEntry &&) = default;
BlobEntry &operator=(const BlobEntry &) = delete;
BlobEntry &operator=(BlobEntry &&) = delete;

/// Initialize this entry with the given key and blob.
void initialize(StringRef newKey, Optional<AsmResourceBlob> newBlob) {
key = newKey;
blob = std::move(newBlob);
}

/// The key used for this blob.
StringRef key;

/// The blob that is referenced by this entry if it is valid.
Optional<AsmResourceBlob> blob;

/// Allow access to the constructors.
friend DialectResourceBlobManager;
friend class llvm::StringMapEntryStorage<BlobEntry>;
};

/// Return the blob registered for the given name, or nullptr if no blob
/// is registered.
BlobEntry *lookup(StringRef name);
const BlobEntry *lookup(StringRef name) const {
return const_cast<DialectResourceBlobManager *>(this)->lookup(name);
}

/// Update the blob for the entry defined by the provided name. This method
/// asserts that an entry for the given name exists in the manager.
void update(StringRef name, AsmResourceBlob &&newBlob);

/// Insert a new entry with the provided name and optional blob data. The name
/// may be modified during insertion if another entry already exists with that
/// name. Returns the inserted entry.
BlobEntry &insert(StringRef name, Optional<AsmResourceBlob> blob = {});
/// Insertion method that returns a dialect specific handle to the inserted
/// entry.
template <typename HandleT>
HandleT insert(typename HandleT::Dialect *dialect, StringRef name,
Optional<AsmResourceBlob> blob = {}) {
BlobEntry &entry = insert(name, std::move(blob));
return HandleT(&entry, dialect);
}

private:
/// A mutex to protect access to the blob map.
llvm::sys::SmartRWMutex<true> blobMapLock;

/// The internal map of tracked blobs. StringMap stores entries in distinct
/// allocations, so we can freely take references to the data without fear of
/// invalidation during additional insertion/deletion.
llvm::StringMap<BlobEntry> blobMap;
};

//===----------------------------------------------------------------------===//
// ResourceBlobManagerDialectInterface
//===---------------------------------------------------------------------===//

/// This class implements a dialect interface that provides common functionality
/// for interacting with a resource blob manager.
class ResourceBlobManagerDialectInterface
: public DialectInterface::Base<ResourceBlobManagerDialectInterface> {
public:
ResourceBlobManagerDialectInterface(Dialect *dialect)
: Base(dialect),
blobManager(std::make_shared<DialectResourceBlobManager>()) {}

/// Return the blob manager held by this interface.
DialectResourceBlobManager &getBlobManager() { return *blobManager; }
const DialectResourceBlobManager &getBlobManager() const {
return *blobManager;
}

/// Set the blob manager held by this interface.
void
setBlobManager(std::shared_ptr<DialectResourceBlobManager> newBlobManager) {
blobManager = std::move(newBlobManager);
}

private:
/// The blob manager owned by the dialect implementing this interface.
std::shared_ptr<DialectResourceBlobManager> blobManager;
};

/// This class provides a base class for dialects implementing the resource blob
/// interface. It provides several additional dialect specific utilities on top
/// of the generic interface. `HandleT` is the type of the handle used to
/// reference a resource blob.
template <typename HandleT>
class ResourceBlobManagerDialectInterfaceBase
: public ResourceBlobManagerDialectInterface {
public:
using ResourceBlobManagerDialectInterface::
ResourceBlobManagerDialectInterface;

/// Update the blob for the entry defined by the provided name. This method
/// asserts that an entry for the given name exists in the manager.
void update(StringRef name, AsmResourceBlob &&newBlob) {
getBlobManager().update(name, std::move(newBlob));
}

/// Insert a new resource blob entry with the provided name and optional blob
/// data. The name may be modified during insertion if another entry already
/// exists with that name. Returns a dialect specific handle to the inserted
/// entry.
HandleT insert(StringRef name, Optional<AsmResourceBlob> blob = {}) {
return getBlobManager().template insert<HandleT>(
cast<typename HandleT::Dialect>(getDialect()), name, std::move(blob));
}

/// Build resources for each of the referenced blobs within this manager.
void buildResources(AsmResourceBuilder &provider,
ArrayRef<AsmDialectResourceHandle> referencedResources) {
for (const AsmDialectResourceHandle &handle : referencedResources) {
if (const auto *dialectHandle = dyn_cast<HandleT>(&handle)) {
if (auto *blob = dialectHandle->getBlob())
provider.buildBlob(dialectHandle->getKey(), *blob);
}
}
}
};

//===----------------------------------------------------------------------===//
// DialectResourceBlobHandle
//===----------------------------------------------------------------------===//

/// This class defines a dialect specific handle to a resource blob. These
/// handles utilize a StringRef for the internal key, and an AsmResourceBlob as
/// the underlying data.
template <typename DialectT>
struct DialectResourceBlobHandle
: public AsmDialectResourceHandleBase<DialectResourceBlobHandle<DialectT>,
DialectResourceBlobManager::BlobEntry,
DialectT> {
using AsmDialectResourceHandleBase<DialectResourceBlobHandle<DialectT>,
DialectResourceBlobManager::BlobEntry,
DialectT>::AsmDialectResourceHandleBase;
using ManagerInterface = ResourceBlobManagerDialectInterfaceBase<
DialectResourceBlobHandle<DialectT>>;

/// Return the human readable string key for this handle.
StringRef getKey() const { return this->getResource()->getKey(); }

/// Return the blob referenced by this handle if the underlying resource has
/// been initialized. Returns nullptr otherwise.
AsmResourceBlob *getBlob() { return this->getResource()->getBlob(); }
const AsmResourceBlob *getBlob() const {
return this->getResource()->getBlob();
}

/// Get the interface for the dialect that owns handles of this type. Asserts
/// that the dialect is registered.
static ManagerInterface &getManagerInterface(MLIRContext *ctx) {
auto *dialect = ctx->getOrLoadDialect<DialectT>();
assert(dialect && "dialect not registered");

auto *iface = dialect->template getRegisteredInterface<ManagerInterface>();
assert(iface && "dialect doesn't provide the blob manager interface?");
return *iface;
}
};

} // namespace mlir

#endif // MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H
13 changes: 11 additions & 2 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,17 @@ class AsmParser {
template <typename ResourceT>
FailureOr<ResourceT> parseResourceHandle() {
SMLoc handleLoc = getCurrentLocation();
FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
getContext()->getOrLoadDialect<typename ResourceT::Dialect>());

// Try to load the dialect that owns the handle.
auto *dialect =
getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
if (!dialect) {
return emitError(handleLoc)
<< "dialect '" << ResourceT::Dialect::getDialectNamespace()
<< "' is unknown";
}

FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
if (failed(handle))
return failure();
if (auto *result = dyn_cast<ResourceT>(&*handle))
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class AsmParserImpl : public BaseT {
/// Parse a handle to a resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) override {
const auto *interface = dyn_cast_or_null<OpAsmDialectInterface>(dialect);
const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
if (!interface) {
return parser.emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
Expand Down
41 changes: 40 additions & 1 deletion mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
#include "AsmParserImpl.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
Expand Down Expand Up @@ -97,6 +98,10 @@ Attribute Parser::parseAttribute(Type type) {
case Token::kw_dense:
return parseDenseElementsAttr(type);

// Parse a dense resource elements attribute.
case Token::kw_dense_resource:
return parseDenseResourceElementsAttr(type);

// Parse a dictionary attribute.
case Token::l_brace: {
NamedAttrList elements;
Expand Down Expand Up @@ -241,6 +246,7 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
case Token::kw_affine_map:
case Token::kw_affine_set:
case Token::kw_dense:
case Token::kw_dense_resource:
case Token::kw_false:
case Token::kw_loc:
case Token::kw_opaque:
Expand Down Expand Up @@ -928,6 +934,39 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return literalParser.getAttr(loc, type);
}

Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
auto loc = getToken().getLoc();
consumeToken(Token::kw_dense_resource);
if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
return nullptr;

// Parse the resource handle.
FailureOr<AsmDialectResourceHandle> rawHandle =
parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
return nullptr;

auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
if (!handle)
return emitError(loc, "invalid `dense_resource` handle type"), nullptr;

// Parse the type of the attribute if the user didn't provide one.
SMLoc typeLoc = loc;
if (!attrType) {
typeLoc = getToken().getLoc();
if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
return nullptr;
}

ShapedType shapedType = attrType.dyn_cast<ShapedType>();
if (!shapedType) {
emitError(typeLoc, "`dense_resource` expected a shaped type");
return nullptr;
}

return DenseResourceElementsAttr::get(shapedType, *handle);
}

/// Parse an opaque elements attribute.
Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
SMLoc loc = getToken().getLoc();
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,17 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
return entry.second;
}

FailureOr<AsmDialectResourceHandle>
Parser::parseResourceHandle(Dialect *dialect) {
const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
if (!interface) {
return emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
}
StringRef resourceName;
return parseResourceHandle(interface, resourceName);
}

//===----------------------------------------------------------------------===//
// Code Completion

Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/AsmParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class Parser {
/// Parse a handle to a dialect resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);

//===--------------------------------------------------------------------===//
// Type Parsing
Expand Down Expand Up @@ -272,6 +273,9 @@ class Parser {
Attribute parseDenseElementsAttr(Type attrType);
ShapedType parseElementsLiteralType(Type type);

/// Parse a dense resource elements attribute.
Attribute parseDenseResourceElementsAttr(Type attrType);

/// Parse a DenseArrayAttr.
Attribute parseDenseArrayAttr();

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/AsmParser/TokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(complex)
TOK_KEYWORD(dense)
TOK_KEYWORD(dense_resource)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -1896,6 +1897,10 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
os << " ";
denseArrayAttr.printWithoutBraces(os);
os << "]";
} else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
os << "dense_resource<";
printResourceHandle(resourceAttr.getRawHandle());
os << ">";
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
} else {
Expand Down
134 changes: 129 additions & 5 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
Expand All @@ -36,11 +37,10 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//

void BuiltinDialect::registerAttributes() {
addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
DenseIntOrFPElementsAttr, DenseStringElementsAttr,
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/IR/BuiltinAttributes.cpp.inc"
>();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1576,6 +1576,130 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
return false;
}

//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//

DenseResourceElementsAttr
DenseResourceElementsAttr::get(ShapedType type,
DenseResourceElementsHandle handle) {
return Base::get(type.getContext(), type, handle);
}

DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
StringRef blobName,
AsmResourceBlob blob) {
// Extract the builtin dialect resource manager from context and construct a
// handle by inserting a new resource using the provided blob.
auto &manager =
DenseResourceElementsHandle::getManagerInterface(type.getContext());
return get(type, manager.insert(blobName, std::move(blob)));
}

//===----------------------------------------------------------------------===//
// DenseResourceElementsAttrBase

namespace {
/// Instantiations of this class provide utilities for interacting with native
/// data types in the context of DenseResourceElementsAttr.
template <typename T>
struct DenseResourceAttrUtil;
template <size_t width, bool isSigned>
struct DenseResourceElementsAttrIntUtil {
static bool checkElementType(Type eltType) {
IntegerType type = eltType.dyn_cast<IntegerType>();
if (!type || type.getWidth() != width)
return false;
return isSigned ? !type.isUnsigned() : !type.isSigned();
}
};
template <>
struct DenseResourceAttrUtil<bool> {
static bool checkElementType(Type eltType) {
return eltType.isSignlessInteger(1);
}
};
template <>
struct DenseResourceAttrUtil<int8_t>
: public DenseResourceElementsAttrIntUtil<8, true> {};
template <>
struct DenseResourceAttrUtil<uint8_t>
: public DenseResourceElementsAttrIntUtil<8, false> {};
template <>
struct DenseResourceAttrUtil<int16_t>
: public DenseResourceElementsAttrIntUtil<16, true> {};
template <>
struct DenseResourceAttrUtil<uint16_t>
: public DenseResourceElementsAttrIntUtil<16, false> {};
template <>
struct DenseResourceAttrUtil<int32_t>
: public DenseResourceElementsAttrIntUtil<32, true> {};
template <>
struct DenseResourceAttrUtil<uint32_t>
: public DenseResourceElementsAttrIntUtil<32, false> {};
template <>
struct DenseResourceAttrUtil<int64_t>
: public DenseResourceElementsAttrIntUtil<64, true> {};
template <>
struct DenseResourceAttrUtil<uint64_t>
: public DenseResourceElementsAttrIntUtil<64, false> {};
template <>
struct DenseResourceAttrUtil<float> {
static bool checkElementType(Type eltType) { return eltType.isF32(); }
};
template <>
struct DenseResourceAttrUtil<double> {
static bool checkElementType(Type eltType) { return eltType.isF64(); }
};
} // namespace

template <typename T>
DenseResourceElementsAttrBase<T>
DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
AsmResourceBlob blob) {
// Check that the blob is in the form we were expecting.
assert(blob.getDataAlignment() == alignof(T) &&
"alignment mismatch between expected alignment and blob alignment");
assert(((blob.getData().size() % sizeof(T)) == 0) &&
"size mismatch between expected element width and blob size");
assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
"invalid shape element type for provided type `T`");
return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
.template cast<DenseResourceElementsAttrBase<T>>();
}

template <typename T>
Optional<ArrayRef<T>>
DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
return blob->template getDataAs<T>();
return llvm::None;
}

template <typename T>
bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
resourceAttr.getElementType());
}

namespace mlir {
namespace detail {
// Explicit instantiation for all the supported DenseResourceElementsAttr.
template class DenseResourceElementsAttrBase<bool>;
template class DenseResourceElementsAttrBase<int8_t>;
template class DenseResourceElementsAttrBase<int16_t>;
template class DenseResourceElementsAttrBase<int32_t>;
template class DenseResourceElementsAttrBase<int64_t>;
template class DenseResourceElementsAttrBase<uint8_t>;
template class DenseResourceElementsAttrBase<uint16_t>;
template class DenseResourceElementsAttrBase<uint32_t>;
template class DenseResourceElementsAttrBase<uint64_t>;
template class DenseResourceElementsAttrBase<float>;
template class DenseResourceElementsAttrBase<double>;
} // namespace detail
} // namespace mlir

//===----------------------------------------------------------------------===//
// OpaqueElementsAttr
//===----------------------------------------------------------------------===//
Expand Down
54 changes: 51 additions & 3 deletions mlir/lib/IR/BuiltinDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,35 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Builtin Dialect
// TableGen'erated dialect
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinDialect.cpp.inc"

//===----------------------------------------------------------------------===//
// BuiltinBlobManagerInterface
//===----------------------------------------------------------------------===//

using BuiltinBlobManagerInterface =
ResourceBlobManagerDialectInterfaceBase<DenseResourceElementsHandle>;

//===----------------------------------------------------------------------===//
// BuiltinOpAsmDialectInterface
//===----------------------------------------------------------------------===//

namespace {
struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
BuiltinOpAsmDialectInterface(Dialect *dialect,
BuiltinBlobManagerInterface &mgr)
: OpAsmDialectInterface(dialect), blobManager(mgr) {}

AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<AffineMapAttr>()) {
Expand All @@ -57,6 +71,38 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
}
return AliasResult::NoAlias;
}

//===------------------------------------------------------------------===//
// Resources
//===------------------------------------------------------------------===//

std::string
getResourceKey(const AsmDialectResourceHandle &handle) const override {
return cast<DenseResourceElementsHandle>(handle).getKey().str();
}
FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key) const final {
return blobManager.insert(key);
}
LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
if (failed(blob))
return failure();

// Update the blob for this entry.
blobManager.update(entry.getKey(), std::move(*blob));
return success();
}
void
buildResources(Operation *op,
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &provider) const final {
blobManager.buildResources(provider, referencedResources.getArrayRef());
}

private:
/// The blob manager for the dialect.
BuiltinBlobManagerInterface &blobManager;
};
} // namespace

Expand All @@ -68,7 +114,9 @@ void BuiltinDialect::initialize() {
#define GET_OP_LIST
#include "mlir/IR/BuiltinOps.cpp.inc"
>();
addInterfaces<BuiltinOpAsmDialectInterface>();

auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_library(MLIRIR
BuiltinTypeInterfaces.cpp
Diagnostics.cpp
Dialect.cpp
DialectResourceBlobManager.cpp
Dominance.cpp
ExtensibleDialect.cpp
FunctionImplementation.cpp
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/IR/DialectResourceBlobManager.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- DialectResourceBlobManager.cpp - Dialect Blob Management -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/DialectResourceBlobManager.h"
#include "llvm/ADT/SmallString.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// DialectResourceBlobManager
//===---------------------------------------------------------------------===//

auto DialectResourceBlobManager::lookup(StringRef name) -> BlobEntry * {
llvm::sys::SmartScopedReader<true> reader(blobMapLock);

auto it = blobMap.find(name);
return it != blobMap.end() ? &it->second : nullptr;
}

void DialectResourceBlobManager::update(StringRef name,
AsmResourceBlob &&newBlob) {
BlobEntry *entry = lookup(name);
assert(entry && "`update` expects an existing entry for the provided name");
entry->setBlob(std::move(newBlob));
}

auto DialectResourceBlobManager::insert(StringRef name,
Optional<AsmResourceBlob> blob)
-> BlobEntry & {
llvm::sys::SmartScopedWriter<true> writer(blobMapLock);

// Functor used to attempt insertion with a given name.
auto tryInsertion = [&](StringRef name) -> BlobEntry * {
auto it = blobMap.try_emplace(name, BlobEntry());
if (it.second) {
it.first->second.initialize(it.first->getKey(), std::move(blob));
return &it.first->second;
}
return nullptr;
};

// Try inserting with the name provided by the user.
if (BlobEntry *entry = tryInsertion(name))
return *entry;

// If an entry already exists for the user provided name, tweak the name and
// re-attempt insertion until we find one that is unique.
llvm::SmallString<32> nameStorage(name);
nameStorage.push_back('_');
size_t nameCounter = 1;
do {
Twine(nameCounter++).toVector(nameStorage);

// Try inserting with the new name.
if (BlobEntry *entry = tryInsertion(nameStorage))
return *entry;
nameStorage.resize(name.size() + 1);
} while (true);
}
5 changes: 3 additions & 2 deletions mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,9 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {

/// Signal a completion for an attribute.
void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
"loc", "opaque", "sparse", "true", "unit"},
appendSimpleCompletions({"affine_set", "affine_map", "dense",
"dense_resource", "false", "loc", "opaque",
"sparse", "true", "unit"},
lsp::CompletionItemKind::Field,
/*sortText=*/"1");

Expand Down
13 changes: 13 additions & 0 deletions mlir/test/IR/dense-resource-elements-attr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file | FileCheck %s

// CHECK: attr = dense_resource<blob1> : tensor<3xi64>
"test.user_op"() {attr = dense_resource<blob1> : tensor<3xi64> } : () -> ()

{-#
dialect_resources: {
builtin: {
// CHECK: blob1: "0x08000000010000000000000002000000000000000300000000000000"
blob1: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}
20 changes: 20 additions & 0 deletions mlir/test/IR/invalid-builtin-attributes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,23 @@ func.func @duplicate_dictionary_attr_key() {
"J// -----
" // expected-error {{expected}}

// -----

// expected-error@+1 {{expected '<' after 'dense_resource'}}
#attr = dense_resource>

// -----

// expected-error@+1 {{expected '>'}}
#attr = dense_resource<resource

// -----

// expected-error@+1 {{expected ':'}}
#attr = dense_resource<resource>

// -----

// expected-error@+1 {{`dense_resource` expected a shaped type}}
#attr = dense_resource<resource> : i32
4 changes: 2 additions & 2 deletions mlir/test/IR/invalid-file-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@

// -----

// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}}
// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'ml_program'}}
{-#
dialect_resources: {
builtin: {
ml_program: {
unknown_entry: "foo"
}
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
let mnemonic = "e1di64_elements";
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
ResourceHandleParameter<"TestExternalElementsDataHandle">:$handle
ResourceHandleParameter<"TestDialectResourceBlobHandle">:$handle
);
let extraClassDeclaration = [{
/// Return the elements referenced by this attribute.
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/lib/Dialect/Test/TestAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
//===----------------------------------------------------------------------===//

ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
return getHandle().getData()->getData();
if (auto *blob = getHandle().getBlob())
return blob->getDataAs<uint64_t>();
return llvm::None;
}

//===----------------------------------------------------------------------===//
Expand Down
7 changes: 6 additions & 1 deletion mlir/test/lib/Dialect/Test/TestAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@

#include "TestAttrInterfaces.h.inc"
#include "TestOpEnums.h.inc"
#include "mlir/IR/DialectResourceBlobManager.h"

namespace test {
struct TestExternalElementsDataHandle;
class TestDialect;

/// A handle used to reference external elements instances.
using TestDialectResourceBlobHandle =
mlir::DialectResourceBlobHandle<TestDialect>;
} // namespace test

#define GET_ATTRDEF_CLASSES
Expand Down
96 changes: 24 additions & 72 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,55 +44,6 @@ void test::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
}

//===----------------------------------------------------------------------===//
// External Elements Data
//===----------------------------------------------------------------------===//

ArrayRef<uint64_t> TestExternalElementsData::getData() const {
ArrayRef<char> data = AsmResourceBlob::getData();
return ArrayRef<uint64_t>((const uint64_t *)data.data(),
data.size() / sizeof(uint64_t));
}

TestExternalElementsData
TestExternalElementsData::allocate(size_t numElements) {
return TestExternalElementsData(
llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements),
[](const uint64_t *data, size_t) { delete[] data; },
/*dataIsMutable=*/true);
}

const TestExternalElementsData *
TestExternalElementsDataManager::getData(StringRef name) const {
auto it = dataMap.find(name);
return it != dataMap.end() ? &*it->second : nullptr;
}

std::pair<TestExternalElementsDataManager::DataMap::iterator, bool>
TestExternalElementsDataManager::insert(StringRef name) {
auto it = dataMap.try_emplace(name, nullptr);
if (it.second)
return it;

llvm::SmallString<32> nameStorage(name);
nameStorage.push_back('_');
size_t nameCounter = 1;
do {
nameStorage += std::to_string(nameCounter++);
auto it = dataMap.try_emplace(nameStorage, nullptr);
if (it.second)
return it;
nameStorage.resize(name.size() + 1);
} while (true);
}

void TestExternalElementsDataManager::setData(StringRef name,
TestExternalElementsData &&data) {
auto it = dataMap.find(name);
assert(it != dataMap.end() && "data not registered");
it->second = std::make_unique<TestExternalElementsData>(std::move(data));
}

//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
Expand All @@ -109,9 +60,18 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
"hasSingleBlockImplicitTerminator does not match "
"SingleBlockImplicitTerminatorOp");

struct TestResourceBlobManagerInterface
: public ResourceBlobManagerDialectInterfaceBase<
TestDialectResourceBlobHandle> {
using ResourceBlobManagerDialectInterfaceBase<
TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
};

// Test support for interacting with the AsmPrinter.
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
: OpAsmDialectInterface(dialect), blobManager(mgr) {}

//===------------------------------------------------------------------===//
// Aliases
Expand Down Expand Up @@ -176,45 +136,34 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {

std::string
getResourceKey(const AsmDialectResourceHandle &handle) const override {
return cast<TestExternalElementsDataHandle>(handle).getKey().str();
return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
}

FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key) const final {
TestDialect *dialect = cast<TestDialect>(getDialect());
TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();

// Resolve the reference by inserting a new entry into the manager.
auto it = mgr.insert(key).first;
return TestExternalElementsDataHandle(&*it, dialect);
return blobManager.insert(key);
}

LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
TestDialect *dialect = cast<TestDialect>(getDialect());
TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();

// The resource entries are external constant data.
auto blobAllocFn = [](unsigned size, unsigned align) {
assert(align == alignof(uint64_t) && "unexpected data alignment");
return TestExternalElementsData::allocate(size / sizeof(uint64_t));
};
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn);
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
if (failed(blob))
return failure();

mgr.setData(entry.getKey(), std::move(*blob));
// Update the blob for this entry.
blobManager.update(entry.getKey(), std::move(*blob));
return success();
}

void
buildResources(Operation *op,
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &provider) const final {
for (const AsmDialectResourceHandle &handle : referencedResources) {
const auto &testHandle = cast<TestExternalElementsDataHandle>(handle);
provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData());
}
blobManager.buildResources(provider, referencedResources.getArrayRef());
}

private:
/// The blob manager for the dialect.
TestResourceBlobManagerInterface &blobManager;
};

struct TestDialectFoldInterface : public DialectFoldInterface {
Expand Down Expand Up @@ -412,8 +361,11 @@ void TestDialect::initialize() {
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));

addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface, TestReductionPatternInterface>();
auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
addInterface<TestOpAsmInterface>(blobInterface);

addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
TestReductionPatternInterface>();
allowUnknownOperations();

// Instantiate our fallback op interface that we'll use on specific
Expand Down
63 changes: 1 addition & 62 deletions mlir/test/lib/Dialect/Test/TestDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
Expand All @@ -45,68 +46,6 @@ class DLTIDialect;
class RewritePatternSet;
} // namespace mlir

namespace test {
class TestDialect;

//===----------------------------------------------------------------------===//
// External Elements Data
//===----------------------------------------------------------------------===//

/// This class represents a single external elements instance. It keeps track of
/// the data, and deallocates when destructed.
class TestExternalElementsData : public mlir::AsmResourceBlob {
public:
using mlir::AsmResourceBlob::AsmResourceBlob;
TestExternalElementsData(mlir::AsmResourceBlob &&blob)
: mlir::AsmResourceBlob(std::move(blob)) {}

/// Return the data of this external elements instance.
llvm::ArrayRef<uint64_t> getData() const;

/// Allocate a new external elements instance with the given number of
/// elements.
static TestExternalElementsData allocate(size_t numElements);
};

/// A handle used to reference external elements instances.
struct TestExternalElementsDataHandle
: public mlir::AsmDialectResourceHandleBase<
TestExternalElementsDataHandle,
llvm::StringMapEntry<std::unique_ptr<TestExternalElementsData>>,
TestDialect> {
using AsmDialectResourceHandleBase::AsmDialectResourceHandleBase;

/// Return a key to use for this handle.
llvm::StringRef getKey() const { return getResource()->getKey(); }

/// Return the data referenced by this handle.
TestExternalElementsData *getData() const {
return getResource()->getValue().get();
}
};

/// This class acts as a manager for external elements data. It provides API
/// for creating and accessing registered elements data.
class TestExternalElementsDataManager {
using DataMap = llvm::StringMap<std::unique_ptr<TestExternalElementsData>>;

public:
/// Return the data registered for the given name, or nullptr if no data is
/// registered.
const TestExternalElementsData *getData(llvm::StringRef name) const;

/// Register an entry with the provided name, which may be modified if another
/// entry was already inserted with that name. Returns the inserted entry.
std::pair<DataMap::iterator, bool> insert(llvm::StringRef name);

/// Set the data for the given entry, which is expected to exist.
void setData(llvm::StringRef name, TestExternalElementsData &&data);

private:
llvm::StringMap<std::unique_ptr<TestExternalElementsData>> dataMap;
};
} // namespace test

//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 0 additions & 8 deletions mlir/test/lib/Dialect/Test/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def Test_Dialect : Dialect {
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const override;

/// Returns the external elements data manager for this dialect.
TestExternalElementsDataManager &getExternalDataManager() {
return externalDataManager;
}

private:
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;
Expand All @@ -55,9 +50,6 @@ def Test_Dialect : Dialect {
::llvm::SetVector<::mlir::Type> &stack) const;
void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer,
::llvm::SetVector<::mlir::Type> &stack) const;

/// An external data manager used to test external elements data.
TestExternalElementsDataManager externalDataManager;
}];
}

Expand Down
118 changes: 118 additions & 0 deletions mlir/unittests/IR/AttributeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/AsmState.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::detail;

//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//

template <typename EltTy>
static void testSplat(Type eltType, const EltTy &splatElt) {
RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
Expand Down Expand Up @@ -203,7 +209,119 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
}
} // namespace

//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//

template <typename AttrT, typename T>
static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
Type elementType) {
auto type = RankedTensorType::get(data.size(), elementType);
auto attr =
AttrT::get(type, "resource", UnmanagedAsmResourceBlob::allocate(data));

// Check that we can access and iterate the data properly.
Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
EXPECT_TRUE(attrData.hasValue());
EXPECT_EQ(*attrData, data);

// Check that we cast to this attribute when possible.
Attribute genericAttr = attr;
EXPECT_TRUE(genericAttr.template isa<AttrT>());
}
template <typename AttrT, typename T>
static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
T data[] = {0, 1, 2};
checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data),
builder.getIntegerType(intWidth));
}

namespace {
TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
MLIRContext context;
Builder builder(&context);

// Bool
bool boolData[] = {true, false, true};
checkNativeAccess<DenseBoolResourceElementsAttr>(
&context, llvm::makeArrayRef(boolData), builder.getI1Type());

// Unsigned integers
checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);

// Signed integers
checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);

// Float
float floatData[] = {0, 1, 2};
checkNativeAccess<DenseF32ResourceElementsAttr>(
&context, llvm::makeArrayRef(floatData), builder.getF32Type());

// Double
double doubleData[] = {0, 1, 2};
checkNativeAccess<DenseF64ResourceElementsAttr>(
&context, llvm::makeArrayRef(doubleData), builder.getF64Type());
}

TEST(DenseResourceElementsAttrTest, CheckNoCast) {
MLIRContext context;
Builder builder(&context);

// Create a i32 attribute.
ArrayRef<uint32_t> data;
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocate(data));

EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
}

TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
MLIRContext context;
Builder builder(&context);

// Create a bool attribute with data of the incorrect type.
ArrayRef<uint32_t> data;
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
ASSERT_DEATH(
{
DenseBoolResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocate(data));
},
"alignment mismatch between expected alignment and blob alignment");
}

TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
MLIRContext context;
Builder builder(&context);

// Create a bool attribute with incorrect type.
ArrayRef<bool> data;
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
ASSERT_DEATH(
{
DenseBoolResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocate(data));
},
"invalid shape element type for provided type `T`");
}
} // namespace

//===----------------------------------------------------------------------===//
// SparseElementsAttr
//===----------------------------------------------------------------------===//

namespace {
TEST(SparseElementsAttrTest, GetZero) {
MLIRContext context;
context.allowUnregisteredDialects();
Expand Down