Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
let description = [{
The pass tries to fuse any MLIR operation which can be tiled. Moreover, this pass aims to support:
1. Matmul fusion with element-wise/reduce/broadcast ops.
2. Pre-op and post-op fusion.
3. Multi-consumer and multi-producer support.
4. Multiple level of nest loops and candidates.
2. Producer and consumer fusion.
3. Arbitrary topology, including residual pattern with multiple consumers .
4. Nest loops structure with multiple level candidates.
5. Flexible option to control the boundary of iterative process.
6. Default tiling when no op is tiled before fusion.
7. Cost-model to determine whether to fuse or not.
Expand All @@ -74,8 +74,11 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
Option<"useCostModel", "use-cost-model", "bool",
/*default=*/"false",
"Decide if enable cost model to control iterative fusion.">,
Option<"defaultNDTile", "default-nd-tile", "unsigned",
/*default=*/"2",
"Set default amount of non-one dimensions in TileSize, such as 1, 2[default, a.k.a. 2D-Tile], etc.">,
ListOption<"defaultTileSize", "default-tile-size", "std::string",
"Set default TileSize for the certain type of op, saying `matmul:{32,32}`">,
"Set default TileSize for the certain type of op, saying `matmul:{32,32}`.">,
];
}
def DeepTileContractionOp
Expand Down
72 changes: 45 additions & 27 deletions lib/gc/Transforms/IterativeTilingAndFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,11 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
tileSizesOnInnerDims =
llvm::to_vector(ArrayRef(tileSizes).take_back(innerTiles.size()));
} else {
// Upstream doesn't implement `getTiledImplementationFromOperandTile`
// interface of `packOp` so far. In another word, `packOp` could not be
// fused as consumer. As a result, just return failure currently.
return failure();
// tileSize comes from OpOperand
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
for (auto &pos : innerDimPos) {
tileSizesOnInnerDims.push_back(tileSizes[pos]);
}
}
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp)) {
innerTiles = unPackOp.getMixedTiles();
Expand Down Expand Up @@ -478,8 +479,8 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
return std::nullopt;

// c. Check the producer of root source if is tilable.
Operation *producer = realProducer->getDefiningOp<TilingInterface>();
if (!producer)
Operation *producerOp = realProducer->getDefiningOp<TilingInterface>();
if (!producerOp)
return std::nullopt;

CandidateDefOrUse defOrUse{*realProducer};
Expand Down Expand Up @@ -536,8 +537,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
for (auto useOperand : *realConsumers) {
// c. Check the consumer of top level result if is tilable.
Operation *consumer = dyn_cast<TilingInterface>(useOperand->getOwner());
if (!consumer)
Operation *consumerOp = dyn_cast<TilingInterface>(useOperand->getOwner());
if (!consumerOp)
continue;

CandidateDefOrUse defOrUse{useOperand};
Expand All @@ -559,7 +560,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
// f. Manually run cse on region which contains original consumer op in
// avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
// loops between next candidate sliceOp and tiled producer.
(void)mlir::simplifyRegions(rewriter, {*consumer->getParentRegion()});
(void)mlir::simplifyRegions(rewriter, {*consumerOp->getParentRegion()});
}
}
if (fusedResultList.empty())
Expand Down Expand Up @@ -647,11 +648,18 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {

using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;

struct defaultTileConfig {
// OpTy-to-TileSize mapping
OpTileSizeMap tsMap;
// ND-tile size
unsigned ndTile;
};

/// Default Tiling function only effective for certain `OpTy` operation
static FailureOr<scf::SCFTilingResult>
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
function_ref<bool(Operation *)> isaOpTy,
const OpTileSizeMap &tsMap) {
const defaultTileConfig &cfg) {
// a. Check <OpTy>
if (!isa<TilingInterface>(op) || !isaOpTy(op))
return failure();
Expand All @@ -672,18 +680,20 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
// Erase dialect name, such as Linalg or Tensor.
opName.erase(0, opName.find(".") + 1);

if (tsMap.count(opName)) {
SmallVector<int64_t> userDefaultTileSize = tsMap.find(opName)->second;
if (cfg.tsMap.count(opName)) {
SmallVector<int64_t> userDefaultTileSize = cfg.tsMap.find(opName)->second;
defaultTileSize =
getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize));
} else {
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
// Try tileSize from `32` to `16`.
SmallVector<int64_t> tsOrder = {32, 16};
// Only 2D tile is expected.
int tileDims = (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp(op))
? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops()
: 0;
// Record how many dims have been tiled, including fully tiled, i.e.
// tileSize == dimSize.
unsigned nonOneTileDims =
(isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp(op))
? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops()
: 0;
// Reverse both of iteration type and domain from inner to outer.
std::reverse(iteratorTypes.begin(), iteratorTypes.end());
std::reverse(iterationDomain.begin(), iterationDomain.end());
Expand All @@ -692,21 +702,29 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
// All parallel iterator will be tiled by `32` or `16`. If need
// specified, please set option `defaultTileSize`, like `matmul:{64,64}`.
if (iterType == utils::IteratorType::parallel) {
Range curDomain = iterationDomain[en];
std::optional<int64_t> tripCount = mlir::constantTripCount(
curDomain.offset, curDomain.size, curDomain.stride);
if (tileDims >= 2 && en > 0) {
if (nonOneTileDims >= cfg.ndTile && en > 0) {
defaultTileSize[en] = rewriter.getIndexAttr(1);
continue;
} else if (tripCount) {
}
Range curDomain = iterationDomain[en];
if (std::optional<int64_t> tripCount = mlir::constantTripCount(
curDomain.offset, curDomain.size, curDomain.stride)) {
// skip dummy tiling.
if (tripCount == 1)
continue;
for (auto &ts : tsOrder) {
if (*tripCount % ts == 0 && *tripCount > ts) {
// If `tripCount` equals to `tileSize`, Do NOT explicitly tile it in
// avoid of non-zero offset.
if (*tripCount == ts)
break;
if (*tripCount % ts == 0) {
defaultTileSize[en] = rewriter.getIndexAttr(ts);
break;
}
}
}
tileDims++;
// Fallback to fully tiled.
nonOneTileDims++;
}
}
}
Expand All @@ -731,7 +749,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,

void iterativeTilingAndFusionUntilExhaustion(
RewriterBase &rewriter, func::FuncOp &f,
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
const CandidateSliceOptions &sliceOptions, const defaultTileConfig &cfg) {
// Collect untiled and tiled ops respectively
llvm::SetVector<Operation *> tiledOps, unTiledOps;

Expand Down Expand Up @@ -799,7 +817,7 @@ void iterativeTilingAndFusionUntilExhaustion(
for (auto &isaOpTy : priorityOpTypeOrder) {
for (auto &op : unTiledOps) {
FailureOr<scf::SCFTilingResult> tilingResult =
defaultTilingOfType(rewriter, op, isaOpTy, tsMap);
defaultTilingOfType(rewriter, op, isaOpTy, cfg);
if (succeeded(tilingResult)) {
tiledOps.insert(tilingResult->tiledOps[0]);
rewriter.replaceOp(op, tilingResult->replacements);
Expand Down Expand Up @@ -881,8 +899,8 @@ struct IterativeTilingAndFusion
// Get rewriter
IRRewriter rewriter(&ctx);
// Run iterative fusion
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions,
tsMap);
iterativeTilingAndFusionUntilExhaustion(
rewriter, func, sliceOptions, defaultTileConfig{tsMap, defaultNDTile});
}
};

Expand Down
33 changes: 28 additions & 5 deletions test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -339,18 +339,41 @@ module {

module {
/// CHECK-LABEL: @not_fuse_pack
func.func @not_fuse_pack(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<1x32x4096xbf16>) -> tensor<1x1x128x32x32xbf16> {
%dest0 = tensor.empty() : tensor<1x32x4096xbf16>
func.func @not_fuse_pack(%arg0: tensor<1x35x4096xbf16>, %arg1: tensor<1x35x4096xbf16>) -> tensor<1x2x128x32x32xbf16> {
%dest0 = tensor.empty() : tensor<1x35x4096xbf16>
/// CHECK: scf.forall
/// CHECK: linalg.add
%add = linalg.add ins(%arg0, %arg1 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%dest0 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
%add = linalg.add ins(%arg0, %arg1 : tensor<1x35x4096xbf16>, tensor<1x35x4096xbf16>) outs(%dest0 : tensor<1x35x4096xbf16>) -> tensor<1x35x4096xbf16>
/// CHECK: }
%dest1 = tensor.empty() : tensor<1x1x128x32x32xbf16>
%dest1 = tensor.empty() : tensor<1x2x128x32x32xbf16>
%pad = arith.constant 0.000000e+00 : bf16
/// CHECK: %[[PACK_OUT:.*]] = scf.forall
/// CHECK: tensor.pack
%pack = tensor.pack %add outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] into %dest1 : tensor<1x32x4096xbf16> -> tensor<1x1x128x32x32xbf16>
%pack = tensor.pack %add padding_value(%pad : bf16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] into %dest1 : tensor<1x35x4096xbf16> -> tensor<1x2x128x32x32xbf16>
/// CHECK: }
/// CHECK: return %[[PACK_OUT]]
return %pack : tensor<1x2x128x32x32xbf16>
}
}

// -----

module {
/// CHECK-LABEL: @fuse_pack
func.func @fuse_pack(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<1x32x4096xbf16>) -> tensor<1x1x128x32x32xbf16> {
%dest0 = tensor.empty() : tensor<1x32x4096xbf16>
/// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) = (0, 0) to (1, 4096) step (1, 32)
/// CHECK: linalg.add
%add = linalg.add ins(%arg0, %arg1 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%dest0 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
%dest1 = tensor.empty() : tensor<1x1x128x32x32xbf16>
/// CHECK-NEXT: affine.apply
/// CHECK-NEXT: tensor.extract_slice
/// CHECK-NEXT: tensor.pack
%pack = tensor.pack %add outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] into %dest1 : tensor<1x32x4096xbf16> -> tensor<1x1x128x32x32xbf16>
/// CHECK: scf.forall.in_parallel
/// CHECK: tensor.parallel_insert_slice
/// CHECK: tensor.parallel_insert_slice
/// CHECK: return %[[FINAL_RESULT]]#1
return %pack : tensor<1x1x128x32x32xbf16>
}
}
Expand Down