diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index cb61177bc7533..1a59f4c7d1acb 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -423,6 +423,12 @@ def FuseOp : Op : $tile_sizes, Variadic : $tile_interchange, + Optional : $packed_tile_sizes, DefaultValuedOptionalAttr:$static_tile_sizes, DefaultValuedOptionalAttr:$static_tile_interchange, UnitAttr:$apply_cleanup, @@ -465,7 +472,9 @@ def FuseOp : Op($tile_sizes, $static_tile_sizes) | + `tile_sizes` custom($packed_tile_sizes, + $tile_sizes, + $static_tile_sizes) | `interchange` custom($tile_interchange, $static_tile_interchange) ) attr-dict `:` functional-type(operands, results) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index baa57f8920094..f44693096b26b 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -654,6 +654,7 @@ void transform::FuseOp::build(OpBuilder &builder, OperationState &result, /*target=*/target, /*tile_sizes=*/dynamicTileSizes, /*tile_interchange=*/dynamicTileInterchange, + /*packed_tile_sizes=*/Value(), /*static_tile_sizes=*/staticTileSizesAttr, /*static_tile_interchange=*/staticTileInterchangeAttr, /*apply_cleanup=*/applyCleanup, @@ -666,10 +667,12 @@ template static LogicalResult applyTilingToAll( RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, + bool packedResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); + size_t numTargets = llvm::range_size(payloadOps); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); @@ -704,8 +707,22 @@ static LogicalResult applyTilingToAll( } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); - for (unsigned int i = 0; i < numLoops; ++i) - transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + if (packedResults) { + // In case of packed results, all created loops are assigned to a single + // handle. Loops are returned in order of targets such as: + // %loops_handle = { + // target0:loop0, ..., target0:loopN, + // target1:loop0, ..., target1:loopN, + // ... } + SmallVector flattenedLoopOps; + for (unsigned int idx = 0; idx < numTargets; ++idx) + for (unsigned int i = 0; i < numLoops; ++i) + flattenedLoopOps.push_back(loopOps[i][idx]); + transformResults.set(transformOp->getOpResult(1), flattenedLoopOps); + } else { + for (unsigned int i = 0; i < numLoops; ++i) + transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + } return success(); } @@ -716,9 +733,13 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformState &state) { auto transformOp = cast(getOperation()); - SmallVector tileSizes; - DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( - state, transformOp, getMixedTileSizes(), tileSizes); + SmallVector mixedTileSizes; + DiagnosedSilenceableFailure status = + getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); if (!status.succeeded()) return status; SmallVector tileInterchange; @@ -733,9 +754,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tilingOptions.setLoopType(useForall ? scf::SCFTilingOptions::LoopType::ForallOp : scf::SCFTilingOptions::LoopType::ForOp); - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); - tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); + tilingOptions = tilingOptions.setTileSizes(mixedTileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; @@ -748,11 +767,20 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } - size_t numLoops = - useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); + size_t numLoops; + if (useForall) { + numLoops = 1; + } else { + numLoops = llvm::count_if(mixedTileSizes, [](OpFoldResult ofr) { + auto attr = dyn_cast(ofr); + if (!attr) + return true; + return cast(attr).getInt() != 0; + }); + } LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, - transformResults, + transformResults, /*packedResults=*/getPackedTileSizes() != nullptr, [&](TilingInterface tilingInterfaceOp) -> FailureOr { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -763,6 +791,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { + bool hasPackedTiles = getPackedTileSizes() != nullptr; + if (!getMixedTileSizes().empty() && hasPackedTiles) + return emitOpError( + "tile_sizes and packed_tile_sizes are mutually exclusive"); + auto iterspace_rank = getStaticTileSizes().size(); ArrayRef permutation = getStaticTileInterchange(); if (permutation.size() > iterspace_rank) @@ -782,8 +815,9 @@ LogicalResult transform::FuseOp::verify() { } ArrayRef sizes = getStaticTileSizes(); - size_t numExpectedLoops = - getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); + size_t numExpectedLoops = getUseForall() || hasPackedTiles + ? 1 + : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; @@ -803,6 +837,7 @@ void transform::FuseOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getPackedTileSizesMutable(), effects); onlyReadsHandle(getTileInterchangeMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 10abd06ff266e..7f1bd2183a0c5 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -240,6 +240,8 @@ def _dispatch_mixed_values( for size in values or []: if isinstance(size, int): static_values.append(size) + elif isinstance(size, IntegerAttr): + static_values.append(size.value) else: static_values.append(ShapedType.get_dynamic_size()) dynamic_values.append(size) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index d9ab504f0de54..a3c3057ddb834 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -183,15 +183,19 @@ def __init__( tile_interchange = tile_interchange if tile_interchange else [] ( dynamic_tile_sizes, + packed_tile_sizes, static_tile_sizes, - _, - ) = _dispatch_dynamic_index_list(tile_sizes) + ) = _dispatch_mixed_values(tile_sizes) ( dynamic_tile_interchange, static_tile_interchange, _, ) = _dispatch_dynamic_index_list(tile_interchange) - num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0) + num_loops = ( + 1 + if use_forall or packed_tile_sizes is not None + else sum(1 for v in static_tile_sizes if v != 0) + ) if isinstance(loop_types_or_target, (Operation, Value, OpView)): loop_types = [transform.AnyOpType.get()] * num_loops @@ -210,6 +214,7 @@ def __init__( target, tile_sizes=dynamic_tile_sizes, tile_interchange=dynamic_tile_interchange, + packed_tile_sizes=packed_tile_sizes, static_tile_sizes=static_tile_sizes, static_tile_interchange=static_tile_interchange, apply_cleanup=apply_cleanup, diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index b05dc1f295a49..dab8491708104 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -112,6 +112,131 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes +func.func @fuse_unary_packed_tile_sizes(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c32 = transform.param.constant 32 : i64 -> !transform.any_param + %c64 = transform.param.constant 64 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c32, %c64 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + // Verify that correct number of loops is present in packed result. + %loop:2 = transform.split_handle %loops : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_forall +func.func @fuse_unary_packed_tile_sizes_forall(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.forall + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c32 = transform.param.constant 32 : i64 -> !transform.any_param + %c64 = transform.param.constant 64 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c32, %c64 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) {use_forall} + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + // Verify that correct number of loops is present in packed result. + %loop:1 = transform.split_handle %loops : (!transform.any_op) + -> (!transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_multiple_targets +func.func @fuse_unary_packed_tile_sizes_multiple_targets( + %arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: scf.for + // CHECK: scf.for + // CHECK: linalg.add + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.add ins(%arg0, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c32 = transform.param.constant 32 : i64 -> !transform.any_param + %c64 = transform.param.constant 64 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c32, %c64 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + // Verify that correct number of loops is present in packed result. + %loop:4 = transform.split_handle %loops : (!transform.any_op) + -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_no_tiling_packed_tile_sizes +func.func @fuse_no_tiling_packed_tile_sizes(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK-NOT: scf.for + // CHECK: linalg.exp + // CHECK: %[[RES:.*]] = linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c0 = transform.param.constant 0 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c0, %c0 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @interchange_reduction // CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index e58b7646316fc..fcede61100e00 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -191,6 +191,33 @@ def testFuseOpAttributes(target): # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +@run +@create_sequence +def testFuseOpPackedTileSizes(target): + tiles = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp(target, tile_sizes=tiles) + # CHECK-LABEL: TEST: testFuseOpPackedTileSizes + # CHECK: transform.sequence + # CHECK: %[[T:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse + # CHECK-SAME: tile_sizes *(%[[T]]) + # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpPackedTileSizesForall(target): + tiles = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp(target, tile_sizes=tiles, use_forall=True) + # CHECK-LABEL: TEST: testFuseOpPackedTileSizesForall + # CHECK: transform.sequence + # CHECK: %[[T:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse + # CHECK-SAME: tile_sizes *(%[[T]]) + # CHECK-SAME: {use_forall} + # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + @run @create_sequence def testGeneralize(target):