30 changes: 24 additions & 6 deletions mlir/include/mlir/IR/OpAsmInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
//
//===----------------------------------------------------------------------===//
//
// This file contains Interfaces for interacting with the AsmParser and
// AsmPrinter.
// This file contains interfaces and other utilities for interacting with the
// AsmParser and AsmPrinter.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_OPASMINTERFACE
#define MLIR_OPASMINTERFACE
#ifndef MLIR_IR_OPASMINTERFACE_TD
#define MLIR_IR_OPASMINTERFACE_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

/// Interface for hooking into the OpAsmPrinter and OpAsmParser.
//===----------------------------------------------------------------------===//
// OpAsmOpInterface
//===----------------------------------------------------------------------===//

def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
let description = [{
This interface provides hooks to interact with the AsmPrinter and AsmParser
Expand Down Expand Up @@ -105,4 +109,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
];
}

#endif // MLIR_OPASMINTERFACE
//===----------------------------------------------------------------------===//
// ResourceHandleParameter
//===----------------------------------------------------------------------===//

/// This parameter represents a handle to a resource that is encoded into the
/// "dialect_resources" section of the assembly format. This parameter expects a
/// C++ `handleType` that derives from `AsmDialectResourceHandleBase` and
/// implements a derived handle to the desired resource type.
class ResourceHandleParameter<string handleType, string desc = "">
: AttrOrTypeParameter<handleType, desc> {
let parser = "$_parser.parseResourceHandle<" # handleType # ">()";
let printer = "$_printer.printResourceHandle($_self)";
}

#endif // MLIR_IR_OPASMINTERFACE_TD
166 changes: 163 additions & 3 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,82 @@
#include "llvm/Support/SMLoc.h"

namespace mlir {

class AsmParsedResourceEntry;
class AsmResourceBuilder;
class Builder;

//===----------------------------------------------------------------------===//
// AsmDialectResourceHandle
//===----------------------------------------------------------------------===//

/// This class represents an opaque handle to a dialect resource entry.
class AsmDialectResourceHandle {
public:
AsmDialectResourceHandle() = default;
AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
: resource(resource), opaqueID(resourceID), dialect(dialect) {}
bool operator==(const AsmDialectResourceHandle &other) const {
return resource == other.resource;
}

/// Return an opaque pointer to the referenced resource.
void *getResource() const { return resource; }

/// Return the type ID of the resource.
TypeID getTypeID() const { return opaqueID; }

/// Return the dialect that owns the resource.
Dialect *getDialect() const { return dialect; }

private:
/// The opaque handle to the dialect resource.
void *resource = nullptr;
/// The type of the resource referenced.
TypeID opaqueID;
/// The dialect owning the given resource.
Dialect *dialect;
};

/// This class represents a CRTP base class for dialect resource handles. It
/// abstracts away various utilities necessary for defined derived resource
/// handles.
template <typename DerivedT, typename ResourceT, typename DialectT>
class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
public:
using Dialect = DialectT;

/// Construct a handle from a pointer to the resource. The given pointer
/// should be guaranteed to live beyond the life of this handle.
AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
: AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
: AsmDialectResourceHandle(handle) {
assert(handle.getTypeID() == TypeID::get<DerivedT>());
}

/// Return the resource referenced by this handle.
ResourceT *getResource() {
return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
}
const ResourceT *getResource() const {
return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
}

/// Return the dialect that owns the resource.
DialectT *getDialect() const {
return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
}

/// Support llvm style casting.
static bool classof(const AsmDialectResourceHandle *handle) {
return handle->getTypeID() == TypeID::get<DerivedT>();
}
};

inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
return llvm::hash_value(param.getResource());
}

//===----------------------------------------------------------------------===//
// AsmPrinter
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -108,6 +181,9 @@ class AsmPrinter {
/// special or non-printable characters in it.
virtual void printSymbolName(StringRef symbolRef);

/// Print a handle to the given dialect resource.
void printResourceHandle(const AsmDialectResourceHandle &resource);

/// Print an optional arrow followed by a type list.
template <typename TypeRange>
void printOptionalArrowTypeList(TypeRange &&types) {
Expand Down Expand Up @@ -870,6 +946,24 @@ class AsmParser {
StringRef attrName,
NamedAttrList &attrs) = 0;

//===--------------------------------------------------------------------===//
// Resource Parsing
//===--------------------------------------------------------------------===//

/// Parse a handle to a resource within the assembly format.
template <typename ResourceT>
FailureOr<ResourceT> parseResourceHandle() {
SMLoc handleLoc = getCurrentLocation();
FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
getContext()->getOrLoadDialect<typename ResourceT::Dialect>());
if (failed(handle))
return failure();
if (auto *result = dyn_cast<ResourceT>(&*handle))
return std::move(*result);
return emitError(handleLoc) << "provided resource handle differs from the "
"expected resource type";
}

//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1026,6 +1120,12 @@ class AsmParser {
/// next token.
virtual ParseResult parseXInDimensionList() = 0;

protected:
/// Parse a handle to a resource within the assembly format for the given
/// dialect.
virtual FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) = 0;

private:
AsmParser(const AsmParser &) = delete;
void operator=(const AsmParser &) = delete;
Expand Down Expand Up @@ -1338,6 +1438,12 @@ using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}

//===------------------------------------------------------------------===//
// Aliases
//===------------------------------------------------------------------===//

/// Holds the result of `getAlias` hook call.
enum class AliasResult {
/// The object (type or attribute) is not supported by the hook
Expand All @@ -1350,8 +1456,6 @@ class OpAsmDialectInterface
FinalAlias
};

OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}

/// Hooks for getting an alias identifier alias for a given symbol, that is
/// not necessarily a part of this dialect. The identifier is used in place of
/// the symbol when printing textual IR. These aliases must not contain `.` or
Expand All @@ -1362,6 +1466,41 @@ class OpAsmDialectInterface
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
return AliasResult::NoAlias;
}

//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//

/// Declare a resource with the given key, returning a handle to use for any
/// references of this resource key within the IR during parsing. The result
/// of `getResourceKey` on the returned handle is permitted to be different
/// than `key`.
virtual FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key) const {
return failure();
}

/// Return a key to use for the given resource. This key should uniquely
/// identify this resource within the dialect.
virtual std::string
getResourceKey(const AsmDialectResourceHandle &handle) const {
llvm_unreachable(
"Dialect must implement `getResourceKey` when defining resources");
}

/// Hook for parsing resource entries. Returns failure if the entry was not
/// valid, or could otherwise not be processed correctly. Any necessary errors
/// can be emitted via the provided entry.
virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;

/// Hook for building resources to use during printing. The given `op` may be
/// inspected to help determine what information to include.
/// `referencedResources` contains all of the resources detected when printing
/// 'op'.
virtual void
buildResources(Operation *op,
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &builder) const {}
};
} // namespace mlir

Expand All @@ -1372,4 +1511,25 @@ class OpAsmDialectInterface
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.h.inc"

namespace llvm {
template <>
struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
static inline mlir::AsmDialectResourceHandle getEmptyKey() {
return {DenseMapInfo<void *>::getEmptyKey(),
DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
}
static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
return {DenseMapInfo<void *>::getTombstoneKey(),
DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
}
static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
return DenseMapInfo<void *>::getHashValue(handle.getResource());
}
static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
const mlir::AsmDialectResourceHandle &rhs) {
return lhs.getResource() == rhs.getResource();
}
};
} // namespace llvm

#endif
33 changes: 17 additions & 16 deletions mlir/include/mlir/Parser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_PARSER_PARSER_H
#define MLIR_PARSER_PARSER_H

#include "mlir/IR/AsmState.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include <cstddef>
Expand Down Expand Up @@ -84,7 +85,7 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
/// SSA uses and definitions). `asmState` should only be provided if this
/// detailed information is desired.
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc = nullptr,
AsmParserState *asmState = nullptr);

