diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 6200366cded29..e6adcde72ad66 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -140,6 +140,9 @@ static bool resultIsNotRead(Operation *op, std::vector &uses) { std::vector opUses; for (OpOperand &use : op->getUses()) { Operation *useOp = use.getOwner(); + // Use escaped the scope + if (useOp->mightHaveTrait()) + return false; if (isa(useOp) || (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && !mlir::hasEffect(useOp)) || diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index 3b37c62fcb28e..6e130912c47e9 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -306,6 +306,23 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @dead_alloc_escaped +func.func @dead_alloc_escaped() -> memref<8x64xf32, 3> { + // CHECK: %{{.+}} = memref.alloc + %0 = memref.alloc() : memref<8x64xf32, 3> + return %0 : memref<8x64xf32, 3> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @dead_alloc func.func @dead_alloc() { // CHECK-NOT: %{{.+}} = memref.alloc