Skip to content

Commit

Permalink
[mlir][Linalg] Retire LinalgStrategyTileAndFusePass and filter-based …
Browse files Browse the repository at this point in the history
…pattern.

Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785

In the process, also retire `tileConsumerAndFuseProducers` that is now replaced by `tileConsumerAndFuseProducerGreedilyUsingSCFForOp`.

Context: https://discourse.llvm.org/t/psa-retire-tileandfuselinalgops-method/63850

When performing this replacement, a change of behavior appeared: the older `tileConsumerAndFuseProducers` would split the parallel
and non-parallel dimensions automatically and perform a first level of tile-and-fuse on parallel dimensions only and then introduce a
second level of tiling-only on the reduction dimensions. The newer `tileConsumerAndFuseProducerGreedilyUsingSCFForOp` on the other hand
does not perform this breakdown. As a consequence, the transform specification is evolved to produce the same output.

Additionally, replace some uses of `unsigned` by `int64_t` where possible without pulling in larger interface changes (left for a future PR).

Context: https://www.youtube.com/watch?v=Puio5dly9N8

Lastly, tests that were performing tile and fuse and distribute on tensors are retired: the generated IR mixing scf.for, tensors and
distributed processor ids was racy at best ..

Differential Revision: https://reviews.llvm.org/D135559
  • Loading branch information
nicolasvasilache committed Oct 10, 2022
1 parent b494a56 commit 7915027
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 432 deletions.
7 changes: 0 additions & 7 deletions mlir/include/mlir/Dialect/Linalg/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
//===----------------------------------------------------------------------===//
/// Linalg strategy passes.
//===----------------------------------------------------------------------===//
/// Create a LinalgStrategyTileAndFusePass.
std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgStrategyTileAndFusePass(
StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {},
const linalg::LinalgTransformationFilter &filter =
linalg::LinalgTransformationFilter());

/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
StringRef opName = "",
Expand Down
12 changes: 0 additions & 12 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,6 @@ def LinalgDetensorize : Pass<"linalg-detensorize", ""> {
];
}

def LinalgStrategyTileAndFusePass
: Pass<"linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> {
let summary = "Configurable pass to apply pattern-based tiling and fusion.";
let constructor = "mlir::createLinalgStrategyTileAndFusePass()";
let options = [
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
"Which func op is the anchor to latch on.">,
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
"Which linalg op within the func is the anchor to latch on.">,
];
}

def LinalgStrategyTilePass
: Pass<"linalg-strategy-tile-pass", "func::FuncOp"> {
let summary = "Configurable pass to apply pattern-based linalg tiling.";
Expand Down
33 changes: 0 additions & 33 deletions mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,6 @@ struct Transformation {
LinalgTransformationFilter::FilterFunction filter = nullptr;
};

/// Represent one application of LinalgStrategyTileAndFusePass.
struct TileAndFuse : public Transformation {
TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(std::move(f)), opName(name),
options(std::move(options)) {}

void addToPassPipeline(OpPassManager &pm,
LinalgTransformationFilter m) const override {
pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
}

private:
std::string opName;
linalg::LinalgTilingAndFusionOptions options;
};

/// Represent one application of LinalgStrategyTilePass.
struct Tile : public Transformation {
Tile(StringRef name, linalg::LinalgTilingOptions options,
Expand All @@ -66,22 +49,6 @@ struct Tile : public Transformation {

/// Codegen strategy controls how a Linalg op is progressively lowered.
struct CodegenStrategy {
/// Append a pattern to tile the Op `opName` and fuse its producers with
/// tiling and fusion `options`.
CodegenStrategy &
tileAndFuse(StringRef opName, const LinalgTilingAndFusionOptions &options,
const LinalgTransformationFilter::FilterFunction &f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<TileAndFuse>(opName, options, f));
return *this;
}
/// Conditionally append a pattern to tile the Op `opName` and fuse its
/// producers with tiling and fusion `options`.
CodegenStrategy &
tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
}
/// Append a pattern to add a level of tiling for Op `opName` with tiling
/// `options`.
CodegenStrategy &
Expand Down
36 changes: 0 additions & 36 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,42 +787,6 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
}
};

///
/// Linalg tile and fuse tensor ops pattern.
///
/// Apply tiling and fusion as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `tileConsumerAndFuseProducers` for more details.
struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
// Entry point to match any LinalgOp.
LinalgTileAndFuseTensorOpsPattern(
MLIRContext *context, LinalgTilingAndFusionOptions options,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific LinalgOp.
LinalgTileAndFuseTensorOpsPattern(
StringRef opName, MLIRContext *context,
LinalgTilingAndFusionOptions options,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);

/// `matchAndRewrite` implementation that returns the significant transformed
/// pieces of IR.
FailureOr<TileLoopNest>
returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const;

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// Tile sizes and interchange used to tile the root operation.
LinalgTilingAndFusionOptions options;
};

///
/// Linalg generalization pattern.
///
Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,6 @@ class TileLoopNest {
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
};

/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
/// the tiling.
FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
const Optional<LinalgLoopDistributionOptions> &tileDistribution);