Expand All @@ -96,7 +97,7 @@ LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
/// non-null, it is populated with a file location representing the start of the
/// source file that is being parsed.
LogicalResult parseSourceFile(llvm::StringRef filename, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc = nullptr);

/// This parses the file specified by the indicated filename using the provided
Expand All @@ -111,7 +112,7 @@ LogicalResult parseSourceFile(llvm::StringRef filename, Block *block,
/// `asmState` should only be provided if this detailed information is desired.
LogicalResult parseSourceFile(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc = nullptr,
AsmParserState *asmState = nullptr);

Expand All @@ -123,22 +124,22 @@ LogicalResult parseSourceFile(llvm::StringRef filename,
/// populated with a file location representing the start of the source file
/// that is being parsed.
LogicalResult parseSourceString(llvm::StringRef sourceStr, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc = nullptr);

namespace detail {
/// The internal implementation of the templated `parseSourceFile` methods
/// below, that simply forwards to the non-templated version.
template <typename ContainerOpT, typename... ParserArgs>
inline OwningOpRef<ContainerOpT> parseSourceFile(MLIRContext *ctx,
inline OwningOpRef<ContainerOpT> parseSourceFile(const ParserConfig &config,
ParserArgs &&...args) {
LocationAttr sourceFileLoc;
Block block;
if (failed(parseSourceFile(std::forward<ParserArgs>(args)..., &block, ctx,
if (failed(parseSourceFile(std::forward<ParserArgs>(args)..., &block, config,
&sourceFileLoc)))
return OwningOpRef<ContainerOpT>();
return detail::constructContainerOpForParserIfNecessary<ContainerOpT>(
&block, ctx, sourceFileLoc);
&block, config.getContext(), sourceFileLoc);
}
} // namespace detail

Expand All @@ -152,8 +153,8 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(MLIRContext *ctx,
/// `SingleBlockImplicitTerminator` trait.
template <typename ContainerOpT>
inline OwningOpRef<ContainerOpT>
parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context) {
return detail::parseSourceFile<ContainerOpT>(context, sourceMgr);
parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) {
return detail::parseSourceFile<ContainerOpT>(config, sourceMgr);
}

/// This parses the file specified by the indicated filename. If the source IR
Expand All @@ -166,8 +167,8 @@ parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context) {
/// `SingleBlockImplicitTerminator` trait.
template <typename ContainerOpT>
inline OwningOpRef<ContainerOpT> parseSourceFile(StringRef filename,
MLIRContext *context) {
return detail::parseSourceFile<ContainerOpT>(context, filename);
const ParserConfig &config) {
return detail::parseSourceFile<ContainerOpT>(config, filename);
}

/// This parses the file specified by the indicated filename using the provided
Expand All @@ -181,8 +182,8 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(StringRef filename,
template <typename ContainerOpT>
inline OwningOpRef<ContainerOpT> parseSourceFile(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr,
MLIRContext *context) {
return detail::parseSourceFile<ContainerOpT>(context, filename, sourceMgr);
const ParserConfig &config) {
return detail::parseSourceFile<ContainerOpT>(config, filename, sourceMgr);
}

/// This parses the provided string containing MLIR. If the source IR contained
Expand All @@ -195,13 +196,13 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(llvm::StringRef filename,
/// `SingleBlockImplicitTerminator` trait.
template <typename ContainerOpT>
inline OwningOpRef<ContainerOpT> parseSourceString(llvm::StringRef sourceStr,
MLIRContext *context) {
const ParserConfig &config) {
LocationAttr sourceFileLoc;
Block block;
if (failed(parseSourceString(sourceStr, &block, context, &sourceFileLoc)))
if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc)))
return OwningOpRef<ContainerOpT>();
return detail::constructContainerOpForParserIfNecessary<ContainerOpT>(
&block, context, sourceFileLoc);
&block, config.getContext(), sourceFileLoc);
}

/// This parses a single MLIR attribute to an MLIR context if it was valid. If
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/Pass/PassRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

namespace mlir {
class OpPassManager;
class ParserConfig;
class Pass;
class PassManager;

namespace detail {
class PassOptions;
Expand Down Expand Up @@ -272,6 +274,18 @@ class PassNameCLParser {
std::unique_ptr<detail::PassPipelineCLParserImpl> impl;
};

//===----------------------------------------------------------------------===//
// Pass Reproducer
//===----------------------------------------------------------------------===//

/// Attach an assembly resource parser that handles MLIR reproducer
/// configurations. Any found reproducer information will be attached to the
/// given pass manager, e.g. the reproducer pipeline, verification flags, etc.
// FIXME: Remove the `enableThreading` flag when possible. Some tools, e.g.
// mlir-opt, force disable threading during parsing.
void attachPassReproducerAsmResource(ParserConfig &config, PassManager &pm,
bool &enableThreading);

} // namespace mlir

#endif // MLIR_PASS_PASSREGISTRY_H_
179 changes: 179 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.cpp.inc"

LogicalResult
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
<< "' for dialect '" << getDialect()->getNamespace()
<< "'";
}

//===----------------------------------------------------------------------===//
// OpPrintingFlags
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1254,6 +1261,15 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
return name;
}

//===----------------------------------------------------------------------===//
// Resources
//===----------------------------------------------------------------------===//

AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
AsmResourceBuilder::~AsmResourceBuilder() = default;
AsmResourceParser::~AsmResourceParser() = default;
AsmResourcePrinter::~AsmResourcePrinter() = default;

//===----------------------------------------------------------------------===//
// AsmState
//===----------------------------------------------------------------------===//
Expand All @@ -1278,6 +1294,17 @@ class AsmStateImpl {
/// Get the state used for SSA names.
SSANameState &getSSANameState() { return nameState; }

/// Return the dialects within the context that implement
/// OpAsmDialectInterface.
DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
return interfaces;
}

/// Return the non-dialect resource printers.
auto getResourcePrinters() {
return llvm::make_pointee_range(externalResourcePrinters);
}

/// Get the printer flags.
const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }

Expand All @@ -1292,6 +1319,9 @@ class AsmStateImpl {
/// Collection of OpAsm interfaces implemented in the context.
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;

/// A collection of non-dialect resource printers.
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;

/// The state used for attribute and type aliases.
AliasState aliasState;

Expand All @@ -1303,6 +1333,9 @@ class AsmStateImpl {

/// An optional location map to be populated.
AsmState::LocationMap *locationMap;

// Allow direct access to the impl fields.
friend AsmState;
};
} // namespace detail
} // namespace mlir
Expand Down Expand Up @@ -1352,6 +1385,11 @@ const OpPrintingFlags &AsmState::getPrinterFlags() const {
return impl->getPrinterFlags();
}

void AsmState::attachResourcePrinter(
std::unique_ptr<AsmResourcePrinter> printer) {
impl->externalResourcePrinters.emplace_back(std::move(printer));
}

//===----------------------------------------------------------------------===//
// AsmPrinter::Impl
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1403,6 +1441,14 @@ class AsmPrinter::Impl {
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);

/// Print a reference to the given resource that is owned by the given
/// dialect.
void printResourceHandle(const AsmDialectResourceHandle &resource) {
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
os << interface->getResourceKey(resource);
dialectResources[resource.getDialect()].insert(resource);
}

void printAffineMap(AffineMap map);
void
printAffineExpr(AffineExpr expr,
Expand Down Expand Up @@ -1462,6 +1508,9 @@ class AsmPrinter::Impl {

/// A tracker for the number of new lines emitted during printing.
NewLineCounter newLine;

/// A set of dialect resources that were referenced during printing.
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
};
} // namespace mlir

