Skip to content

Commit

Permalink
Set encoding hints on dispatches.
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW committed Apr 11, 2024
1 parent 096d872 commit b34f5a9
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Common/CPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand All @@ -26,6 +27,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "cpu-materialize-encoding"

namespace mlir::iree_compiler {

using namespace IREE::LinalgExt;
Expand Down Expand Up @@ -428,7 +431,8 @@ struct CPUMaterializeUpperBoundTileSizePass

FailureOr<MaterializeEncodingInfo>
materializeEncodingForTarget(RankedTensorType tensorType,
ExecutableTargetAttr targetAttr) {
ExecutableTargetAttr targetAttr,
std::optional<int64_t> alignment = std::nullopt) {
IREE::LinalgExt::EncodingAttr encoding =
tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
Expand Down Expand Up @@ -469,19 +473,36 @@ materializeEncodingForTarget(RankedTensorType tensorType,
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK =
chooseMatmulTile(enumeratedTileMxNxK, matmulNarrowM, matmulNarrowN);

SmallVector<int64_t> tileSizes = {chosenTileMxNxK.M, chosenTileMxNxK.N,
chosenTileMxNxK.K};
if (alignment.has_value() && llvm::any_of(tileSizes, [&](int64_t v) {
return v > alignment.value();
})) {
LLVM_DEBUG(llvm::dbgs() << "failed, because some of selected tile sizes (";
llvm::interleaveComma(tileSizes, llvm::dbgs());
llvm::dbgs() << ") are greater than alignment("
<< alignment.value() << ")\n";);
return failure();
}

// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
// based on its role in the matmul.
auto rank = tensorType.getRank();
return getEncodingInfoForMatmul(encoding, rank, chosenTileMxNxK);
}

static MaterializeEncodingFn
getMaterializeEncodingFn(ExecutableTargetAttr targetAttr) {
return
[targetAttr](
RankedTensorType tensorType) -> FailureOr<MaterializeEncodingInfo> {
return materializeEncodingForTarget(tensorType, targetAttr);
};
getMaterializeEncodingFn(ExecutableTargetAttr targetAttr,
IREE::Codegen::EncodingRoundDimsToAttr alignment) {
std::optional<int64_t> alignmentValue;
if (alignment) {
alignmentValue = alignment.getValue();
}
return [targetAttr, alignmentValue](RankedTensorType tensorType)
-> FailureOr<MaterializeEncodingInfo> {
return materializeEncodingForTarget(tensorType, targetAttr, alignmentValue);
};
}

// Like getMaterializeEncodingFn, but iterating over an array of targets and
Expand Down Expand Up @@ -556,9 +577,13 @@ void CPUMaterializeEncodingPass::runOnOperation() {
MLIRContext *context = &getContext();
auto operation = getOperation();
RewritePatternSet materializeEncodingPattern(context);
if (!targetAttr)
if (!targetAttr) {
targetAttr = ExecutableTargetAttr::lookup(operation);
auto materializeEncodingFn = getMaterializeEncodingFn(targetAttr);
}
auto alignment =
operation->getAttrOfType<IREE::Codegen::EncodingRoundDimsToAttr>(
IREE::Codegen::EncodingRoundDimsToAttr::getMnemonic());
auto materializeEncodingFn = getMaterializeEncodingFn(targetAttr, alignment);
if (!materializeEncodingFn) {
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ iree_compiler_cc_library(
"PassDetail.h",
"Passes.cpp",
"RegionOpUtils.cpp",
"SetEncodingHintOnDispatches.cpp",
"SplitReduction.cpp",
"TensorPadToTensorInsertSlice.cpp",
"TopLevelSCFToCFG.cpp",
Expand All @@ -72,6 +73,7 @@ iree_compiler_cc_library(
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ iree_cc_library(
"PassDetail.h"
"Passes.cpp"
"RegionOpUtils.cpp"
"SetEncodingHintOnDispatches.cpp"
"SplitReduction.cpp"
"TensorPadToTensorInsertSlice.cpp"
"TopLevelSCFToCFG.cpp"
Expand Down Expand Up @@ -97,6 +98,7 @@ iree_cc_library(
MLIRTransformDialectTransforms
MLIRTransformUtils
MLIRTransforms
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
})
////////////////////////////////////////////////////////////////////////
.addPass(createCaptureDynamicDimsPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass(createSetEncodingHintOnDispatchesPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)

// Initialize any empty tensors to zero.
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ createInsertDebugTargetAtOrdinalPass(std::string breakDebugTarget = "",
// Exports all functions and dispatch executables as `() -> ()` benchmark funcs.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createExportBenchmarkFuncsPass();

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createSetEncodingHintOnDispatchesPass();

//===----------------------------------------------------------------------===//
// Optimizations
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,13 @@ def OutlineDispatchRegions :
let constructor = "mlir::iree_compiler::IREE::Flow::createOutlineDispatchRegionsPass()";
}

def SetEncodingHintOnDispatches :
InterfacePass<"iree-flow-set-encoding-hint-on-dispatches", "mlir::FunctionOpInterface"> {
let summary = "Set the encoding hints on dispatches and replace"
"upper_bound_tile_size with such constants.";
let constructor = "mlir::iree_compiler::IREE::Flow::createSetEncodingHintOnDispatchesPass()";
}

def SplitReduction :
Pass<"iree-flow-split-reduction-ops", ""> {
let summary = "Split reduction dimension to increase parallelism.";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <queue>

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::iree_compiler::IREE::Flow {

namespace {

struct SetEncodingHintOnDispatchesPass
: public SetEncodingHintOnDispatchesBase<SetEncodingHintOnDispatchesPass> {
void runOnOperation() override;
};

} // namespace

static LogicalResult lowerUpperBoundTileSizeOpToConstantsAndAttachEncodingHints(
RewriterBase &rewriter,
LinalgExt::UpperBoundTileSizeOp upperBoundTileSizeOp) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(upperBoundTileSizeOp);

AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineMap roundMap = AffineMap::get(0, 2, s1.ceilDiv(s0) * s0);

SmallVector<affine::AffineApplyOp> applyOps;
SmallVector<DispatchWorkgroupsOp> dispatches;
std::queue<Operation *> que;
que.push(upperBoundTileSizeOp);
while (!que.empty()) {
auto op = que.front();
que.pop();
for (Operation *user : op->getUsers()) {
if (auto applyOp = dyn_cast<affine::AffineApplyOp>(user)) {
if (applyOp.getMap() != roundMap) {
return failure();
}
applyOps.push_back(applyOp);
que.push(applyOp);
} else if (auto dispatch = dyn_cast<DispatchWorkgroupsOp>(user)) {
dispatches.push_back(dispatch);
} else {
return failure();
}
}
}

constexpr int64_t kAlignment = 16;
Location loc = upperBoundTileSizeOp.getLoc();
Value cst = rewriter.createOrFold<arith::ConstantIndexOp>(loc, kAlignment);
for (auto value : upperBoundTileSizeOp.getResults()) {
rewriter.replaceAllUsesWith(value, cst);
}
for (auto applyOp : applyOps) {
rewriter.replaceAllUsesWith(applyOp.getResult(),
applyOp.getMapOperands().back());
}

auto encodingAttr = Codegen::EncodingRoundDimsToAttr::get(
rewriter.getContext(), kAlignment);
for (auto dispatch : dispatches) {
dispatch->setAttr(Codegen::EncodingRoundDimsToAttr::getMnemonic(),
encodingAttr);
}

return success();
}
void SetEncodingHintOnDispatchesPass::runOnOperation() {
MLIRContext *ctx = &getContext();
Operation *funcOp = getOperation();
IRRewriter rewriter(ctx);
auto res =
funcOp->walk([&](LinalgExt::UpperBoundTileSizeOp op) -> WalkResult {
if (failed(lowerUpperBoundTileSizeOpToConstantsAndAttachEncodingHints(
rewriter, op))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) {
return signalPassFailure();
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createSetEncodingHintOnDispatchesPass() {
return std::make_unique<SetEncodingHintOnDispatchesPass>();
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
workgroupSizeVals.resize(3, targetBuilder.getIndexAttr(1));
workgroupSize = targetBuilder.getArrayAttr(workgroupSizeVals);
}
break;
break;
}
}

Expand Down Expand Up @@ -436,6 +436,13 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp,
auto variantFuncOp = cloneFuncWithInterface(sourceFuncOp, resourceMap,
variantLayoutAttr);
targetFuncOps[sourceFuncOp][variantOp] = variantFuncOp;

if (auto attr =
exportOp->getAttrOfType<Codegen::EncodingRoundDimsToAttr>(
Codegen::EncodingRoundDimsToAttr::getMnemonic())) {
variantFuncOp->setAttr(
Codegen::EncodingRoundDimsToAttr::getMnemonic(), attr);
}
}
}

Expand Down

0 comments on commit b34f5a9

Please sign in to comment.