478 changes: 98 additions & 380 deletions mlir/include/mlir/IR/BuiltinAttributes.h

Large diffs are not rendered by default.

497 changes: 494 additions & 3 deletions mlir/include/mlir/IR/BuiltinAttributes.td

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -2673,6 +2673,8 @@ class TypeDef<Dialect dialect, string name,
class AttrOrTypeParameter<string type, string desc> {
// Custom memory allocation code for storage constructor.
code allocator = ?;
// Custom comparator used to compare two instances for equality.
code comparator = ?;
// The C++ type of this parameter.
string cppType = type;
// One-line human-readable description of the argument.
Expand All @@ -2689,6 +2691,12 @@ class StringRefParameter<string desc = ""> :
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
}

// For APFloats, which require comparison.
class APFloatParameter<string desc> :
AttrOrTypeParameter<"::llvm::APFloat", desc> {
let comparator = "$_lhs.bitwiseIsEqual($_rhs)";
}

// For standard ArrayRefs, which require allocation.
class ArrayRefParameter<string arrayOf, string desc = ""> :
AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
Expand All @@ -2715,7 +2723,8 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
// This is a special parameter used for AttrDefs that represents a `mlir::Type`
// that is also used as the value `Type` of the attribute. Only one parameter
// of the attribute may be of this type.
class AttributeSelfTypeParameter<string desc> :
AttrOrTypeParameter<"::mlir::Type", desc> {}
class AttributeSelfTypeParameter<string desc,
string derivedType = "::mlir::Type"> :
AttrOrTypeParameter<derivedType, desc> {}

#endif // OP_BASE
3 changes: 3 additions & 0 deletions mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ class AttrOrTypeParameter {
// If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const;

// If specified, get the custom comparator code for this parameter.
Optional<StringRef> getComparator() const;

// Get the C++ type of this parameter.
StringRef getCppType() const;

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,7 @@ static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
// accept the string "elided". The first string must be a registered dialect
// name and the latter must be a hex constant.
static void printElidedElementsAttr(raw_ostream &os) {
os << R"(opaque<"", "0xDEADBEEF">)";
os << R"(opaque<"_", "0xDEADBEEF">)";
}

void ModulePrinter::printAttribute(Attribute attr,
Expand Down Expand Up @@ -1610,8 +1610,8 @@ void ModulePrinter::printAttribute(Attribute attr,
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
} else {
os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x"
<< llvm::toHex(opaqueAttr.getValue()) << "\">";
}

} else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
Expand Down
196 changes: 14 additions & 182 deletions mlir/lib/IR/AttributeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,113 +27,6 @@
namespace mlir {
namespace detail {

/// An attribute representing a floating point value.
struct FloatAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
using KeyTy = std::pair<Type, APFloat>;

FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
size_t numObjects)
: AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}

/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key.first == getType() && key.second.bitwiseIsEqual(getValue());
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
}

/// Construct a key with a type and double.
static KeyTy getKey(Type type, double value) {
if (type.isF64())
return KeyTy(type, APFloat(value));

// This handles, e.g., F16 because there is no APFloat constructor for it.
bool unused;
APFloat val(value);
val.convert(type.cast<FloatType>().getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
return KeyTy(type, val);
}

/// Construct a new storage instance.
static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
const auto &apint = key.second.bitcastToAPInt();

// Here one word's bitwidth equals to that of uint64_t.
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());

auto byteSize =
FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
auto result = ::new (rawMem) FloatAttributeStorage(
key.second.getSemantics(), key.first, elements.size());
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());
return result;
}

/// Returns an APFloat representing the stored value.
APFloat getValue() const {
auto val = APInt(APFloat::getSizeInBits(semantics),
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}

const llvm::fltSemantics &semantics;
size_t numObjects;
};

/// An attribute representing an integral value.
struct IntegerAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
using KeyTy = std::pair<Type, APInt>;

IntegerAttributeStorage(Type type, size_t numObjects)
: AttributeStorage(type), numObjects(numObjects) {
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
}

/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == KeyTy(getType(), getValue());
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
}

/// Construct a new storage instance.
static IntegerAttributeStorage *
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
Type type;
APInt value;
std::tie(type, value) = key;

auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
auto size =
IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());
return result;
}

/// Returns an APInt representing the stored value.
APInt getValue() const {
if (getType().isIndex())
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
return APInt(getType().getIntOrFloatBitWidth(),
{getTrailingObjects<uint64_t>(), numObjects});
}

size_t numObjects;
};

//===----------------------------------------------------------------------===//
// Elements Attributes
//===----------------------------------------------------------------------===//
Expand All @@ -158,10 +51,9 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
};

/// An attribute representing a reference to a dense vector or tensor object.
struct DenseIntOrFPElementsAttributeStorage
: public DenseElementsAttributeStorage {
DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
bool isSplat = false)
struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data,
bool isSplat = false)
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}

struct KeyTy {
Expand Down Expand Up @@ -287,7 +179,7 @@ struct DenseIntOrFPElementsAttributeStorage
}

/// Construct a new storage instance.
static DenseIntOrFPElementsAttributeStorage *
static DenseIntOrFPElementsAttrStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// If the data buffer is non-empty, we copy it into the allocator with a
// 64-bit alignment.
Expand All @@ -303,19 +195,18 @@ struct DenseIntOrFPElementsAttributeStorage
copy = ArrayRef<char>(rawData, data.size());
}

return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>())
DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat);
return new (allocator.allocate<DenseIntOrFPElementsAttrStorage>())
DenseIntOrFPElementsAttrStorage(key.type, copy, key.isSplat);
}

ArrayRef<char> data;
};

/// An attribute representing a reference to a dense vector or tensor object
/// containing strings.
struct DenseStringElementsAttributeStorage
: public DenseElementsAttributeStorage {
DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data,
bool isSplat = false)
struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data,
bool isSplat = false)
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}

struct KeyTy {
Expand Down Expand Up @@ -385,14 +276,14 @@ struct DenseStringElementsAttributeStorage
}

/// Construct a new storage instance.
static DenseStringElementsAttributeStorage *
static DenseStringElementsAttrStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// If the data buffer is non-empty, we copy it into the allocator with a
// 64-bit alignment.
ArrayRef<StringRef> copy, data = key.data;
if (data.empty()) {
return new (allocator.allocate<DenseStringElementsAttributeStorage>())
DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
return new (allocator.allocate<DenseStringElementsAttrStorage>())
DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
}

int numEntries = key.isSplat ? 1 : data.size();
Expand Down Expand Up @@ -421,72 +312,13 @@ struct DenseStringElementsAttributeStorage
copy =
ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);

return new (allocator.allocate<DenseStringElementsAttributeStorage>())
DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
return new (allocator.allocate<DenseStringElementsAttrStorage>())
DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
}

ArrayRef<StringRef> data;
};

/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::tuple<Type, Dialect *, StringRef>;

OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
: AttributeStorage(type), dialect(dialect), bytes(bytes) {}

/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == std::make_tuple(getType(), dialect, bytes);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
std::get<2>(key));
}

/// Construct a new storage instance.
static OpaqueElementsAttributeStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// TODO: Provide a way to avoid copying content of large opaque
// tensors This will likely require a new reference attribute kind.
return new (allocator.allocate<OpaqueElementsAttributeStorage>())
OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
allocator.copyInto(std::get<2>(key)));
}

Dialect *dialect;
StringRef bytes;
};

/// An attribute representing a reference to a sparse vector or tensor object.
struct SparseElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;

SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
DenseElementsAttr values)
: AttributeStorage(type), indices(indices), values(values) {}

/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == std::make_tuple(getType(), indices, values);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
std::get<2>(key));
}

/// Construct a new storage instance.
static SparseElementsAttributeStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
return new (allocator.allocate<SparseElementsAttributeStorage>())
SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
std::get<2>(key));
}

DenseIntElementsAttr indices;
DenseElementsAttr values;
};
} // namespace detail
} // namespace mlir

Expand Down
158 changes: 34 additions & 124 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,6 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
// FloatAttr
//===----------------------------------------------------------------------===//

FloatAttr FloatAttr::get(Type type, double value) {
return Base::get(type.getContext(), type, value);
}

FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type type, double value) {
return Base::getChecked(emitError, type.getContext(), type, value);
}

FloatAttr FloatAttr::get(Type type, const APFloat &value) {
return Base::get(type.getContext(), type, value);
}

FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type type, const APFloat &value) {
return Base::getChecked(emitError, type.getContext(), type, value);
}

APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }

double FloatAttr::getValueAsDouble() const {
return getValueAsDouble(getValue());
}
Expand All @@ -234,25 +214,11 @@ double FloatAttr::getValueAsDouble(APFloat value) {
return value.convertToDouble();
}

/// Verify construction invariants.
static LogicalResult
verifyFloatTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
Type type) {
if (!type.isa<FloatType>())
return emitError() << "expected floating point type";
return success();
}

LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, double value) {
return verifyFloatTypeInvariants(emitError, type);
}

LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, const APFloat &value) {
Type type, APFloat value) {
// Verify that the type is correct.
if (failed(verifyFloatTypeInvariants(emitError, type)))
return failure();
if (!type.isa<FloatType>())
return emitError() << "expected floating point type";

// Verify that the type semantics match that of the value.
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
Expand All @@ -279,72 +245,47 @@ StringRef SymbolRefAttr::getLeafReference() const {
// IntegerAttr
//===----------------------------------------------------------------------===//

IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
if (type.isSignlessInteger(1))
return BoolAttr::get(type.getContext(), value.getBoolValue());
return Base::get(type.getContext(), type, value);
}

IntegerAttr IntegerAttr::get(Type type, int64_t value) {
// This uses 64 bit APInts by default for index type.
if (type.isIndex())
return get(type, APInt(IndexType::kInternalStorageBitWidth, value));

auto intType = type.cast<IntegerType>();
return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
}

APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }

int64_t IntegerAttr::getInt() const {
assert((getImpl()->getType().isIndex() ||
getImpl()->getType().isSignlessInteger()) &&
assert((getType().isIndex() || getType().isSignlessInteger()) &&
"must be signless integer");
return getValue().getSExtValue();
}

int64_t IntegerAttr::getSInt() const {
assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
assert(getType().isSignedInteger() && "must be signed integer");
return getValue().getSExtValue();
}

uint64_t IntegerAttr::getUInt() const {
assert(getImpl()->getType().isUnsignedInteger() &&
"must be unsigned integer");
assert(getType().isUnsignedInteger() && "must be unsigned integer");
return getValue().getZExtValue();
}

static LogicalResult
verifyIntegerTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
Type type) {
if (type.isa<IntegerType, IndexType>())
return success();
return emitError() << "expected integer or index type";
}

LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, int64_t value) {
return verifyIntegerTypeInvariants(emitError, type);
}

LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, const APInt &value) {
if (failed(verifyIntegerTypeInvariants(emitError, type)))
return failure();
if (auto integerType = type.dyn_cast<IntegerType>())
Type type, APInt value) {
if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
if (integerType.getWidth() != value.getBitWidth())
return emitError() << "integer type bit width (" << integerType.getWidth()
<< ") doesn't match value bit width ("
<< value.getBitWidth() << ")";
return success();
return success();
}
if (type.isa<IndexType>())
return success();
return emitError() << "expected integer or index type";
}

BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
return attr.cast<BoolAttr>();
}

//===----------------------------------------------------------------------===//
// BoolAttr

bool BoolAttr::getValue() const {
auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
return storage->getValue().getBoolValue();
auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
return storage->value.getBoolValue();
}

bool BoolAttr::classof(Attribute attr) {
Expand Down Expand Up @@ -987,11 +928,11 @@ auto DenseElementsAttr::getComplexFloatValues() const

/// Return the raw storage data held by this attribute.
ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
}

ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
}

/// Return a new DenseElementsAttr that has the same data as the current
Expand Down Expand Up @@ -1021,15 +962,6 @@ DenseElementsAttr DenseElementsAttr::mapValues(
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
}

//===----------------------------------------------------------------------===//
// DenseStringElementsAttr
//===----------------------------------------------------------------------===//

DenseStringElementsAttr
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
return Base::get(type.getContext(), type, values, (values.size() == 1));
}

