Skip to content

Commit

Permalink
[mlir][ods] Fix builder gen for VariadicRegion with inferred types
Browse files Browse the repository at this point in the history
Builders generated for ops with variadic regions and inferred return types were not being correctly generated (missing parameter).
  • Loading branch information
Mogball committed Apr 7, 2022
1 parent ee2d9b8 commit 2f78b43
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
16 changes: 16 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -347,6 +347,22 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
}

def VariadicRegionInferredTypesOp : TEST_Op<"variadic_region_inferred",
[InferTypeOpInterface]> {
let regions = (region VariadicRegion<AnyRegion>:$bodies);
let results = (outs Variadic<AnyType>);

let extraClassDeclaration = [{
static mlir::LogicalResult inferReturnTypes(mlir::MLIRContext *context,
llvm::Optional<::mlir::Location> location, mlir::ValueRange operands,
mlir::DictionaryAttr attributes, mlir::RegionRange regions,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({mlir::IntegerType::get(context, 16)});
return mlir::success();
}
}];
}

//===----------------------------------------------------------------------===//
// NoTerminator Operation
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 5 additions & 2 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Expand Up @@ -1373,13 +1373,16 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
}

void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
"attributes", attributesDefaultValue);
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");

auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
Expand Down

0 comments on commit 2f78b43

Please sign in to comment.