Skip to content

Commit

Permalink
[mlir] Transform op for multitile size generation
Browse files Browse the repository at this point in the history
Introduce a structured transform op that emits IR computing the multi-tile
sizes with requested parameters (target size and divisor) for the given
structured op. The sizes may fold to arithmetic constant operations when the
shape is constant. These operations may then be used to call the existing
tiling transformation with a single non-zero dynamic size (i.e. perform
strip-mining) for each of the dimensions separately, thus achieving multi-size
tiling with optional loop interchange. A separate test exercises the entire
script.

Depends On D129217

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129287
  • Loading branch information
ftynse committed Jul 12, 2022
1 parent cc30972 commit 3963b4d
Show file tree
Hide file tree
Showing 10 changed files with 488 additions and 0 deletions.
Expand Up @@ -127,6 +127,71 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}

def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface, TransformEachOpTrait]> {
let description = [{
Emits the IR computing the tile sizes `s1` and `s2` such that:

- there exists a combination of `n` tiles of size `s1` and `m` tiles of
size `s2` that covers the entirety of the iteration space `dimension` of
the target structured op;
- `s1`, `s2` is less than or equal to `target_size`;
- `s1` and `s2` are divisible by `divisor.

For example, for a dimension of size 54 with target size 12 and divisor 2,
this can emit the IR computing the tile size 10, used for 3 tiles, and 12,
used for 2 tiles, totally 10*3 + 12*2 = 54. Note that when the divisor does
not divide the original dimension size, it is impossible to compute such
tile sizes. An assertion is emitted to guard against this in the dynamic
case.

Expects the target size and the divisor to be strictly positive. Folds the
IR as much as possible, normally obtaining constant sizes and numbers of
tiles for a statically known dimension.

This does *not* consume the target handle and produces three handles each
pointing to single-result index-typed operations (which may be arithmetic
constant operations) defining the two respective tile sizes and the product
of the first tile size with the number of tiles of that size (useful for
splitting the iteration space).

This operation composes with the regular tiling when applied per-dimension:

```mlir
%sz1, %sz2, %split = structured.multitile_sizes %target
{ target_size = 10, dimension = 1 }
%low, %high = structured.split %target after %split { dimension = 1 }
%tiled_low = structured.tile %low [0, %sz1]
%tiled_high = structured.tile %high [0, %sz2]
%common = merge_handles %tiled_low, %tiled_high

%sz3, %sz4, %split = structured.multitile_size %target
{ target_size = 42, dimension = 0 }
%sz3r, %sz4r, %splitr = replicate num(%common) %sz3, %sz4, %splitr
structured.split %common after %splitr { dimension = 0 }
// ...
```
}];

let arguments = (ins PDL_Operation:$target,
I64Attr:$dimension,
I64Attr:$target_size,
DefaultValuedAttr<I64Attr, "1">:$divisor);
let results = (outs PDL_Operation:$low_size,
PDL_Operation:$high_size,
PDL_Operation:$split_point);
let assemblyFormat = "$target attr-dict";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
::llvm::SmallVector<::mlir::Operation *> &results,
TransformState &state);
}];
}


def PadOp : Op<Transform_Dialect, "structured.pad",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
Expand Down
42 changes: 42 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -479,6 +479,48 @@ std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
ValueRange allShapeSizes, ValueRange allTileSizes);

/// A description of a multi-size tiling comprising tile sizes and numbers of
/// tiles, expressed as Values which may or may not be constant. Multi-size
/// currently means two-size.
struct MultiSizeSpecification {
/// Tile sizes.
Value lowTileSize, highTileSize;
/// Number of tiles associated with each size.
Value lowTripCount, highTripCount;
};

/// Emits the IR computing the multi-sized tiling specification with two tile
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such that
/// there exist numbers of tiles with these sizes that fully cover the given
/// iteration space `dimension` of the structured `op`.
///
/// The computation is as follows:
///
/// b = originalTripCount floordiv sizeDivisor
/// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
/// d = (b + t - 1) floordiv t
/// s = (b floordiv d) * sizeDivisor
/// v = b % d
/// u = d - v
///
/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
/// the corresponding tiles are `u` and `v`, respectively. Alternatively,
///
/// s * u + (s + sizeDivisor) * v == original size,
/// where s mod sizeDivisor = 0.
///
/// Expects all values to be positive. In some cases with the target tile size
/// sufficiently close to the dimension shape and non-unit divisor, it is
/// impossible to compute such sizes. If `emitAssertion` is set, also emit the
/// assertion that size computation succeeded.
///
/// Returns the specification consisting of both tile values and the number of
/// tiles of each size.
FailureOr<MultiSizeSpecification>
computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
OpFoldResult targetSize, OpFoldResult divisor,
bool emitAssertions = true);

/// All indices returned by IndexOp should be invariant with respect to tiling.
/// Therefore, if an operation is tiled, we have to transform the indices
/// accordingly, i.e. offset them by the values of the corresponding induction
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
Expand Up @@ -8,11 +8,14 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
MLIRLinalgTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithmeticDialect
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransforms
MLIRParser
MLIRPDLDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRVectorDialect
Expand Down
52 changes: 52 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -8,12 +8,14 @@

#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -276,6 +278,55 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}

//===---------------------------------------------------------------------===//
// MultiTileSizesOp
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
OpBuilder builder(target.getContext());
builder.setInsertionPoint(target);
OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
OpFoldResult divisor = builder.getIndexAttr(getDivisor());
FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
builder, target, getDimension(), targetSize, divisor);
if (failed(spec)) {
return emitSilenceableError() << "could not generate tile size computation";
}

Operation *splitPoint =
builder
.createOrFold<arith::MulIOp>(target.getLoc(), spec->lowTileSize,
spec->lowTripCount)
.getDefiningOp();
Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
Operation *highTileSize = spec->highTileSize.getDefiningOp();
assert(lowTileSize && highTileSize && splitPoint &&
"tile sizes are not produced by operations");
results.reserve(results.size() + 3);
results.push_back(lowTileSize);
results.push_back(highTileSize);
results.push_back(splitPoint);
return DiagnosedSilenceableFailure::success();
}

void transform::MultiTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
transform::TransformMappingResource::get());
for (Value result : getResults()) {
effects.emplace_back(MemoryEffects::Allocate::get(), result,
transform::TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Write::get(), result,
transform::TransformMappingResource::get());
}

effects.emplace_back(MemoryEffects::Read::get(),
transform::PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(),
transform::PayloadIRResource::get());
}

//===---------------------------------------------------------------------===//
// PadOp
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -782,6 +833,7 @@ class LinalgTransformDialectExtension
LinalgTransformDialectExtension> {
public:
LinalgTransformDialectExtension() {
declareDependentDialect<AffineDialect>();
declareDependentDialect<arith::ArithmeticDialect>();
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<scf::SCFDialect>();
Expand Down
87 changes: 87 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Expand Up @@ -13,6 +13,7 @@
#include <utility>

#include "PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Expand Down Expand Up @@ -82,6 +83,92 @@ void mlir::linalg::transformIndexOps(
addTileLoopIvsToIndexOpResults(b, op, allIvs);
}

/// Asserts that the given index-typed value is strictly positive. If the value
/// is an attribute, asserts at compile time, otherwise emits an assertion
/// checked at runtime.
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
OpFoldResult value) {
if (auto attr = value.dyn_cast<Attribute>()) {
assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
"expected strictly positive tile size and divisor");
return;
}

Value zero = b.create<arith::ConstantIndexOp>(0);
Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
value.get<Value>(), zero);
b.create<cf::AssertOp>(
condition,
b.getStringAttr("expected strictly positive tile size and divisor"));
}

FailureOr<MultiSizeSpecification>
mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
unsigned dimension, OpFoldResult targetSize,
OpFoldResult divisor, bool emitAssertions) {
// Bail out on dimension overflow.
if (dimension >= op.getNumLoops())
return failure();

// The code below works only on values.
ImplicitLocOpBuilder b(op.getLoc(), builder);
if (emitAssertions) {
emitIsPositiveIndexAssertion(b, targetSize);
emitIsPositiveIndexAssertion(b, divisor);
}
Value targetSizeValue = materializeOpFoldResult(b, targetSize);
Value divisorValue = materializeOpFoldResult(b, divisor);

// Find the trip count of the iteration space dimension for which the tile
// sizes are computed.
// TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid
// littering by useless constant materialization.
SmallVector<Value, 4> allShapes =
op.createFlatListOfOperandDims(b, b.getLoc());
AffineMap shapesToLoops = op.getShapesToLoopsMap();
SmallVector<Value, 4> loopRanges =
applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes);
Value tripCount = loopRanges[dimension];

// Compute the tile sizes and the respective numbers of tiles.
AffineExpr s0 = b.getAffineSymbolExpr(0);
AffineExpr s1 = b.getAffineSymbolExpr(1);
AffineExpr s2 = b.getAffineSymbolExpr(2);
auto apply = [&](AffineExpr expr, ValueRange values) -> Value {
return makeComposedAffineApply(b, b.getLoc(), expr, values);
};
Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
Value v = apply(s0 % s1, {a, d});
Value u = apply(s0 - s1, {d, v});

MultiSizeSpecification spec;
spec.lowTileSize = s;
spec.highTileSize = apply(s0 + s1, {s, divisorValue});
spec.lowTripCount = u;
spec.highTripCount = v;

// If requested, emit the check that the tile sizes are computed correctly.
// For example, for iteration dimension size of 15 and the target size 8 it is
// impossible to find two tile sizes both divisible by 8 that fully cover the
// original space dimension.
if (emitAssertions) {
AffineExpr s3 = builder.getAffineSymbolExpr(3);
Value coveredSize =
apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
spec.highTileSize, spec.highTripCount});
Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
coveredSize, tripCount);
b.create<cf::AssertOp>(
equals, builder.getStringAttr(
"could not compute dynamic multi-size tile shapes"));
}

return spec;
}

// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.
Expand Down
23 changes: 23 additions & 0 deletions mlir/python/mlir/dialects/_structured_transform_ops_ext.py
Expand Up @@ -110,6 +110,29 @@ def __init__(self,
ip=ip)


class MultiTileSizesOp:
"""Specialization for MultitileSizesOp class."""

def __init__(self,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
target_size: Union[int, IntegerAttr],
divisor: Optional[Union[int, IntegerAttr]] = None,
loc=None,
ip=None):
super().__init__(
pdl.OperationType.get(),
pdl.OperationType.get(),
pdl.OperationType.get(),
_get_op_result_or_value(target),
dimension=_get_int64_attr(dimension),
target_size=_get_int64_attr(target_size),
divisor=_get_int64_attr(divisor if divisor else 1),
loc=loc,
ip=ip)


class PadOp:
"""Specialization for PadOp class."""

Expand Down

0 comments on commit 3963b4d

Please sign in to comment.