diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 860384f954536..ce45f847ccaed 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -37,40 +38,101 @@ static bool overrideBuffer(Operation *op, Value buffer) { return copyOp.getTarget() == buffer; } -/// Replace the uses of `oldOp` with the given `val` and for subview uses +/// Replace the uses of `oldOp` with the given `val` and for view-like uses /// propagate the type change. Changing the memref type may require propagating -/// it through subview ops so we cannot just do a replaceAllUse but need to -/// propagate the type change and erase old subview ops. -static void replaceUsesAndPropagateType(RewriterBase &rewriter, - Operation *oldOp, Value val) { +/// it through view-like ops (subview, expand_shape, collapse_shape, cast) so +/// we need to propagate the type change and erase old view ops. +/// +/// Only view-like ops whose result type can be recomputed from the new source +/// type and existing op attributes are handled here. Other ops fall back to +/// operand replacement without type propagation. +static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter, + Operation *oldOp, Value val) { + SmallVector opsToErase; // Iterate with early_inc to erase current user inside the loop. for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) { Operation *user = use.getOwner(); - if (auto subviewUse = dyn_cast(user)) { - // `subview(old_op)` is replaced by a new `subview(val)`. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(subviewUse); - MemRefType newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), cast(val.getType()), - subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), - subviewUse.getStaticStrides()); - Value newSubview = memref::SubViewOp::create( - rewriter, subviewUse->getLoc(), newType, val, - subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), - subviewUse.getMixedStrides()); - - // Ouch recursion ... is this really necessary? - replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); - - // Safe to erase. - rewriter.eraseOp(subviewUse); - continue; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(user); + MemRefType srcType = cast(val.getType()); + + // Try to create a new view-like op with updated result type. + // Each view-like op has its own method to compute the result type. + bool typeInferenceFailed = false; + Value replacement = + llvm::TypeSwitch(user) + .Case([&](memref::SubViewOp subview) -> Value { + MemRefType newType = + memref::SubViewOp::inferRankReducedResultType( + subview.getType().getShape(), srcType, + subview.getStaticOffsets(), subview.getStaticSizes(), + subview.getStaticStrides()); + return memref::SubViewOp::create( + rewriter, subview->getLoc(), newType, val, + subview.getMixedOffsets(), subview.getMixedSizes(), + subview.getMixedStrides()); + }) + .Case([&](memref::ExpandShapeOp expand) -> Value { + FailureOr newType = + memref::ExpandShapeOp::computeExpandedType( + srcType, expand.getResultType().getShape(), + expand.getReassociationIndices()); + if (failed(newType)) { + typeInferenceFailed = true; + return Value(); + } + return memref::ExpandShapeOp::create( + rewriter, expand->getLoc(), *newType, val, + expand.getReassociationIndices(), + expand.getMixedOutputShape()); + }) + .Case([&](memref::CollapseShapeOp collapse) -> Value { + FailureOr newType = + memref::CollapseShapeOp::computeCollapsedType( + srcType, collapse.getReassociationIndices()); + if (failed(newType)) { + typeInferenceFailed = true; + return Value(); + } + return memref::CollapseShapeOp::create( + rewriter, collapse->getLoc(), *newType, val, + collapse.getReassociationIndices()); + }) + .Case([&](memref::CastOp cast) -> Value { + if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) { + typeInferenceFailed = true; + return Value(); + } + return memref::CastOp::create(rewriter, cast->getLoc(), + cast.getType(), val); + }) + .Default([&](Operation *) -> Value { return Value(); }); + + if (typeInferenceFailed) { + user->emitOpError( + "failed to compute view-like result type after multi-buffering"); + return failure(); + } + + if (replacement) { + // Recursively propagate through view-like ops and mark old op for + // erasure. + if (failed(replaceUsesAndPropagateType(rewriter, user, replacement))) + return failure(); + opsToErase.push_back(user); + } else { + // Not a view-like op: just replace operand. + rewriter.startOpModification(user); + use.set(val); + rewriter.finalizeOpModification(user); } - // Non-subview: replace with new value. - rewriter.startOpModification(user); - use.set(val); - rewriter.finalizeOpModification(user); } + + for (Operation *op : opsToErase) { + rewriter.eraseOp(op); + } + + return success(); } // Transformation to do multi-buffering/array expansion to remove dependencies @@ -216,7 +278,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, } // 6. RAUW with the particular slice, taking modular rotation into account. - replaceUsesAndPropagateType(rewriter, allocOp, subview); + if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview))) + return failure(); // 7. Finally, erase the old allocOp. rewriter.eraseOp(allocOp); diff --git a/mlir/test/Dialect/MemRef/multibuffer.mlir b/mlir/test/Dialect/MemRef/multibuffer.mlir index 4ab7d993e6fd1..b004ebfa1abd0 100644 --- a/mlir/test/Dialect/MemRef/multibuffer.mlir +++ b/mlir/test/Dialect/MemRef/multibuffer.mlir @@ -104,3 +104,123 @@ func.func @multi_buffer_negative(%a: memref<1024x1024xf32>) { return } +// ----- + +// Test that multi-buffering correctly propagates strided layout through expand_shape. + +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)> + +// CHECK-LABEL: func @multi_buffer_expand_shape +func.func @multi_buffer_expand_shape(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32> + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: scf.for %[[IV:.*]] = %{{.*}} + scf.for %arg2 = %c1 to %c1024 step %c3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]]) +// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[SV]] {{\[\[}}0, 1], [2, 3]] output_shape [2, 2, 64, 2] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>> + %expanded = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 2, 64, 2] + : memref<4x128xf32> into memref<2x2x64x2xf32> +// CHECK: "some_use"(%[[EXPANDED]]) : (memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>) -> () + "some_use"(%expanded) : (memref<2x2x64x2xf32>) -> () + } + return +} + +// ----- + +// Test that multi-buffering correctly propagates strided layout through collapse_shape. + +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)> + +// CHECK-LABEL: func @multi_buffer_collapse_shape +func.func @multi_buffer_collapse_shape(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32> + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: scf.for %[[IV:.*]] = %{{.*}} + scf.for %arg2 = %c1 to %c1024 step %c3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]]) +// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[SV]] {{\[\[}}0, 1]] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<512xf32, strided<[1], offset: ?>> + %collapsed = memref.collapse_shape %0 [[0, 1]] + : memref<4x128xf32> into memref<512xf32> +// CHECK: "some_use"(%[[COLLAPSED]]) : (memref<512xf32, strided<[1], offset: ?>>) -> () + "some_use"(%collapsed) : (memref<512xf32>) -> () + } + return +} + +// ----- + +// Test that multi-buffering correctly propagates through memref.cast. + +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)> + +// CHECK-LABEL: func @multi_buffer_cast +func.func @multi_buffer_cast(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32> + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: scf.for %[[IV:.*]] = %{{.*}} + scf.for %arg2 = %c1 to %c1024 step %c3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]]) +// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: %[[CAST:.*]] = memref.cast %[[SV]] : memref<4x128xf32, strided<[128, 1], offset: ?>> to memref + %casted = memref.cast %0 : memref<4x128xf32> to memref +// CHECK: "some_use"(%[[CAST]]) : (memref) -> () + "some_use"(%casted) : (memref) -> () + } + return +} + +// ----- + +// Test that multi-buffering correctly propagates through chained view-like ops. + +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)> + +// CHECK-LABEL: func @multi_buffer_chained_view_ops +func.func @multi_buffer_chained_view_ops(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32> + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: scf.for %[[IV:.*]] = %{{.*}} + scf.for %arg2 = %c1 to %c1024 step %c3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]]) +// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[SV]] {{\[\[}}0, 1], [2, 3]] output_shape [2, 2, 64, 2] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>> + %expanded = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 2, 64, 2] + : memref<4x128xf32> into memref<2x2x64x2xf32> +// CHECK: %[[CAST:.*]] = memref.cast %[[EXPANDED]] : memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>> to memref + %casted = memref.cast %expanded : memref<2x2x64x2xf32> to memref +// CHECK: "some_use"(%[[CAST]]) : (memref) -> () + "some_use"(%casted) : (memref) -> () + } + return +}