//===----------------------------------------------------------------------===//
// Generic op region utilities
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ struct SCFTilingOptions {
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);

/// The interchange vector to reorder the tiled loops.
SmallVector<unsigned> interchangeVector = {};
SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
SmallVector<int64_t> interchangeVector = {};
SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
interchangeVector = llvm::to_vector(interchange);
return *this;
}
Expand Down
87 changes: 51 additions & 36 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
Expand Down Expand Up @@ -99,45 +100,63 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}

//===----------------------------------------------------------------------===//
// FuseOp
//===----------------------------------------------------------------------===//

/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult
applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
unsigned numLoops,
transform::TransformResults &transformResults,
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
static LogicalResult applyTilingToAll(
Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
transform::TransformResults &transformResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) {
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].reserve(payloadOps.size());

for (Operation *target : payloadOps) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
if (!linalgOp)
return transformOp->emitError("only LinalgOps are supported");

FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
if (failed(tiled))
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");

SimpleRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
applyFn(tilingInterfaceOp);
if (failed(tiledResults))
return failure();

tiledLinalgOps.push_back(tiled->op);
if (tiled->loops.size() != numLoops)
// Not enough loops were generated. This usually means that the input size
// was smaller than the tiling size.
// TODO: LinalgTilingPattern should return failure().
return failure();
// Perform the replacement of tiled and fused values.
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
SmallVector<Value> replacements;
replacements.reserve(toReplace->getNumResults());
for (OpResult res : toReplace->getResults()) {
auto it = tiledResults->replacements.find(res);
if (it == tiledResults->replacements.end())
replacements.push_back(res);
else
replacements.push_back(it->getSecond());
}
rewriter.replaceOp(toReplace, replacements);
}

// Report back the relevant handles to the transform op.
tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
assert(tiledResults->loops.size() == numLoops &&
"Mismatched number of loops, tile and fuse transform should have "
"failed");
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].push_back(tiled->loops[i]);
loopOps[i].push_back(tiledResults->loops[i]);
}

transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);

return success();
}

Expand Down Expand Up @@ -172,27 +191,23 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
DiagnosedSilenceableFailure
transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
LinalgTilingAndFusionOptions fusionOptions;
fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
SmallVector<int64_t> tileInterchange =
extractFromI64ArrayAttr(getTileInterchange());

scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
tilingOptions = tilingOptions.setTileSizes(tileSizes);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
LogicalResult result = applyTilingToAll(
getOperation(), state.getPayloadOps(getTarget()),
fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(linalgOp);
FailureOr<TileLoopNest> tileLoopNest =
pattern.returningMatchAndRewrite(linalgOp, rewriter);
if (failed(tileLoopNest))
return failure();

TiledLinalgOp tiledLinalgOp;
tiledLinalgOp.op = tileLoopNest->getRootOp();
tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
tileLoopNest->getLoopOps().end()};
return tiledLinalgOp;
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
return DiagnosedSilenceableFailure(result);
}
Expand Down
65 changes: 0 additions & 65 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,68 +414,3 @@ SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
}
return result;
}

//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//

FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
const Optional<LinalgLoopDistributionOptions> &tileDistribution) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
assert(isPermutation(tileInterchange) &&
"expect tile interchange is a permutation");

// Create an empty tile loop nest.
TileLoopNest tileLoopNest(consumerOp);

// Search the number of outer parallel loops to separate them from possible
// inner reduction dimensions.
SmallVector<StringRef> iterTypes = consumerOp.getIteratorTypesArray();
applyPermutationToVector(iterTypes, tileInterchange);
auto *it = find_if_not(iterTypes, isParallelIterator);
int64_t split = std::distance(iterTypes.begin(), it);

// Helper to fuse the producers greedily using a queue of fusion candidates.
auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
while (!candidates.empty()) {
FailureOr<LinalgOp> fusedProducer =
tileLoopNest.fuseProducer(b, candidates.pop_back_val());
if (failed(fusedProducer))
continue;
candidates.append(fusedProducer->getInputAndOutputOperands());
}
};

// Perform tiling and fusion in two steps. We need to respect the loop
// interchange here; filter parellel dimensions based on their order *after*
// permutation but pass in the original configuration *before* permuation,
// given the tiling and interchange happen together.
SmallVector<int64_t> outerTileSizes(tileSizes.size(), 0);
SmallVector<int64_t> innerTileSizes(tileSizes.size(), 0);
for (int64_t i : tileInterchange.take_front(split))
outerTileSizes[i] = tileSizes[i];
for (int64_t i : tileInterchange.drop_front(split))
innerTileSizes[i] = tileSizes[i];

// Tile the outer parallel loops and fuse the output operands.
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
tileDistribution)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());

// Tile the remaining loops and fuse the input operands.
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
tileDistribution)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());

// Exit if the tile loop nest is empty since all tile sizes are zero.
if (tileLoopNest.isEmpty())
return failure();

return tileLoopNest;
}

0 comments on commit 7915027

Please sign in to comment.