Skip to content

Commit

Permalink
[mlir][Transforms] teach CSE about recursive memory effects
Browse files Browse the repository at this point in the history
Add support for reasoning about operations with recursive memory effects
to CSE. The recursive effects are gathered by a helper function. I
decided to allow returning duplicates from the helper function because
there's no benefit to spending the computation time to remove them in
the existing use case.

Differential Revision: https://reviews.llvm.org/D156805
  • Loading branch information
tblah committed Aug 10, 2023
1 parent e6d5dcf commit dea33c8
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 13 deletions.
11 changes: 11 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,17 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);

/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
/// std::nullopt indicates that an option did not have a memory effect interface
/// and so no result could be obtained. An empty vector indicates that there
/// were no memory effects found (but every operation implemented the memory
/// effect interface or has RecursiveMemoryEffects). If the vector contains
/// multiple effects, these effects may be duplicates.
std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
getEffectsRecursively(Operation *rootOp);

/// Returns true if the given operation is speculatable, i.e. has no undefined
/// behavior or other side effects.
///
Expand Down
33 changes: 33 additions & 0 deletions mlir/lib/Interfaces/SideEffectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,39 @@ bool mlir::isMemoryEffectFree(Operation *op) {
return true;
}

// the returned vector may contain duplicate effects
std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
mlir::getEffectsRecursively(Operation *rootOp) {
SmallVector<MemoryEffects::EffectInstance> effects;
SmallVector<Operation *> effectingOps(1, rootOp);
while (!effectingOps.empty()) {
Operation *op = effectingOps.pop_back_val();

// If the operation has recursive effects, push all of the nested
// operations on to the stack to consider.
bool hasRecursiveEffects =
op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
if (hasRecursiveEffects) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (Operation &nestedOp : block) {
effectingOps.push_back(&nestedOp);
}
}
}
}

if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
effectInterface.getEffects(effects);
} else if (!hasRecursiveEffects) {
// the operation does not have recursive memory effects or implement
// the memory effect op interface. Its effects are unknown.
return std::nullopt;
}
}
return effects;
}

bool mlir::isSpeculatable(Operation *op) {
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
if (!conditionallySpeculatable)
Expand Down
20 changes: 13 additions & 7 deletions mlir/lib/Transforms/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,23 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
}
}
while (nextOp && nextOp != toOp) {
auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
// TODO: Do we need to handle other effects generically?
// If the operation does not implement the MemoryEffectOpInterface we
// conservatively assumes it writes.
if ((nextOpMemEffects &&
nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
!nextOpMemEffects) {
std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
getEffectsRecursively(nextOp);
if (!effects) {
// TODO: Do we need to handle other effects generically?
// If the operation does not implement the MemoryEffectOpInterface we
// conservatively assume it writes.
result.first->second =
std::make_pair(nextOp, MemoryEffects::Write::get());
return true;
}

for (const MemoryEffects::EffectInstance &effect : *effects) {
if (isa<MemoryEffects::Write>(effect.getEffect())) {
result.first->second = {nextOp, MemoryEffects::Write::get()};
return true;
}
}
nextOp = nextOp->getNextNode();
}
result.first->second = std::make_pair(toOp, nullptr);
Expand Down
5 changes: 2 additions & 3 deletions mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK: scf.yield %[[VAL_145]]
// CHECK: }
// CHECK: %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]]
// CHECK: %[[VAL_148:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_148]]
// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]]
// CHECK: %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]]
// CHECK: %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]]
// CHECK: %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]]
Expand Down Expand Up @@ -529,4 +528,4 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,7 @@
// CHECK: scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
// CHECK: }
// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
// CHECK: %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK: %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
// CHECK: memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK: memref.store %[[VAL_112]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK: scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
// CHECK: }
// CHECK: %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
Expand Down
61 changes: 61 additions & 0 deletions mlir/test/Transforms/cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,64 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
// CHECK: }
// CHECK-NOT: scf.if
// CHECK: return %[[if]], %[[if]]

// CHECK-LABEL: @cse_recursive_effects_success
func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
%0 = "test.op_with_memread"() : () -> (i32)

// do something with recursive effects, containing no side effects
%true = arith.constant true
// CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
// CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
%1 = scf.if %true -> (i32) {
%c42 = arith.constant 42 : i32
scf.yield %c42 : i32
// CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
// CHECK-NEXT: scf.yield %[[C42]]
// CHECK-NEXT: } else {
} else {
%c24 = arith.constant 24 : i32
scf.yield %c24 : i32
// CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
// CHECK-NEXT: scf.yield %[[C24]]
// CHECK-NEXT: }
}

// %2 can be removed
// CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE]], %[[IF]] : i32, i32, i32
%2 = "test.op_with_memread"() : () -> (i32)
return %0, %2, %1 : i32, i32, i32
}

// CHECK-LABEL: @cse_recursive_effects_failure
func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
%0 = "test.op_with_memread"() : () -> (i32)

// do something with recursive effects, containing a write effect
%true = arith.constant true
// CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
// CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
%1 = scf.if %true -> (i32) {
"test.op_with_memwrite"() : () -> ()
// CHECK-NEXT: "test.op_with_memwrite"() : () -> ()
%c42 = arith.constant 42 : i32
scf.yield %c42 : i32
// CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
// CHECK-NEXT: scf.yield %[[C42]]
// CHECK-NEXT: } else {
} else {
%c24 = arith.constant 24 : i32
scf.yield %c24 : i32
// CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
// CHECK-NEXT: scf.yield %[[C24]]
// CHECK-NEXT: }
}

// %2 can not be be removed because of the write
// CHECK-NEXT: %[[READ_VALUE2:.*]] = "test.op_with_memread"() : () -> i32
// CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE2]], %[[IF]] : i32, i32, i32
%2 = "test.op_with_memread"() : () -> (i32)
return %0, %2, %1 : i32, i32, i32
}

0 comments on commit dea33c8

Please sign in to comment.