Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions mlir/include/mlir/TableGen/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,22 @@ class Operator {
/// Requires: all result types are known.
const InferredResultType &getInferredResultType(int index) const;

/// Pair consisting kind of argument and index into operands or attributes.
struct OperandOrAttribute {
enum class Kind { Operand, Attribute };
OperandOrAttribute(Kind kind, int index) {
packed = (index << 1) | (kind == Kind::Attribute);
/// Pair consisting kind of argument and index into operands, attributes, or
/// properties.
struct OperandAttrOrProp {
enum class Kind { Operand = 0x0, Attribute = 0x1, Property = 0x2 };
OperandAttrOrProp(Kind kind, int index) {
packed = (index << 2) | static_cast<int>(kind);
}
int operandOrAttributeIndex() const { return (packed >> 1); }
Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
int operandOrAttributeIndex() const { return (packed >> 2); }
Kind kind() const { return static_cast<Kind>(packed & 0x3); }

private:
int packed;
};

/// Returns the OperandOrAttribute corresponding to the index.
OperandOrAttribute getArgToOperandOrAttribute(int index) const;
/// Returns the OperandAttrOrProp corresponding to the index.
OperandAttrOrProp getArgToOperandAttrOrProp(int index) const;

/// Returns the builders of this operation.
ArrayRef<Builder> getBuilders() const { return builders; }
Expand Down Expand Up @@ -405,8 +406,8 @@ class Operator {
/// The argument with the same type as the result.
SmallVector<InferredResultType> resultTypeMapping;

/// Map from argument to attribute or operand number.
SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
/// Map from argument to attribute, property, or operand number.
SmallVector<OperandAttrOrProp, 4> attrPropOrOperandMapping;

/// The builders of this operator.
SmallVector<Builder> builders;
Expand Down
18 changes: 10 additions & 8 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ void Operator::populateTypeInferenceInfo(
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
NamedTypeConstraint *operand =
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
Expand Down Expand Up @@ -663,15 +664,17 @@ void Operator::populateOpStructure() {
argDef = argDef->getValueAsDef("constraint");

if (argDef->isSubClassOf(typeConstraintClass)) {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Operand, operandIndex});
attrPropOrOperandMapping.push_back(
{OperandAttrOrProp::Kind::Operand, operandIndex});
arguments.emplace_back(&operands[operandIndex++]);
} else if (argDef->isSubClassOf(attrClass)) {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Attribute, attrIndex});
attrPropOrOperandMapping.push_back(
{OperandAttrOrProp::Kind::Attribute, attrIndex});
arguments.emplace_back(&attributes[attrIndex++]);
} else {
assert(argDef->isSubClassOf(propertyClass));
attrPropOrOperandMapping.push_back(
{OperandAttrOrProp::Kind::Property, propIndex});
arguments.emplace_back(&properties[propIndex++]);
}
}
Expand Down Expand Up @@ -867,9 +870,8 @@ auto Operator::VariableDecoratorIterator::unwrap(const Init *init)
return VariableDecorator(cast<DefInit>(init)->getDef());
}

auto Operator::getArgToOperandOrAttribute(int index) const
-> OperandOrAttribute {
return attrOrOperandMapping[index];
auto Operator::getArgToOperandAttrOrProp(int index) const -> OperandAttrOrProp {
return attrPropOrOperandMapping[index];
}

std::string Operator::getGetterName(StringRef name) const {
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/mlir-tblgen/op-decl-and-defs.td
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,12 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;

// REDUCE_EXC-NOT: NS::AOp declarations
// REDUCE_EXC-LABEL: NS::BOp declarations

// CHECK-LABEL: _TypeInferredPropOp declarations
def _TypeInferredPropOp : NS_Op<"type_inferred_prop_op_with_properties", [
AllTypesMatch<["value", "result"]>
]> {
let arguments = (ins Property<"unsigned">:$prop, AnyType:$value);
let results = (outs AnyType:$result);
let hasCustomAssemblyFormat = 1;
}
20 changes: 11 additions & 9 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3849,9 +3849,9 @@ void OpEmitter::genTypeInterfaceMethods() {
const InferredResultType &infer = op.getInferredResultType(i);
if (!infer.isArg())
continue;
Operator::OperandOrAttribute arg =
op.getArgToOperandOrAttribute(infer.getIndex());
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
Operator::OperandAttrOrProp arg =
op.getArgToOperandAttrOrProp(infer.getIndex());
if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) {
maxAccessedIndex =
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
}
Expand All @@ -3877,17 +3877,16 @@ void OpEmitter::genTypeInterfaceMethods() {
if (infer.isArg()) {
// If this is an operand, just index into operand list to access the
// type.
Operator::OperandOrAttribute arg =
op.getArgToOperandOrAttribute(infer.getIndex());
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
Operator::OperandAttrOrProp arg =
op.getArgToOperandAttrOrProp(infer.getIndex());
if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) {
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
"].getType()")
.str();

// If this is an attribute, index into the attribute dictionary.
} else {
auto *attr =
cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex()));
} else if (auto *attr = dyn_cast<NamedAttribute *>(
op.getArg(arg.operandOrAttributeIndex()))) {
body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
<< " = ";
if (op.getDialect().usePropertiesForAttributes()) {
Expand All @@ -3907,6 +3906,9 @@ void OpEmitter::genTypeInterfaceMethods() {
typeStr =
("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
.str();
} else {
llvm::PrintFatalError(&op.getDef(),
"Properties cannot be used for type inference");
}
} else if (std::optional<StringRef> builder =
op.getResult(infer.getResultIndex())
Expand Down