Skip to content

Commit

Permalink
[SimplifyCFG] Prevent merging cbranch to cbranch if the branch probab…
Browse files Browse the repository at this point in the history
…ility from the first to second is too low. (#69375)

AMDGPU target has faced the situation which can be illustrated with the
following testcase:

define void @dont_merge_cbranches(i32 %V) {
  %divergent_cond = icmp ne i32 %V, 0
  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
  br i1 %uniform_cond, label %bb2, label %exit, !prof !0
bb2:
  br i1 %divergent_cond, label %bb3, label %exit
bb3:
  call void @bar( )
  br label %exit
exit:
  ret void
}
!0 = !{!"branch_weights", i32 1, i32 100000}

SimplifyCFG merges branches on %uniform_cond and %divergent_cond which is undesirable because the first branch to bb2 is taken extremely rare and the second branch is expensive. The merged branch becomes as expensive as the second.

This patch prevents such merging if the branch to the second branch is unlikely to happen.
  • Loading branch information
vpykhtin committed Nov 13, 2023
1 parent dde85f8 commit f054947
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
14 changes: 14 additions & 0 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4347,6 +4347,20 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
if (PBI->getSuccessor(PBIOp) == BB)
return false;

// If predecessor's branch probability to BB is too low don't merge branches.
SmallVector<uint32_t, 2> PredWeights;
if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
extractBranchWeights(*PBI, PredWeights) &&
(PredWeights[0] + PredWeights[1]) != 0) {

BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
PredWeights[PBIOp], PredWeights[0] + PredWeights[1]);

BranchProbability Likely = TTI.getPredictableBranchThreshold();
if (CommonDestProb >= Likely)
return false;
}

// Do not perform this transformation if it would require
// insertion of a large number of select instructions. For targets
// without predication/cmovs, this is a big pessimization.
Expand Down
84 changes: 84 additions & 0 deletions llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=simplifycfg -S | FileCheck %s

declare void @bar()
declare i1 @uniform_result(i1 %c)

define void @dont_merge_cbranches1(i32 %V) {
; CHECK-LABEL: @dont_merge_cbranches1(
; CHECK-NEXT: [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
; CHECK-NEXT: [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
; CHECK-NEXT: br i1 [[UNIFORM_COND]], label [[BB2:%.*]], label [[EXIT:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK: bb2:
; CHECK-NEXT: br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
; CHECK: bb3:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: ret void
;
%divergent_cond = icmp ne i32 %V, 0
%uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
br i1 %uniform_cond, label %bb2, label %exit, !prof !0
bb2:
br i1 %divergent_cond, label %bb3, label %exit
bb3:
call void @bar( )
br label %exit
exit:
ret void
}

define void @dont_merge_cbranches2(i32 %V) {
; CHECK-LABEL: @dont_merge_cbranches2(
; CHECK-NEXT: [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
; CHECK-NEXT: [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
; CHECK-NEXT: br i1 [[UNIFORM_COND]], label [[EXIT:%.*]], label [[BB2:%.*]], !prof [[PROF1:![0-9]+]]
; CHECK: bb2:
; CHECK-NEXT: br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
; CHECK: bb3:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: ret void
;
%divergent_cond = icmp ne i32 %V, 0
%uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
br i1 %uniform_cond, label %exit, label %bb2, !prof !1
bb2:
br i1 %divergent_cond, label %bb3, label %exit
bb3:
call void @bar( )
br label %exit
exit:
ret void
}

define void @merge_cbranches(i32 %V) {
; CHECK-LABEL: @merge_cbranches(
; CHECK-NEXT: [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
; CHECK-NEXT: [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
; CHECK-NEXT: [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
; CHECK-NEXT: [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND]], i1 true, i1 [[DIVERGENT_COND_NOT]]
; CHECK-NEXT: br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF2:![0-9]+]]
; CHECK: bb3:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: ret void
;
%divergent_cond = icmp ne i32 %V, 0
%uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
br i1 %uniform_cond, label %exit, label %bb2, !prof !2
bb2:
br i1 %divergent_cond, label %bb3, label %exit
bb3:
call void @bar( )
br label %exit
exit:
ret void
}

!0 = !{!"branch_weights", i32 1, i32 1000}
!1 = !{!"branch_weights", i32 1000, i32 1}
!2 = !{!"branch_weights", i32 3, i32 2}

0 comments on commit f054947

Please sign in to comment.