-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Transform] FuseOp: accept transform params, add use_forall argument #161883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Tuomas Kärnä (tkarna) ChangesChanges to linalg
Patch is 41.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161883.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 0d6ebc087e2f3..40588afa6477a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -395,31 +395,72 @@ def EliminateLinalgOpAnchoredEmptyTensorsOp
//===----------------------------------------------------------------------===//
def FuseOp : Op<Transform_Dialect, "structured.fuse",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>,
- ReportTrackingListenerFailuresOpTrait]> {
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Tiles the operations pointed to by the target handle and fuses their
producers greedily using the options provided as attributes.
If `apply_cleanup` is true then slice canonicalization is applied between
- fusion steps.
+ fusion steps. If `use_forall` is true then tiling method generates a
+ `scf.forall` loop instead of `scf.for` loops.
}];
let arguments =
(ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
- DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
+ DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
+ DefaultValuedAttr<BoolAttr, "false">:$use_forall);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);
+ let builders = [
+ OpBuilder<(ins "TypeRange":$loopTypes,
+ "Value":$target,
+ "ArrayRef<int64_t>":$staticTileSizes,
+ "ArrayRef<int64_t>":$staticTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ OpBuilder<(ins "TypeRange":$loopTypes,
+ "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedTileSizes,
+ "ArrayRef<OpFoldResult>":$mixedTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$staticTileSizes,
+ "ArrayRef<int64_t>":$staticTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedTileSizes,
+ "ArrayRef<OpFoldResult>":$mixedTileInterchange,
+ CArg<"bool", "false">:$applyCleanup,
+ CArg<"bool", "false">:$useForall)>,
+ ];
let assemblyFormat = [{
- $target ($tile_sizes^)? (`interchange` $tile_interchange^)?
- (`apply_cleanup` `=` $apply_cleanup^)? attr-dict
+ $target
+ (`tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes)^)?
+ (`interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)^)?
+ (`apply_cleanup` `=` $apply_cleanup^)?
+ (`use_forall` `=` $use_forall^)? attr-dict
`:` functional-type(operands, results)
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileSizes();
+ ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileInterchange();
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dd9b4c2490ef4..0d365f29a51a3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
// FuseOp
//===----------------------------------------------------------------------===//
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result, loopTypes,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target, ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ // Loop types are automaticaly splat by the callee, setting up one is
+ // enough.
+ SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
+ build(builder, result, loopTypes, target, mixedTileSizes,
+ mixedTileInterchange, applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ SmallVector<int64_t> staticTileSizes;
+ SmallVector<Value> dynamicTileSizes;
+ dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
+ SmallVector<int64_t> staticTileInterchange;
+ SmallVector<Value> dynamicTileInterchange;
+ dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
+ staticTileInterchange);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileInterchangeAttr =
+ builder.getDenseI64ArrayAttr(staticTileInterchange);
+ unsigned numExpectedLoops =
+ useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(numExpectedLoops);
+ assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
+ "expected one loop type or as many as loops");
+ if (loopTypes.size() == 1)
+ resultTypes.append(numExpectedLoops, loopTypes[0]);
+ else
+ llvm::append_range(resultTypes, loopTypes);
+ build(builder, result, /*transformed=*/target.getType(),
+ /*loops=*/resultTypes,
+ /*target=*/target,
+ /*tile_sizes=*/dynamicTileSizes,
+ /*tile_interchange=*/dynamicTileInterchange,
+ /*static_tile_sizes=*/staticTileSizesAttr,
+ /*static_tile_interchange=*/staticTileInterchangeAttr,
+ /*apply_cleanup=*/applyCleanup,
+ /*use_forall=*/useForall);
+}
+
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
@@ -630,13 +710,25 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- SmallVector<int64_t> tileSizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- SmallVector<int64_t> tileInterchange =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int64_t> tileSizes;
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileSizes(), tileSizes);
+ if (!status.succeeded())
+ return status;
+ SmallVector<int64_t> tileInterchange;
+ status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileInterchange(), tileInterchange);
+ if (!status.succeeded())
+ return status;
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
+ bool useForall = getUseForall();
+ tilingOptions.setLoopType(useForall
+ ? scf::SCFTilingOptions::LoopType::ForallOp
+ : scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
@@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
+ size_t numLoops =
+ useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
LogicalResult result = applyTilingToAll(
- rewriter, getOperation(), state.getPayloadOps(getTarget()),
- tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+ rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
+ transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -665,24 +759,69 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << getTileInterchange();
+ ArrayRef<int64_t> permutation = getStaticTileInterchange();
+ if (!llvm::any_of(permutation,
+ [](int64_t v) { return ShapedType::isDynamic(v); })) {
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError() << "expects interchange to be a permutation, found "
+ << getTileInterchange();
+ }
}
- SmallVector<int64_t> sizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+ ArrayRef<int64_t> sizes = getStaticTileSizes();
+ size_t numExpectedLoops =
+ getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
return success();
}
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+ ValueRange dynamicValues = getTileSizes();
+ ArrayRef<int64_t> staticValues = getStaticTileSizes();
+ SmallVector<OpFoldResult> results;
+ results.reserve(staticValues.size());
+ unsigned dynamicPos = 0;
+ Builder builder(getContext());
+ for (int64_t size : staticValues) {
+ if (size == ShapedType::kDynamic) {
+ results.push_back(dynamicValues[dynamicPos++]);
+ } else {
+ results.push_back(builder.getIndexAttr(size));
+ }
+ }
+ return results;
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+ ValueRange dynamicValues = getTileInterchange();
+ ArrayRef<int64_t> staticValues = getStaticTileInterchange();
+ SmallVector<OpFoldResult> results;
+ results.reserve(staticValues.size());
+ unsigned dynamicPos = 0;
+ Builder builder(getContext());
+ for (int64_t size : staticValues) {
+ if (size == ShapedType::kDynamic) {
+ results.push_back(dynamicValues[dynamicPos++]);
+ } else {
+ results.push_back(builder.getIndexAttr(size));
+ }
+ }
+ return results;
+}
+
+void transform::FuseOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getTileSizesMutable(), effects);
+ onlyReadsHandle(getTileInterchangeMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index e3bacb5777d9f..d3fe3d5f085bf 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -144,9 +144,10 @@ def __init__(
loop_types: Union[Type, Sequence[Type]],
target: Union[Operation, Value, OpView],
*,
- tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- tile_interchange: OptionalIntList = None,
+ tile_sizes: Optional[MixedValues] = None,
+ tile_interchange: Optional[MixedValues] = None,
apply_cleanup: Optional[bool] = False,
+ use_forall: Optional[bool] = False,
loc=None,
ip=None,
):
@@ -157,9 +158,10 @@ def __init__(
self,
target: Union[Operation, Value, OpView],
*,
- tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- tile_interchange: OptionalIntList = None,
+ tile_sizes: Optional[MixedValues] = None,
+ tile_interchange: Optional[MixedValues] = None,
apply_cleanup: Optional[bool] = False,
+ use_forall: Optional[bool] = False,
loc=None,
ip=None,
):
@@ -170,17 +172,26 @@ def __init__(
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
- tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- tile_interchange: OptionalIntList = None,
+ tile_sizes: Optional[MixedValues] = None,
+ tile_interchange: Optional[MixedValues] = None,
apply_cleanup: Optional[bool] = False,
+ use_forall: Optional[bool] = False,
loc=None,
ip=None,
):
tile_sizes = tile_sizes if tile_sizes else []
tile_interchange = tile_interchange if tile_interchange else []
- _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
- _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
- num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
+ (
+ dynamic_tile_sizes,
+ static_tile_sizes,
+ _,
+ ) = _dispatch_dynamic_index_list(tile_sizes)
+ (
+ dynamic_tile_interchange,
+ static_tile_interchange,
+ _,
+ ) = _dispatch_dynamic_index_list(tile_interchange)
+ num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in static_tile_sizes)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -197,9 +208,12 @@ def __init__(
target.type,
loop_types,
target,
- tile_sizes=tile_sizes,
- tile_interchange=tile_interchange,
+ tile_sizes=dynamic_tile_sizes,
+ tile_interchange=dynamic_tile_interchange,
+ static_tile_sizes=static_tile_sizes,
+ static_tile_interchange=static_tile_interchange,
apply_cleanup=apply_cleanup,
+ use_forall=use_forall,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 9a44f95afb586..d472f75bfcb9a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -18,7 +18,7 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
@@ -48,7 +48,7 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
@@ -57,6 +57,60 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @fuse_unary_param
+func.func @fuse_unary_param(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: %[[RES:.*]] = scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
+ // CHECK: return %[[RES]]
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.param.constant 32 : i32 -> !transform.param<i32>
+ %2 = transform.param.constant 1 : i32 -> !transform.param<i32>
+ %3, %loops:2 = transform.structured.fuse %0 tile_sizes [%1, 32] interchange [0, %2]
+ : (!transform.any_op, !transform.param<i32>, !transform.param<i32>) ->
+ (!transfor...
[truncated]
|
FYI @adam-smnk @fschlimb |
✅ With the latest revision this PR passed the Python code formatter. |
97f53b9
to
af17ff1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tkarna!
LGTM, though would like to resolve a couple questions before it goes in. Nothing major.
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Changes to linalg
structured.fuse
transform op:use_forall
boolean argument which generates a tiledscf.forall
loop instead ofscf.for
loops.tile_sizes
can now be any parameter or handle.tile_interchange
can now be any parameter or handle.transform.structured.fuse %0 [4, 8] ...
transform.structured.fuse %0 tile_sizes [4, 8] ...
UnitAttrs
and should be set via the op attr-dict:{apply_cleanup, use_forall}