Skip to content

Commit

Permalink
Modify the tile root fuse producer consumer for reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 committed Aug 9, 2024
1 parent a034fea commit 55f1611
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,48 +32,116 @@ namespace mlir::iree_compiler {

namespace {

/// Implementation of tile root and fuse producers and consumers greedily.
static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
RewriterBase &rewriter, TilingInterface root,
const scf::SCFTileAndFuseOptions &options) {

// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
if (!root->getNumResults()) {
return rewriter.notifyMatchFailure(
root, "invalid pattern for op with no results");
/// Starting from `op` walk all operands backwards to find all
/// potentially fusable operations, i.e. operations that implement
/// the `TilingInterface`.
static void collectTiledAndFusedOps(Operation *rootOp,
llvm::SmallDenseSet<Operation *> &result) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
result.insert(rootOp);
while (!worklist.empty()) {
Operation *current = worklist.pop_back_val();
for (OpOperand &operand : current->getOpOperands()) {
Operation *producer = operand.get().getDefiningOp();
if (!producer || !isa<TilingInterface>(producer) ||
result.count(producer))
continue;
worklist.push_back(producer);
result.insert(producer);
}
}
}

// 1. Tile root op and Fuse Producers.
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options);
/// Tile the root operation and fuse the producers of the root operation.
/// If `onlyFuseProducerInputOperands` is set, only fuse producer input
/// operands. Returns the tiled operation to be used for fusing consumers.
FailureOr<Operation *>
tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
int64_t tilingLevel,
bool onlyFuseProducerInputOperands) {
mlir::DominanceInfo dominanceInfo(rootOp);
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
collectTiledAndFusedOps(rootOp, tiledAndFusedOps);

llvm::DenseSet<Operation *> yieldReplacementsFor;
for (auto op : tiledAndFusedOps) {
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return dominanceInfo.properlyDominates(rootOp, user);
})) {
yieldReplacementsFor.insert(op);
}
}

SmallVector<OpFoldResult> tileSizes =
getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
rootOp);

// Pad the tile sizes with zero.
auto zero = rewriter.getIndexAttr(0);
int64_t numLoops = rootOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops) {
return failure();
}
while (tileSizes.size() < numLoops) {
tileSizes.push_back(zero);
}

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);

scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand) {
Operation *owner = originalProducer.getOwner();
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
// Do not fuse destination operands.
bool shouldFuse =
!(onlyFuseProducerInputOperands && isDestinationOperand);
return std::make_tuple(shouldFuse, yieldProducerReplacement);
};
tileAndFuseOptions.setFusionControlFn(controlFn);

FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, rootOp,
tileAndFuseOptions);
if (failed(tiledResults)) {
return rewriter.notifyMatchFailure(
root, "failed to tile root and fuse producers");
return failure();
}

// 2. Replace the producers with the tiled verison.
SmallVector<Operation *> opsToReplace = {root};
// Perform the replacement of tiled and fused values.
SmallVector<Operation *> opsToReplace{rootOp};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res)) {
rewriter.replaceAllUsesWith(res, replacement);
Operation *replacementOp = replacement.getDefiningOp();
rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
Operation *user = use.getOwner();
return dominanceInfo.properlyDominates(replacementOp, user);
});
}

if (toReplace->use_empty()) {
rewriter.eraseOp(toReplace);
}
}

// 3. Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
// `tensor.insert_slice` operations, whose outputs are the operands of the
// untiled operation. Create a worklist of these `tensor.insert_siices`
// operations. If the consumers of the source of the `tensor.insert_slices`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
return tiledResults->tiledAndFusedOps.front();
}

static LogicalResult fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {

// Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
// `tensor.insert_slice` operations, whose outputs are the operands of the
// untiled operation. Create a worklist of these `tensor.insert_siices`
// operations. If the consumers of the source of the `tensor.insert_slices`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
auto addCandidateSlices = [](Operation *fusedOp,
std::queue<tensor::InsertSliceOp> &candidates) {
for (auto *userOp : fusedOp->getResults().getUsers()) {
Expand All @@ -86,7 +154,7 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
// Collect the candidate slices which can be potential consumers that can be
// fused.
std::queue<tensor::InsertSliceOp> candidates;
addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates);
addCandidateSlices(tiledOp, candidates);

while (!candidates.empty()) {

Expand Down Expand Up @@ -115,39 +183,42 @@ static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
return success();
}

static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter,
TilingInterface rootOp,
int64_t tilingLevel) {
/// Implementation of tile root and fuse producers and consumers greedily.
/// If `onlyFuseProducerInputOperands` is set, only fuse producer input operands
/// and disable consumer fusion.
static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
TilingInterface rootOp,
int64_t tilingLevel,
bool onlyFuseProducerInputOperands) {

SmallVector<OpFoldResult> tileSizes =
getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
rootOp);
int64_t numLoops = rootOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops)
return failure();
FailureOr<Operation *> tiledOp = tileRootAndFuseProducers(
rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands);

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);
if (failed(tiledOp))
return failure();

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);
if (!onlyFuseProducerInputOperands)
return fuseConsumers(rewriter, tiledOp.value());

