Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,31 +395,73 @@ def EliminateLinalgOpAnchoredEmptyTensorsOp
//===----------------------------------------------------------------------===//

def FuseOp : Op<Transform_Dialect, "structured.fuse",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Tiles the operations pointed to by the target handle and fuses their
producers greedily using the options provided as attributes.
producers greedily using the options provided as attributes. Tile sizes
and loop interchange permutation can be provided as either static
attributes or dynamic values (transform parameters or payload handles).

If `apply_cleanup` is true then slice canonicalization is applied between
fusion steps.
fusion steps. If `use_forall` is true then tiling method generates a
`scf.forall` loop instead of `scf.for` loops.
}];

let arguments =
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
UnitAttr:$apply_cleanup,
UnitAttr:$use_forall);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);
let builders = [
OpBuilder<(ins "TypeRange":$loopTypes,
"Value":$target,
"ArrayRef<int64_t>":$staticTileSizes,
"ArrayRef<int64_t>":$staticTileInterchange,
CArg<"bool", "false">:$applyCleanup,
CArg<"bool", "false">:$useForall)>,
OpBuilder<(ins "TypeRange":$loopTypes,
"Value":$target,
"ArrayRef<OpFoldResult>":$mixedTileSizes,
"ArrayRef<OpFoldResult>":$mixedTileInterchange,
CArg<"bool", "false">:$applyCleanup,
CArg<"bool", "false">:$useForall)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$staticTileSizes,
"ArrayRef<int64_t>":$staticTileInterchange,
CArg<"bool", "false">:$applyCleanup,
CArg<"bool", "false">:$useForall)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedTileSizes,
"ArrayRef<OpFoldResult>":$mixedTileInterchange,
CArg<"bool", "false">:$applyCleanup,
CArg<"bool", "false">:$useForall)>,
];

let assemblyFormat = [{
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
`:` functional-type(operands, results)
$target oilist(
`tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes) |
`interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)
)
attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state);

::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileSizes();
::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileInterchange();
}];
}

//===----------------------------------------------------------------------===//
Expand Down
153 changes: 137 additions & 16 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
// FuseOp
//===----------------------------------------------------------------------===//

void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
TypeRange loopTypes, Value target,
ArrayRef<int64_t> staticTileSizes,
ArrayRef<int64_t> staticTileInterchange,
bool applyCleanup, bool useForall) {
return build(
builder, result, loopTypes,
/*target=*/target,
/*mixedTileSizes=*/
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
/*mixedTileInterchange=*/
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
applyCleanup, useForall);
}

void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
Value target, ArrayRef<int64_t> staticTileSizes,
ArrayRef<int64_t> staticTileInterchange,
bool applyCleanup, bool useForall) {
return build(
builder, result,
/*target=*/target,
/*mixedTileSizes=*/
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
/*mixedTileInterchange=*/
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
applyCleanup, useForall);
}

void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
Value target,
ArrayRef<OpFoldResult> mixedTileSizes,
ArrayRef<OpFoldResult> mixedTileInterchange,
bool applyCleanup, bool useForall) {
// Loop types are automaticaly splat by the callee, setting up one is
// enough.
SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
build(builder, result, loopTypes, target, mixedTileSizes,
mixedTileInterchange, applyCleanup, useForall);
}

void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
TypeRange loopTypes, Value target,
ArrayRef<OpFoldResult> mixedTileSizes,
ArrayRef<OpFoldResult> mixedTileInterchange,
bool applyCleanup, bool useForall) {
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
SmallVector<int64_t> staticTileInterchange;
SmallVector<Value> dynamicTileInterchange;
dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
staticTileInterchange);
// Call the default builder which sets up the proper operands segment sizes
// attributes for multiple variadic operands. In the absence of this,
// horrible bugs ensue.
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
auto staticTileInterchangeAttr =
builder.getDenseI64ArrayAttr(staticTileInterchange);
unsigned numExpectedLoops =
useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
SmallVector<Type> resultTypes;
resultTypes.reserve(numExpectedLoops);
assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
"expected one loop type or as many as loops");
if (loopTypes.size() == 1)
resultTypes.append(numExpectedLoops, loopTypes[0]);
else
llvm::append_range(resultTypes, loopTypes);
build(builder, result, /*transformed=*/target.getType(),
/*loops=*/resultTypes,
/*target=*/target,
/*tile_sizes=*/dynamicTileSizes,
/*tile_interchange=*/dynamicTileInterchange,
/*static_tile_sizes=*/staticTileSizesAttr,
/*static_tile_interchange=*/staticTileInterchangeAttr,
/*apply_cleanup=*/applyCleanup,
/*use_forall=*/useForall);
}

