Skip to content

Commit

Permalink
[mlir] Expand operand adapter to take attributes
Browse files Browse the repository at this point in the history
* Enables using with more variadic sized operands;
* Generate convenience accessors for attributes;
  - The accessor are named the same as their name in ODS and returns attribute
    type (not convenience type) and no derived attributes.

This is first step to changing adapter to support verifying argument
constraints before the op is even created. This does not change the name of
adaptor nor does it require it except for ops with variadic operands to keep this change smaller.

Considered creating separate adapter but decided against that given operands also require attributes in general (and definitely for verification of operands and attributes).

Differential Revision: https://reviews.llvm.org/D80420
  • Loading branch information
jpienaar committed May 25, 2020
1 parent 838d122 commit 4b8632e
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 47 deletions.
5 changes: 0 additions & 5 deletions mlir/include/mlir/TableGen/OpClass.h
Expand Up @@ -145,10 +145,6 @@ class OpClass : public Class {
public:
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");

// Sets whether this OpClass should generate the using directive for its
// associate operand adaptor class.
void setHasOperandAdaptorClass(bool has);

// Adds an op trait.
void addTrait(Twine trait);

Expand All @@ -160,7 +156,6 @@ class OpClass : public Class {
StringRef extraClassDeclaration;
SmallVector<std::string, 4> traitsVec;
StringSet<> traitsSet;
bool hasOperandAdaptor;
};

} // namespace tblgen
Expand Down
10 changes: 2 additions & 8 deletions mlir/lib/TableGen/OpClass.cpp
Expand Up @@ -188,12 +188,7 @@ void tblgen::Class::writeDefTo(raw_ostream &os) const {
//===----------------------------------------------------------------------===//

tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: Class(name), extraClassDeclaration(extraClassDeclaration),
hasOperandAdaptor(true) {}

void tblgen::OpClass::setHasOperandAdaptorClass(bool has) {
hasOperandAdaptor = has;
}
: Class(name), extraClassDeclaration(extraClassDeclaration) {}

void tblgen::OpClass::addTrait(Twine trait) {
auto traitStr = trait.str();
Expand All @@ -207,8 +202,7 @@ void tblgen::OpClass::writeDeclTo(raw_ostream &os) const {
os << ", " << trait;
os << "> {\npublic:\n";
os << " using Op::Op;\n";
if (hasOperandAdaptor)
os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";
os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";

bool hasPrivateMethod = false;
for (const auto &method : methods) {
Expand Down
30 changes: 28 additions & 2 deletions mlir/test/mlir-tblgen/op-decl.td
Expand Up @@ -50,12 +50,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {

// CHECK: class AOpOperandAdaptor {
// CHECK: public:
// CHECK: AOpOperandAdaptor(ArrayRef<Value> values);
// CHECK: AOpOperandAdaptor(ArrayRef<Value> values
// CHECK: ArrayRef<Value> getODSOperands(unsigned index);
// CHECK: Value a();
// CHECK: ArrayRef<Value> b();
// CHECK: IntegerAttr attr1();
// CHECL: FloatAttr attr2();
// CHECK: private:
// CHECK: ArrayRef<Value> tblgen_operands;
// CHECK: ArrayRef<Value> odsOperands;
// CHECK: };

// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
Expand Down Expand Up @@ -90,6 +92,29 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: void displayGraph();
// CHECK: };

// Check AttrSizedOperandSegments
// ---

def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
[AttrSizedOperandSegments]> {
let arguments = (ins
Variadic<I32>:$a,
Variadic<I32>:$b,
I32:$c,
Variadic<I32>:$d,
I32ElementsAttr:$operand_segment_sizes
);
}

// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor(
// CHECK-SAME: ArrayRef<Value> values
// CHECK-SAME: DictionaryAttr attrs
// CHECK: ArrayRef<Value> a();
// CHECK: ArrayRef<Value> b();
// CHECK: Value c();
// CHECK: ArrayRef<Value> d();
// CHECK: DenseIntElementsAttr operand_segment_sizes();

// Check op trait for different number of operands
// ---

Expand Down Expand Up @@ -150,3 +175,4 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;

// CHECK-LABEL: _BOp declarations
// CHECK: class _BOp : public Op<_BOp

2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/op-operand.td
Expand Up @@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK-LABEL: OpA definitions

// CHECK: OpAOperandAdaptor::OpAOperandAdaptor
// CHECK-NEXT: tblgen_operands = values
// CHECK-NEXT: odsOperands = values

// CHECK: void OpA::build
// CHECK: Value input
Expand Down
103 changes: 72 additions & 31 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Expand Up @@ -70,13 +70,19 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
// (variadic or not).
//
// {0}: The name of the attribute specifying the segment sizes.
const char *attrSizedSegmentValueRangeCalcCode = R"(
const char *adapterSegmentSizeAttrInitCode = R"(
assert(odsAttrs && "missing segment size attribute for op");
auto sizeAttr = odsAttrs.get("{0}").cast<DenseIntElementsAttr>();
)";
const char *opSegmentSizeAttrInitCode = R"(
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
)";
const char *attrSizedSegmentValueRangeCalcCode = R"(
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += (*(sizeAttr.begin() + i)).getZExtValue();
unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
return {{start, size};
return {start, size};
)";

// The logic to build a range of either operand or result values.
Expand Down Expand Up @@ -496,15 +502,14 @@ static void
generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
int numVariadic, int numNonVariadic,
StringRef rangeSizeCall, bool hasAttrSegmentSize,
StringRef segmentSizeAttr, RangeT &&odsValues) {
StringRef sizeAttrInit, RangeT &&odsValues) {
auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
"unsigned index");

if (numVariadic == 0) {
method.body() << " return {index, 1};\n";
} else if (hasAttrSegmentSize) {
method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
segmentSizeAttr);
method.body() << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
Expand Down Expand Up @@ -532,6 +537,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef sizeAttrInit,
StringRef rangeType,
StringRef rangeBeginCall,
StringRef rangeSizeCall,
Expand Down Expand Up @@ -563,18 +569,17 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,

// First emit a few "sink" getter methods upon which we layer all nicer named
// getter methods.
generateValueRangeStartAndEnd(
opClass, "getODSOperandIndexAndLength", numVariadicOperands,
numNormalOperands, rangeSizeCall, attrSizedOperands,
"operand_segment_sizes", const_cast<Operator &>(op).getOperands());
generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
numVariadicOperands, numNormalOperands,
rangeSizeCall, attrSizedOperands, sizeAttrInit,
const_cast<Operator &>(op).getOperands());

auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
"getODSOperandIndexAndLength(index)");

