Skip to content

Commit

Permalink
[mlir][ODS] Automatically create result_segment_sizes in builder
Browse files Browse the repository at this point in the history
When using multiple variadic results of differing sizes, using `AttrSizedResultSegments` is currently a requirement. Unlike `AttrSizedOperandSegments` however, it is not created within the default builders created by tablegen. Instead, one has to explicitly add `DenseI32ArrayAttr:$result_segments_sizes` as argument and then also explicitly specify all the sizes when using the builder from C++.

This patch fixes that redundancy, by making the builder generate the attribute in similar fashion as it already does for `AttrSizedOperandSegments`. The sizes required are simply gathered from the result type arguments of the builder.

Differential Revision: https://reviews.llvm.org/D132656
  • Loading branch information
zero9178 committed Aug 25, 2022
1 parent 525af9f commit 0393443
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
8 changes: 5 additions & 3 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -722,9 +722,6 @@ def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",

def AttrSizedResultOp : TEST_Op<"attr_sized_results",
[AttrSizedResultSegments]> {
let arguments = (ins
DenseI32ArrayAttr:$result_segment_sizes
);
let results = (outs
Variadic<I32>:$a,
Variadic<I32>:$b,
Expand All @@ -733,6 +730,11 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
);
}

def AttrSizedResultCompileTestOp : TEST_Op<"attr_sized_results_compile_test",
[AttrSizedResultSegments]> {
let results = (outs Variadic<I32>:$a, I32:$b, Optional<I32>:$c);
}

// This is used to test that the fallback for a custom op's parser and printer
// is the dialect parser and printer hooks.
def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/mlir-tblgen/op-result.td
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,10 @@ def OpL3 : NS_Op<"op_with_all_types_constraint",
// CHECK-NOT: }
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;

def OpM : NS_Op<"mix_diff_size_variadic_and_normal_results_op", [AttrSizedResultSegments]> {
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Optional<AnyTensor>:$output3);
}

// CHECK-LABEL: OpM::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange output1, ::mlir::Type output2, /*optional*/::mlir::Type output3)
// CHECK: odsState.addAttribute(result_segment_sizesAttrName(odsState.name), odsBuilder.getDenseI32ArrayAttr({static_cast<int32_t>(output1.size()), 1, (output3 ? 1 : 0)}));
28 changes: 28 additions & 0 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,34 @@ void OpEmitter::genSeparateArgParamBuilder() {
body << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}

// Automatically create the 'result_segment_sizes' attribute using
// the length of the type ranges.
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
std::string getterName = op.getGetterName(resultSegmentAttrName);
body << " " << builderOpState << ".addAttribute(" << getterName
<< "AttrName(" << builderOpState << ".name), "
<< "odsBuilder.getDenseI32ArrayAttr({";

interleaveComma(
llvm::seq<int>(0, op.getNumResults()), body, [&](int i) {
const NamedTypeConstraint &result = op.getResult(i);
if (!result.isVariableLength()) {
body << "1";
} else if (result.isOptional()) {
body << "(" << resultNames[i] << " ? 1 : 0)";
} else {
// VariadicOfVariadic of results are currently unsupported in
// MLIR, hence it can only be a simple variadic.
// TODO: Add implementation for VariadicOfVariadic results here
// once supported.
assert(result.isVariadic());
body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
}
});
body << "}));\n";
}

return;
case TypeParamKind::Collective: {
int numResults = op.getNumResults();
Expand Down

0 comments on commit 0393443

Please sign in to comment.