147 changes: 4 additions & 143 deletions mlir/include/mlir/IR/Identifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,151 +9,12 @@
#ifndef MLIR_IR_IDENTIFIER_H
#define MLIR_IR_IDENTIFIER_H

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringMapEntry.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "mlir/IR/BuiltinAttributes.h"

namespace mlir {
class Dialect;
class MLIRContext;

/// This class represents a uniqued string owned by an MLIRContext. Strings
/// represented by this type cannot contain nul characters, and may not have a
/// zero length.
///
/// This is a POD type with pointer size, so it should be passed around by
/// value. The underlying data is owned by MLIRContext and is thus immortal for
/// almost all clients.
///
/// An Identifier may be prefixed with a dialect namespace followed by a single
/// dot `.`. This is particularly useful when used as a key in a NamedAttribute
/// to differentiate a dependent attribute (specific to an operation) from a
/// generic attribute defined by the dialect (in general applicable to multiple
/// operations).
class Identifier {
using EntryType =
llvm::StringMapEntry<PointerUnion<Dialect *, MLIRContext *>>;

public:
/// Return an identifier for the specified string.
static Identifier get(const Twine &string, MLIRContext *context);

Identifier(const Identifier &) = default;
Identifier &operator=(const Identifier &other) = default;

/// Return a StringRef for the string.
StringRef strref() const { return entry->first(); }

/// Identifiers implicitly convert to StringRefs.
operator StringRef() const { return strref(); }

/// Return an std::string.
std::string str() const { return strref().str(); }

/// Return a null terminated C string.
const char *c_str() const { return entry->getKeyData(); }

/// Return a pointer to the start of the string data.
const char *data() const { return entry->getKeyData(); }

/// Return the number of bytes in this string.
unsigned size() const { return entry->getKeyLength(); }

/// Return the dialect loaded in the context for this identifier or nullptr if
/// this identifier isn't prefixed with a loaded dialect. For example the
/// `llvm.fastmathflags` identifier would return the LLVM dialect here,
/// assuming it is loaded in the context.
Dialect *getDialect();

/// Return the current MLIRContext associated with this identifier.
MLIRContext *getContext();

const char *begin() const { return data(); }
const char *end() const { return entry->getKeyData() + size(); }

bool operator==(Identifier other) const { return entry == other.entry; }
bool operator!=(Identifier rhs) const { return !(*this == rhs); }

void print(raw_ostream &os) const;
void dump() const;

const void *getAsOpaquePointer() const {
return static_cast<const void *>(entry);
}
static Identifier getFromOpaquePointer(const void *entry) {
return Identifier(static_cast<const EntryType *>(entry));
}

/// Compare the underlying StringRef.
int compare(Identifier rhs) const { return strref().compare(rhs.strref()); }

private:
/// This contains the bytes of the string, which is guaranteed to be nul
/// terminated.
const EntryType *entry;
explicit Identifier(const EntryType *entry) : entry(entry) {}
};

inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) {
identifier.print(os);
return os;
}

// Identifier/Identifier equality comparisons are defined inline.
inline bool operator==(Identifier lhs, StringRef rhs) {
return lhs.strref() == rhs;
}
inline bool operator!=(Identifier lhs, StringRef rhs) { return !(lhs == rhs); }

inline bool operator==(StringRef lhs, Identifier rhs) {
return rhs.strref() == lhs;
}
inline bool operator!=(StringRef lhs, Identifier rhs) { return !(lhs == rhs); }

// Make identifiers hashable.
inline llvm::hash_code hash_value(Identifier arg) {
// Identifiers are uniqued, so we can just hash the pointer they contain.
return llvm::hash_value(arg.getAsOpaquePointer());
}
/// NOTICE: Identifier is deprecated and usages of it should be replaced with
/// StringAttr.
using Identifier = StringAttr;
} // end namespace mlir

namespace llvm {
// Identifiers hash just like pointers, there is no need to hash the bytes.
template <>
struct DenseMapInfo<mlir::Identifier> {
static mlir::Identifier getEmptyKey() {
auto pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
return mlir::Identifier::getFromOpaquePointer(pointer);
}
static mlir::Identifier getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
return mlir::Identifier::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::Identifier val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::Identifier lhs, mlir::Identifier rhs) {
return lhs == rhs;
}
};

