diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 0d6ebc087e2f3..8728e666cd59d 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -395,31 +395,73 @@ def EliminateLinalgOpAnchoredEmptyTensorsOp //===----------------------------------------------------------------------===// def FuseOp : Op, - ReportTrackingListenerFailuresOpTrait]> { + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tiles the operations pointed to by the target handle and fuses their - producers greedily using the options provided as attributes. + producers greedily using the options provided as attributes. Tile sizes + and loop interchange permutation can be provided as either static + attributes or dynamic values (transform parameters or payload handles). If `apply_cleanup` is true then slice canonicalization is applied between - fusion steps. + fusion steps. If `use_forall` is true then tiling method generates a + `scf.forall` loop instead of `scf.for` loops. }]; let arguments = (ins TransformHandleTypeInterface:$target, - DefaultValuedAttr:$tile_sizes, - DefaultValuedAttr:$tile_interchange, - DefaultValuedAttr:$apply_cleanup); + Variadic : $tile_sizes, + Variadic : $tile_interchange, + DefaultValuedOptionalAttr:$static_tile_sizes, + DefaultValuedOptionalAttr:$static_tile_interchange, + UnitAttr:$apply_cleanup, + UnitAttr:$use_forall); let results = (outs TransformHandleTypeInterface:$transformed, Variadic:$loops); + let builders = [ + OpBuilder<(ins "TypeRange":$loopTypes, + "Value":$target, + "ArrayRef":$staticTileSizes, + "ArrayRef":$staticTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + OpBuilder<(ins "TypeRange":$loopTypes, + "Value":$target, + "ArrayRef":$mixedTileSizes, + "ArrayRef":$mixedTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + OpBuilder<(ins "Value":$target, + "ArrayRef":$staticTileSizes, + "ArrayRef":$staticTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedTileSizes, + "ArrayRef":$mixedTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + ]; let assemblyFormat = [{ - $target ($tile_sizes^)? (`interchange` $tile_interchange^)? - (`apply_cleanup` `=` $apply_cleanup^)? attr-dict - `:` functional-type(operands, results) + $target oilist( + `tile_sizes` custom($tile_sizes, $static_tile_sizes) | + `interchange` custom($tile_interchange, $static_tile_interchange) + ) + attr-dict `:` functional-type(operands, results) }]; let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileSizes(); + ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileInterchange(); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dd9b4c2490ef4..d8f983f98ae77 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( // FuseOp //===----------------------------------------------------------------------===// +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef staticTileSizes, + ArrayRef staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, loopTypes, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef staticTileSizes, + ArrayRef staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, + ArrayRef mixedTileSizes, + ArrayRef mixedTileInterchange, + bool applyCleanup, bool useForall) { + // Loop types are automaticaly splat by the callee, setting up one is + // enough. + SmallVector loopTypes(1, builder.getType()); + build(builder, result, loopTypes, target, mixedTileSizes, + mixedTileInterchange, applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef mixedTileSizes, + ArrayRef mixedTileInterchange, + bool applyCleanup, bool useForall) { + SmallVector staticTileSizes; + SmallVector dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); + SmallVector staticTileInterchange; + SmallVector dynamicTileInterchange; + dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange, + staticTileInterchange); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + auto staticTileInterchangeAttr = + builder.getDenseI64ArrayAttr(staticTileInterchange); + unsigned numExpectedLoops = + useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*transformed=*/target.getType(), + /*loops=*/resultTypes, + /*target=*/target, + /*tile_sizes=*/dynamicTileSizes, + /*tile_interchange=*/dynamicTileInterchange, + /*static_tile_sizes=*/staticTileSizesAttr, + /*static_tile_interchange=*/staticTileInterchangeAttr, + /*apply_cleanup=*/applyCleanup, + /*use_forall=*/useForall); +} + /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector tileSizes = - extractFromIntegerArrayAttr(getTileSizes()); - SmallVector tileInterchange = - extractFromIntegerArrayAttr(getTileInterchange()); + auto transformOp = cast(getOperation()); + + SmallVector tileSizes; + DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileSizes(), tileSizes); + if (!status.succeeded()) + return status; + SmallVector tileInterchange; + status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileInterchange(), tileInterchange); + if (!status.succeeded()) + return status; scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; + bool useForall = getUseForall(); + tilingOptions.setLoopType(useForall + ? scf::SCFTilingOptions::LoopType::ForallOp + : scf::SCFTilingOptions::LoopType::ForOp); SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } + size_t numLoops = + useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); LogicalResult result = applyTilingToAll( - rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, + transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { - SmallVector permutation = - extractFromIntegerArrayAttr(getTileInterchange()); - auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); - if (!std::is_permutation(sequence.begin(), sequence.end(), - permutation.begin(), permutation.end())) { - return emitOpError() << "expects interchange to be a permutation, found " - << getTileInterchange(); + auto iterspace_rank = getStaticTileSizes().size(); + ArrayRef permutation = getStaticTileInterchange(); + if (permutation.size() > iterspace_rank) + return emitOpError() + << "interchange length exceeds iteration space dimensions (" + << iterspace_rank << "), found " << getTileInterchange(); + SmallVector seen(iterspace_rank, false); + for (int64_t v : permutation) { + if (!ShapedType::isDynamic(v)) { + if (v < 0 || v >= static_cast(iterspace_rank)) + return emitOpError() << "expects interchange values to be in range [0, " + << iterspace_rank << "), found: " << v; + if (seen[v]) + return emitOpError() << "found duplicate interchange value: " << v; + seen[v] = true; + } } - SmallVector sizes = - extractFromIntegerArrayAttr(getTileSizes()); - size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + ArrayRef sizes = getStaticTileSizes(); + size_t numExpectedLoops = + getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } +SmallVector transform::FuseOp::getMixedTileSizes() { + return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext()); +} + +SmallVector transform::FuseOp::getMixedTileInterchange() { + return getMixedValues(getStaticTileInterchange(), getTileInterchange(), + getContext()); +} + +void transform::FuseOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getTileInterchangeMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index e3bacb5777d9f..14c7380e432f0 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -144,9 +144,10 @@ def __init__( loop_types: Union[Type, Sequence[Type]], target: Union[Operation, Value, OpView], *, - tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - tile_interchange: OptionalIntList = None, - apply_cleanup: Optional[bool] = False, + tile_sizes: Optional[MixedValues] = None, + tile_interchange: Optional[MixedValues] = None, + apply_cleanup: bool = False, + use_forall: bool = False, loc=None, ip=None, ): @@ -157,9 +158,10 @@ def __init__( self, target: Union[Operation, Value, OpView], *, - tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - tile_interchange: OptionalIntList = None, - apply_cleanup: Optional[bool] = False, + tile_sizes: Optional[MixedValues] = None, + tile_interchange: Optional[MixedValues] = None, + apply_cleanup: bool = False, + use_forall: bool = False, loc=None, ip=None, ): @@ -170,17 +172,26 @@ def __init__( loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], target_or_none: Optional[Union[Operation, Value, OpView]] = None, *, - tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - tile_interchange: OptionalIntList = None, - apply_cleanup: Optional[bool] = False, + tile_sizes: Optional[MixedValues] = None, + tile_interchange: Optional[MixedValues] = None, + apply_cleanup: bool = False, + use_forall: bool = False, loc=None, ip=None, ): tile_sizes = tile_sizes if tile_sizes else [] tile_interchange = tile_interchange if tile_interchange else [] - _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) - _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) - num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) + ( + dynamic_tile_sizes, + static_tile_sizes, + _, + ) = _dispatch_dynamic_index_list(tile_sizes) + ( + dynamic_tile_interchange, + static_tile_interchange, + _, + ) = _dispatch_dynamic_index_list(tile_interchange) + num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0) if isinstance(loop_types_or_target, (Operation, Value, OpView)): loop_types = [transform.AnyOpType.get()] * num_loops @@ -197,9 +208,12 @@ def __init__( target.type, loop_types, target, - tile_sizes=tile_sizes, - tile_interchange=tile_interchange, + tile_sizes=dynamic_tile_sizes, + tile_interchange=dynamic_tile_interchange, + static_tile_sizes=static_tile_sizes, + static_tile_interchange=static_tile_interchange, apply_cleanup=apply_cleanup, + use_forall=use_forall, loc=loc, ip=ip, ) diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 9a44f95afb586..7dc0a87bfa04c 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -18,7 +18,7 @@ func.func @fuse_unary(%arg0: tensor, %arg1: tensor) -> tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -48,7 +48,7 @@ func.func @fuse_unary(%arg0: tensor, %arg1: tensor) -> tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) transform.yield @@ -57,6 +57,60 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @fuse_unary_param +func.func @fuse_unary_param(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.param.constant 32 : i32 -> !transform.param + %2 = transform.param.constant 1 : i32 -> !transform.param + %3, %loops:2 = transform.structured.fuse %0 tile_sizes [%1, 32] interchange [0, %2] + : (!transform.any_op, !transform.param, !transform.param) -> + (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary_forall +func.func @fuse_unary_forall(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.forall + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop = transform.structured.fuse %0 tile_sizes [32, 32] {use_forall} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @interchange_reduction // CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { @@ -93,7 +147,7 @@ func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf3 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [5, 0, 7] interchange [0, 2, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) %2, %loops_2 = transform.structured.tile_using_for %1 tile_sizes [0, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) @@ -121,7 +175,7 @@ func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [16, 32] interchange [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -147,7 +201,7 @@ func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [3, 5, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -173,7 +227,7 @@ func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]} + %1, %loops:3 = transform.structured.fuse %0 tile_sizes [3, 5, 2, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -204,7 +258,7 @@ func.func @fuse_through_slice(%arg0: tensor, %arg1: tensor) -> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -238,7 +292,7 @@ func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -273,7 +327,7 @@ func.func @fuse_unrelated_slices(%arg0: tensor, %arg1: tensor) module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -299,7 +353,7 @@ func.func @bubble_up_extract_slice_through_expand_shape(%0: tensor<60xf32>) -> t module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 1, 5] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) transform.yield } @@ -324,7 +378,7 @@ func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim(%0: tenso module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:2 = transform.structured.fuse %0 tile_sizes [1, 2, 0] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -348,7 +402,7 @@ func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous(%0: te module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 2, 5] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) transform.yield } @@ -379,7 +433,7 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true : + %transformed, %loops:4 = transform.structured.fuse %0 tile_sizes [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -408,7 +462,7 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true : + %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [0, 0, 1, 0] interchange [0, 1, 2, 3] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) transform.yield } @@ -433,7 +487,7 @@ func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false(%0: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = false : + %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 1, 5] interchange [0, 1, 2] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) transform.yield } @@ -456,7 +510,7 @@ func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [1, 0, 0] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) transform.yield } @@ -482,7 +536,7 @@ func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [1, 0, 0] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) transform.yield } diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir index 04a99b5fd0d68..32fb0c9e41c39 100644 --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -149,7 +149,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %copy [2, 3] + %a, %b, %c = transform.structured.fuse %copy tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 8a0390a4379cf..8116044594fca 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -17,7 +17,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %matmul [10, 20] + %a, %b, %c = transform.structured.fuse %matmul tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -69,7 +69,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %mm1, %mm2 = transform.split_handle %matmuls : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.structured.fuse %mm2 [10] + %a, %b = transform.structured.fuse %mm2 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -188,7 +188,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -248,7 +248,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] interchange[1, 0] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] interchange[1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -307,7 +307,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -367,7 +367,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -423,7 +423,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %mm1, %mm2, %mm3 = transform.split_handle %matmuls : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %a, %b = transform.structured.fuse %mm3 [10] + %a, %b = transform.structured.fuse %mm3 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -512,7 +512,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %generic1, %generic2, %generic3 = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %a, %b = transform.structured.fuse %generic3 [10] + %a, %b = transform.structured.fuse %generic3 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -568,7 +568,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.structured.fuse %pad [8] + %a, %b = transform.structured.fuse %pad tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -614,7 +614,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.structured.fuse %matmul [0, 1, 0] + %a, %b = transform.structured.fuse %matmul tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -652,7 +652,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %loops:4 = transform.structured.fuse %generic {tile_sizes = [1, 16, 16, 16], tile_interchange = [0, 1, 2, 3], apply_cleanup = false} + %a, %loops:4 = transform.structured.fuse %generic tile_sizes [1, 16, 16, 16] interchange [0, 1, 2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 8785d6d360074..d6b70dc9d1978 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -109,11 +109,27 @@ def testFuseOpCompact(target): ) # CHECK-LABEL: TEST: testFuseOpCompact # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] - # CHECK-SAME: interchange [0, 1] apply_cleanup = true + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8] + # CHECK-SAME: interchange [0, 1] {apply_cleanup} # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +@run +@create_sequence +def testFuseOpCompactForall(target): + structured.FuseOp( + target, + tile_sizes=[4, 8], + apply_cleanup=True, + use_forall=True, + ) + # CHECK-LABEL: TEST: testFuseOpCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8] + # CHECK-SAME: {apply_cleanup, use_forall} + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + @run @create_sequence def testFuseOpNoArg(target): @@ -124,6 +140,44 @@ def testFuseOpNoArg(target): # CHECK-SAME: (!transform.any_op) -> !transform.any_op +@run +@create_sequence +def testFuseOpParams(target): + structured.FuseOp( + target, + tile_sizes=[constant_param(4), Attribute.parse("8")], + tile_interchange=[constant_param(0), Attribute.parse("1")], + ) + # CHECK-LABEL: TEST: testFuseOpParams + # CHECK: transform.sequence + # CHECK-DAG: %[[P:.*]] = transform.param.constant 4 + # CHECK-DAG: %[[I:.*]] = transform.param.constant 0 + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse + # CHECK-SAME: tile_sizes [%[[P]], 8] + # CHECK-SAME: interchange [%[[I]], 1] + # CHECK-SAME: (!transform.any_op, !transform.param, !transform.param) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpHandles(target): + size1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp( + target, + tile_sizes=[size1, 8], + tile_interchange=[ichange1, 1], + ) + # CHECK-LABEL: TEST: testFuseOpHandles + # CHECK: transform.sequence + # CHECK: %[[H:.*]] = transform.structured.match + # CHECK: %[[I:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse + # CHECK-SAME: tile_sizes [%[[H]], 8] + # CHECK-SAME: interchange [%[[I]], 1] + # CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + @run @create_sequence def testFuseOpAttributes(target): @@ -132,7 +186,7 @@ def testFuseOpAttributes(target): structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) # CHECK-LABEL: TEST: testFuseOpAttributes # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8] # CHECK-SAME: interchange [0, 1] # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)