Skip to content

Commit

Permalink
[mlir] Allow Tile transform op to take dynamic sizes
Browse files Browse the repository at this point in the history
Extend the definition of the Tile structured transform op to enable it
accepting handles to operations that produce tile sizes at runtime. This is
useful by itself and prepares for more advanced tiling strategies. Note that
the changes are relevant only to the transform dialect, the tiling
transformation itself already supports dynamic sizes.

Depends On D129216

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129217
  • Loading branch information
ftynse committed Jul 12, 2022
1 parent 7b69843 commit 4e4a4c0
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 67 deletions.
53 changes: 42 additions & 11 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Expand Up @@ -396,28 +396,59 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",

def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Indicates that the given `target` op should be tiled with the options
provided as attributes. This transform generates a loop nest with a smaller
("tiled") target operation in its body. Currently limited to LinalgOps.

`sizes` are the tile sizes. A tile size of `0` indicates that the
respective dimension should not be tiled. No loop will be generated for such
dimensions. If all tile sizes are `0`, this transform is effectively a
no-op.
Indicates that the given `target` op should be tiled with the given sizes.
This transform generates a loop nest with a smaller ("tiled") target
operation in its body. Currently limited to LinalgOps.

Tile sizes may be known at transformation time, in which case they are
expected to be provided in the `static_size` attribute, or not, in which
case the tile value must be computed by the payload IR and the handle to the
operation computing it must be provided through `dynamic_sizes`. When the
sizes are not known statically, the corresponding entry in the
`static_sizes` attribute must be set to `ShapedType::kDynamicSize`. Only
the dynamic sizes must be provided in `dynamic_sizes`, i.e., there should
be as many handles as `ShapedType::kDynamicSize` values in the
`static_sizes` attribute. A static size of `0` indicates that the dimension
should not be tiled. No loop will be generated for such dimensions. If all
tile sizes are `0`, this transform is effectively a no-op.

This op returns handles to the tiled op (in the generated loop nest) and the
generated loops. The number of loops is the number of non-zero tile sizes.
generated loops. The number of loops is the number of tile sizes that are
statically known to be non-zero.

#### Return modes

On success, the resulting handles are associated with co-indexed lists of
tiled operations and loops around them.

This operation only supports Linalg ops and produces a silenceable failure
if the input contains any non-Linalg ops. The ops preceding it in the list
associated with the `target` handle will have been tiled.

This operation produces a silenceable failure if the `dynamic_sizes` handles
are associated with lists of payload operations of a size different than
that of the list associated with the `target` handle.

If the internal implementation of tiling for any of the operations fails,
produces a definite failure.
}];

let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
Variadic<PDL_Operation>:$dynamic_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
/// Returns the list of tile sizes, which may be static (Attribute) or
/// dynamic (Value).
SmallVector<OpFoldResult> getMixedSizes();
}];
}

def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Expand Up @@ -210,6 +210,11 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
LinalgOp op, ValueRange operands,
ValueRange results);

/// Turns an OpFoldResult into a value, creating an index-typed constant if
/// necessary.
Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
OpFoldResult opFoldResult);

/// Creates an extract_slice/subview op for a single `valueToTile` with
/// `builder`. This new operation extracts a tile of `valueToTile`, starting
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
Expand Down
163 changes: 134 additions & 29 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
Expand Down Expand Up @@ -103,16 +104,10 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult
applyTilingToAll(Operation *transformOp, Value target,
ArrayRef<int64_t> tileSizes,
applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
unsigned numLoops,
transform::TransformResults &transformResults,
transform::TransformState &state,
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
// Number of loops: Number of tiles sizes that are not zero.
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
// All payload ops. These should all be LinalgOps for now.
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);

SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (unsigned int i = 0; i < numLoops; ++i)
Expand Down Expand Up @@ -178,8 +173,9 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
fusionOptions.tileInterchange = extractI64Array(getTileInterchange());

LogicalResult result = applyTilingToAll(
getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
getOperation(), state.getPayloadOps(getTarget()),
fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(linalgOp);
Expand All @@ -194,8 +190,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
tileLoopNest->getLoopOps().end()};
return tiledLinalgOp;
});
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
return DiagnosedSilenceableFailure(result);
}

ParseResult transform::FuseOp::parse(OpAsmParser &parser,
Expand Down Expand Up @@ -603,32 +598,141 @@ DiagnosedSilenceableFailure
transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
LinalgTilingOptions tilingOptions;
SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());

ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
dynamicSizeProducers.reserve(getDynamicSizes().size());
for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
dynamicSizeProducers.push_back(
state.getPayloadOps(dynamicSizeProducerHandle));

if (dynamicSizeProducers.back().size() != targets.size()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "expected as many dynamic size-producing operations ("
<< dynamicSizeProducers.back().size() << ") as target ops ("
<< targets.size() << ")";
diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
return diag;
}

if (!tileSizes.empty())
tilingOptions.setTileSizes(tileSizes);
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
LinalgTilingPattern pattern(getContext(), tilingOptions);
for (Operation *op : dynamicSizeProducers.back()) {
if (op->getNumResults() == 1 &&
op->getResult(0).getType().isa<IndexType>())
continue;
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected sizes to be produced by ops "
"with a single index-type result";
diag.attachNote(op->getLoc()) << "size producer op";
diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
return diag;
}
}