Expand Down Expand Up @@ -2241,6 +2290,11 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
Impl subPrinter(attrNameStr, printerFlags, state);
DialectAsmPrinter printer(subPrinter);
dialect.printAttribute(attr, printer);

// FIXME: Delete this when we no longer require a nested printer.
for (auto &it : subPrinter.dialectResources)
for (const auto &resource : it.second)
dialectResources[it.first].insert(resource);
}
printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
}
Expand All @@ -2255,6 +2309,11 @@ void AsmPrinter::Impl::printDialectType(Type type) {
Impl subPrinter(typeNameStr, printerFlags, state);
DialectAsmPrinter printer(subPrinter);
dialect.printType(type, printer);

// FIXME: Delete this when we no longer require a nested printer.
for (auto &it : subPrinter.dialectResources)
for (const auto &resource : it.second)
dialectResources[it.first].insert(resource);
}
printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
}
Expand Down Expand Up @@ -2325,6 +2384,11 @@ void AsmPrinter::printSymbolName(StringRef symbolRef) {
::printSymbolReference(symbolRef, impl->getStream());
}

void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
impl->printResourceHandle(resource);
}

//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2654,6 +2718,51 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
void printUserIDs(Operation *user, bool prefixComma = false);

private:
/// This class represents a resource builder implementation for the MLIR
/// textual assembly format.
class ResourceBuilder : public AsmResourceBuilder {
public:
using ValueFn = function_ref<void(raw_ostream &)>;
using PrintFn = function_ref<void(StringRef, ValueFn)>;

ResourceBuilder(OperationPrinter &p, PrintFn printFn)
: p(p), printFn(printFn) {}
~ResourceBuilder() override = default;

void buildBool(StringRef key, bool data) final {
printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); });
}

void buildString(StringRef key, StringRef data) final {
printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); });
}

void buildBlob(StringRef key, ArrayRef<char> data,
uint32_t dataAlignment) final {
printFn(key, [&](raw_ostream &os) {
// Store the blob in a hex string containing the alignment and the data.
os << "\"0x"
<< llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignment),
sizeof(dataAlignment)))
<< llvm::toHex(StringRef(data.data(), data.size())) << "\"";
});
}

private:
OperationPrinter &p;
PrintFn printFn;
};

/// Print the metadata dictionary for the file, eliding it if it is empty.
void printFileMetadataDictionary(Operation *op);

/// Print the resource sections for the file metadata dictionary.
/// `checkAddMetadataDict` is used to indicate that metadata is going to be
/// added, and the file metadata dictionary should be started if it hasn't
/// yet.
void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
Operation *op);

// Contains the stack of default dialects to use when printing regions.
// A new dialect is pushed to the stack before parsing regions nested under an
// operation implementing `OpAsmOpInterface`, and popped when done. At the
Expand All @@ -2679,6 +2788,76 @@ void OperationPrinter::printTopLevelOperation(Operation *op) {

// Output the aliases at the top level that can be deferred.
state->getAliasState().printDeferredAliases(os, newLine);

// Output any file level metadata.
printFileMetadataDictionary(op);
}

void OperationPrinter::printFileMetadataDictionary(Operation *op) {
bool sawMetadataEntry = false;
auto checkAddMetadataDict = [&] {
if (!std::exchange(sawMetadataEntry, true))
os << newLine << "{-#" << newLine;
};

// Add the various types of metadata.
printResourceFileMetadata(checkAddMetadataDict, op);

// If the file dictionary exists, close it.
if (sawMetadataEntry)
os << newLine << "#-}" << newLine;
}

void OperationPrinter::printResourceFileMetadata(
function_ref<void()> checkAddMetadataDict, Operation *op) {
// Functor used to add data entries to the file metadata dictionary.
bool hadResource = false;
auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
auto &&...providerArgs) {
bool hadEntry = false;
auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
checkAddMetadataDict();

// Emit the top-level resource entry if we haven't yet.
if (!std::exchange(hadResource, true))
os << " " << dictName << "_resources: {" << newLine;
// Emit the parent resource entry if we haven't yet.
if (!std::exchange(hadEntry, true))
os << " " << name << ": {" << newLine;
else
os << "," << newLine;

os << " " << key << ": ";
valueFn(os);
};
ResourceBuilder entryBuilder(*this, printFn);
provider.buildResources(op, providerArgs..., entryBuilder);

if (hadEntry)
os << newLine << " }";
};

// Print the `dialect_resources` section if we have any dialects with
// resources.
for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) {
StringRef name = interface.getDialect()->getNamespace();
auto it = dialectResources.find(interface.getDialect());
if (it != dialectResources.end())
processProvider("dialect", name, interface, it->second);
else
processProvider("dialect", name, interface,
SetVector<AsmDialectResourceHandle>());
}
if (hadResource)
os << newLine << " }";

// Print the `external_resources` section if we have any external clients with
// resources.
hadResource = false;
for (const auto &printer : state->getResourcePrinters())
processProvider("external", printer.getName(), printer);
if (hadResource)
os << newLine << " }";
}

/// Print a block argument in the usual format of:
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Parser/AffineParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context,
/*RequiresNullTerminator=*/false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState symbolState;
ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
ParserConfig config(context);
ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr);
Parser parser(state);

raw_ostream &os = printDiagnosticInfo ? llvm::errs() : llvm::nulls();
Expand Down
29 changes: 20 additions & 9 deletions mlir/lib/Parser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,11 @@ class AsmParserImpl : public BaseT {
return success();
}

/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return parser.getToken().isAny(Token::bare_identifier, Token::inttype) ||
parser.getToken().isKeyword();
}

/// Parse the given keyword if present.
ParseResult parseOptionalKeyword(StringRef keyword) override {
// Check that the current token has the same spelling.
if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
if (!parser.isCurrentTokenAKeyword() ||
parser.getTokenSpelling() != keyword)
return failure();
parser.consumeToken();
return success();
Expand All @@ -260,7 +255,7 @@ class AsmParserImpl : public BaseT {
/// Parse a keyword, if present, into 'keyword'.
ParseResult parseOptionalKeyword(StringRef *keyword) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
if (!parser.isCurrentTokenAKeyword())
return failure();

*keyword = parser.getTokenSpelling();
Expand All @@ -273,7 +268,7 @@ class AsmParserImpl : public BaseT {
parseOptionalKeyword(StringRef *keyword,
ArrayRef<StringRef> allowedKeywords) override {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
if (!parser.isCurrentTokenAKeyword())
return failure();

StringRef currentKeyword = parser.getTokenSpelling();
Expand Down Expand Up @@ -439,6 +434,22 @@ class AsmParserImpl : public BaseT {
return success();
}

//===--------------------------------------------------------------------===//
// Resource Parsing
//===--------------------------------------------------------------------===//

/// Parse a handle to a resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) override {
const auto *interface = dyn_cast_or_null<OpAsmDialectInterface>(dialect);
if (!interface) {
return parser.emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
}
StringRef resourceName;
return parser.parseResourceHandle(interface, resourceName);
}

//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
Expand Down
27 changes: 13 additions & 14 deletions mlir/lib/Parser/DialectSymbolParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
inputStr, /*BufferName=*/"<mlir_parser_buffer>",
/*RequiresNullTerminator=*/false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
ParserConfig config(context);
ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr);
Parser parser(state);

Token startTok = parser.getToken();
Expand Down Expand Up @@ -237,6 +238,7 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
/// attribute-alias ::= `#` alias-name
///
Attribute Parser::parseExtendedAttr(Type type) {
MLIRContext *ctx = getContext();
Attribute attr = parseExtendedSymbol<Attribute>(
*this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
Expand All @@ -250,17 +252,16 @@ Attribute Parser::parseExtendedAttr(Type type) {
if (Dialect *dialect =
builder.getContext()->getOrLoadDialect(dialectName)) {
return parseSymbol<Attribute>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
symbolData, ctx, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
return dialect->parseAttribute(customParser, attrType);
});
}

// Otherwise, form a new opaque attribute.
return OpaqueAttr::getChecked(
[&] { return emitError(loc); },
StringAttr::get(state.context, dialectName), symbolData,
attrType ? attrType : NoneType::get(state.context));
[&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
symbolData, attrType ? attrType : NoneType::get(ctx));
});

// Ensure that the attribute has the same type as requested.
Expand All @@ -280,25 +281,23 @@ Attribute Parser::parseExtendedAttr(Type type) {
/// type-alias ::= `!` alias-name
///
Type Parser::parseExtendedType() {
MLIRContext *ctx = getContext();
return parseExtendedSymbol<Type>(
*this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
SMLoc loc) -> Type {
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
// If we found a registered dialect, then ask it to parse the type.
auto *dialect = state.context->getOrLoadDialect(dialectName);

if (dialect) {
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
return parseSymbol<Type>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
symbolData, ctx, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
return dialect->parseType(customParser);
});
}

// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
[&] { return emitError(loc); },
StringAttr::get(state.context, dialectName), symbolData);
return OpaqueType::getChecked([&] { return emitError(loc); },
StringAttr::get(ctx, dialectName),
symbolData);
});
}

