diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index c8f0806e27a62..f1c3d717f1fa9 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1819,7 +1819,7 @@ def TileUsingForOp : Op:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, - DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$interchange, DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs TransformHandleTypeInterface:$tiled_linalg_op, Variadic:$loops); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 73de3f22d896f..de4965f937162 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2477,7 +2477,7 @@ void transform::TileUsingForOp::build( /*target=*/target, /*dynamic_sizes=*/dynamicTileSizes, /*static_sizes=*/staticTileSizesAttr, - /*interchange=*/builder.getI64ArrayAttr(interchange), + /*interchange=*/builder.getDenseI64ArrayAttr(interchange), /*scalable_sizes=*/expandedScalableSizes); } @@ -2611,8 +2611,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, }); } - tilingOptions.setInterchange( - extractFromIntegerArrayAttr(getInterchange())); + tilingOptions.setInterchange(getInterchange()); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2649,6 +2648,33 @@ SmallVector transform::TileUsingForOp::getMixedSizes() { return results; } +// We want to parse `DenseI64ArrayAttr` using the short form without the +// `array` prefix to be consistent in the IR with `parseDynamicIndexList`. +ParseResult parseOptionalInterchange(OpAsmParser &parser, + OperationState &result) { + if (succeeded(parser.parseOptionalLBrace())) { + if (failed(parser.parseKeyword("interchange"))) + return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; + if (failed(parser.parseEqual())) + return parser.emitError(parser.getNameLoc()) << "expect `=`"; + result.addAttribute("interchange", + DenseI64ArrayAttr::parse(parser, Type{})); + if (failed(parser.parseRBrace())) + return parser.emitError(parser.getNameLoc()) << "expect `}`"; + } + return success(); +} + +void printOptionalInterchange(OpAsmPrinter &p, + ArrayRef interchangeVals) { + if (!interchangeVals.empty()) { + p << " {interchange = ["; + llvm::interleaveComma(interchangeVals, p, + [&](int64_t integer) { p << integer; }); + p << "]}"; + } +} + ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; @@ -2660,7 +2686,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser, if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) || - parser.parseOptionalAttrDict(result.attributes) || + parseOptionalInterchange(parser, result) || parser.parseColonType(functionalType)) return ParseResult::failure(); @@ -2694,10 +2720,7 @@ void TileUsingForOp::print(OpAsmPrinter &p) { printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); - p.printOptionalAttrDict( - (*this)->getAttrs(), - /*elidedAttrs=*/{getScalableSizesAttrName(getOperation()->getName()), - getStaticSizesAttrName(getOperation()->getName())}); + printOptionalInterchange(p, getInterchange()); p << " : "; p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); } diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir index 4d7c514dcca62..e9f044be5b4ed 100644 --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -6,14 +6,6 @@ transform.sequence failures(propagate) { %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) } -// check that the Attributes of `tile_using_for` are preserved through printing -// and parsing. -transform.sequence failures(propagate) { -^bb1(%arg0: !transform.any_op): - // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile %arg0 [2, 0, 3] {interchange = [2, 1], test_attr1 = 1 : i64, test_attr2} - %0, %1:2 = transform.structured.tile_using_for %arg0 [2, 0, 3] {test_attr1 = 1 : i64, interchange = [2, 1], test_attr2}: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) -} - transform.sequence failures(propagate) { ^bb1(%arg0: !transform.any_op): %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op