// Then we emit nicer named getter methods by redirecting to the "sink" getter
// method.

for (int i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
Expand All @@ -595,11 +600,11 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
}

void OpEmitter::genNamedOperandGetters() {
if (op.getTrait("OpTrait::AttrSizedOperandSegments"))
opClass.setHasOperandAdaptorClass(false);

generateNamedOperandGetters(
op, opClass, /*rangeType=*/"Operation::operand_range",
op, opClass,
/*sizeAttrInit=*/
formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
/*rangeType=*/"Operation::operand_range",
/*rangeBeginCall=*/"getOperation()->operand_begin()",
/*rangeSizeCall=*/"getOperation()->getNumOperands()",
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
Expand Down Expand Up @@ -656,7 +661,8 @@ void OpEmitter::genNamedResultGetters() {
generateValueRangeStartAndEnd(
opClass, "getODSResultIndexAndLength", numVariadicResults,
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
"result_segment_sizes", op.getResults());
formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
op.getResults());
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
"unsigned index");
m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
Expand Down Expand Up @@ -1840,15 +1846,56 @@ class OpOperandAdaptorEmitter {

OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
: adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
adapterClass.newField("ArrayRef<Value>", "tblgen_operands");
auto &constructor = adapterClass.newConstructor("ArrayRef<Value> values");
constructor.body() << " tblgen_operands = values;\n";

generateNamedOperandGetters(op, adapterClass,
adapterClass.newField("ArrayRef<Value>", "odsOperands");
adapterClass.newField("DictionaryAttr", "odsAttrs");
const auto *attrSizedOperands =
op.getTrait("OpTrait::AttrSizedOperandSegments");
auto &constructor = adapterClass.newConstructor(
attrSizedOperands
? "ArrayRef<Value> values, DictionaryAttr attrs"
: "ArrayRef<Value> values, DictionaryAttr attrs = nullptr");
constructor.body() << " odsOperands = values;\n";
constructor.body() << " odsAttrs = attrs;\n";

std::string sizeAttrInit =
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
generateNamedOperandGetters(op, adapterClass, sizeAttrInit,
/*rangeType=*/"ArrayRef<Value>",
/*rangeBeginCall=*/"tblgen_operands.begin()",
/*rangeSizeCall=*/"tblgen_operands.size()",
/*getOperandCallPattern=*/"tblgen_operands[{0}]");
/*rangeBeginCall=*/"odsOperands.begin()",
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");

FmtContext fctx;
fctx.withBuilder("mlir::Builder(odsAttrs.getContext())");

auto emitAttr = [&](StringRef name, Attribute attr) {
auto &body = adapterClass.newMethod(attr.getStorageType(), name).body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
<< "\n " << attr.getStorageType() << " attr = "
<< "odsAttrs.get(\"" << name << "\").";
if (attr.hasDefaultValue() || attr.isOptional())
body << "dyn_cast_or_null<";
else
body << "cast<";
body << attr.getStorageType() << ">();\n";

if (attr.hasDefaultValue()) {
// Use the default value if attribute is not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
std::string defaultValue = std::string(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " if (!attr)\n attr = " << defaultValue << ";\n";
}
body << " return attr;\n";
};

for (auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr())
emitAttr(name, attr);
}
}

void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
Expand All @@ -1873,19 +1920,13 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
}
for (auto *def : defs) {
Operator op(*def);
const auto *attrSizedOperands =
op.getTrait("OpTrait::AttrSizedOperandSegments");
if (emitDecl) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
// We cannot generate the operand adaptor class if operand getters depend
// on an attribute.
if (!attrSizedOperands)
OpOperandAdaptorEmitter::emitDecl(op, os);
OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os);
} else {
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
if (!attrSizedOperands)
OpOperandAdaptorEmitter::emitDef(op, os);
OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os);
}
}
Expand Down

0 comments on commit 4b8632e

Please sign in to comment.