Expand Down
14 changes: 10 additions & 4 deletions mlir/lib/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ Token Lexer::lexToken() {
case ')':
return formToken(Token::r_paren, tokStart);
case '{':
if (*curPtr == '-' && *(curPtr + 1) == '#') {
curPtr += 2;
return formToken(Token::file_metadata_begin, tokStart);
}
return formToken(Token::l_brace, tokStart);
case '}':
return formToken(Token::r_brace, tokStart);
Expand Down Expand Up @@ -140,12 +144,14 @@ Token Lexer::lexToken() {
case '@':
return lexAtIdentifier(tokStart);

case '!':
LLVM_FALLTHROUGH;
case '^':
LLVM_FALLTHROUGH;
case '#':
if (*curPtr == '-' && *(curPtr + 1) == '}') {
curPtr += 2;
return formToken(Token::file_metadata_end, tokStart);
}
LLVM_FALLTHROUGH;
case '!':
case '^':
case '%':
return lexPrefixedIdentifier(tokStart);
case '"':
Expand Down
276 changes: 255 additions & 21 deletions mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Parser.h"
#include "AsmParserImpl.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Verifier.h"
Expand Down Expand Up @@ -291,6 +292,48 @@ ParseResult Parser::parseFloatFromIntegerLiteral(
return success();
}

ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
return failure();

*keyword = getTokenSpelling();
consumeToken();
return success();
}

//===----------------------------------------------------------------------===//
// Resource Parsing

FailureOr<AsmDialectResourceHandle>
Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
StringRef &name) {
assert(dialect && "expected valid dialect interface");
SMLoc nameLoc = getToken().getLoc();
if (failed(parseOptionalKeyword(&name)))
return emitError("expected identifier key for 'resource' entry");
auto &resources = getState().symbols.dialectResources;

// If this is the first time encountering this handle, ask the dialect to
// resolve a reference to this handle. This allows for us to remap the name of
// the handle if necessary.
std::pair<std::string, AsmDialectResourceHandle> &entry =
resources[dialect][name];
if (entry.first.empty()) {
FailureOr<AsmDialectResourceHandle> result = dialect->declareResource(name);
if (failed(result)) {
return emitError(nameLoc)
<< "unknown 'resource' key '" << name << "' for dialect '"
<< dialect->getDialect()->getNamespace() << "'";
}
entry.first = dialect->getResourceKey(*result);
entry.second = *result;
}

name = entry.first;
return entry.second;
}

//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2064,17 +2107,103 @@ class TopLevelOperationParser : public Parser {

private:
/// Parse an attribute alias declaration.
///
/// attribute-alias-def ::= '#' alias-name `=` attribute-value
///
ParseResult parseAttributeAliasDef();

/// Parse an attribute alias declaration.
/// Parse a type alias declaration.
///
/// type-alias-def ::= '!' alias-name `=` type
///
ParseResult parseTypeAliasDef();

/// Parse a top-level file metadata dictionary.
///
/// file-metadata-dict ::= '{-#' file-metadata-entry* `#-}'
///
ParseResult parseFileMetadataDictionary();

/// Parse a resource metadata dictionary.
ParseResult parseResourceFileMetadata(
function_ref<ParseResult(StringRef, SMLoc)> parseBody);
ParseResult parseDialectResourceFileMetadata();
ParseResult parseExternalResourceFileMetadata();
};

/// This class represents an implementation of a resource entry for the MLIR
/// textual format.
class ParsedResourceEntry : public AsmParsedResourceEntry {
public:
ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p)
: key(key), keyLoc(keyLoc), value(value), p(p) {}
~ParsedResourceEntry() override = default;

StringRef getKey() const final { return key; }

InFlightDiagnostic emitError() const final { return p.emitError(keyLoc); }

FailureOr<bool> parseAsBool() const final {
if (value.is(Token::kw_true))
return true;
if (value.is(Token::kw_false))
return false;
return p.emitError(value.getLoc(),
"expected 'true' or 'false' value for key '" + key +
"'");
}

FailureOr<std::string> parseAsString() const final {
if (value.isNot(Token::string))
return p.emitError(value.getLoc(),
"expected string value for key '" + key + "'");
return value.getStringValue();
}

FailureOr<AsmResourceBlob>
parseAsBlob(BlobAllocatorFn allocator) const final {
// Blob data within then textual format is represented as a hex string.
// TODO: We could avoid an additional alloc+copy here if we pre-allocated
// the buffer to use during hex processing.
Optional<std::string> blobData =
value.is(Token::string) ? value.getHexStringValue() : llvm::None;
if (!blobData)
return p.emitError(value.getLoc(),
"expected hex string blob for key '" + key + "'");

// Extract the alignment of the blob data, which gets stored at the
// beginning of the string.
if (blobData->size() < sizeof(uint32_t)) {
return p.emitError(value.getLoc(),
"expected hex string blob for key '" + key +
"' to encode alignment in first 4 bytes");
}
uint32_t align = 0;
memcpy(&align, blobData->data(), sizeof(uint32_t));

// Get the data portion of the blob.
StringRef data = StringRef(*blobData).drop_front(sizeof(uint32_t));
if (data.empty())
return AsmResourceBlob();

// Allocate memory for the blob using the provided allocator and copy the
// data into it.
AsmResourceBlob blob = allocator(data.size(), align);
assert(llvm::isAddrAligned(llvm::Align(align), blob.getData().data()) &&
blob.isMutable() &&
"blob allocator did not return a properly aligned address");
memcpy(blob.getMutableData().data(), data.data(), data.size());
return blob;
}

private:
StringRef key;
SMLoc keyLoc;
Token value;
Parser &p;
};
} // namespace