/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
Expand Down Expand Up @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
SmallVector<int64_t> tileSizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
SmallVector<int64_t> tileInterchange =
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
auto transformOp = cast<TransformOpInterface>(getOperation());

SmallVector<int64_t> tileSizes;
DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
state, transformOp, getMixedTileSizes(), tileSizes);
if (!status.succeeded())
return status;
SmallVector<int64_t> tileInterchange;
status = reifyMixedParamAndHandleResults(
state, transformOp, getMixedTileInterchange(), tileInterchange);
if (!status.succeeded())
return status;

scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
bool useForall = getUseForall();
tilingOptions.setLoopType(useForall
? scf::SCFTilingOptions::LoopType::ForallOp
: scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
Expand All @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}

size_t numLoops =
useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
LogicalResult result = applyTilingToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
Expand All @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}

LogicalResult transform::FuseOp::verify() {
SmallVector<int64_t> permutation =
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
return emitOpError() << "expects interchange to be a permutation, found "
<< getTileInterchange();
auto iterspace_rank = getStaticTileSizes().size();
ArrayRef<int64_t> permutation = getStaticTileInterchange();
if (permutation.size() > iterspace_rank)
return emitOpError()
<< "interchange length exceeds iteration space dimensions ("
<< iterspace_rank << "), found " << getTileInterchange();
SmallVector<bool> seen(iterspace_rank, false);
for (int64_t v : permutation) {
if (!ShapedType::isDynamic(v)) {
if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
return emitOpError() << "expects interchange values to be in range [0, "
<< iterspace_rank << "), found: " << v;
if (seen[v])
return emitOpError() << "found duplicate interchange value: " << v;
seen[v] = true;
}
}

SmallVector<int64_t> sizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
ArrayRef<int64_t> sizes = getStaticTileSizes();
size_t numExpectedLoops =
getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";

return success();
}

SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
}

SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
getContext());
}

void transform::FuseOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getTileSizesMutable(), effects);
onlyReadsHandle(getTileInterchangeMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
Expand Down
42 changes: 28 additions & 14 deletions mlir/python/mlir/dialects/transform/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def __init__(
loop_types: Union[Type, Sequence[Type]],
target: Union[Operation, Value, OpView],
*,
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
tile_sizes: Optional[MixedValues] = None,
tile_interchange: Optional[MixedValues] = None,
apply_cleanup: bool = False,
use_forall: bool = False,
loc=None,
ip=None,
):
Expand All @@ -157,9 +158,10 @@ def __init__(
self,
target: Union[Operation, Value, OpView],
*,
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
tile_sizes: Optional[MixedValues] = None,
tile_interchange: Optional[MixedValues] = None,
apply_cleanup: bool = False,
use_forall: bool = False,
loc=None,
ip=None,
):
Expand All @@ -170,17 +172,26 @@ def __init__(
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
tile_sizes: Optional[MixedValues] = None,
tile_interchange: Optional[MixedValues] = None,
apply_cleanup: bool = False,
use_forall: bool = False,
loc=None,
ip=None,
):
tile_sizes = tile_sizes if tile_sizes else []
tile_interchange = tile_interchange if tile_interchange else []
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
(
dynamic_tile_sizes,
static_tile_sizes,
_,
) = _dispatch_dynamic_index_list(tile_sizes)
(
dynamic_tile_interchange,
static_tile_interchange,
_,
) = _dispatch_dynamic_index_list(tile_interchange)
num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0)

if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
Expand All @@ -197,9 +208,12 @@ def __init__(
target.type,
loop_types,
target,
tile_sizes=tile_sizes,
tile_interchange=tile_interchange,
tile_sizes=dynamic_tile_sizes,
tile_interchange=dynamic_tile_interchange,
static_tile_sizes=static_tile_sizes,
static_tile_interchange=static_tile_interchange,
apply_cleanup=apply_cleanup,
use_forall=use_forall,
loc=loc,
ip=ip,
)
Expand Down
Loading