return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp,
tileAndFuseOptions);
return success();
}

/// This pass starts with the first TilingInterface operation that has
/// lowering_config attribute, tiles the op and fuses its consumers and
/// producers recursively. The `tilingLevel` must be specified. It picks the
/// `tilingLevel`-th list as tiling sizes from lowering_config.
/// producers recursively. If the `onlyFuseProducerInputOperands` is set, it
/// only fuses producer input operands and disables consumer fusion. The
/// `tilingLevel` must be specified. It picks the `tilingLevel`-th list as
/// tiling sizes from lowering_config.
struct LLVMCPUTileRootAndFuseProducerConsumer
: impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
LLVMCPUTileRootAndFuseProducerConsumer> {
using impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
LLVMCPUTileRootAndFuseProducerConsumer>::
LLVMCPUTileRootAndFuseProducerConsumerPassBase;
explicit LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
explicit LLVMCPUTileRootAndFuseProducerConsumer(
int64_t tilingLevel, bool onlyFuseProducerInputOperands) {
this->tilingLevel = tilingLevel;
this->onlyFuseProducerInputOperands = onlyFuseProducerInputOperands;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, affine::AffineDialect,
Expand Down Expand Up @@ -186,9 +257,9 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {
return signalPassFailure();
}

if (failed(tileRootAndFuseProducerConsumer(
if (failed(tileRootAndFuse(
rewriter, dyn_cast<TilingInterface>(rootOp.value()),
tilingLevel.getValue()))) {
tilingLevel.getValue(), onlyFuseProducerInputOperands.getValue()))) {
funcOp.emitError() << "tiling of level " << tilingLevel.getValue()
<< " failed\n";
return signalPassFailure();
Expand All @@ -212,6 +283,12 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(tilingLevel);
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
tilingLevel, /*onlyFuseProducerInputOperands=*/false);
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel) {
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
tilingLevel, /*onlyFuseProducerInputOperands=*/true);
}
} // namespace mlir::iree_compiler
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ void addConvTileAndDecomposeExpertPassPipeline(
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());

funcPassManager.addPass(
createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands(
tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(
createLLVMCPUTileAndFusePass(tilingConfig.getVectorInnerParallelLevel()));
funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass());
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ createLLVMCPUTileAndFusePass(int64_t tilingLevel);
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel);

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel);

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUVerifyVectorSizeLegalityPass(
int64_t maxAllowedNumberOfNativeVectors);
Expand Down
23 changes: 15 additions & 8 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def LLVMCPUTilePass :
];
}

def LLVMCPUTileAndFuse :
def LLVMCPUTileAndFusePass :
InterfacePass<"iree-llvmcpu-tile-and-fuse", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile and fuse TilingInterface operations.";
let options = [
Expand All @@ -140,13 +140,20 @@ def LLVMCPUTileAndFuse :
];
}

def LLVMCPUTileRootAndFuseProducerConsumerPass :
InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops.";
let options = [
Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
"Use default tiling level used to retrieve the configuration from lowering_config">
];
def LLVMCPUTileRootAndFuseProducerConsumerPass
: InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer",
"mlir::FunctionOpInterface"> {
let summary = "Pass to tile root op and fuse with producer and consumer "
"TilingInterface ops.";
let options =
[Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
"Use default tiling level used to retrieve the configuration "
"from lowering_config">,
Option<"onlyFuseProducerInputOperands",
"only-fuse-producer-input-operands", "bool",
/*default=*/"false",
"Specifies if we only want to fuse producer's input operands. "
"This is helpful to tile&fuse in case of reduction dimensions.">];
}

def LLVMCPUVerifyVectorSizeLegalityPass :
Expand Down
Loading

0 comments on commit 55f1611

Please sign in to comment.