Skip to content

Commit

Permalink
[MLIR] Don't sort operand of commutative ops when comparing two ops a…
Browse files Browse the repository at this point in the history
…s there is a correctness issue

This feature was introduced in `D123492`.

Doing equivalence on pointers to sort operands of commutative operations is incorrect when checking equivalence of ops in separate regions (where the lhs and rhs operands are marked as equivalent but are not the same value).

It was also discussed in `D123492` and `D129480` that the correct solution would be to stable sort the operands in canonicalization (based on some numbering in the region maybe), but until that lands, reverting this change will unblock us and other users.

An example of a pass that might not work properly because of this is `DuplicateFunctionEliminationPass`.

Reviewed By: mehdi_amini, jpienaar

Differential Revision: https://reviews.llvm.org/D154699
  • Loading branch information
tomnatan30 authored and jpienaar committed Jul 14, 2023
1 parent 71a2545 commit 2109587
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 91 deletions.
3 changes: 3 additions & 0 deletions flang/test/Fir/commute.fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
// RUN: fir-opt %s | tco | FileCheck %s
//
// XFAIL:*
// See: https://github.com/llvm/llvm-project/issues/63784

// CHECK-LABEL: define i32 @f1(i32 %0, i32 %1)
func.func @f1(%a : i32, %b : i32) -> i32 {
Expand Down
49 changes: 3 additions & 46 deletions mlir/lib/IR/OperationSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,19 +661,10 @@ llvm::hash_code OperationEquivalence::computeHash(
hash = llvm::hash_combine(hash, op->getLoc());

// - Operands
ValueRange operands = op->getOperands();
SmallVector<Value> operandStorage;
if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
operandStorage.append(operands.begin(), operands.end());
llvm::sort(operandStorage, [](Value a, Value b) -> bool {
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
});
operands = operandStorage;
}
for (Value operand : operands)
for (Value operand : op->getOperands())
hash = llvm::hash_combine(hash, hashOperands(operand));

// - Operands
// - Results
for (Value result : op->getResults())
hash = llvm::hash_combine(hash, hashResults(result));
return hash;
Expand Down Expand Up @@ -784,41 +775,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
return false;

// 2. Compare operands.
ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
auto sortValues = [](ValueRange values) {
SmallVector<Value> sortedValues = llvm::to_vector(values);
llvm::sort(sortedValues, [](Value a, Value b) {
auto aArg = llvm::dyn_cast<BlockArgument>(a);
auto bArg = llvm::dyn_cast<BlockArgument>(b);

// Case 1. Both `a` and `b` are `BlockArgument`s.
if (aArg && bArg) {
if (aArg.getParentBlock() == bArg.getParentBlock())
return aArg.getArgNumber() < bArg.getArgNumber();
return aArg.getParentBlock() < bArg.getParentBlock();
}

// Case 2. One of then is a `BlockArgument` and other is not. Treat
// `BlockArgument` as lesser.
if (aArg && !bArg)
return true;
if (bArg && !aArg)
return false;

// Case 3. Both are values.
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
});
return sortedValues;
};
lhsOperandStorage = sortValues(lhsOperands);
lhsOperands = lhsOperandStorage;
rhsOperandStorage = sortValues(rhsOperands);
rhsOperands = rhsOperandStorage;
}

for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) {
for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
Value curArg = std::get<0>(operandPair);
Value otherArg = std::get<1>(operandPair);
if (curArg == otherArg)
Expand Down
19 changes: 10 additions & 9 deletions mlir/test/Dialect/Func/duplicate-function-elimination.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ func.func @user(%arg0: f32, %arg1: f32) -> f32 {

// CHECK: @add_lr
// CHECK-NOT: @also_add_lr
// CHECK-NOT: @add_rl
// CHECK: @add_rl
// CHECK-NOT: @also_add_rl
// CHECK: @user
// CHECK-4: call @add_lr
// CHECK-2: call @add_lr
// CHECK-2: call @add_rl

// -----

Expand Down Expand Up @@ -108,7 +109,7 @@ func.func @user(%pred : i1, %arg0: f32, %arg1: f32) -> f32 {

// -----

func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32)
func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32)
-> f32 {
%0 = scf.if %p0 -> f32 {
%1 = scf.if %p1 -> f32 {
Expand Down Expand Up @@ -188,7 +189,7 @@ func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32)
return %0 : f32
}

func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
%odd: f32) -> f32 {
%0 = scf.if %p0 -> f32 {
%1 = scf.if %p1 -> f32 {
Expand Down Expand Up @@ -268,7 +269,7 @@ func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
return %0 : f32
}

func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
%odd: f32) -> f32 {
%0 = scf.if %p0 -> f32 {
%1 = scf.if %p1 -> f32 {
Expand Down Expand Up @@ -348,13 +349,13 @@ func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
return %0 : f32
}

func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
-> (f32, f32, f32) {
%0 = call @deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
%0 = call @deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
: (i1, i1, i1, i1, f32, f32) -> f32
%1 = call @also_deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
%1 = call @also_deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
: (i1, i1, i1, i1, f32, f32) -> f32
%2 = call @reverse_deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
%2 = call @reverse_deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
: (i1, i1, i1, i1, f32, f32) -> f32
return %0, %1, %2 : f32, f32, f32
}
Expand Down
38 changes: 2 additions & 36 deletions mlir/test/Transforms/cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -311,18 +311,6 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
return %2 : i32
}

/// This test is checking that identical commutative operation are gracefully
/// handled but the CSE pass.
// CHECK-LABEL: func @check_cummutative_cse
func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
// CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
%1 = arith.addi %a, %b : i32
%2 = arith.addi %b, %a : i32
// CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32
%3 = arith.muli %1, %2 : i32
return %3 : i32
}

// Check that an operation with a single region can CSE.
func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
Expand Down Expand Up @@ -425,31 +413,9 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]

// Account for commutative ops within regions during CSE.
func.func @cse_single_block_with_commutative_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = test.cse_of_single_block_op inputs(%a, %b) {
^bb0(%arg0 : f32, %arg1 : f32):
%1 = arith.addf %arg0, %arg1 : f32
%2 = arith.mulf %1, %c : f32
test.region_yield %2 : f32
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = test.cse_of_single_block_op inputs(%a, %b) {
^bb0(%arg0 : f32, %arg1 : f32):
%1 = arith.addf %arg1, %arg0 : f32
%2 = arith.mulf %c, %1 : f32
test.region_yield %2 : f32
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
// CHECK-LABEL: func @cse_single_block_with_commutative_ops
// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]

func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
%false_2 = arith.constant false
%true_5 = arith.constant true
%false_2 = arith.constant false
%true_5 = arith.constant true
%9 = test.cse_of_single_block_op inputs(%arg2) {
^bb0(%out: i1):
%true_144 = arith.constant true
Expand Down

0 comments on commit 2109587

Please sign in to comment.