//===----------------------------------------------------------------------===//
// DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1254,59 +1186,37 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
// OpaqueElementsAttr
//===----------------------------------------------------------------------===//

OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
StringRef bytes) {
assert(TensorType::isValidElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
return Base::get(type.getContext(), type, dialect, bytes);
}

StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }

/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
return Attribute();
}

Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }

bool OpaqueElementsAttr::decode(ElementsAttr &result) {
auto *d = getDialect();
if (!d)
Dialect *dialect = getDialect().getDialect();
if (!dialect)
return true;
auto *interface =
d->getRegisteredInterface<DialectDecodeAttributesInterface>();
dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
if (!interface)
return true;
return failed(interface->decode(*this, result));
}

LogicalResult
OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Identifier dialect, StringRef value,
ShapedType type) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitError() << "invalid dialect namespace '" << dialect << "'";
return success();
}

//===----------------------------------------------------------------------===//
// SparseElementsAttr
//===----------------------------------------------------------------------===//

SparseElementsAttr SparseElementsAttr::get(ShapedType type,
DenseElementsAttr indices,
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), type,
indices.cast<DenseIntElementsAttr>(), values);
}

DenseIntElementsAttr SparseElementsAttr::getIndices() const {
return getImpl()->indices;
}

DenseElementsAttr SparseElementsAttr::getValues() const {
return getImpl()->values;
}

/// Return the value of the element at the given index.
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
Expand Down
12 changes: 4 additions & 8 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,13 @@ MLIRContext::MLIRContext(const DialectRegistry &registry)
//// Attributes.
//// Note: These must be registered after the types as they may generate one
//// of the above types internally.
/// Unknown Location Attribute.
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
/// Bool Attributes.
impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
this, impl->int1Ty, APInt(/*numBits=*/1, false))
.cast<BoolAttr>();
impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
this, impl->int1Ty, APInt(/*numBits=*/1, true))
.cast<BoolAttr>();
impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
/// Unit Attribute.
impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
/// Unknown Location Attribute.
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
/// The empty dictionary attribute.
impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);

Expand Down
13 changes: 2 additions & 11 deletions mlir/lib/Parser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,16 +862,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
if (getToken().isNot(Token::string))
return (emitError("expected dialect namespace"), nullptr);

auto name = getToken().getStringValue();
// Lazy load a dialect in the context if there is a possible namespace.
Dialect *dialect = builder.getContext()->getOrLoadDialect(name);

// TODO: Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.
if (!dialect)
return (emitError("no registered dialect with namespace '" + name + "'"),
nullptr);
std::string name = getToken().getStringValue();
consumeToken(Token::string);

if (parseToken(Token::comma, "expected ','"))
Expand All @@ -888,7 +879,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
std::string data;
if (parseElementAttrHexValues(*this, hexTok, data))
return nullptr;
return OpaqueElementsAttr::get(dialect, type, data);
return OpaqueElementsAttr::get(builder.getIdentifier(name), type, data);
}

/// Shaped type for elements attribute.
Expand Down
30 changes: 11 additions & 19 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
}
builders.emplace_back(builder);
}
} else if (skipDefaultBuilders()) {
PrintFatalError(
def->getLoc(),
"default builders are skipped and no custom builders provided");
}
}

Expand Down Expand Up @@ -177,22 +173,18 @@ Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
llvm::Init *parameterType = def->getArg(index);
if (isa<llvm::StringInit>(parameterType))
return Optional<StringRef>();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
return param->getDef()->getValueAsOptionalString("allocator");
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter\n");
}

if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
llvm::RecordVal *code = param->getDef()->getValue("allocator");
if (!code)
return Optional<StringRef>();
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
return ci->getValue();
if (isa<llvm::UnsetInit>(code->getValue()))
return Optional<StringRef>();

llvm::PrintFatalError(
param->getDef()->getLoc(),
"Record `" + def->getArgName(index)->getValue() +
"', field `printer' does not have a code initializer!");
}