/// Parses an attribute alias declaration.
///
/// attribute-alias-def ::= '#' alias-name `=` attribute-value
///
ParseResult TopLevelOperationParser::parseAttributeAliasDef() {
assert(getToken().is(Token::hash_identifier));
StringRef aliasName = getTokenSpelling().drop_front();
Expand Down Expand Up @@ -2103,10 +2232,6 @@ ParseResult TopLevelOperationParser::parseAttributeAliasDef() {
return success();
}

/// Parse a type alias declaration.
///
/// type-alias-def ::= '!' alias-name `=` type
///
ParseResult TopLevelOperationParser::parseTypeAliasDef() {
assert(getToken().is(Token::exclamation_identifier));
StringRef aliasName = getTokenSpelling().drop_front();
Expand Down Expand Up @@ -2135,6 +2260,108 @@ ParseResult TopLevelOperationParser::parseTypeAliasDef() {
return success();
}

ParseResult TopLevelOperationParser::parseFileMetadataDictionary() {
consumeToken(Token::file_metadata_begin);
return parseCommaSeparatedListUntil(
Token::file_metadata_end, [&]() -> ParseResult {
// Parse the key of the metadata dictionary.
SMLoc keyLoc = getToken().getLoc();
StringRef key;
if (failed(parseOptionalKeyword(&key)))
return emitError("expected identifier key in file "
"metadata dictionary");
if (parseToken(Token::colon, "expected ':'"))
return failure();

// Process the metadata entry.
if (key == "dialect_resources")
return parseDialectResourceFileMetadata();
if (key == "external_resources")
return parseExternalResourceFileMetadata();
return emitError(keyLoc, "unknown key '" + key +
"' in file metadata dictionary");
});
}

ParseResult TopLevelOperationParser::parseResourceFileMetadata(
function_ref<ParseResult(StringRef, SMLoc)> parseBody) {
if (parseToken(Token::l_brace, "expected '{'"))
return failure();

return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
// Parse the top-level name entry.
SMLoc nameLoc = getToken().getLoc();
StringRef name;
if (failed(parseOptionalKeyword(&name)))
return emitError("expected identifier key for 'resource' entry");

if (parseToken(Token::colon, "expected ':'") ||
parseToken(Token::l_brace, "expected '{'"))
return failure();
return parseBody(name, nameLoc);
});
}

ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() {
return parseResourceFileMetadata([&](StringRef name,
SMLoc nameLoc) -> ParseResult {
// Lookup the dialect and check that it can handle a resource entry.
Dialect *dialect = getContext()->getOrLoadDialect(name);
if (!dialect)
return emitError(nameLoc, "dialect '" + name + "' is unknown");
const auto *handler = dyn_cast<OpAsmDialectInterface>(dialect);
if (!handler) {
return emitError() << "unexpected 'resource' section for dialect '"
<< dialect->getNamespace() << "'";
}

return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
// Parse the name of the resource entry.
SMLoc keyLoc = getToken().getLoc();
StringRef key;
if (failed(parseResourceHandle(handler, key)) ||
parseToken(Token::colon, "expected ':'"))
return failure();
Token valueTok = getToken();
consumeToken();

ParsedResourceEntry entry(key, keyLoc, valueTok, *this);
return handler->parseResource(entry);
});
});
}

ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() {
return parseResourceFileMetadata([&](StringRef name,
SMLoc nameLoc) -> ParseResult {
AsmResourceParser *handler = state.config.getResourceParser(name);

// TODO: Should we require handling external resources in some scenarios?
if (!handler) {
emitWarning(getEncodedSourceLocation(nameLoc))
<< "ignoring unknown external resources for '" << name << "'";
}

return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
// Parse the name of the resource entry.
SMLoc keyLoc = getToken().getLoc();
StringRef key;
if (failed(parseOptionalKeyword(&key)))
return emitError(
"expected identifier key for 'external_resources' entry");
if (parseToken(Token::colon, "expected ':'"))
return failure();
Token valueTok = getToken();
consumeToken();

if (!handler)
return success();
ParsedResourceEntry entry(key, keyLoc, valueTok, *this);
return handler->parseResource(entry);
});
});
}

ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
Location parserLoc) {
// Create a top-level operation to contain the parsed state.
Expand Down Expand Up @@ -2179,63 +2406,70 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
if (parseTypeAliasDef())
return failure();
break;

// Parse a file-level metadata dictionary.
case Token::file_metadata_begin:
if (parseFileMetadataDictionary())
return failure();
break;
}
}
}

//===----------------------------------------------------------------------===//

LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
Block *block, MLIRContext *context,
Block *block, const ParserConfig &config,
LocationAttr *sourceFileLoc,
AsmParserState *asmState) {
const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());

Location parserLoc = FileLineColLoc::get(
context, sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0);
Location parserLoc =
FileLineColLoc::get(config.getContext(), sourceBuf->getBufferIdentifier(),
/*line=*/0, /*column=*/0);
if (sourceFileLoc)
*sourceFileLoc = parserLoc;

SymbolState aliasState;
ParserState state(sourceMgr, context, aliasState, asmState);
ParserState state(sourceMgr, config, aliasState, asmState);
return TopLevelOperationParser(state).parse(block, parserLoc);
}

LogicalResult mlir::parseSourceFile(llvm::StringRef filename, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc) {
llvm::SourceMgr sourceMgr;
return parseSourceFile(filename, sourceMgr, block, context, sourceFileLoc);
return parseSourceFile(filename, sourceMgr, block, config, sourceFileLoc);
}

LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc,
AsmParserState *asmState) {
if (sourceMgr.getNumBuffers() != 0) {
// TODO: Extend to support multiple buffers.
return emitError(mlir::UnknownLoc::get(context),
return emitError(mlir::UnknownLoc::get(config.getContext()),
"only main buffer parsed at the moment");
}
auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code error = fileOrErr.getError())
return emitError(mlir::UnknownLoc::get(context),
return emitError(mlir::UnknownLoc::get(config.getContext()),
"could not open input file " + filename);

// Load the MLIR source file.
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc());
return parseSourceFile(sourceMgr, block, context, sourceFileLoc, asmState);
return parseSourceFile(sourceMgr, block, config, sourceFileLoc, asmState);
}

LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block,
MLIRContext *context,
const ParserConfig &config,
LocationAttr *sourceFileLoc) {
auto memBuffer = MemoryBuffer::getMemBuffer(sourceStr);
if (!memBuffer)
return failure();

SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
return parseSourceFile(sourceMgr, block, context, sourceFileLoc);
return parseSourceFile(sourceMgr, block, config, sourceFileLoc);
}
22 changes: 20 additions & 2 deletions mlir/lib/Parser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ class Parser {

Builder builder;

Parser(ParserState &state) : builder(state.context), state(state) {}
Parser(ParserState &state)
: builder(state.config.getContext()), state(state) {}

// Helper methods to get stuff from the parser-global state.
ParserState &getState() const { return state; }
MLIRContext *getContext() const { return state.context; }
MLIRContext *getContext() const { return state.config.getContext(); }
const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }

/// Parse a comma-separated list of elements up until the specified end token.
Expand Down Expand Up @@ -153,6 +154,23 @@ class Parser {
const llvm::fltSemantics &semantics,
size_t typeSizeInBits);

/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return getToken().isAny(Token::bare_identifier, Token::inttype) ||
getToken().isKeyword();
}

/// Parse a keyword, if present, into 'keyword'.
ParseResult parseOptionalKeyword(StringRef *keyword);

//===--------------------------------------------------------------------===//
// Resource Parsing
//===--------------------------------------------------------------------===//

/// Parse a handle to a dialect resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);