/// The pointer inside of an identifier comes from a StringMap, so its alignment
/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
/// steal the low bits.
template <>
struct PointerLikeTypeTraits<mlir::Identifier> {
public:
static inline void *getAsVoidPointer(mlir::Identifier i) {
return const_cast<void *>(i.getAsOpaquePointer());
}
static inline mlir::Identifier getFromVoidPointer(void *p) {
return mlir::Identifier::getFromOpaquePointer(p);
}
static constexpr int NumLowBitsAvailable = 2;
};

} // end namespace llvm
#endif
1 change: 0 additions & 1 deletion mlir/include/mlir/IR/Location.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

namespace mlir {

class Identifier;
class Location;
class WalkResult;

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class OperationName {
Dialect *getDialect() const {
if (const auto *abstractOp = getAbstractOperation())
return &abstractOp->dialect;
return representation.get<Identifier>().getDialect();
return representation.get<Identifier>().getReferencedDialect();
}

/// Return the operation name with dialect name stripped, if it has one.
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/IR/StorageUniquerSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {

/// Get an instance of the concrete type from a void pointer.
static ConcreteT getFromOpaquePointer(const void *ptr) {
return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>()
: nullptr;
return ConcreteT((const typename BaseT::ImplType *)ptr);
}

protected:
Expand Down
2 changes: 0 additions & 2 deletions mlir/include/mlir/IR/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#include "llvm/ADT/StringMap.h"

namespace mlir {
class Identifier;
class Operation;

/// This class allows for representing and managing the symbol table used by
/// operations with the 'SymbolTable' trait. Inserting into and erasing from
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,

// Make Type hashable.
inline ::llvm::hash_code hash_value(Type arg) {
return ::llvm::hash_value(arg.impl);
return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
}

template <typename U> bool Type::isa() const {
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ class Any;

namespace mlir {
class AnalysisManager;
class Identifier;
class MLIRContext;
class Operation;
class Pass;
class PassInstrumentation;
class PassInstrumentor;
class StringAttr;

// TODO: Remove this when all usages have been replaced with StringAttr.
using Identifier = StringAttr;

namespace detail {
struct OpPassManagerImpl;
Expand Down
9 changes: 7 additions & 2 deletions mlir/include/mlir/Support/StorageUniquer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,13 @@ class StorageUniquer {
/// Copy the provided string into memory managed by our bump pointer
/// allocator.
StringRef copyInto(StringRef str) {
auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
return StringRef(result.data(), str.size());
if (str.empty())
return StringRef();

char *result = allocator.Allocate<char>(str.size() + 1);
std::uninitialized_copy(str.begin(), str.end(), result);
result[str.size()] = 0;
return StringRef(result, str.size());
}

/// Allocate an instance of the provided type.
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Support/TypeID.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ class TypeID {
TypeID() : TypeID(get<void>()) {}

/// Comparison operations.
bool operator==(const TypeID &other) const {
inline bool operator==(const TypeID &other) const {
return storage == other.storage;
}
bool operator!=(const TypeID &other) const { return !(*this == other); }
inline bool operator!=(const TypeID &other) const {
return !(*this == other);
}

/// Construct a type info object for the given type T.
template <typename T>
Expand Down Expand Up @@ -94,7 +96,7 @@ class TypeID {

/// Enable hashing TypeID.
inline ::llvm::hash_code hash_value(TypeID id) {
return llvm::hash_value(id.storage);
return DenseMapInfo<const TypeID::Storage *>::getHashValue(id.storage);
}

namespace detail {
Expand Down Expand Up @@ -166,11 +168,11 @@ TypeID TypeID::get() {

namespace llvm {
template <> struct DenseMapInfo<mlir::TypeID> {
static mlir::TypeID getEmptyKey() {
static inline mlir::TypeID getEmptyKey() {
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::TypeID::getFromOpaquePointer(pointer);
}
static mlir::TypeID getTombstoneKey() {
static inline mlir::TypeID getTombstoneKey() {
void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::TypeID::getFromOpaquePointer(pointer);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class LLVMTranslationInterface
amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface =
getInterfaceFor(attribute.first.getDialect())) {
getInterfaceFor(attribute.first.getReferencedDialect())) {
return iface->amendOperation(op, attribute, moduleTranslation);
}
return success();
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,8 @@ class PyOpAttributeMap {
mlirOperationGetAttribute(operation->get(), index);
return PyNamedAttribute(
namedAttr.attribute,
std::string(mlirIdentifierStr(namedAttr.name).data));
std::string(mlirIdentifierStr(namedAttr.name).data,
mlirIdentifierStr(namedAttr.name).length));
}

void dunderSetItem(const std::string &name, PyAttribute attr) {
Expand Down Expand Up @@ -2601,7 +2602,8 @@ void mlir::python::populateIRCore(py::module &m) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
printAccum.parts.append(
mlirIdentifierStr(self.namedAttr.name).data);
py::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length));
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
printAccum.getCallback(),
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/CAPI/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
}

MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
}

MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(str), unwrap(type)));
return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
}

MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,

MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
MlirOperation operation) {
return wrap(unwrap(symbolTable)->insert(unwrap(operation)));
return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
}

void mlirSymbolTableErase(MlirSymbolTable symbolTable,
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
} else {
auto id = entry.getKey().get<Identifier>();
if (!ids.insert(id).second)
return emitError() << "repeated layout entry key: " << id;
return emitError() << "repeated layout entry key: " << id.getValue();
}
}
return success();
Expand Down Expand Up @@ -221,7 +221,7 @@ combineOneSpec(DataLayoutSpecInterface spec,

for (const auto &kvp : newEntriesForID) {
Identifier id = kvp.second.getKey().get<Identifier>();
Dialect *dialect = id.getDialect();
Dialect *dialect = id.getReferencedDialect();
if (!entriesForID.count(id)) {
entriesForID[id] = kvp.second;
continue;
Expand Down Expand Up @@ -377,6 +377,6 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
return success();
}

return op->emitError() << "attribute '" << attr.first
return op->emitError() << "attribute '" << attr.first.getValue()
<< "' not supported by dialect";
}
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
// Copy over unknown attributes. They might be load bearing for some flow.
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
for (NamedAttribute kv : genericOp->getAttrs()) {
if (!llvm::is_contained(odsAttrs, kv.first.c_str())) {
if (!llvm::is_contained(odsAttrs, kv.first.getValue())) {
newOp->setAttr(kv.first, kv.second);
}
}
Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@
using namespace mlir;
using namespace mlir::detail;

void Identifier::print(raw_ostream &os) const { os << str(); }

void Identifier::dump() const { print(llvm::errs()); }

void OperationName::print(raw_ostream &os) const { os << getStringRef(); }

void OperationName::dump() const { print(llvm::errs()); }
Expand Down Expand Up @@ -1339,7 +1335,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
})
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
if (pretty) {
os << loc.getFilename();
os << loc.getFilename().getValue();
} else {
os << "\"";
printEscapedString(loc.getFilename(), os);
Expand Down Expand Up @@ -1693,7 +1689,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
} else {
os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x"
os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
<< llvm::toHex(opaqueAttr.getValue()) << "\">";
}

Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/IR/AttributeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,41 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
ArrayRef<StringRef> data;
};

//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//

struct StringAttrStorage : public AttributeStorage {
StringAttrStorage(StringRef value, Type type)
: AttributeStorage(type), value(value), referencedDialect(nullptr) {}

/// The hash key is a tuple of the parameter types.
using KeyTy = std::pair<StringRef, Type>;
bool operator==(const KeyTy &key) const {
return value == key.first && getType() == key.second;
}
static ::llvm::hash_code hashKey(const KeyTy &key) {
return DenseMapInfo<KeyTy>::getHashValue(key);
}

/// Define a construction method for creating a new instance of this
/// storage.
static StringAttrStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<StringAttrStorage>())
StringAttrStorage(allocator.copyInto(key.first), key.second);
}

/// Initialize the storage given an MLIRContext.
void initialize(MLIRContext *context);

/// The raw string value.
StringRef value;
/// If the string value contains a dialect namespace prefix (e.g.
/// dialect.blah), this is the dialect referenced.
Dialect *referencedDialect;
};

} // namespace detail
} // namespace mlir

Expand Down
27 changes: 2 additions & 25 deletions mlir/lib/IR/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,10 @@
using namespace mlir;
using namespace mlir::detail;

//===----------------------------------------------------------------------===//
// AttributeStorage
//===----------------------------------------------------------------------===//

AttributeStorage::AttributeStorage(Type type)
: type(type.getAsOpaquePointer()) {}
AttributeStorage::AttributeStorage() : type(nullptr) {}

Type AttributeStorage::getType() const {
return Type::getFromOpaquePointer(type);
}
void AttributeStorage::setType(Type newType) {
type = newType.getAsOpaquePointer();
}

//===----------------------------------------------------------------------===//
// Attribute
//===----------------------------------------------------------------------===//

/// Return the type of this attribute.
Type Attribute::getType() const { return impl->getType(); }

/// Return the context this attribute belongs to.
MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }

Expand All @@ -42,13 +24,8 @@ MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }
//===----------------------------------------------------------------------===//

bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
return strcmp(lhs.first.data(), rhs.first.data()) < 0;
return lhs.first.compare(rhs.first) < 0;
}
bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
// This is correct even when attr.first.data()[name.size()] is not a zero
// string terminator, because we only care about a less than comparison.
// This can't use memcmp, because it doesn't guarantee that it will stop
// reading both buffers if one is shorter than the other, even if there is
// a difference.
return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
return lhs.first.getValue().compare(rhs) < 0;
}
8 changes: 7 additions & 1 deletion mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ StringAttr StringAttr::get(const Twine &twine, Type type) {
return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
}

