diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 995aab50f9ee81..84f5690a8649f1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -47,8 +47,8 @@ def Linalg_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps"; - using RegionBuilderFunType = - llvm::function_ref; + using RegionBuilderFunType = llvm::function_ref< + void(ImplicitLocOpBuilder &b, Block &, ArrayRef)>; RegionBuilderFunType getRegionBuilder(StringRef name) { return namedStructuredOpRegionBuilders.lookup(name); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index f40f82ed7cd94a..00de7ba0265bd0 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1025,7 +1025,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function)>", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 8bdec0971ee181..33e2422060f2f0 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -83,8 +83,10 @@ def FillOp : LinalgStructured_Op<"fill", []> { extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); - static std::function + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs); + static std::function)> getRegionBuilder() { return ®ionBuilder; } @@ -254,7 +256,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> { library_call()->str() : "op_has_no_registered_library_name"; } - static std::function + static std::function)> getRegionBuilder() { return nullptr; } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 8862b6b154ea5b..bfb3313d1a21d6 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - fun(b, *body); + fun(b, *body, op->getAttrs()); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 4868fdb99341a2..87278dcba08961 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -49,7 +49,7 @@ using namespace mlir::linalg; template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, + TypeRange outputTypes, ArrayRef attrs, llvm::function_ref errorHandler = nullptr); /// Generic entry point to create both the region and the block of a LinalgOp. @@ -72,7 +72,8 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes); + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef attrs); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, @@ -375,7 +376,8 @@ class RegionBuilderHelper { //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// -void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { +void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } @@ -384,16 +386,16 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value value, Value output) { build(builder, result, output.getType().dyn_cast(), value, output); - fillStructuredOpRegion(builder, *result.regions.front(), - TypeRange{value.getType()}, - TypeRange{output.getType()}, {}); + fillStructuredOpRegion( + builder, *result.regions.front(), TypeRange{value.getType()}, + TypeRange{output.getType()}, result.attributes.getAttrs(), {}); } ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, Type outputType) { OpBuilder opBuilder(parser.getContext()); fillStructuredOpRegion(opBuilder, r, TypeRange{valueType}, - TypeRange{outputType}); + TypeRange{outputType}, {}); return success(); } @@ -1820,7 +1822,7 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) { template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, + TypeRange outputTypes, ArrayRef attrs, llvm::function_ref errorHandler) { assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); @@ -1851,7 +1853,7 @@ static void fillStructuredOpRegion( opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - NamedStructuredOpType::regionBuilder(b, *body); + NamedStructuredOpType::regionBuilder(b, *body, attrs); // indexing_maps is an auto-generated method. @@ -1866,7 +1868,7 @@ void createAndFillStructuredOpRegion(OpBuilder &opBuilder, TypeRange outputTypes) { Region ®ion = *result.addRegion(); fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, + opBuilder, region, inputTypes, outputTypes, result.attributes.getAttrs(), [&](unsigned expected, unsigned actual) { assert(expected != actual && "incorrect number of arguments"); }); @@ -1929,14 +1931,15 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes) { + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef attrs) { ParseResult res = success(); OpBuilder opBuilder(parser.getContext()); // Resolve `captures` into `capturedValues` at parse time so we can build the // region with captures. SmallVector capturedValues; fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, + opBuilder, region, inputTypes, outputTypes, attrs, [&](unsigned expected, unsigned actual) { res = parser.emitError( parser.getCurrentLocation(), @@ -1973,7 +1976,8 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes)) + parser, *region, inputTypes, outputTypes, + result.attributes.getAttrs())) return failure(); result.addRegion(std::move(region)); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 5b7d973429269d..f5834efe9cb5ab 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2555,11 +2555,13 @@ def TestLinalgConvOp : let extraClassDeclaration = [{ bool hasIndexSemantics() { return false; } - static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block) { + static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, + mlir::ArrayRef attrs) { b.create(block.getArguments().back()); } - static std::function + static std::function)> getRegionBuilder() { return ®ionBuilder; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index 347923825cee3d..ba44e1eeb62624 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -83,8 +83,8 @@ structured_op: !LinalgStructuredOpConfig # ODS-NEXT: TypeRange(inputs), # ODS-NEXT: TypeRange(outputs) -# IMPL-LABEL: void Test1Op::regionBuilder( -# IMPL: ImplicitLocOpBuilder &b, Block &block) +# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, +# IMPL-NEXT: Block &block, ArrayRef attrs) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); @@ -174,7 +174,8 @@ structured_op: !LinalgStructuredOpConfig # IMPL: auto attr = op->getAttrOfType("strides") # IMPL: "incorrect element type for index attribute 'strides'" # IMPL: "incorrect shape for index attribute 'strides'" -# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block) +# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, +# IMPL-NEXT: Block &block, ArrayRef attrs) # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && # IMPL: yields.push_back(block.getArgument(0)); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index f1fac9f5786120..eb1af791f1cf90 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -523,8 +523,10 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); - static std::function + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + static std::function)> getRegionBuilder() {{ return regionBuilder; } @@ -952,7 +954,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ // {1}: Number of args // {2}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( -void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{ +void {0}::regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs) {{ assert({1} > 0 && block.getNumArguments() == {1} && "{0} regionBuilder expects {1} (>=0) args"); RegionBuilderHelper helper(block.getArgument(0).getContext(), block);