//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
Expand Down
22 changes: 14 additions & 8 deletions mlir/lib/Parser/ParserState.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ namespace detail {

/// This class contains record of any parsed top-level symbols.
struct SymbolState {
// A map from attribute alias identifier to Attribute.
/// A map from attribute alias identifier to Attribute.
llvm::StringMap<Attribute> attributeAliasDefinitions;

// A map from type alias identifier to Type.
/// A map from type alias identifier to Type.
llvm::StringMap<Type> typeAliasDefinitions;

/// A map of dialect resource keys to the resolved resource name and handle
/// to use during parsing.
DenseMap<const OpAsmDialectInterface *,
llvm::StringMap<std::pair<std::string, AsmDialectResourceHandle>>>
dialectResources;

/// A set of locations into the main parser memory buffer for each of the
/// active nested parsers. Given that some nested parsers, i.e. custom dialect
/// parsers, operate on a temporary memory buffer, this provides an anchor
Expand All @@ -47,11 +53,11 @@ struct SymbolState {
/// This class refers to all of the state maintained globally by the parser,
/// such as the current lexer position etc.
struct ParserState {
ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
ParserState(const llvm::SourceMgr &sourceMgr, const ParserConfig &config,
SymbolState &symbols, AsmParserState *asmState)
: context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
symbols(symbols), parserDepth(symbols.nestedParserLocs.size()),
asmState(asmState) {
: config(config), lex(sourceMgr, config.getContext()),
curToken(lex.lexToken()), symbols(symbols),
parserDepth(symbols.nestedParserLocs.size()), asmState(asmState) {
// Set the top level lexer for the symbol state if one doesn't exist.
if (!symbols.topLevelLexer)
symbols.topLevelLexer = &lex;
Expand All @@ -64,8 +70,8 @@ struct ParserState {
ParserState(const ParserState &) = delete;
void operator=(const ParserState &) = delete;

/// The context we're parsing into.
MLIRContext *const context;
/// The configuration used to setup the parser.
const ParserConfig &config;

/// The lexer for the source file we're parsing.
Lexer lex;
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Parser/Token.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ Optional<std::string> Token::getHexStringValue() const {
// Get the internal string data, without the quotes.
StringRef bytes = getSpelling().drop_front().drop_back();

// Try to extract the binary data from the hex string.
// Try to extract the binary data from the hex string. We expect the hex
// string to start with `0x` and have an even number of hex nibbles (nibbles
// should come in pairs).
std::string hex;
if (!bytes.consume_front("0x") || !llvm::tryGetFromHex(bytes, hex))
if (!bytes.consume_front("0x") || (bytes.size() & 1) ||
!llvm::tryGetFromHex(bytes, hex))
return llvm::None;
return hex;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Parser/TokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ TOK_PUNCTUATION(r_square, "]")
TOK_PUNCTUATION(star, "*")
TOK_PUNCTUATION(vertical_bar, "|")

TOK_PUNCTUATION(file_metadata_begin, "{-#")
TOK_PUNCTUATION(file_metadata_end, "#-}")

// Keywords. These turn "foo" into Token::kw_foo enums.

// NOTE: Please key these alphabetized to make it easier to find something in
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ Type Parser::parseMemRefType() {
if (failed(parseStridedLayout(offset, strides)))
return failure();
// Construct strided affine map.
AffineMap map =
makeStridedLinearLayoutMap(strides, offset, state.context);
AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext());
layout = AffineMapAttr::get(map);
} else {
// Either it is MemRefLayoutAttrInterface or memory space attribute.
Expand Down
53 changes: 44 additions & 9 deletions mlir/lib/Pass/PassCrashRecovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -117,17 +118,16 @@ void RecoveryReproducerContext::generate(std::string &description) {
}
descOS << "reproducer generated at `" << stream->description() << "`";

// Output the current pass manager configuration to the crash stream.
auto &os = stream->os();
os << "// configuration: -pass-pipeline='" << pipeline << "'";
if (disableThreads)
os << " -mlir-disable-threading";
if (verifyPasses)
os << " -verify-each";
os << '\n';
AsmState state(preCrashOperation);
state.attachResourcePrinter(
"mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
builder.buildString("pipeline", pipeline);
builder.buildBool("disable_threading", disableThreads);
builder.buildBool("verify_each", verifyPasses);
});

// Output the .mlir module.
preCrashOperation->print(os);
preCrashOperation->print(stream->os(), state);
}

void RecoveryReproducerContext::disable() {
Expand Down Expand Up @@ -438,3 +438,38 @@ void PassManager::enableCrashReproducerGeneration(
addInstrumentation(
std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
}

//===----------------------------------------------------------------------===//
// Asm Resource
//===----------------------------------------------------------------------===//

void mlir::attachPassReproducerAsmResource(ParserConfig &config,
PassManager &pm,
bool &enableThreading) {
auto parseFn = [&](AsmParsedResourceEntry &entry) -> LogicalResult {
if (entry.getKey() == "pipeline") {
FailureOr<std::string> pipeline = entry.parseAsString();
if (failed(pipeline))
return failure();
return parsePassPipeline(*pipeline, pm);
}
if (entry.getKey() == "disable_threading") {
FailureOr<bool> value = entry.parseAsBool();

// FIXME: We should just update the context directly, but some places
// force disable threading during parsing.
if (succeeded(value))
enableThreading = !(*value);
return value;
}
if (entry.getKey() == "verify_each") {
FailureOr<bool> value = entry.parseAsBool();
if (succeeded(value))
pm.enableVerifier(*value);
return value;
}
return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
<< entry.getKey() << "'";
};
config.attachResourceParser("mlir_reproducer", parseFn);
}
41 changes: 12 additions & 29 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,25 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool wasThreadingEnabled = context->isMultithreadingEnabled();
context->disableMultithreading();

// Prepare the pass manager and apply any command line options.
PassManager pm(context, OpPassManager::Nesting::Implicit);
pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);

// Prepare the parser config, and attach any useful/necessary resource
// handlers.
ParserConfig config(context);
attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);

// Parse the input file and reset the context threading state.
TimingScope parserTiming = timing.nest("Parser");
OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, context));
OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, config));
context->enableMultithreading(wasThreadingEnabled);
if (!module)
return failure();
parserTiming.stop();

// Apply any pass manager command line options.
PassManager pm(context, OpPassManager::Nesting::Implicit);
pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);

// Callback to build the pipeline.
if (failed(passManagerSetupFn(pm)))
return failure();
Expand Down Expand Up @@ -219,11 +224,6 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
"show-dialects", cl::desc("Print the list of registered dialects"),
cl::init(false));

static cl::opt<bool> runRepro(
"run-reproducer",
cl::desc("Append the command line options of the reproducer"),
cl::init(false));

InitLLVM y(argc, argv);

// Register any command line options.
Expand Down Expand Up @@ -260,23 +260,6 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
return failure();
}

// Parse reproducer options.
BumpPtrAllocator a;
StringSaver saver(a);
if (runRepro) {
auto pair = file->getBuffer().split('\n');
if (!pair.first.consume_front("// configuration:")) {
llvm::errs() << "Failed to find repro configuration, expect file to "
"begin with '// configuration:'\n";
return failure();
}
// Tokenize & parse the first line.
SmallVector<const char *, 4> newArgv;
newArgv.push_back(argv[0]);
llvm::cl::TokenizeGNUCommandLine(pair.first, saver, newArgv);
cl::ParseCommandLineOptions(newArgv.size(), &newArgv[0], helpHeader);
}

auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/IR/elements-attr-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ arith.constant [:i64 10, 11, -12, 13, 14]
arith.constant [:f32 10., 11., -12., 13., 14.]
// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
arith.constant [:f64 10., 11., -12., 13., 14.]

// Check that we handle an external constant parsed from the config.
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `uint64_t`: 1, 2, 3}}
// expected-error@below {{Test iterating `APInt`: unable to iterate type}}
// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}}
arith.constant #test.e1di64_elements<blob1> : tensor<3xi64>