StringRef StringAttr::getValue() const { return getImpl()->value; }

Dialect *StringAttr::getReferencedDialect() const {
return getImpl()->referencedDialect;
}

//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1250,7 +1256,7 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//

bool OpaqueElementsAttr::decode(ElementsAttr &result) {
Dialect *dialect = getDialect().getDialect();
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
if (!dialect)
return true;
auto *interface =
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/IR/BuiltinDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ static LogicalResult verify(ModuleOp op) {
attr.first.strref()))
return op.emitOpError() << "can only contain attributes with "
"dialect-prefixed names, found: '"
<< attr.first << "'";
<< attr.first.getValue() << "'";
}

// Check that there is at most one data layout spec attribute.
Expand All @@ -266,7 +266,8 @@ static LogicalResult verify(ModuleOp op) {
op.emitOpError() << "expects at most one data layout attribute";
diag.attachNote() << "'" << layoutSpecAttrName
<< "' is a data layout attribute";
diag.attachNote() << "'" << na.first << "' is a data layout attribute";
diag.attachNote() << "'" << na.first.getValue()
<< "' is a data layout attribute";
}
layoutSpecAttrName = na.first.strref();
layoutSpec = spec;
Expand Down
10 changes: 3 additions & 7 deletions mlir/lib/IR/Diagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -109,11 +108,8 @@ Diagnostic &Diagnostic::operator<<(Twine &&val) {
return *this;
}