LogicalResult result = applyTilingToAll(
getOperation(), getTarget(), tileSizes, transformResults, state,
[&](LinalgOp linalgOp) {
SimpleRewriter rewriter(linalgOp.getContext());
return pattern.returningMatchAndRewrite(linalgOp, rewriter);
});
return DiagnosedSilenceableFailure(result);
SmallVector<Operation *> tiled;
SmallVector<SmallVector<Operation *, 4>, 4> loops;
loops.resize(getLoops().size());
for (auto &en : llvm::enumerate(targets)) {
auto linalgOp = dyn_cast<LinalgOp>(en.value());
if (!linalgOp) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "only linalg ops are supported";
diag.attachNote(en.value()->getLoc()) << "target op";
return diag;
}

unsigned index = en.index();
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction(
[&, index](OpBuilder &b, Operation *) {
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), attr.cast<IntegerAttr>().getInt()));
} else {
sizes.push_back(
dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
}
}
return sizes;
});
}

tilingOptions.setInterchange(extractUIntArray(getInterchange()));
LinalgTilingPattern pattern(getContext(), tilingOptions);
SimpleRewriter rewriter(linalgOp.getContext());
FailureOr<TiledLinalgOp> tiledOp =
pattern.returningMatchAndRewrite(linalgOp, rewriter);
if (failed(tiledOp))
return DiagnosedSilenceableFailure::definiteFailure();

tiled.push_back(tiledOp->op);
for (const auto &en2 : llvm::enumerate(tiledOp->loops))
loops[en2.index()].push_back(en2.value());
}

transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
for (const auto &en : llvm::enumerate(loops))
transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());

return DiagnosedSilenceableFailure::success();
}

SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
Builder builder(getContext());
for (int64_t size : tileSizes) {
if (size == ShapedType::kDynamicSize) {
results.push_back(dynamic[dynamicPos++]);
} else {
results.push_back(builder.getIndexAttr(size));
}
}
return results;
}

ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseTileLikeOp(parser, result,
TileOp::getSizesAttrName(result.name).getValue());
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
ArrayAttr staticSizes;
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
return ParseResult::failure();

result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}

void TileOp::print(OpAsmPrinter &p) {
p << ' ';
p << getTarget();
p.printOptionalAttrDict((*this)->getAttrs());
p << ' ' << getTarget();
printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
getStaticSizes());
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
}

void transform::TileOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
onlyReadsHandle(getDynamicSizes(), effects);
producesHandle(getTiledLinalgOp(), effects);
producesHandle(getLoops(), effects);
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -678,6 +782,7 @@ class LinalgTransformDialectExtension
LinalgTransformDialectExtension> {
public:
LinalgTransformDialectExtension() {
declareDependentDialect<arith::ArithmeticDialect>();
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<scf::SCFDialect>();
declareDependentDialect<vector::VectorDialect>();
Expand Down
10 changes: 0 additions & 10 deletions mlir/lib/Dialect/Linalg/Transforms/Split.cpp
Expand Up @@ -15,16 +15,6 @@
using namespace mlir;
using namespace mlir::linalg;

/// Turns an OpFoldResult into a value, creating an index-typed constant if
/// necessary.
static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
OpFoldResult opFoldResult) {
if (opFoldResult.is<Value>())
return opFoldResult.get<Value>();
auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
}

/// Extract the slices of `operands` supplied to the given operation `op` such
/// that they are sufficient to execute the op for the subset of its iteration
/// space defined by `splitIterationSpace`. The subset is a part of the original
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Expand Up @@ -993,6 +993,14 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
return tensorResults;
}

Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
OpFoldResult opFoldResult) {
if (auto value = opFoldResult.dyn_cast<Value>())
return value;
auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
}

SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
LinalgOp linalgOp,
ArrayRef<Value> valuesToTile,
Expand Down
28 changes: 25 additions & 3 deletions mlir/python/mlir/dialects/_structured_transform_ops_ext.py
Expand Up @@ -191,18 +191,40 @@ class TileOp:
def __init__(self,
target: Union[Operation, Value],
*,
sizes: OptionalIntList = None,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
Value]], ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
sizes_attr = _get_int_array_attr(sizes)
i64_type = IntegerType.get_signless(64)

if sizes is None:
sizes = []

static_sizes = []
dynamic_sizes = []
if isinstance(sizes, ArrayAttr):
sizes_attr = sizes
else:
for size in sizes:
if isinstance(size, int):
static_sizes.append(IntegerAttr.get(i64_type, size))
elif isinstance(size, IntegerAttr):
static_sizes.append(size)
else:
static_sizes.append(
IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = ArrayAttr.get(static_sizes)

num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
super().__init__(
pdl_operation_type, [pdl_operation_type] * num_loops,
_get_op_result_or_value(target),
sizes=sizes_attr,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
interchange=_get_int_array_attr(interchange) if interchange else None,
loc=loc,
ip=ip)
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
Expand Up @@ -23,7 +23,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1, %loops = transform.structured.tile %0 {sizes = [10, 0, 0]}
%1, %loops = transform.structured.tile %0 [10, 0, 0]
%2 = transform.structured.scalarize %1
}
}

0 comments on commit 4e4a4c0

Please sign in to comment.