{-#
dialect_resources: {
test: {
blob1: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}
17 changes: 17 additions & 0 deletions mlir/test/IR/file-metadata-resources.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt %s -split-input-file | FileCheck %s

// Check that we only preserve the blob that got referenced.
// CHECK: test: {
// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000"
// CHECK-NEXT: }

module attributes { test.blob_ref = #test.e1di64_elements<blob1> } {}

{-#
dialect_resources: {
test: {
blob1: "0x08000000010000000000000002000000000000000300000000000000",
blob2: "0x08000000040000000000000005000000000000000600000000000000"
}
}
#-}
142 changes: 142 additions & 0 deletions mlir/test/IR/invalid-file-metadata.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+2 {{expected identifier key in file metadata dictionary}}
{-#

// -----

// expected-error@+2 {{expected ':'}}
{-#
key
#-}

// -----

// expected-error@+2 {{unknown key 'some_key' in file metadata dictionary}}
{-#
some_key: {}
#-}

// -----

//===----------------------------------------------------------------------===//
// `dialect_resources`
//===----------------------------------------------------------------------===//

// expected-error@+2 {{expected '{'}}
{-#
dialect_resources: "value"
#-}

// -----

// expected-error@+3 {{expected identifier key for 'resource' entry}}
{-#
dialect_resources: {
10
}
#-}

// -----

// expected-error@+3 {{expected ':'}}
{-#
dialect_resources: {
entry "value"
}
#-}

// -----

// expected-error@+3 {{dialect 'foobar' is unknown}}
{-#
dialect_resources: {
foobar: {
entry: "foo"
}
}
#-}

// -----

// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}}
{-#
dialect_resources: {
builtin: {
unknown_entry: "foo"
}
}
#-}

// -----

// expected-error@+4 {{expected hex string blob for key 'invalid_blob'}}
{-#
dialect_resources: {
test: {
invalid_blob: 10
}
}
#-}

// -----

// expected-error@+4 {{expected hex string blob for key 'invalid_blob'}}
{-#
dialect_resources: {
test: {
invalid_blob: ""
}
}
#-}

// -----

// expected-error@+4 {{expected hex string blob for key 'invalid_blob' to encode alignment in first 4 bytes}}
{-#
dialect_resources: {
test: {
invalid_blob: "0x"
}
}
#-}

// -----

//===----------------------------------------------------------------------===//
// `external_resources`
//===----------------------------------------------------------------------===//

// expected-error@+2 {{expected '{'}}
{-#
external_resources: "value"
#-}

// -----

// expected-error@+3 {{expected identifier key for 'resource' entry}}
{-#
external_resources: {
10
}
#-}

// -----

// expected-error@+3 {{expected ':'}}
{-#
external_resources: {
entry "value"
}
#-}

// -----

// expected-warning@+3 {{ignoring unknown external resources for 'foobar'}}
{-#
external_resources: {
foobar: {
entry: "foo"
}
}
#-}
3 changes: 2 additions & 1 deletion mlir/test/Pass/crash-recovery-dynamic-failure.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ module @inner_mod1 {
module @foo {}
}

// REPRO_LOCAL_DYNAMIC_FAILURE: configuration: -pass-pipeline='builtin.module(test-pass-failure)'

// REPRO_LOCAL_DYNAMIC_FAILURE: module @inner_mod1
// REPRO_LOCAL_DYNAMIC_FAILURE: module @foo {

// REPRO_LOCAL_DYNAMIC_FAILURE: pipeline: "builtin.module(test-pass-failure)"
9 changes: 3 additions & 6 deletions mlir/test/Pass/crash-recovery.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ module @inner_mod1 {
module @foo {}
}

// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass,test-pass-crash)'

// REPRO: module @inner_mod1
// REPRO: module @foo {

// REPRO_LOCAL: configuration: -pass-pipeline='builtin.module(test-pass-crash)'
// REPRO: pipeline: "builtin.module(test-module-pass,test-pass-crash)"

// REPRO_LOCAL: module @inner_mod1
// REPRO_LOCAL: module @foo {

// REPRO_LOCAL_DYNAMIC: configuration: -pass-pipeline='builtin.module(test-pass-crash)'
// REPRO_LOCAL: pipeline: "builtin.module(test-pass-crash)"

// REPRO_LOCAL_DYNAMIC: module @inner_mod1
// REPRO_LOCAL_DYNAMIC: module @foo {
// REPRO_LOCAL_DYNAMIC: pipeline: "builtin.module(test-pass-crash)"
16 changes: 10 additions & 6 deletions mlir/test/Pass/run-reproducer.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
// configuration: -mlir-disable-threading=true -pass-pipeline='func.func(cse,canonicalize)' -mlir-print-ir-before=cse

// Test of the reproducer run option. The first line has to be the
// configuration (matching what is produced by reproducer).

// RUN: mlir-opt %s -run-reproducer 2>&1 | FileCheck -check-prefix=BEFORE %s
// RUN: mlir-opt %s -mlir-print-ir-before=cse 2>&1 | FileCheck -check-prefix=BEFORE %s

func.func @foo() {
%0 = arith.constant 0 : i32
Expand All @@ -14,6 +9,15 @@ func.func @bar() {
return
}

{-#
external_resources: {
mlir_reproducer: {
pipeline: "func.func(cse,canonicalize)",
disable_threading: true
}
}
#-}

// BEFORE: // -----// IR Dump Before{{.*}}CSE //----- //
// BEFORE-NEXT: func @foo()
// BEFORE: // -----// IR Dump Before{{.*}}CSE //----- //
Expand Down
7 changes: 3 additions & 4 deletions mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class TestModuleCombinerPass
TestModuleCombinerPass() = default;
TestModuleCombinerPass(const TestModuleCombinerPass &) {}
void runOnOperation() override;

private:
OwningOpRef<spirv::ModuleOp> combinedModule;
};
} // namespace

Expand All @@ -46,10 +43,12 @@ void TestModuleCombinerPass::runOnOperation() {
<< " -> " << newSymbol << "\n";
};

combinedModule = spirv::combine(modules, combinedModuleBuilder, listener);
OwningOpRef<spirv::ModuleOp> combinedModule =
spirv::combine(modules, combinedModuleBuilder, listener);

for (spirv::ModuleOp module : modules)
module.erase();
combinedModule.release();
}

namespace mlir {
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "TestDialect.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SubElementInterfaces.td"

// All of the attributes will extend this class.
Expand Down Expand Up @@ -222,4 +223,29 @@ def TestAttrSelfTypeParameterFormat
let assemblyFormat = "`<` $a `>`";
}

// Test simple extern 1D vector using ElementsAttrInterface.
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
ElementsAttrInterface
]> {
let mnemonic = "e1di64_elements";
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
ResourceHandleParameter<"TestExternalElementsDataHandle">:$handle
);
let extraClassDeclaration = [{
/// Return the elements referenced by this attribute.
llvm::ArrayRef<uint64_t> getElements() const;

/// The set of data types that can be iterated by this attribute.
using ContiguousIterableTypesT = std::tuple<uint64_t>;

/// Provide begin iterators for the various iterable types.
// * uint64_t
auto value_begin_impl(OverloadToken<uint64_t>) const {
return getElements().begin();
}
}];
let assemblyFormat = "`<` $handle `>`";
}

#endif // TEST_ATTRDEFS
8 changes: 8 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
return get(getContext(), first, second, third);
}

//===----------------------------------------------------------------------===//
// TestExtern1DI64ElementsAttr
//===----------------------------------------------------------------------===//

ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
return getHandle().getData()->getData();
}

//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "TestAttrInterfaces.h.inc"
#include "TestOpEnums.h.inc"

namespace test {
struct TestExternalElementsDataHandle;
} // namespace test

#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.h.inc"

Expand Down
100 changes: 100 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -43,6 +44,55 @@ void test::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
}

//===----------------------------------------------------------------------===//
// External Elements Data
//===----------------------------------------------------------------------===//

ArrayRef<uint64_t> TestExternalElementsData::getData() const {
ArrayRef<char> data = AsmResourceBlob::getData();
return ArrayRef<uint64_t>((const uint64_t *)data.data(),
data.size() / sizeof(uint64_t));
}

TestExternalElementsData
TestExternalElementsData::allocate(size_t numElements) {
return TestExternalElementsData(
llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements),
[](const uint64_t *data, size_t) { delete[] data; },
/*dataIsMutable=*/true);
}

const TestExternalElementsData *
TestExternalElementsDataManager::getData(StringRef name) const {
auto it = dataMap.find(name);
return it != dataMap.end() ? &*it->second : nullptr;
}

std::pair<TestExternalElementsDataManager::DataMap::iterator, bool>
TestExternalElementsDataManager::insert(StringRef name) {
auto it = dataMap.try_emplace(name, nullptr);
if (it.second)
return it;

llvm::SmallString<32> nameStorage(name);
nameStorage.push_back('_');
size_t nameCounter = 1;
do {
nameStorage += std::to_string(nameCounter++);
auto it = dataMap.try_emplace(nameStorage, nullptr);
if (it.second)
return it;
nameStorage.resize(name.size() + 1);
} while (true);
}

void TestExternalElementsDataManager::setData(StringRef name,
TestExternalElementsData &&data) {
auto it = dataMap.find(name);
assert(it != dataMap.end() && "data not registered");
it->second = std::make_unique<TestExternalElementsData>(std::move(data));
}

//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
Expand All @@ -63,6 +113,10 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;

//===------------------------------------------------------------------===//
// Aliases
//===------------------------------------------------------------------===//

AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
StringAttr strAttr = attr.dyn_cast<StringAttr>();
if (!strAttr)
Expand Down Expand Up @@ -108,6 +162,52 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
}
return AliasResult::NoAlias;
}

//===------------------------------------------------------------------===//
// Resources
//===------------------------------------------------------------------===//

std::string
getResourceKey(const AsmDialectResourceHandle &handle) const override {
return cast<TestExternalElementsDataHandle>(handle).getKey().str();
}

FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key) const final {
TestDialect *dialect = cast<TestDialect>(getDialect());
TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();

// Resolve the reference by inserting a new entry into the manager.
auto it = mgr.insert(key).first;
return TestExternalElementsDataHandle(&*it, dialect);
}

LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
TestDialect *dialect = cast<TestDialect>(getDialect());
TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();

// The resource entries are external constant data.
auto blobAllocFn = [](unsigned size, unsigned align) {
assert(align == alignof(uint64_t) && "unexpected data alignment");
return TestExternalElementsData::allocate(size / sizeof(uint64_t));
};
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn);
if (failed(blob))
return failure();

mgr.setData(entry.getKey(), std::move(*blob));
return success();
}

void
buildResources(Operation *op,
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &provider) const final {
for (const AsmDialectResourceHandle &handle : referencedResources) {
const auto &testHandle = cast<TestExternalElementsDataHandle>(handle);
provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData());
}
}
};

struct TestDialectFoldInterface : public DialectFoldInterface {
Expand Down
67 changes: 67 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
Expand All @@ -44,6 +45,72 @@ class DLTIDialect;
class RewritePatternSet;
} // namespace mlir

namespace test {
class TestDialect;

//===----------------------------------------------------------------------===//
// External Elements Data
//===----------------------------------------------------------------------===//

/// This class represents a single external elements instance. It keeps track of
/// the data, and deallocates when destructed.
class TestExternalElementsData : public mlir::AsmResourceBlob {
public:
using mlir::AsmResourceBlob::AsmResourceBlob;
TestExternalElementsData(mlir::AsmResourceBlob &&blob)
: mlir::AsmResourceBlob(std::move(blob)) {}

/// Return the data of this external elements instance.
llvm::ArrayRef<uint64_t> getData() const;

/// Allocate a new external elements instance with the given number of
/// elements.
static TestExternalElementsData allocate(size_t numElements);
};

/// A handle used to reference external elements instances.
struct TestExternalElementsDataHandle
: public mlir::AsmDialectResourceHandleBase<
TestExternalElementsDataHandle,
llvm::StringMapEntry<std::unique_ptr<TestExternalElementsData>>,
TestDialect> {
using AsmDialectResourceHandleBase::AsmDialectResourceHandleBase;

/// Return a key to use for this handle.
llvm::StringRef getKey() const { return getResource()->getKey(); }

/// Return the data referenced by this handle.
TestExternalElementsData *getData() const {
return getResource()->getValue().get();
}
};

/// This class acts as a manager for external elements data. It provides API
/// for creating and accessing registered elements data.
class TestExternalElementsDataManager {
using DataMap = llvm::StringMap<std::unique_ptr<TestExternalElementsData>>;

public:
/// Return the data registered for the given name, or nullptr if no data is
/// registered.
const TestExternalElementsData *getData(llvm::StringRef name) const;

/// Register an entry with the provided name, which may be modified if another
/// entry was already inserted with that name. Returns the inserted entry.
std::pair<DataMap::iterator, bool> insert(llvm::StringRef name);

/// Set the data for the given entry, which is expected to exist.
void setData(llvm::StringRef name, TestExternalElementsData &&data);

private:
llvm::StringMap<std::unique_ptr<TestExternalElementsData>> dataMap;
};
} // namespace test

//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//

#include "TestOpInterfaces.h.inc"
#include "TestOpsDialect.h.inc"

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def Test_Dialect : Dialect {
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const override;

/// Returns the external elements data manager for this dialect.
TestExternalElementsDataManager &getExternalDataManager() {
return externalDataManager;
}

private:
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;
Expand All @@ -49,6 +55,9 @@ def Test_Dialect : Dialect {
::llvm::SetVector<::mlir::Type> &stack) const;
void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer,
::llvm::SetVector<::mlir::Type> &stack) const;

/// An external data manager used to test external elements data.
TestExternalElementsDataManager externalDataManager;
}];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Support)
add_subdirectory(Rewrite)
Expand Down
13 changes: 13 additions & 0 deletions mlir/unittests/Parser/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_mlir_unittest(MLIRParserTests
ResourceTest.cpp

DEPENDS
MLIRTestInterfaceIncGen
)
target_include_directories(MLIRParserTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")

target_link_libraries(MLIRParserTests PRIVATE
MLIRIR
MLIRParser
MLIRTestDialect
)
75 changes: 75 additions & 0 deletions mlir/unittests/Parser/ResourceTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//===- ResourceTest.cpp -----------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "mlir/Parser/Parser.h"

#include "gmock/gmock.h"

using namespace mlir;

namespace {
TEST(MLIRParser, ResourceKeyConflict) {
std::string moduleStr = R"mlir(
"test.use1"() {attr = #test.e1di64_elements<blob1> : tensor<3xi64> } : () -> ()
{-#
dialect_resources: {
test: {
blob1: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}
)mlir";
std::string moduleStr2 = R"mlir(
"test.use2"() {attr = #test.e1di64_elements<blob1> : tensor<3xi64> } : () -> ()
{-#
dialect_resources: {
test: {
blob1: "0x08000000040000000000000005000000000000000600000000000000"
}
}
#-}
)mlir";

MLIRContext context;
context.loadDialect<test::TestDialect>();

// Parse both modules into the same context so that we ensure the conflicting
// resources have been loaded.
OwningOpRef<ModuleOp> module1 =
parseSourceString<ModuleOp>(moduleStr, &context);
OwningOpRef<ModuleOp> module2 =
parseSourceString<ModuleOp>(moduleStr2, &context);
ASSERT_TRUE(module1 && module2);

// Merge the two modules so that we can test printing the remapped resources.
Block *block = module1->getBody();
block->getOperations().splice(block->end(),
module2->getBody()->getOperations());

// Check that conflicting resources were remapped.
std::string outputStr;
{
llvm::raw_string_ostream os(outputStr);
module1->print(os);
}
StringRef output(outputStr);
EXPECT_TRUE(
output.contains("\"test.use1\"() {attr = #test.e1di64_elements<blob1>"));
EXPECT_TRUE(output.contains(
"blob1: \"0x08000000010000000000000002000000000000000300000000000000\""));
EXPECT_TRUE(output.contains(
"\"test.use2\"() {attr = #test.e1di64_elements<blob1_1>"));
EXPECT_TRUE(output.contains(
"blob1_1: "
"\"0x08000000040000000000000005000000000000000600000000000000\""));
}
} // namespace