Skip to content

Commit 0adfedf

Browse files
authored
Merge branch 'main' into da-remove-split-iter
2 parents 947f6d4 + 9e2ca0d commit 0adfedf

File tree

10 files changed

+1641
-253
lines changed

10 files changed

+1641
-253
lines changed

.github/CODEOWNERS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@
6060
/mlir/lib/Conversion/*ToROCDL @krzysz00 @kuhar
6161
/mlir/include/mlir/Dialect/LLVMIR/ROCDL* @krzysz00 @kuhar
6262

63+
# Arith dialect in MLIR.
64+
/mlir/include/mlir/Dialect/Arith @kuhar
65+
/mlir/lib/Dialect/Arith @kuhar
66+
/mlir/lib/Conversion/ArithTo* @kuhar
67+
6368
# XeGPU and XeVM dialects in MLIR.
6469
/mlir/include/mlir/Dialect/XeGPU @charithaintc @Jianhui-Li
6570
/mlir/lib/Dialect/XeGPU @charithaintc @Jianhui-Li

libc/include/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ add_header_macro(
200200
DEPENDS
201201
.llvm_libc_common_h
202202
.inttypes
203+
.llvm-libc-types.in_addr
204+
.llvm-libc-types.in_addr_t
203205
)
204206

205207
file(MAKE_DIRECTORY ${LIBC_INCLUDE_DIR}/netinet)

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
613613
getNumDynamicControlOperands() + getRank());
614614
}
615615

616+
BlockArgument getTiedBlockArgument(OpResult opResult) {
617+
assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
618+
return getBody()->getArgument(getRank() + opResult.getResultNumber());
619+
}
620+
616621
::mlir::Value getInductionVar(int64_t idx) {
617622
return getInductionVars()[idx];
618623
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
415415
/// tiled in a manner that is consistent for all the passed slices. Note that
416416
/// the method replaces the uses of `candidateSlices` with the tiled and fused
417417
/// consumer value but does not delete the slice operations.
418+
/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
419+
/// is to take the consumer operation, and find the slices to use for fusion
420+
/// by walking its operands to the `loops` and then into the body to get the
421+
/// slices used for fusion.
418422
struct SCFFuseConsumerOfSliceResult {
419423
// Original untiled consumer operands.
420424
SmallVector<OpOperand *> origConsumerOperands;
@@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
427431
ArrayRef<Operation *> candidateSlices,
428432
MutableArrayRef<LoopLikeOpInterface> loops);
429433

434+
/// Fuse the `consumer` operation into the loop nest provided by `loops`.
435+
/// The transformation looks for operands in the `consumer` that are defined
436+
/// by the outermost loop of the loop nest in `loops`. The nested loop is
437+
/// expected to have the structure of the loops generated through tiling.
438+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
439+
tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
440+
MutableArrayRef<LoopLikeOpInterface> loops);
441+
430442
/// Method to lower an `op` that implements the `TilingInterface` to
431443
/// loops/scalars.
432444
FailureOr<SmallVector<scf::ForOp>>

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 167 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
10921092
for (auto [outerLoop, innerLoop] :
10931093
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
10941094
// Again assume that all the outer loops are scf.for operations.
1095-
auto outerForLoop = cast<scf::ForOp>(outerLoop);
1095+
auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
10961096
auto outerLoopYield =
10971097
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
10981098
SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
21842184
return clonedSlices;
21852185
}
21862186

2187-
/// Implementation of fusing consumer of a single slice by computing the
2188-
/// slice of the consumer in-place for scf loop.
2189-
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2190-
mlir::scf::tileAndFuseConsumerOfSlices(
2191-
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2192-
MutableArrayRef<LoopLikeOpInterface> loops) {
2193-
if (candidateSlices.empty()) {
2194-
return rewriter.notifyMatchFailure(
2195-
rewriter.getUnknownLoc(),
2196-
"no candidate slices provided for consumer fusion");
2197-
}
2198-
// Return if `loops` is empty, return an error for now. Caller is expected
2199-
// to handle this case.
2200-
if (loops.empty()) {
2201-
return rewriter.notifyMatchFailure(
2202-
candidateSlices.front(),
2203-
"cannot call tile and fuse consumer with an empty loop nest");
2204-
}
2187+
static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2188+
tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
2189+
ArrayRef<OpOperand *> consumerOpOperands,
2190+
ArrayRef<Operation *> candidateSlices,
2191+
MutableArrayRef<LoopLikeOpInterface> loops) {
2192+
assert(!loops.empty() && "expected loops to be not empty");
22052193

2206-
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2207-
llvm::all_of(candidateSlices,
2208-
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2194+
// 1. Check assumption for loop with `reorderOperations` disabled.
2195+
if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
22092196
return rewriter.notifyMatchFailure(
2210-
candidateSlices.front(),
2211-
"candidates slices need to be all `tensor.extract_slice`s or "
2212-
"`tensor.parallel_insert_slice`s");
2213-
}
2214-
2215-
// 1. Get the consumer of scf.for for the result yielded by
2216-
// tensor.insert_slice/parallel_insert_slice.
2217-
SmallVector<OpOperand *> consumerOpOperands;
2218-
Operation *consumerOp;
2219-
{
2220-
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2221-
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2222-
if (failed(maybeConsumerOpOperand)) {
2223-
return rewriter.notifyMatchFailure(candidateSlices.front(),
2224-
"could not fetch consumer to fuse");
2225-
}
2226-
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2227-
consumerOp = consumerOpOperands.front()->getOwner();
2197+
loops.front(), "the first user of loop should not dominate any define "
2198+
"of consumer operand(s)");
22282199
}
22292200

22302201
LoopLikeOpInterface outerMostLoop = loops.front();
22312202
LoopLikeOpInterface innerMostLoop = loops.back();
22322203

2233-
// Check assumption for loop with `reorderOperations` disabled.
2234-
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2235-
return rewriter.notifyMatchFailure(
2236-
outerMostLoop, "the first user of loop should not dominate any define "
2237-
"of consumer operand(s)");
2238-
}
2239-
22402204
OpBuilder::InsertionGuard g(rewriter);
2241-
22422205
// 2. Check consumer is not using scf loop's output as init.
22432206
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
22442207
if (!dstOp)
@@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices(
24282391
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
24292392
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
24302393
});
2394+
auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
24312395
return scf::SCFFuseConsumerOfSliceResult{
2432-
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
2396+
std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
24332397
std::move(tileAndFuseResult->tiledOps)};
24342398
}
24352399

2400+
/// Implementation of fusing consumer of a single slice by computing the
2401+
/// slice of the consumer in-place for scf loop.
2402+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2403+
mlir::scf::tileAndFuseConsumerOfSlices(
2404+
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2405+
MutableArrayRef<LoopLikeOpInterface> loops) {
2406+
if (candidateSlices.empty()) {
2407+
return rewriter.notifyMatchFailure(
2408+
rewriter.getUnknownLoc(),
2409+
"no candidate slices provided for consumer fusion");
2410+
}
2411+
// Return if `loops` is empty, return an error for now. Caller is expected
2412+
// to handle this case.
2413+
if (loops.empty()) {
2414+
return rewriter.notifyMatchFailure(
2415+
candidateSlices.front(),
2416+
"cannot call tile and fuse consumer with an empty loop nest");
2417+
}
2418+
2419+
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2420+
llvm::all_of(candidateSlices,
2421+
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2422+
return rewriter.notifyMatchFailure(
2423+
candidateSlices.front(),
2424+
"candidates slices need to be all `tensor.extract_slice`s or "
2425+
"`tensor.parallel_insert_slice`s");
2426+
}
2427+
2428+
// Get the consumer of scf.for for the result yielded by
2429+
// tensor.insert_slice/parallel_insert_slice.
2430+
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2431+
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2432+
if (failed(maybeConsumerOpOperands)) {
2433+
return rewriter.notifyMatchFailure(candidateSlices.front(),
2434+
"could not fetch consumer to fuse");
2435+
}
2436+
Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
2437+
2438+
return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
2439+
maybeConsumerOpOperands.value(),
2440+
candidateSlices, loops);
2441+
}
2442+
2443+
/// For a given `result` of a `forallOp` return the
2444+
/// `tensor.parallel_insert_slice` op (or combining op) that is used to
2445+
/// construct this result.
2446+
static std::optional<Operation *>
2447+
getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
2448+
if (result.getOwner() != forallOp)
2449+
return std::nullopt;
2450+
BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
2451+
SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
2452+
// If the number of combining ops is not 1, then this is unexpected. Return
2453+
// nullopt.
2454+
if (combiningOps.size() != 1)
2455+
return std::nullopt;
2456+
return combiningOps[0];
2457+
}
2458+
2459+
/// For a given result of the loop nest that is a tiled loop nest, return the
2460+
/// insert slice-like op that is used for consumer fusion
2461+
static std::optional<Operation *>
2462+
getProducingInsertSliceLikeOp(OpResult result,
2463+
ArrayRef<LoopLikeOpInterface> loops) {
2464+
assert(!loops.empty() && "Expected loops to be not empty");
2465+
LoopLikeOpInterface outerMostLoop = loops.front();
2466+
if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
2467+
assert(loops.size() == 1 &&
2468+
"expected only a single loop when tiling using scf.forall");
2469+
return getProducingParallelInsertSlice(forallOp, result);
2470+
}
2471+
// Assume that the loop nest is a nested `scf.for` that is created through
2472+
// tiling and retrieve the `tensor.insert_slice` operation used to construct
2473+
// the result.
2474+
while (loops.size() != 1) {
2475+
LoopLikeOpInterface loop = loops.front();
2476+
if (result.getOwner() != loop)
2477+
return std::nullopt;
2478+
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2479+
if (!forOp)
2480+
return std::nullopt;
2481+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2482+
auto innerForResult =
2483+
dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
2484+
if (!innerForResult)
2485+
return std::nullopt;
2486+
result = innerForResult;
2487+
loops = loops.drop_front();
2488+
}
2489+
LoopLikeOpInterface loop = loops.front();
2490+
if (result.getOwner() != loop)
2491+
return std::nullopt;
2492+
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2493+
if (!forOp)
2494+
return std::nullopt;
2495+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2496+
auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
2497+
.getDefiningOp<tensor::InsertSliceOp>();
2498+
if (!insertSliceOp)
2499+
return std::nullopt;
2500+
return insertSliceOp;
2501+
}
2502+
2503+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2504+
mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
2505+
MutableArrayRef<LoopLikeOpInterface> loops) {
2506+
if (!isa<TilingInterface>(consumer)) {
2507+
return rewriter.notifyMatchFailure(
2508+
consumer, "unhandled consumer that does not implement TilingInterface");
2509+
}
2510+
2511+
// Return if `loops` is empty, return an error for now. Caller is expected
2512+
// to handle this case.
2513+
if (loops.empty()) {
2514+
return rewriter.notifyMatchFailure(
2515+
consumer, "cannot call tile and fuse consumer with an empty loop nest");
2516+
}
2517+
2518+
LoopLikeOpInterface outermostLoop = loops.front();
2519+
2520+
// Collect the operands of the consumer that come from the outermost loop of
2521+
// the loop nest.
2522+
SmallVector<OpOperand *> consumerFusableOperands;
2523+
for (OpOperand &opOperand : consumer->getOpOperands()) {
2524+
if (opOperand.get().getDefiningOp() == outermostLoop) {
2525+
consumerFusableOperands.push_back(&opOperand);
2526+
}
2527+
}
2528+
2529+
// Nothing to fuse. Just return an empty set.
2530+
if (consumerFusableOperands.empty()) {
2531+
return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2532+
SmallVector<OpOperand *>{},
2533+
SmallVector<Operation *>{}};
2534+
}
2535+
2536+
// Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
2537+
// for fusion.
2538+
SmallVector<Operation *> candidateSlices;
2539+
candidateSlices.reserve(consumerFusableOperands.size());
2540+
for (OpOperand *opOperand : consumerFusableOperands) {
2541+
std::optional<Operation *> slice =
2542+
getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
2543+
if (!slice) {
2544+
return rewriter.notifyMatchFailure(
2545+
consumer,
2546+
"couldnt find producing insert-slice like operation for operand");
2547+
}
2548+
candidateSlices.push_back(slice.value());
2549+
}
2550+
return tileAndFuseConsumerOfSlicesImpl(
2551+
rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2552+
}
2553+
24362554
//===----------------------------------------------------------------------===//
24372555
// lowerToLoopsUsingSCFForOp implementation.
24382556
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ module {
170170
// Fuse the consumer operation into the tiled loop.
171171
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
172172
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
173-
transform.test.fuse_consumer %slice_op in (%forall_op)
173+
transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
174174
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
175175
transform.yield
176176
}
@@ -231,7 +231,7 @@ module {
231231
// Fuse the consumer operation into the tiled loop.
232232
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
233233
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
234-
// Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
234+
// Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
235235
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
236236
// to fuse" error.
237237
transform.yield

0 commit comments

Comments
 (0)