Skip to content

Commit 3cfdafc

Browse files
committed
yield fused producer if necessary
1 parent 9094fa6 commit 3cfdafc

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

lib/gc/Transforms/TilingUsingInterfaceX.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2323
#include "mlir/Interfaces/TilingInterface.h"
24+
#include "mlir/Transforms/RegionUtils.h"
2425
#include "llvm/ADT/TypeSwitch.h"
2526
#include "llvm/Support/Debug.h"
2627
#include <optional>
@@ -255,6 +256,18 @@ SmallVector<LoopLikeOpInterface> mlir::scfX::getOuterNestLoopsWhile(
255256
return {nestLoops.rbegin(), nestLoops.rend()};
256257
}
257258

259+
/// A listener that watches which ops were erased.
260+
struct ErasedOpListener : public RewriterBase::Listener {
261+
private:
262+
/// Pointers to all erased operations and blocks.
263+
DenseSet<void *> erased;
264+
265+
public:
266+
ErasedOpListener() = default;
267+
void notifyOperationErased(Operation *op) override { erased.insert(op); }
268+
bool isErased(Operation *op) { return erased.count(op); }
269+
};
270+
258271
/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
259272
/// multi-level `extractSliceOp`. E.g.
260273
///
@@ -296,6 +309,51 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
296309
tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
297310
if (!fuseProducerResult)
298311
return std::nullopt;
312+
313+
// Cache old listener.
314+
OpBuilder::Listener *oldListener = rewriter.getListener();
315+
// Set new listener.
316+
ErasedOpListener *newListener = new ErasedOpListener();
317+
rewriter.setListener(newListener);
318+
319+
auto producerOp =
320+
cast<TilingInterface>(fuseProducerResult->origProducer.getDefiningOp());
321+
unsigned resultNumber = fuseProducerResult->origProducer.getResultNumber();
322+
// cache candidate slice
323+
auto extractSliceOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
324+
SmallVector<OpFoldResult> offsets = extractSliceOp.getMixedOffsets(),
325+
sizes = extractSliceOp.getMixedSizes(),
326+
strides = extractSliceOp.getMixedStrides();
327+
(void)mlir::simplifyRegions(rewriter, {*producerOp->getParentRegion()});
328+
// If fused producer has multiple users.
329+
bool yieldReplacement = !newListener->isErased(producerOp);
330+
// Reset to old listener.
331+
rewriter.setListener(oldListener);
332+
// Delete new listener.
333+
delete newListener;
334+
335+
if (yieldReplacement) {
336+
// Manually clone new candidate slice.
337+
OpBuilder::InsertionGuard g(rewriter);
338+
rewriter.setInsertionPoint(fuseProducerResult->tiledOps[0]);
339+
auto clonedExtractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
340+
producerOp->getLoc(), producerOp->getResult(resultNumber), offsets,
341+
sizes, strides);
342+
// Yield replacement for fused producer in avoid of repeated computation.
343+
if (failed(scf::yieldReplacementForFusedProducer(
344+
rewriter, clonedExtractSliceOp, fuseProducerResult.value(),
345+
outerLoops)))
346+
return std::nullopt;
347+
348+
unsigned loopNumResults = outerLoops.front()->getNumResults(),
349+
producerNumResults = producerOp->getNumResults();
350+
// Replace other users of fused producer with new loop results.
351+
for (auto [index, result] : llvm::enumerate(producerOp->getResults())) {
352+
rewriter.replaceAllUsesWith(
353+
result, outerLoops.front()->getResult(loopNumResults -
354+
producerNumResults + index));
355+
}
356+
}
299357
}
300358
return fuseProducerResult;
301359
}

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,11 @@ module {
358358
// -----
359359

360360
module {
361-
// CHECK: func.func @fuse_generic_matmul(
362-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
363-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
364-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
365-
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} {
361+
/// CHECK-LABEL: @fuse_generic_matmul
362+
/// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
363+
/// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
364+
/// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
365+
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> {
366366
/// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
367367
%0 = tensor.empty() : tensor<2x2x16x16xf32>
368368
%pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<32x32xf32> -> tensor<2x2x16x16xf32>
@@ -406,4 +406,33 @@ module {
406406
/// CHECK: return %[[FINAL_RESULT]]#1
407407
return %unpack : tensor<32x64xf32>
408408
}
409+
}
410+
411+
// -----
412+
413+
module {
414+
/// CHECK-LABEL: @yield_fused_producer
415+
func.func @yield_fused_producer(%arg0: tensor<16x32x32xf32>) -> (tensor<16x32x32xf32>, tensor<16x32xf32>) {
416+
/// CHECK: arith.constant
417+
%cst_0 = arith.constant dense<2.000000e+00> : tensor<16x32x32xf32>
418+
/// CHECK-NEXT: tensor.empty
419+
%dest0 = tensor.empty() : tensor<16x32x32xf32>
420+
%0 = linalg.powf ins(%arg0, %cst_0 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%dest0 : tensor<16x32x32xf32>) -> tensor<16x32x32xf32>
421+
/// CHECK-NEXT: tensor.empty
422+
%dest1 = tensor.empty() : tensor<16x32xf32>
423+
/// CHECK-NEXT: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) to (16, 32) step (1, 16)
424+
/// CHECK-NEXT: tensor.extract_slice
425+
/// CHECK-NEXT: tensor.extract_slice
426+
/// CHECK-NEXT: tensor.extract_slice
427+
/// CHECK-NEXT: linalg.powf
428+
/// CHECK-NEXT: tensor.extract_slice
429+
/// CHECK-NEXT: linalg.reduce
430+
%1 = linalg.reduce { arith.addf } ins(%0 : tensor<16x32x32xf32>) outs(%dest1 : tensor<16x32xf32>) dimensions = [2]
431+
/// CHECK-NEXT: scf.forall.in_parallel
432+
/// CHECK-NEXT: tensor.parallel_insert_slice
433+
/// CHECK-NEXT: tensor.parallel_insert_slice
434+
/// CHECK-NEXT: }
435+
/// CHECK: return %[[FINAL_RESULT]]#1, %[[FINAL_RESULT]]#0
436+
return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32>
437+
}
409438
}

0 commit comments

Comments
 (0)