Optional<StringRef> AttrOrTypeParameter::getComparator() const {
llvm::Init *parameterType = def->getArg(index);
if (isa<llvm::StringInit>(parameterType))
return Optional<StringRef>();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
return param->getDef()->getValueAsOptionalString("comparator");
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter\n");
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
fprintf(stderr, "\n");
// clang-format off
// CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
// CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
// clang-format on

mlirOpPrintingFlagsDestroy(flags);
Expand Down
13 changes: 3 additions & 10 deletions mlir/test/IR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -766,21 +766,14 @@ func @elementsattr_malformed_opaque() -> () {
func @elementsattr_malformed_opaque1() -> () {
^bb0:
"foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}}
"foo"(){bar = opaque<"_", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}}
}
// -----
func @elementsattr_malformed_opaque2() -> () {
^bb0:
"foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}}
}
// -----
func @elementsattr_malformed_opaque3() -> () {
^bb0:
"foo"(){bar = opaque<"t", "0xabc"> : tensor<1xi8>} : () -> () // expected-error {{no registered dialect with namespace 't'}}
"foo"(){bar = opaque<"_", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}}
}
// -----
Expand Down Expand Up @@ -881,7 +874,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined s
func @complex_loops() {
affine.for %i1 = 1 to 100 {
// expected-error @+1 {{expected '"' in string literal}}
"opaqueIntTensor"(){bar = opaque<"", "0x686]> : tensor<2x1x4xi32>} : () -> ()
"opaqueIntTensor"(){bar = opaque<"_", "0x686]> : tensor<2x1x4xi32>} : () -> ()
// -----
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/IR/pretty-attributes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
// tensor which passes don't look at directly, this isn't an issue.
// RUN: mlir-opt %s -mlir-elide-elementsattrs-if-larger=2 | mlir-opt

// CHECK: opaque<"", "0xDEADBEEF"> : tensor<3xi32>
// CHECK: opaque<"_", "0xDEADBEEF"> : tensor<3xi32>
"test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()

// CHECK: dense<[1, 2]> : tensor<2xi32>
"test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> ()

// CHECK: opaque<"", "0xDEADBEEF"> : vector<1x1x1xf16>
// CHECK: opaque<"_", "0xDEADBEEF"> : vector<1x1x1xf16>
"test.sparse_attr"() {foo.sparse_attr = sparse<[[1, 2, 3]], -2.0> : vector<1x1x1xf16>} : () -> ()

// CHECK: opaque<"", "0xDEADBEEF"> : tensor<100xf32>
"test.opaque_attr"() {foo.opaque_attr = opaque<"", "0xEBFE"> : tensor<100xf32> } : () -> ()
// CHECK: opaque<"_", "0xDEADBEEF"> : tensor<100xf32>
"test.opaque_attr"() {foo.opaque_attr = opaque<"_", "0xEBFE"> : tensor<100xf32> } : () -> ()

// CHECK: dense<1> : tensor<3xi32>
"test.dense_splat"() {foo.dense_attr = dense<1> : tensor<3xi32>} : () -> ()
24 changes: 17 additions & 7 deletions mlir/test/mlir-tblgen/attrdefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
ins
"int":$widthOfSomething,
"::mlir::test::SimpleTypeA": $exampleTdType,
"SomeCppStruct": $exampleCppType,
APFloatParameter<"">: $apFloat,
ArrayRefParameter<"int", "Matrix dimensions">:$dims,
AttributeSelfTypeParameter<"">:$inner
);

let genVerifyDecl = 1;

// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
// DECL: return ::llvm::StringLiteral("cmpnd_a");
// DECL: }
Expand All @@ -71,21 +71,31 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
// DECL: SomeCppStruct getExampleCppType() const;
// DECL: ::llvm::APFloat getApFloat() const;

// Check that AttributeSelfTypeParameter is handled properly.
// DEF-LABEL: struct CompoundAAttrStorage
// DEF: CompoundAAttrStorage (
// DEF-NEXT: : ::mlir::AttributeStorage(inner),

// DEF: bool operator==(const KeyTy &key) const {
// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType());
// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key)))
// DEF-NEXT: return false;
// DEF-NEXT: if (!(exampleTdType == std::get<1>(key)))
// DEF-NEXT: return false;
// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key))))
// DEF-NEXT: return false;
// DEF-NEXT: if (!(dims == std::get<3>(key)))
// DEF-NEXT: return false;
// DEF-NEXT: if (!(getType() == std::get<4>(key)))
// DEF-NEXT: return false;
// DEF-NEXT: return true;

// DEF: static CompoundAAttrStorage *construct
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner);
// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);

// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); }
// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType().cast<::mlir::Type>(); }
}

def C_IndexAttr : TestAttr<"Index"> {
Expand Down
69 changes: 39 additions & 30 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,22 +432,16 @@ static ::mlir::LogicalResult generated{0}Printer(
/// {1}: Storage class c++ name.
/// {2}: Parameters parameters.
/// {3}: Parameter initializer string.
/// {4}: Parameter name list.
/// {5}: Parameter types.
/// {6}: The name of the base value type, e.g. Attribute or Type.
/// {4}: Parameter types.
/// {5}: The name of the base value type, e.g. Attribute or Type.
static const char *const defStorageClassBeginStr = R"(
namespace {0} {{
struct {1} : public ::mlir::{6}Storage {{
struct {1} : public ::mlir::{5}Storage {{
{1} ({2})
: {3} {{ }
/// The hash key is a tuple of the parameter types.
using KeyTy = std::tuple<{5}>;
/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {{
return key == KeyTy({4});
}
using KeyTy = std::tuple<{4}>;
)";

/// The storage class' constructor template.
Expand Down Expand Up @@ -555,23 +549,34 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
});
}

// Construct the parameter list that is used when a concrete instance of the
// storage exists.
auto nonStaticParameterNames = llvm::map_range(params, [](const auto &param) {
return isa<AttributeSelfTypeParameter>(param) ? "getType()"
: param.getName();
});

// 1) Emit most of the storage class up until the hashKey body.
// * Emit most of the storage class up until the hashKey body.
os << formatv(
defStorageClassBeginStr, def.getStorageNamespace(),
def.getStorageClassName(),
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
params, /*prependComma=*/false),
paramInitializer, llvm::join(nonStaticParameterNames, ", "),
parameterTypeList, valueType);
paramInitializer, parameterTypeList, valueType);

// * Emit the comparison method.
os << " bool operator==(const KeyTy &key) const {\n";
for (auto it : llvm::enumerate(params)) {
os << " if (!(";

// Build the comparator context.
bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
FmtContext context;
context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
.addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)");

// Use the parameter specified comparator if possible, otherwise default to
// operator==.
Optional<StringRef> comparator = it.value().getComparator();
os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context);
os << "))\n return false;\n";
}
os << " return true;\n }\n";

// 2) Emit the haskKey method.
// * Emit the haskKey method.
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";

// Extract each parameter from the key.
Expand All @@ -581,7 +586,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
[&](unsigned it) { os << "std::get<" << it << ">(key)"; });
os << ");\n }\n";

// 3) Emit the construct method.
// * Emit the construct method.

// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
Expand Down Expand Up @@ -611,7 +616,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
llvm::join(parameterNames, ", "));
}

// 4) Emit the parameters as storage class members.
// * Emit the parameters as storage class members.
for (const AttrOrTypeParameter &parameter : params) {
// Attribute value types are not stored as fields in the storage.
if (!isa<AttributeSelfTypeParameter>(parameter))
Expand Down Expand Up @@ -771,15 +776,19 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
// Generate accessor definitions only if we also generate the storage class.
// Otherwise, let the user define the exact accessor definition.
if (def.genAccessors() && def.genStorageClass()) {
for (const AttrOrTypeParameter &parameter : parameters) {
StringRef paramStorageName = isa<AttributeSelfTypeParameter>(parameter)
? "getType()"
: parameter.getName();

SmallString<16> name = parameter.getName();
for (const AttrOrTypeParameter &param : parameters) {
SmallString<32> paramStorageName;
if (isa<AttributeSelfTypeParameter>(param)) {
Twine("getType().cast<" + param.getCppType() + ">()")
.toVector(paramStorageName);
} else {
paramStorageName = param.getName();
}

SmallString<16> name = param.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
parameter.getCppType(), name, paramStorageName,
param.getCppType(), name, paramStorageName,
def.getCppClassName());
}
}
Expand Down