/// Stream in an Identifier.
Diagnostic &Diagnostic::operator<<(Identifier val) {
// An identifier is stored in the context, so we don't need to worry about the
// lifetime of its data.
arguments.push_back(DiagnosticArgument(val.strref()));
Diagnostic &Diagnostic::operator<<(StringAttr val) {
arguments.push_back(DiagnosticArgument(val));
return *this;
}

Expand Down Expand Up @@ -469,7 +465,7 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
// the constructor of SMDiagnostic that takes a location.
std::string locStr;
llvm::raw_string_ostream locOS(locStr);
locOS << fileLoc->getFilename() << ":" << fileLoc->getLine() << ":"
locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
<< fileLoc->getColumn();
llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
diag.print(nullptr, os);
Expand Down
117 changes: 35 additions & 82 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
Expand All @@ -33,6 +32,7 @@
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -227,14 +227,6 @@ class MLIRContextImpl {
/// An action manager for use within the context.
DebugActionManager debugActionManager;

//===--------------------------------------------------------------------===//
// Identifier uniquing
//===--------------------------------------------------------------------===//

// Identifier allocator and mutex for thread safety.
llvm::BumpPtrAllocator identifierAllocator;
llvm::sys::SmartRWMutex<true> identifierMutex;

//===--------------------------------------------------------------------===//
// Diagnostics
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -289,12 +281,6 @@ class MLIRContextImpl {
/// operations.
llvm::StringMap<AbstractOperation> registeredOperations;

/// Identifiers are uniqued by string value and use the internal string set
/// for storage.
llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>,
llvm::BumpPtrAllocator &>
identifiers;

/// An allocator used for AbstractAttribute and AbstractType objects.
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;

Expand Down Expand Up @@ -349,10 +335,15 @@ class MLIRContextImpl {
DictionaryAttr emptyDictionaryAttr;
StringAttr emptyStringAttr;

/// Map of string attributes that may reference a dialect, that are awaiting
/// that dialect to be loaded.
llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
DenseMap<StringRef, SmallVector<StringAttrStorage *>>
dialectReferencingStrAttrs;

public:
MLIRContextImpl(bool threadingIsEnabled)
: threadingIsEnabled(threadingIsEnabled),
identifiers(identifierAllocator) {
: threadingIsEnabled(threadingIsEnabled) {
if (threadingIsEnabled) {
ownedThreadPool = std::make_unique<llvm::ThreadPool>();
threadPool = ownedThreadPool.get();
Expand Down Expand Up @@ -541,12 +532,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
// Refresh all the identifiers dialect field, this catches cases where a
// dialect may be loaded after identifier prefixed with this dialect name
// were already created.
llvm::SmallString<32> dialectPrefix(dialectNamespace);
dialectPrefix.push_back('.');
for (auto &identifierEntry : impl.identifiers)
if (identifierEntry.second.is<MLIRContext *>() &&
identifierEntry.first().startswith(dialectPrefix))
identifierEntry.second = dialect.get();
auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
for (StringAttrStorage *storage : stringAttrsIt->second)
storage->referencedDialect = dialect.get();
impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
}

// Actually register the interfaces with delayed registration.
impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
Expand Down Expand Up @@ -784,7 +775,8 @@ void AbstractOperation::insert(
MutableArrayRef<Identifier> cachedAttrNames;
if (!attrNames.empty()) {
cachedAttrNames = MutableArrayRef<Identifier>(
impl.identifierAllocator.Allocate<Identifier>(attrNames.size()),
impl.abstractDialectSymbolAllocator.Allocate<Identifier>(
attrNames.size()),
attrNames.size());
for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx));
Expand Down Expand Up @@ -840,63 +832,6 @@ AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
return it->second;
}

//===----------------------------------------------------------------------===//
// Identifier uniquing
//===----------------------------------------------------------------------===//

/// Return an identifier for the specified string.
Identifier Identifier::get(const Twine &string, MLIRContext *context) {
SmallString<32> tempStr;
StringRef str = string.toStringRef(tempStr);

// Check invariants after seeing if we already have something in the
// identifier table - if we already had it in the table, then it already
// passed invariant checks.
assert(!str.empty() && "Cannot create an empty identifier");
assert(!str.contains('\0') &&
"Cannot create an identifier with a nul character");

auto getDialectOrContext = [&]() {
PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context;
auto dialectNamePair = str.split('.');
if (!dialectNamePair.first.empty())
if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first))
dialectOrContext = dialect;
return dialectOrContext;
};

auto &impl = context->getImpl();
if (!context->isMultithreadingEnabled()) {
auto insertedIt = impl.identifiers.insert({str, nullptr});
if (insertedIt.second)
insertedIt.first->second = getDialectOrContext();
return Identifier(&*insertedIt.first);
}

// Check for an existing identifier in read-only mode.
{
llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.find(str);
if (it != impl.identifiers.end())
return Identifier(&*it);
}

// Acquire a writer-lock so that we can safely create the new instance.
llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.insert({str, getDialectOrContext()}).first;
return Identifier(&*it);
}

Dialect *Identifier::getDialect() {
return entry->second.dyn_cast<Dialect *>();
}

MLIRContext *Identifier::getContext() {
if (Dialect *dialect = getDialect())
return dialect->getContext();
return entry->second.get<MLIRContext *>();
}

//===----------------------------------------------------------------------===//
// Type uniquing
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -995,7 +930,7 @@ StorageUniquer &MLIRContext::getAttributeUniquer() {
void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
MLIRContext *ctx,
TypeID attrID) {
storage->initialize(AbstractAttribute::lookup(attrID, ctx));
storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));

// If the attribute did not provide a type, then default to NoneType.
if (!storage->getType())
Expand All @@ -1019,6 +954,24 @@ DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
return context->getImpl().emptyDictionaryAttr;
}

void StringAttrStorage::initialize(MLIRContext *context) {
// Check for a dialect namespace prefix, if there isn't one we don't need to
// do any additional initialization.
auto dialectNamePair = value.split('.');
if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
return;

// If one exists, we check to see if this dialect is loaded. If it is, we set
// the dialect now, if it isn't we record this storage for initialization
// later if the dialect ever gets loaded.
if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
return;

MLIRContextImpl &impl = context->getImpl();
llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
}

/// Return an empty string.
StringAttr StringAttr::get(MLIRContext *context) {
return context->getImpl().emptyStringAttr;
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/IR/OperationSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {

void NamedAttrList::push_back(NamedAttribute newAttribute) {
assert(newAttribute.second && "unexpected null attribute");
if (isSorted())
dictionarySorted.setInt(
attrs.empty() ||
strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0);
if (isSorted()) {
dictionarySorted.setInt(attrs.empty() ||
attrs.back().first.compare(newAttribute.first) < 0);
}
dictionarySorted.setPointer(nullptr);
attrs.push_back(newAttribute);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ LogicalResult OperationVerifier::verifyOperation(
/// Verify that all of the attributes are okay.
for (auto attr : op.getAttrs()) {
// Check for any optional dialect specific attributes.
if (auto *dialect = attr.first.getDialect())
if (auto *dialect = attr.first.getReferencedDialect())
if (failed(dialect->verifyOperationAttribute(&op, attr)))
return failure();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Interfaces/DataLayoutInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,

for (const auto &kvp : ids) {
Identifier identifier = kvp.second.getKey().get<Identifier>();
Dialect *dialect = identifier.getDialect();
Dialect *dialect = identifier.getReferencedDialect();

// Ignore attributes that belong to an unknown dialect, the dialect may
// actually implement the relevant interface but we don't know about that.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Parser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return emitError("expected attribute name");
if (!seenKeys.insert(*nameId).second)
return emitError("duplicate key '")
<< *nameId << "' in dictionary attribute";
<< nameId->getValue() << "' in dictionary attribute";
consumeToken();

// Lazy load a dialect in the context if there is a possible namespace.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
Optional<NamedAttribute> duplicate = opState.attributes.findDuplicate();
if (duplicate)
return emitError(getNameLoc(), "attribute '")
<< duplicate->first
<< duplicate->first.getValue()
<< "' occurs more than once in the attribute list";
return success();
}
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/Support/StorageUniquer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,26 @@ class ParametricStorageUniquer {
};

/// Storage info for derived TypeStorage objects.
struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
static HashedStorage getEmptyKey() {
struct StorageKeyInfo {
static inline HashedStorage getEmptyKey() {
return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
}
static HashedStorage getTombstoneKey() {
static inline HashedStorage getTombstoneKey() {
return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
}

static unsigned getHashValue(const HashedStorage &key) {
static inline unsigned getHashValue(const HashedStorage &key) {
return key.hashValue;
}
static inline unsigned getHashValue(const LookupKey &key) {
return key.hashValue;
}
static unsigned getHashValue(LookupKey key) { return key.hashValue; }

static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
static inline bool isEqual(const HashedStorage &lhs,
const HashedStorage &rhs) {
return lhs.storage == rhs.storage;
}
static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
static inline bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
return false;
// Invoke the equality function on the lookup key.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ CppEmitter::emitOperandsAndAttributes(Operation &op,
auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
if (llvm::is_contained(exclude, attr.first.strref()))
return success();
os << "/* " << attr.first << " */";
os << "/* " << attr.first.getValue() << " */";
if (failed(emitAttribute(op.getLoc(), attr.second)))
return failure();
return success();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/ViewOpGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
if (printAttrs) {
os << "\n";
for (const NamedAttribute &attr : op->getAttrs()) {
os << '\n' << attr.first << ": ";
os << '\n' << attr.first.getValue() << ": ";
emitMlirAttr(os, attr.second);
}
}
Expand Down
8 changes: 3 additions & 5 deletions mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
.setDistributionOptions(cyclicNprocsEqNiters),
.setDistributionOptions(cyclicNprocsEqNiters),
LinalgTransformationFilter(
Identifier::get("tensors_distribute1", context),
Identifier::get("tensors_after_distribute1", context)));
Expand All @@ -508,8 +508,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
SmallVector<RewritePatternSet, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
Expand All @@ -519,8 +518,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
.setInterchange({1, 2, 0}),
LinalgTransformationFilter(Identifier::get("START", ctx),
Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
}
{
// Canonicalization patterns
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ void TestDerivedAttributeDriver::runOnFunction() {
if (!dAttr)
return;
for (auto d : dAttr)
dOp.emitRemark() << d.first << " = " << d.second;
dOp.emitRemark() << d.first.getValue() << " = " << d.second;
});
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/lib/IR/TestPrintNesting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct TestPrintNestingPass
if (!op->getAttrs().empty()) {
printIndent() << op->getAttrs().size() << " attributes:\n";
for (NamedAttribute attr : op->getAttrs())
printIndent() << " - '" << attr.first << "' : '" << attr.second
<< "'\n";
printIndent() << " - '" << attr.first.getValue() << "' : '"
<< attr.second << "'\n";
}

// Recurse into each of the regions attached to the operation.
Expand Down