Skip to content

Commit 0f75ff9

Browse files
committed
yield fused producer if necessary
1 parent 43177a3 commit 0f75ff9

File tree

2 files changed

+98
-5
lines changed

2 files changed

+98
-5
lines changed

lib/gc/Transforms/TilingUsingInterfaceX.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2323
#include "mlir/Interfaces/TilingInterface.h"
24+
#include "mlir/Transforms/RegionUtils.h"
2425
#include "llvm/ADT/TypeSwitch.h"
2526
#include "llvm/Support/Debug.h"
2627
#include <optional>
@@ -255,6 +256,18 @@ SmallVector<LoopLikeOpInterface> mlir::scfX::getOuterNestLoopsWhile(
255256
return {nestLoops.rbegin(), nestLoops.rend()};
256257
}
257258

259+
/// A listener that watches which ops were erased.
260+
struct ErasedOpListener : public RewriterBase::Listener {
261+
private:
262+
/// Pointers to all erased operations and blocks.
263+
DenseSet<void *> erased;
264+
265+
public:
266+
ErasedOpListener() = default;
267+
void notifyOperationErased(Operation *op) override { erased.insert(op); }
268+
bool isErased(Operation *op) { return erased.count(op); }
269+
};
270+
258271
/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
259272
/// multi-level `extractSliceOp`. E.g.
260273
///
@@ -296,6 +309,57 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
296309
tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
297310
if (!fuseProducerResult)
298311
return std::nullopt;
312+
313+
// Cache old listener.
314+
OpBuilder::Listener *oldListener = rewriter.getListener();
315+
// Set new listener.
316+
ErasedOpListener *newListener = new ErasedOpListener();
317+
rewriter.setListener(newListener);
318+
319+
auto producerOp =
320+
cast<TilingInterface>(fuseProducerResult->origProducer.getDefiningOp());
321+
unsigned resultNumber = fuseProducerResult->origProducer.getResultNumber();
322+
// cache candidate slice
323+
auto extractSliceOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
324+
SmallVector<OpFoldResult> offsets = extractSliceOp.getMixedOffsets(),
325+
sizes = extractSliceOp.getMixedSizes(),
326+
strides = extractSliceOp.getMixedStrides();
327+
// Explicitly execute DCE.
328+
(void)mlir::simplifyRegions(rewriter, {*producerOp->getParentRegion()});
329+
// If fused producer has multiple users.
330+
bool yieldReplacement = !newListener->isErased(producerOp);
331+
// Reset to old listener.
332+
rewriter.setListener(oldListener);
333+
// Delete new listener.
334+
delete newListener;
335+
336+
if (yieldReplacement) {
337+
OpBuilder::InsertionGuard g(rewriter);
338+
// Set insertPoint right before tiled op.
339+
rewriter.setInsertionPoint(fuseProducerResult->tiledOps[0]);
340+
// Manually clone new candidate slice.
341+
auto clonedExtractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
342+
producerOp->getLoc(), producerOp->getResult(resultNumber), offsets,
343+
sizes, strides);
344+
// Yield replacement for fused producer in avoid of repeated computation.
345+
if (failed(scf::yieldReplacementForFusedProducer(
346+
rewriter, clonedExtractSliceOp, fuseProducerResult.value(),
347+
outerLoops)))
348+
return std::nullopt;
349+
// Erase cloned candidate slice.
350+
rewriter.eraseOp(clonedExtractSliceOp);
351+
352+
unsigned loopNumResults = outerLoops.front()->getNumResults(),
353+
producerNumResults = producerOp->getNumResults();
354+
// Replace other users of fused producer with new loop results.
355+
for (auto &&[index, result] : llvm::enumerate(producerOp->getResults())) {
356+
rewriter.replaceAllUsesWith(
357+
result, outerLoops.front()->getResult(loopNumResults -
358+
producerNumResults + index));
359+
}
360+
// Erase fused producer op.
361+
rewriter.eraseOp(producerOp);
362+
}
299363
}
300364
return fuseProducerResult;
301365
}

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,11 @@ module {
381381
// -----
382382

383383
module {
384-
// CHECK: func.func @fuse_generic_matmul(
385-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
386-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
387-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
388-
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} {
384+
/// CHECK-LABEL: @fuse_generic_matmul
385+
/// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
386+
/// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
387+
/// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
388+
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> {
389389
/// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
390390
%0 = tensor.empty() : tensor<2x2x16x16xf32>
391391
%pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<32x32xf32> -> tensor<2x2x16x16xf32>
@@ -429,4 +429,33 @@ module {
429429
/// CHECK: return %[[FINAL_RESULT]]#1
430430
return %unpack : tensor<32x64xf32>
431431
}
432+
}
433+
434+
// -----
435+
436+
module {
437+
/// CHECK-LABEL: @yield_fused_producer
438+
func.func @yield_fused_producer(%arg0: tensor<16x32x32xf32>) -> (tensor<16x32x32xf32>, tensor<16x32xf32>) {
439+
/// CHECK: arith.constant
440+
%cst_0 = arith.constant dense<2.000000e+00> : tensor<16x32x32xf32>
441+
/// CHECK-NEXT: tensor.empty
442+
%dest0 = tensor.empty() : tensor<16x32x32xf32>
443+
%0 = linalg.powf ins(%arg0, %cst_0 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%dest0 : tensor<16x32x32xf32>) -> tensor<16x32x32xf32>
444+
/// CHECK-NEXT: tensor.empty
445+
%dest1 = tensor.empty() : tensor<16x32xf32>
446+
/// CHECK-NEXT: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) to (16, 32) step (1, 16)
447+
/// CHECK-NEXT: tensor.extract_slice
448+
/// CHECK-NEXT: tensor.extract_slice
449+
/// CHECK-NEXT: tensor.extract_slice
450+
/// CHECK-NEXT: linalg.powf
451+
/// CHECK-NEXT: tensor.extract_slice
452+
/// CHECK-NEXT: linalg.reduce
453+
%1 = linalg.reduce { arith.addf } ins(%0 : tensor<16x32x32xf32>) outs(%dest1 : tensor<16x32xf32>) dimensions = [2]
454+
/// CHECK-NEXT: scf.forall.in_parallel
455+
/// CHECK-NEXT: tensor.parallel_insert_slice
456+
/// CHECK-NEXT: tensor.parallel_insert_slice
457+
/// CHECK-NEXT: }
458+
/// CHECK: return %[[FINAL_RESULT]]#1, %[[FINAL_RESULT]]#0
459+
return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32>
460+
}
432461
}

0 commit comments

Comments
 (0)