diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index e6adcde72ad66..e5486988947c6 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -133,7 +133,7 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, } /// Returns true if all the uses of op are not read/load. -/// There can be SubviewOp users as long as all its users are also +/// There can be view-like-op users as long as all its users are also /// StoreOp/transfer_write. If return true it also fills out the uses, if it /// returns false uses is unchanged. static bool resultIsNotRead(Operation *op, std::vector &uses) { @@ -146,7 +146,7 @@ static bool resultIsNotRead(Operation *op, std::vector &uses) { if (isa(useOp) || (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && !mlir::hasEffect(useOp)) || - (isa(useOp) && resultIsNotRead(useOp, opUses))) { + (isa(useOp) && resultIsNotRead(useOp, opUses))) { opUses.push_back(useOp); continue; } diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index 6e130912c47e9..7fc84d419f18d 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -395,6 +395,73 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @dead_store_through_subview +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_subview(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + %subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]} + : vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @dead_store_through_expand +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_expand(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + %expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32> + vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @dead_store_through_collapse +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_collapse(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32> + %collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32> + vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + // CHECK-LABEL: func @lower_to_llvm // CHECK-NOT: memref.alloc // CHECK: llvm.call @malloc