diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index 27930bbc651bd..8bd060ae8f485 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -3556,6 +3556,11 @@ class SwitchInstProfUpdateWrapper { /// correspondent branch weight. LLVM_ABI SwitchInst::CaseIt removeCase(SwitchInst::CaseIt I); + /// Replace the default destination by given case. Delegate the call to + /// the underlying SwitchInst::setDefaultDest and remove correspondent branch + /// weight. + LLVM_ABI void replaceDefaultDest(SwitchInst::CaseIt I); + /// Delegate the call to the underlying SwitchInst::addCase() and set the /// specified branch weight for the added case. LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W); diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 9060a89bb15ff..cb87691922a10 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -4171,6 +4171,16 @@ SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) { return SI.removeCase(I); } +void SwitchInstProfUpdateWrapper::replaceDefaultDest(SwitchInst::CaseIt I) { + auto *DestBlock = I->getCaseSuccessor(); + if (Weights) { + auto Weight = getSuccessorWeight(I->getCaseIndex() + 1); + (*Weights)[0] = Weight.value(); + } + + SI.setDefaultDest(DestBlock); +} + void SwitchInstProfUpdateWrapper::addCase( ConstantInt *OnVal, BasicBlock *Dest, SwitchInstProfUpdateWrapper::CaseWeightOpt W) { diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index d831c2737e5f8..1948948d4f129 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -7540,6 +7540,82 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Tries to transform the switch when the condition is umin with a constant. +/// In that case, the default branch can be replaced by the constant's branch. +/// This method also removes dead cases when the simplification cannot replace +/// the default branch. +/// +/// For example: +/// switch(umin(a, 3)) { +/// case 0: +/// case 1: +/// case 2: +/// case 3: +/// case 4: +/// // ... +/// default: +/// unreachable +/// } +/// +/// Transforms into: +/// +/// switch(a) { +/// case 0: +/// case 1: +/// case 2: +/// default: +/// // This is case 3 +/// } +static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) { + Value *A; + ConstantInt *Constant; + + if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant)))) + return false; + + SmallVector Updates; + SwitchInstProfUpdateWrapper SIW(*SI); + BasicBlock *BB = SIW->getParent(); + + // Dead cases are removed even when the simplification fails. + // A case is dead when its value is higher than the Constant. + SmallVector DeadCases; + for (auto Case : SI->cases()) + if (Case.getCaseValue()->getValue().ugt(Constant->getValue())) + DeadCases.push_back(Case.getCaseValue()); + + for (ConstantInt *DeadCaseVal : DeadCases) { + SwitchInst::CaseIt DeadCase = SIW->findCaseValue(DeadCaseVal); + BasicBlock *DeadCaseBB = DeadCase->getCaseSuccessor(); + DeadCaseBB->removePredecessor(SIW->getParent()); + SIW.removeCase(DeadCase); + Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); + } + + auto Case = SI->findCaseValue(Constant); + // If the case value is not found, `findCaseValue` returns the default case. + // In this scenario, since there is no explicit `case 3:`, the simplification + // fails. The simplification also fails when the switch’s default destination + // is reachable. + if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { + if (DTU) + DTU->applyUpdates(Updates); + return !Updates.empty(); + } + + BasicBlock *Unreachable = SI->getDefaultDest(); + SIW.replaceDefaultDest(Case); + SIW.removeCase(Case); + SIW->setCondition(A); + + Updates.push_back({DominatorTree::Delete, BB, Unreachable}); + + if (DTU) + DTU->applyUpdates(Updates); + + return true; +} + /// Tries to transform switch of powers of two to reduce switch range. /// For example, switch like: /// switch (C) { case 1: case 2: case 64: case 128: } @@ -7966,6 +8042,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (simplifyDuplicateSwitchArms(SI, DTU)) return requestResimplify(); + if (simplifySwitchWhenUMin(SI, DTU)) + return requestResimplify(); + return false; } diff --git a/llvm/test/Transforms/SimplifyCFG/switch-umin.ll b/llvm/test/Transforms/SimplifyCFG/switch-umin.ll new file mode 100644 index 0000000000000..44665365dc222 --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/switch-umin.ll @@ -0,0 +1,246 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt -S -passes=simplifycfg < %s | FileCheck %s + +declare void @a() +declare void @b() +declare void @c() +declare void @d() + +define void @switch_replace_default(i32 %x) { +; CHECK-LABEL: define void @switch_replace_default( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3) +; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [ +; CHECK-NEXT: i32 0, label %[[CASE0:.*]] +; CHECK-NEXT: i32 1, label %[[CASE1:.*]] +; CHECK-NEXT: i32 2, label %[[CASE2:.*]] +; CHECK-NEXT: ], !prof [[PROF0:![0-9]+]] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret void +; CHECK: [[CASE0]]: +; CHECK-NEXT: call void @a() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE1]]: +; CHECK-NEXT: call void @b() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE2]]: +; CHECK-NEXT: call void @c() +; CHECK-NEXT: br label %[[COMMON_RET]] +; + %min = call i32 @llvm.umin.i32(i32 %x, i32 3) + switch i32 %min, label %unreachable [ + i32 0, label %case0 + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ], !prof !0 + +case0: + call void @a() + ret void + +case1: + call void @b() + ret void + +case2: + call void @c() + ret void + +case3: + ret void + +unreachable: + unreachable +} + +define void @switch_replace_default_and_remove_dead_cases(i32 %x) { +; CHECK-LABEL: define void @switch_replace_default_and_remove_dead_cases( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3) +; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [ +; CHECK-NEXT: i32 2, label %[[CASE2:.*]] +; CHECK-NEXT: i32 1, label %[[CASE1:.*]] +; CHECK-NEXT: ] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret void +; CHECK: [[CASE1]]: +; CHECK-NEXT: call void @b() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE2]]: +; CHECK-NEXT: call void @c() +; CHECK-NEXT: br label %[[COMMON_RET]] +; + %min = call i32 @llvm.umin.i32(i32 %x, i32 3) + switch i32 %min, label %unreachable [ + i32 4, label %case4 + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ] + +case4: + call void @a() + ret void + +case1: + call void @b() + ret void + +case2: + call void @c() + ret void + +case3: + ret void + +unreachable: + unreachable +} + +define void @switch_replace_default_when_holes(i32 %x) { +; CHECK-LABEL: define void @switch_replace_default_when_holes( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3) +; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [ +; CHECK-NEXT: i32 1, label %[[CASE1:.*]] +; CHECK-NEXT: i32 2, label %[[CASE2:.*]] +; CHECK-NEXT: ] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret void +; CHECK: [[CASE1]]: +; CHECK-NEXT: call void @b() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE2]]: +; CHECK-NEXT: call void @c() +; CHECK-NEXT: br label %[[COMMON_RET]] +; + %min = call i32 @llvm.umin.i32(i32 %x, i32 3) + switch i32 %min, label %unreachable [ + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ] + +case1: + call void @b() + ret void + +case2: + call void @c() + ret void + +case3: + ret void + +unreachable: + unreachable +} + +define void @do_not_switch_replace_default(i32 %x, i32 %y) { +; CHECK-LABEL: define void @do_not_switch_replace_default( +; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) { +; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]]) +; CHECK-NEXT: switch i32 [[MIN]], label %[[UNREACHABLE:.*]] [ +; CHECK-NEXT: i32 0, label %[[CASE0:.*]] +; CHECK-NEXT: i32 1, label %[[CASE1:.*]] +; CHECK-NEXT: i32 2, label %[[CASE2:.*]] +; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]] +; CHECK-NEXT: ] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret void +; CHECK: [[CASE0]]: +; CHECK-NEXT: call void @a() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE1]]: +; CHECK-NEXT: call void @b() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE2]]: +; CHECK-NEXT: call void @c() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[UNREACHABLE]]: +; CHECK-NEXT: unreachable +; + %min = call i32 @llvm.umin.i32(i32 %x, i32 %y) + switch i32 %min, label %unreachable [ + i32 0, label %case0 + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ] + +case0: + call void @a() + ret void + +case1: + call void @b() + ret void + +case2: + call void @c() + ret void + +case3: + ret void + +unreachable: + unreachable +} + +define void @do_not_replace_switch_default_but_remove_dead_cases(i32 %x) { +; CHECK-LABEL: define void @do_not_replace_switch_default_but_remove_dead_cases( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3) +; CHECK-NEXT: switch i32 [[MIN]], label %[[CASE0:.*]] [ +; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]] +; CHECK-NEXT: i32 1, label %[[CASE1:.*]] +; CHECK-NEXT: i32 2, label %[[CASE2:.*]] +; CHECK-NEXT: ] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret void +; CHECK: [[CASE0]]: +; CHECK-NEXT: call void @a() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE1]]: +; CHECK-NEXT: call void @b() +; CHECK-NEXT: br label %[[COMMON_RET]] +; CHECK: [[CASE2]]: +; CHECK-NEXT: call void @c() +; CHECK-NEXT: br label %[[COMMON_RET]] +; + %min = call i32 @llvm.umin.i32(i32 %x, i32 3) + switch i32 %min, label %case0 [ ; default is reachable, therefore simplification not triggered + i32 0, label %case0 + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + i32 4, label %case4 + ] + +case0: + call void @a() + ret void + +case1: + call void @b() + ret void + +case2: + call void @c() + ret void + +case3: + ret void + +case4: + call void @d() + ret void + +} + + +!0 = !{!"branch_weights", i32 1, i32 2, i32 3, i32 99, i32 5} +;. +; CHECK: [[PROF0]] = !{!"branch_weights", i32 5, i32 2, i32 3, i32 99} +;.