Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SimplifyCFG] Convert switch to cmp/select sequence #82795

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Acim-Maravic
Copy link
Contributor

@Acim-Maravic Acim-Maravic commented Feb 23, 2024

I have added the option -switch-to-select, which, when passed, will trigger this optimization. I am not sure where this code should reside and if I am on the right track. What I am currently attempting is to add support for simple switches, i.e., switches where all values in the PHI node are constants. Now, I am trying to add support for more complex switch cases, i.e., where the values in the PHI node do not have to be constants.

When I am getting something like this:

switch i32 %cond.freeze1, label %21 [
i32 0, label %15
i32 1, label %16
i32 2, label %17
i32 3, label %19
]

15: ; preds = %6
br label %21

16: ; preds = %6
br label %21

17: ; preds = %6
%18 = fneg reassoc nnan nsz arcp contract afn float %.i0
br label %21

19: ; preds = %6
%20 = fneg reassoc nnan nsz arcp contract afn float %.i0
br label %21

21: ; preds = %6, %15, %16, %17, %19, %.entry
%samplePos.1.i0 = phi float [ 0.000000e+00, %.entry ], [ 0.000000e+00, %6 ], [ %20, %19 ], [ 1.000000e+00, %17 ], [ %.i0, %16 ], [ 1.000000e+00, %15 ]
%samplePos.1.i1 = phi float [ 0.000000e+00, %.entry ], [ 0.000000e+00, %6 ], [ %.i1, %19 ], [ %.i1, %17 ], [ %.i1, %16 ], [ 1.000000e+00, %15 ]
%samplePos.1.i2 = phi float [ 0.000000e+00, %.entry ], [ 0.000000e+00, %6 ], [ -1.000000e+00, %19 ], [ %18, %17 ], [ 1.000000e+00, %16 ], [ 1.000000e+00, %15 ]

Next step is to gather all blocks that can be safely merged be merged into 21 in order for me to create selects and to allow non constant values in helpers.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Acim Maravic (Acim-Maravic)

Changes

Next step is to gather all blocks that can be safely merged be merged into 21 in order for me to create selects and to allow non constant values in helpers.


Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82795.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h (+5)
  • (modified) llvm/lib/Passes/PassBuilder.cpp (+2)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1)
  • (modified) llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp (+6)
  • (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+55-8)
  • (added) llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll (+312)
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h b/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h
index 8008fc6e8422d3..cb3ef663408153 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h
@@ -30,6 +30,7 @@ struct SimplifyCFGOptions {
   bool SinkCommonInsts = false;
   bool SimplifyCondBranch = true;
   bool SpeculateBlocks = true;
+  bool ConvertSwitchToSelect = false;
 
   AssumptionCache *AC = nullptr;
 
@@ -46,6 +47,10 @@ struct SimplifyCFGOptions {
     ConvertSwitchRangeToICmp = B;
     return *this;
   }
+  SimplifyCFGOptions &convertSwitchToSelect(bool B) {
+    ConvertSwitchToSelect = B;
+    return *this;
+  }
   SimplifyCFGOptions &convertSwitchToLookupTable(bool B) {
     ConvertSwitchToLookupTable = B;
     return *this;
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c934ec42f6eb15..360dc4da3ca91d 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -819,6 +819,8 @@ Expected<SimplifyCFGOptions> parseSimplifyCFGOptions(StringRef Params) {
       Result.forwardSwitchCondToPhi(Enable);
     } else if (ParamName == "switch-range-to-icmp") {
       Result.convertSwitchRangeToICmp(Enable);
+    } else if (ParamName == "switch-to-select") {
+      Result.convertSwitchToSelect(Enable);
     } else if (ParamName == "switch-to-lookup") {
       Result.convertSwitchToLookupTable(Enable);
     } else if (ParamName == "keep-loops") {
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 44511800ccff8d..a3783afdf3be68 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -548,6 +548,7 @@ FUNCTION_PASS_WITH_PARAMS(
     "no-forward-switch-cond;forward-switch-cond;no-switch-range-to-icmp;"
     "switch-range-to-icmp;no-switch-to-lookup;switch-to-lookup;no-keep-loops;"
     "keep-loops;no-hoist-common-insts;hoist-common-insts;no-sink-common-insts;"
+    "switch-to-select;"
     "sink-common-insts;bonus-inst-threshold=N")
 FUNCTION_PASS_WITH_PARAMS(
     "speculative-execution", "SpeculativeExecutionPass",
diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
index 7017f6adf3a2bb..eafedef2af7d2e 100644
--- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
+++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
@@ -61,6 +61,10 @@ static cl::opt<bool> UserSwitchRangeToICmp(
     cl::desc(
         "Convert switches into an integer range comparison (default = false)"));
 
+static cl::opt<bool> UserSwitchToSelect(
+    "switch-to-select", cl::Hidden, cl::init(false),
+    cl::desc("Convert switches into icmp + select (default = false)"));
+
 static cl::opt<bool> UserSwitchToLookup(
     "switch-to-lookup", cl::Hidden, cl::init(false),
     cl::desc("Convert switches to lookup tables (default = false)"));
@@ -323,6 +327,8 @@ static void applyCommandLineOverridesToOptions(SimplifyCFGOptions &Options) {
     Options.HoistCommonInsts = UserHoistCommonInsts;
   if (UserSinkCommonInsts.getNumOccurrences())
     Options.SinkCommonInsts = UserSinkCommonInsts;
+  if (UserSwitchToSelect.getNumOccurrences())
+    Options.ConvertSwitchToSelect = UserSwitchToSelect;
 }
 
 SimplifyCFGPass::SimplifyCFGPass() {
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 254795ec244534..3572c2fe68a538 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5916,7 +5916,7 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
       return false;
 
     // Only one value per case is permitted.
-    if (Results.size() > 1)
+    if (Results.size() > 3) // How many PHI instructions are hendled
       return false;
 
     // Add the case->result mapping to UniqueResults.
@@ -5953,12 +5953,31 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
   return true;
 }
 
+Value *createSelectChain(Value *Condition, Constant *DefaultResult,
+                         const SwitchCaseResultVectorTy &ResultVector,
+                         unsigned StartIndex, IRBuilder<> &Builder) {
+  if (StartIndex >= ResultVector.size() && DefaultResult) {
+    return DefaultResult;
+  }
+
+  ConstantInt *CurrentCase = ResultVector[StartIndex].second[0];
+  Value *ValueCompare =
+      Builder.CreateICmpEQ(Condition, CurrentCase, "switch.selectcmp");
+
+  Value *NextSelect = createSelectChain(Condition, DefaultResult, ResultVector,
+                                        StartIndex + 1, Builder);
+
+  return Builder.CreateSelect(ValueCompare, ResultVector[StartIndex].first,
+                              NextSelect, "switch.select");
+}
+
 // Helper function that checks if it is possible to transform a switch with only
 // two cases (or two cases + default) that produces a result into a select.
 // TODO: Handle switches with more than 2 cases that map to the same result.
 static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
                                  Constant *DefaultResult, Value *Condition,
-                                 IRBuilder<> &Builder) {
+                                 IRBuilder<> &Builder,
+                                 bool IsComplexSwitchTransform = false) {
   // If we are selecting between only two cases transform into a simple
   // select or a two-way select if default is possible.
   // Example:
@@ -5967,6 +5986,22 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
   //   case 20: return 2;   ---->  %2 = icmp eq i32 %a, 20
   //   default: return 4;          %3 = select i1 %2, i32 2, i32 %1
   // }
+
+  if (IsComplexSwitchTransform) {
+    bool IsSizeOkay = true;
+
+    for (int i = 0; i < ResultVector.size(); i++)
+      if (ResultVector[i].second.size() != 1)
+        IsSizeOkay = false;
+
+    if (IsSizeOkay && ResultVector.size() > 2) {
+      Value *FinalSelect =
+          createSelectChain(Condition, DefaultResult, ResultVector, 0, Builder);
+      if (FinalSelect)
+        return FinalSelect;
+    }
+  }
+
   if (ResultVector.size() == 2 && ResultVector[0].second.size() == 1 &&
       ResultVector[1].second.size() == 1) {
     ConstantInt *FirstCase = ResultVector[0].second[0];
@@ -6071,21 +6106,29 @@ static void removeSwitchAfterSelectFold(SwitchInst *SI, PHINode *PHI,
 /// switch with a select. Returns true if the fold was made.
 static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
                               DomTreeUpdater *DTU, const DataLayout &DL,
-                              const TargetTransformInfo &TTI) {
+                              const TargetTransformInfo &TTI,
+                              bool IsComplexSwitchTransform = false) {
   Value *const Cond = SI->getCondition();
   PHINode *PHI = nullptr;
   BasicBlock *CommonDest = nullptr;
   Constant *DefaultResult;
   SwitchCaseResultVectorTy UniqueResults;
   // Collect all the cases that will deliver the same value from the switch.
-  if (!initializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult,
-                             DL, TTI, /*MaxUniqueResults*/ 2))
+  if (!initializeUniqueCases(
+          SI, PHI, CommonDest, UniqueResults, DefaultResult, DL, TTI,
+          /*MaxUniqueResults*/ 7)) // I think that the next step is to expand
+                                   // this function to return a list of basic
+                                   // blocks that can be merged
     return false;
 
   assert(PHI != nullptr && "PHI for value select not found");
   Builder.SetInsertPoint(SI);
-  Value *SelectValue =
-      foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder);
+  Value *SelectValue = foldSwitchToSelect(
+      UniqueResults, DefaultResult, Cond, Builder,
+      IsComplexSwitchTransform); //  Afterwards this function should just merge
+                                 //  these blocks with the predaccessor of the
+                                 //  switch. Also, UniqueResults would no longer
+                                 //  be just constants
   if (!SelectValue)
     return false;
 
@@ -7028,7 +7071,11 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
     return requestResimplify();
 
-  if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
+  bool IsSwitchToSelect = false;
+  if (Options.ConvertSwitchToSelect)
+    IsSwitchToSelect = true;
+
+  if (trySwitchToSelect(SI, Builder, DTU, DL, TTI, IsSwitchToSelect))
     return requestResimplify();
 
   if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI))
diff --git a/llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll b/llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll
new file mode 100644
index 00000000000000..fa783283b881f1
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll
@@ -0,0 +1,312 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=simplifycfg -switch-to-select < %s | FileCheck -check-prefix=ALL %s
+
+
+define float @SimpleTestTwoCasesAndDefault(<2 x float> noundef %PerspInterpCenter, i32 inreg noundef %PrimMask) {
+; ALL-LABEL: @SimpleTestTwoCasesAndDefault(
+; ALL-NEXT:  .entry:
+; ALL-NEXT:    [[PERSPINTERPCENTER_I1:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER:%.*]], i64 1
+; ALL-NEXT:    [[PERSPINTERPCENTER_I0:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER]], i64 0
+; ALL-NEXT:    [[TMP0:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK:%.*]])
+; ALL-NEXT:    [[TMP1:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP0]], float [[PERSPINTERPCENTER_I1]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP1]], 3.000000e+00
+; ALL-NEXT:    [[TMP3:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP2]])
+; ALL-NEXT:    [[TMP4:%.*]] = fptosi float [[TMP3]] to i32
+; ALL-NEXT:    [[DOTFR:%.*]] = freeze i32 [[TMP4]]
+; ALL-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[DOTFR]], 1
+; ALL-NEXT:    br i1 [[TMP5]], label [[TMP6:%.*]], label [[TMP12:%.*]]
+; ALL:       6:
+; ALL-NEXT:    [[TMP7:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP8:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP7]], float [[PERSPINTERPCENTER_I1]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP9:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP8]], 4.000000e+00
+; ALL-NEXT:    [[TMP10:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP9]])
+; ALL-NEXT:    [[TMP11:%.*]] = fptosi float [[TMP10]] to i32
+; ALL-NEXT:    [[COND_FREEZE1:%.*]] = freeze i32 [[TMP11]]
+; ALL-NEXT:    [[SWITCH_SELECTCMP:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
+; ALL-NEXT:    [[SWITCH_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP]], float 2.000000e+00, float 4.000000e+00
+; ALL-NEXT:    [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 0
+; ALL-NEXT:    [[SWITCH_SELECT2:%.*]] = select i1 [[SWITCH_SELECTCMP1]], float 1.000000e+00, float [[SWITCH_SELECT]]
+; ALL-NEXT:    br label [[TMP12]]
+; ALL:       12:
+; ALL-NEXT:    [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ [[SWITCH_SELECT2]], [[TMP6]] ]
+; ALL-NEXT:    ret float [[SAMPLEPOS_1_I0]]
+;
+.entry:
+  %PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
+  %PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
+  %0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %1 = call float @llvm.amdgcn.interp.p2(float %0, float %PerspInterpCenter.i1, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %2 = fmul reassoc nnan nsz arcp contract afn float %1, 3.000000e00
+  %3 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %2)
+  %4 = fptosi float %3 to i32
+  %.fr = freeze i32 %4
+  %5 = icmp eq i32 %.fr, 1
+  br i1 %5, label %6, label %14
+
+6:                                                ; preds = %.entry
+  %7 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %8 = call float @llvm.amdgcn.interp.p2(float %7, float %PerspInterpCenter.i1, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %9 = fmul reassoc nnan nsz arcp contract afn float %8, 4.000000e00
+  %10 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %9)
+  %11 = fptosi float %10 to i32
+  %cond.freeze1 = freeze i32 %11
+  switch i32 %cond.freeze1, label %14 [
+  i32 0, label %12
+  i32 1, label %13
+  ]
+
+12:                                               ; preds = %6
+  br label %14
+
+13:                                               ; preds = %6
+  br label %14
+
+14:                                               ; preds = %12, %13, %6, %.entry
+  %samplePos.1.i0 = phi float [ 0.000000e00, %.entry ], [ 2.000000e00, %13 ], [ 1.000000e00, %12 ], [ 4.000000e00, %6 ]
+  ret float %samplePos.1.i0
+}
+
+define float @SimpleTestTwoCases(<2 x float> noundef %PerspInterpCenter, i32 inreg noundef %PrimMask) {
+; ALL-LABEL: @SimpleTestTwoCases(
+; ALL-NEXT:  .entry:
+; ALL-NEXT:    [[PERSPINTERPCENTER_I1:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER:%.*]], i64 1
+; ALL-NEXT:    [[PERSPINTERPCENTER_I0:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER]], i64 0
+; ALL-NEXT:    [[TMP0:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK:%.*]])
+; ALL-NEXT:    [[TMP1:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP0]], float [[PERSPINTERPCENTER_I1]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP1]], 3.000000e+00
+; ALL-NEXT:    [[TMP3:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP2]])
+; ALL-NEXT:    [[TMP4:%.*]] = fptosi float [[TMP3]] to i32
+; ALL-NEXT:    [[DOTFR:%.*]] = freeze i32 [[TMP4]]
+; ALL-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[DOTFR]], 1
+; ALL-NEXT:    br i1 [[TMP5]], label [[TMP6:%.*]], label [[TMP12:%.*]]
+; ALL:       6:
+; ALL-NEXT:    [[TMP7:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP8:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP7]], float [[PERSPINTERPCENTER_I1]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP9:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP8]], 4.000000e+00
+; ALL-NEXT:    [[TMP10:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP9]])
+; ALL-NEXT:    [[TMP11:%.*]] = fptosi float [[TMP10]] to i32
+; ALL-NEXT:    [[COND_FREEZE1:%.*]] = freeze i32 [[TMP11]]
+; ALL-NEXT:    [[SWITCH_SELECTCMP:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
+; ALL-NEXT:    [[SWITCH_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP]], float 2.000000e+00, float 0.000000e+00
+; ALL-NEXT:    [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 0
+; ALL-NEXT:    [[SWITCH_SELECT2:%.*]] = select i1 [[SWITCH_SELECTCMP1]], float 1.000000e+00, float [[SWITCH_SELECT]]
+; ALL-NEXT:    br label [[TMP12]]
+; ALL:       12:
+; ALL-NEXT:    [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ [[SWITCH_SELECT2]], [[TMP6]] ]
+; ALL-NEXT:    ret float [[SAMPLEPOS_1_I0]]
+;
+.entry:
+  %PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
+  %PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
+  %0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %1 = call float @llvm.amdgcn.interp.p2(float %0, float %PerspInterpCenter.i1, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %2 = fmul reassoc nnan nsz arcp contract afn float %1, 3.000000e00
+  %3 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %2)
+  %4 = fptosi float %3 to i32
+  %.fr = freeze i32 %4
+  %5 = icmp eq i32 %.fr, 1
+  br i1 %5, label %6, label %14
+
+6:                                                ; preds = %.entry
+  %7 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %8 = call float @llvm.amdgcn.interp.p2(float %7, float %PerspInterpCenter.i1, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %9 = fmul reassoc nnan nsz arcp contract afn float %8, 4.000000e00
+  %10 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %9)
+  %11 = fptosi float %10 to i32
+  %cond.freeze1 = freeze i32 %11
+  switch i32 %cond.freeze1, label %14 [
+  i32 0, label %12
+  i32 1, label %13
+  ]
+
+12:                                               ; preds = %6
+  br label %14
+
+13:                                               ; preds = %6
+  br label %14
+
+14:                                               ; preds = %6, %12, %13, %.entry
+  %samplePos.1.i0 = phi float [ 0.000000e00, %.entry ], [ 0.000000e00, %6 ], [ 2.000000e00, %13 ], [ 1.000000e00, %12 ]
+  ret float %samplePos.1.i0
+}
+
+
+
+define float @SimpleTestSwitch(<2 x float> noundef %PerspInterpCenter, i32 inreg noundef %PrimMask) {
+; ALL-LABEL: @SimpleTestSwitch(
+; ALL-NEXT:  .entry:
+; ALL-NEXT:    [[PERSPINTERPCENTER_I1:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER:%.*]], i64 1
+; ALL-NEXT:    [[PERSPINTERPCENTER_I0:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER]], i64 0
+; ALL-NEXT:    [[TMP0:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK:%.*]])
+; ALL-NEXT:    [[TMP1:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP0]], float [[PERSPINTERPCENTER_I1]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP1]], 3.000000e+00
+; ALL-NEXT:    [[TMP3:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP2]])
+; ALL-NEXT:    [[TMP4:%.*]] = fptosi float [[TMP3]] to i32
+; ALL-NEXT:    [[DOTFR:%.*]] = freeze i32 [[TMP4]]
+; ALL-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[DOTFR]], 1
+; ALL-NEXT:    br i1 [[TMP5]], label [[TMP6:%.*]], label [[TMP12:%.*]]
+; ALL:       6:
+; ALL-NEXT:    [[TMP7:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP8:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP7]], float [[PERSPINTERPCENTER_I1]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP9:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP8]], 4.000000e+00
+; ALL-NEXT:    [[TMP10:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP9]])
+; ALL-NEXT:    [[TMP11:%.*]] = fptosi float [[TMP10]] to i32
+; ALL-NEXT:    [[COND_FREEZE1:%.*]] = freeze i32 [[TMP11]]
+; ALL-NEXT:    [[SWITCH_SELECTCMP:%.*]] = icmp eq i32 [[COND_FREEZE1]], 0
+; ALL-NEXT:    [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
+; ALL-NEXT:    [[SWITCH_SELECTCMP2:%.*]] = icmp eq i32 [[COND_FREEZE1]], 2
+; ALL-NEXT:    [[SWITCH_SELECTCMP3:%.*]] = icmp eq i32 [[COND_FREEZE1]], 3
+; ALL-NEXT:    [[SWITCHACIM_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP3]], float 4.000000e+00, float 0.000000e+00
+; ALL-NEXT:    [[SWITCHACIM_SELECT4:%.*]] = select i1 [[SWITCH_SELECTCMP2]], float 3.000000e+00, float [[SWITCHACIM_SELECT]]
+; ALL-NEXT:    [[SWITCHACIM_SELECT5:%.*]] = select i1 [[SWITCH_SELECTCMP1]], float 2.000000e+00, float [[SWITCHACIM_SELECT4]]
+; ALL-NEXT:    [[SWITCHACIM_SELECT6:%.*]] = select i1 [[SWITCH_SELECTCMP]], float 1.000000e+00, float [[SWITCHACIM_SELECT5]]
+; ALL-NEXT:    br label [[TMP12]]
+; ALL:       12:
+; ALL-NEXT:    [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ [[SWITCHACIM_SELECT6]], [[TMP6]] ]
+; ALL-NEXT:    ret float [[SAMPLEPOS_1_I0]]
+;
+.entry:
+  %PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
+  %PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
+  %0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask)...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Acim Maravic (Acim-Maravic)

Changes

Next step is to gather all blocks that can be safely merged be merged into 21 in order for me to create selects and to allow non constant values in helpers.


Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82795.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h (+5)
  • (modified) llvm/lib/Passes/PassBuilder.cpp (+2)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1)
  • (modified) llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp (+6)
  • (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+55-8)
  • (added) llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll (+312)
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h b/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h
index 8008fc6e8422d3..cb3ef663408153 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h
@@ -30,6 +30,7 @@ struct SimplifyCFGOptions {
   bool SinkCommonInsts = false;
   bool SimplifyCondBranch = true;
   bool SpeculateBlocks = true;
+  bool ConvertSwitchToSelect = false;
 
   AssumptionCache *AC = nullptr;
 
@@ -46,6 +47,10 @@ struct SimplifyCFGOptions {
     ConvertSwitchRangeToICmp = B;
     return *this;
   }
+  SimplifyCFGOptions &convertSwitchToSelect(bool B) {
+    ConvertSwitchToSelect = B;
+    return *this;
+  }
   SimplifyCFGOptions &convertSwitchToLookupTable(bool B) {
     ConvertSwitchToLookupTable = B;
     return *this;
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c934ec42f6eb15..360dc4da3ca91d 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -819,6 +819,8 @@ Expected<SimplifyCFGOptions> parseSimplifyCFGOptions(StringRef Params) {
       Result.forwardSwitchCondToPhi(Enable);
     } else if (ParamName == "switch-range-to-icmp") {
       Result.convertSwitchRangeToICmp(Enable);
+    } else if (ParamName == "switch-to-select") {
+      Result.convertSwitchToSelect(Enable);
     } else if (ParamName == "switch-to-lookup") {
       Result.convertSwitchToLookupTable(Enable);
     } else if (ParamName == "keep-loops") {
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 44511800ccff8d..a3783afdf3be68 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -548,6 +548,7 @@ FUNCTION_PASS_WITH_PARAMS(
     "no-forward-switch-cond;forward-switch-cond;no-switch-range-to-icmp;"
     "switch-range-to-icmp;no-switch-to-lookup;switch-to-lookup;no-keep-loops;"
     "keep-loops;no-hoist-common-insts;hoist-common-insts;no-sink-common-insts;"
+    "switch-to-select;"
     "sink-common-insts;bonus-inst-threshold=N")
 FUNCTION_PASS_WITH_PARAMS(
     "speculative-execution", "SpeculativeExecutionPass",
diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
index 7017f6adf3a2bb..eafedef2af7d2e 100644
--- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
+++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
@@ -61,6 +61,10 @@ static cl::opt<bool> UserSwitchRangeToICmp(
     cl::desc(
         "Convert switches into an integer range comparison (default = false)"));
 
+static cl::opt<bool> UserSwitchToSelect(
+    "switch-to-select", cl::Hidden, cl::init(false),
+    cl::desc("Convert switches into icmp + select (default = false)"));
+
 static cl::opt<bool> UserSwitchToLookup(
     "switch-to-lookup", cl::Hidden, cl::init(false),
     cl::desc("Convert switches to lookup tables (default = false)"));
@@ -323,6 +327,8 @@ static void applyCommandLineOverridesToOptions(SimplifyCFGOptions &Options) {
     Options.HoistCommonInsts = UserHoistCommonInsts;
   if (UserSinkCommonInsts.getNumOccurrences())
     Options.SinkCommonInsts = UserSinkCommonInsts;
+  if (UserSwitchToSelect.getNumOccurrences())
+    Options.ConvertSwitchToSelect = UserSwitchToSelect;
 }
 
 SimplifyCFGPass::SimplifyCFGPass() {
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 254795ec244534..3572c2fe68a538 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5916,7 +5916,7 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
       return false;
 
     // Only one value per case is permitted.
-    if (Results.size() > 1)
+    if (Results.size() > 3) // How many PHI instructions are hendled
       return false;
 
     // Add the case->result mapping to UniqueResults.
@@ -5953,12 +5953,31 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
   return true;
 }
 
+Value *createSelectChain(Value *Condition, Constant *DefaultResult,
+                         const SwitchCaseResultVectorTy &ResultVector,
+                         unsigned StartIndex, IRBuilder<> &Builder) {
+  if (StartIndex >= ResultVector.size() && DefaultResult) {
+    return DefaultResult;
+  }
+
+  ConstantInt *CurrentCase = ResultVector[StartIndex].second[0];
+  Value *ValueCompare =
+      Builder.CreateICmpEQ(Condition, CurrentCase, "switch.selectcmp");
+
+  Value *NextSelect = createSelectChain(Condition, DefaultResult, ResultVector,
+                                        StartIndex + 1, Builder);
+
+  return Builder.CreateSelect(ValueCompare, ResultVector[StartIndex].first,
+                              NextSelect, "switch.select");
+}
+
 // Helper function that checks if it is possible to transform a switch with only
 // two cases (or two cases + default) that produces a result into a select.
 // TODO: Handle switches with more than 2 cases that map to the same result.
 static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
                                  Constant *DefaultResult, Value *Condition,
-                                 IRBuilder<> &Builder) {
+                                 IRBuilder<> &Builder,
+                                 bool IsComplexSwitchTransform = false) {
   // If we are selecting between only two cases transform into a simple
   // select or a two-way select if default is possible.
   // Example:
@@ -5967,6 +5986,22 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
   //   case 20: return 2;   ---->  %2 = icmp eq i32 %a, 20
   //   default: return 4;          %3 = select i1 %2, i32 2, i32 %1
   // }
+
+  if (IsComplexSwitchTransform) {
+    bool IsSizeOkay = true;
+
+    for (int i = 0; i < ResultVector.size(); i++)
+      if (ResultVector[i].second.size() != 1)
+        IsSizeOkay = false;
+
+    if (IsSizeOkay && ResultVector.size() > 2) {
+      Value *FinalSelect =
+          createSelectChain(Condition, DefaultResult, ResultVector, 0, Builder);
+      if (FinalSelect)
+        return FinalSelect;
+    }
+  }
+
   if (ResultVector.size() == 2 && ResultVector[0].second.size() == 1 &&
       ResultVector[1].second.size() == 1) {
     ConstantInt *FirstCase = ResultVector[0].second[0];
@@ -6071,21 +6106,29 @@ static void removeSwitchAfterSelectFold(SwitchInst *SI, PHINode *PHI,
 /// switch with a select. Returns true if the fold was made.
 static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
                               DomTreeUpdater *DTU, const DataLayout &DL,
-                              const TargetTransformInfo &TTI) {
+                              const TargetTransformInfo &TTI,
+                              bool IsComplexSwitchTransform = false) {
   Value *const Cond = SI->getCondition();
   PHINode *PHI = nullptr;
   BasicBlock *CommonDest = nullptr;
   Constant *DefaultResult;
   SwitchCaseResultVectorTy UniqueResults;
   // Collect all the cases that will deliver the same value from the switch.
-  if (!initializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult,
-                             DL, TTI, /*MaxUniqueResults*/ 2))
+  if (!initializeUniqueCases(
+          SI, PHI, CommonDest, UniqueResults, DefaultResult, DL, TTI,
+          /*MaxUniqueResults*/ 7)) // I think that the next step is to expand
+                                   // this function to return a list of basic
+                                   // blocks that can be merged
     return false;
 
   assert(PHI != nullptr && "PHI for value select not found");
   Builder.SetInsertPoint(SI);
-  Value *SelectValue =
-      foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder);
+  Value *SelectValue = foldSwitchToSelect(
+      UniqueResults, DefaultResult, Cond, Builder,
+      IsComplexSwitchTransform); //  Afterwards this function should just merge
+                                 //  these blocks with the predaccessor of the
+                                 //  switch. Also, UniqueResults would no longer
+                                 //  be just constants
   if (!SelectValue)
     return false;
 
@@ -7028,7 +7071,11 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
     return requestResimplify();
 
-  if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
+  bool IsSwitchToSelect = false;
+  if (Options.ConvertSwitchToSelect)
+    IsSwitchToSelect = true;
+
+  if (trySwitchToSelect(SI, Builder, DTU, DL, TTI, IsSwitchToSelect))
     return requestResimplify();
 
   if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI))
diff --git a/llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll b/llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll
new file mode 100644
index 00000000000000..fa783283b881f1
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/AMDGPU/switch-to-select.ll
@@ -0,0 +1,312 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=simplifycfg -switch-to-select < %s | FileCheck -check-prefix=ALL %s
+
+
+define float @SimpleTestTwoCasesAndDefault(<2 x float> noundef %PerspInterpCenter, i32 inreg noundef %PrimMask) {
+; ALL-LABEL: @SimpleTestTwoCasesAndDefault(
+; ALL-NEXT:  .entry:
+; ALL-NEXT:    [[PERSPINTERPCENTER_I1:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER:%.*]], i64 1
+; ALL-NEXT:    [[PERSPINTERPCENTER_I0:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER]], i64 0
+; ALL-NEXT:    [[TMP0:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK:%.*]])
+; ALL-NEXT:    [[TMP1:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP0]], float [[PERSPINTERPCENTER_I1]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP1]], 3.000000e+00
+; ALL-NEXT:    [[TMP3:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP2]])
+; ALL-NEXT:    [[TMP4:%.*]] = fptosi float [[TMP3]] to i32
+; ALL-NEXT:    [[DOTFR:%.*]] = freeze i32 [[TMP4]]
+; ALL-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[DOTFR]], 1
+; ALL-NEXT:    br i1 [[TMP5]], label [[TMP6:%.*]], label [[TMP12:%.*]]
+; ALL:       6:
+; ALL-NEXT:    [[TMP7:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP8:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP7]], float [[PERSPINTERPCENTER_I1]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP9:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP8]], 4.000000e+00
+; ALL-NEXT:    [[TMP10:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP9]])
+; ALL-NEXT:    [[TMP11:%.*]] = fptosi float [[TMP10]] to i32
+; ALL-NEXT:    [[COND_FREEZE1:%.*]] = freeze i32 [[TMP11]]
+; ALL-NEXT:    [[SWITCH_SELECTCMP:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
+; ALL-NEXT:    [[SWITCH_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP]], float 2.000000e+00, float 4.000000e+00
+; ALL-NEXT:    [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 0
+; ALL-NEXT:    [[SWITCH_SELECT2:%.*]] = select i1 [[SWITCH_SELECTCMP1]], float 1.000000e+00, float [[SWITCH_SELECT]]
+; ALL-NEXT:    br label [[TMP12]]
+; ALL:       12:
+; ALL-NEXT:    [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ [[SWITCH_SELECT2]], [[TMP6]] ]
+; ALL-NEXT:    ret float [[SAMPLEPOS_1_I0]]
+;
+.entry:
+  %PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
+  %PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
+  %0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %1 = call float @llvm.amdgcn.interp.p2(float %0, float %PerspInterpCenter.i1, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %2 = fmul reassoc nnan nsz arcp contract afn float %1, 3.000000e00
+  %3 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %2)
+  %4 = fptosi float %3 to i32
+  %.fr = freeze i32 %4
+  %5 = icmp eq i32 %.fr, 1
+  br i1 %5, label %6, label %14
+
+6:                                                ; preds = %.entry
+  %7 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %8 = call float @llvm.amdgcn.interp.p2(float %7, float %PerspInterpCenter.i1, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %9 = fmul reassoc nnan nsz arcp contract afn float %8, 4.000000e00
+  %10 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %9)
+  %11 = fptosi float %10 to i32
+  %cond.freeze1 = freeze i32 %11
+  switch i32 %cond.freeze1, label %14 [
+  i32 0, label %12
+  i32 1, label %13
+  ]
+
+12:                                               ; preds = %6
+  br label %14
+
+13:                                               ; preds = %6
+  br label %14
+
+14:                                               ; preds = %12, %13, %6, %.entry
+  %samplePos.1.i0 = phi float [ 0.000000e00, %.entry ], [ 2.000000e00, %13 ], [ 1.000000e00, %12 ], [ 4.000000e00, %6 ]
+  ret float %samplePos.1.i0
+}
+
+define float @SimpleTestTwoCases(<2 x float> noundef %PerspInterpCenter, i32 inreg noundef %PrimMask) {
+; ALL-LABEL: @SimpleTestTwoCases(
+; ALL-NEXT:  .entry:
+; ALL-NEXT:    [[PERSPINTERPCENTER_I1:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER:%.*]], i64 1
+; ALL-NEXT:    [[PERSPINTERPCENTER_I0:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER]], i64 0
+; ALL-NEXT:    [[TMP0:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK:%.*]])
+; ALL-NEXT:    [[TMP1:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP0]], float [[PERSPINTERPCENTER_I1]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP1]], 3.000000e+00
+; ALL-NEXT:    [[TMP3:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP2]])
+; ALL-NEXT:    [[TMP4:%.*]] = fptosi float [[TMP3]] to i32
+; ALL-NEXT:    [[DOTFR:%.*]] = freeze i32 [[TMP4]]
+; ALL-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[DOTFR]], 1
+; ALL-NEXT:    br i1 [[TMP5]], label [[TMP6:%.*]], label [[TMP12:%.*]]
+; ALL:       6:
+; ALL-NEXT:    [[TMP7:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP8:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP7]], float [[PERSPINTERPCENTER_I1]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP9:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP8]], 4.000000e+00
+; ALL-NEXT:    [[TMP10:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP9]])
+; ALL-NEXT:    [[TMP11:%.*]] = fptosi float [[TMP10]] to i32
+; ALL-NEXT:    [[COND_FREEZE1:%.*]] = freeze i32 [[TMP11]]
+; ALL-NEXT:    [[SWITCH_SELECTCMP:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
+; ALL-NEXT:    [[SWITCH_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP]], float 2.000000e+00, float 0.000000e+00
+; ALL-NEXT:    [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 0
+; ALL-NEXT:    [[SWITCH_SELECT2:%.*]] = select i1 [[SWITCH_SELECTCMP1]], float 1.000000e+00, float [[SWITCH_SELECT]]
+; ALL-NEXT:    br label [[TMP12]]
+; ALL:       12:
+; ALL-NEXT:    [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ [[SWITCH_SELECT2]], [[TMP6]] ]
+; ALL-NEXT:    ret float [[SAMPLEPOS_1_I0]]
+;
+.entry:
+  %PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
+  %PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
+  %0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %1 = call float @llvm.amdgcn.interp.p2(float %0, float %PerspInterpCenter.i1, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
+  %2 = fmul reassoc nnan nsz arcp contract afn float %1, 3.000000e00
+  %3 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %2)
+  %4 = fptosi float %3 to i32
+  %.fr = freeze i32 %4
+  %5 = icmp eq i32 %.fr, 1
+  br i1 %5, label %6, label %14
+
+6:                                                ; preds = %.entry
+  %7 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %8 = call float @llvm.amdgcn.interp.p2(float %7, float %PerspInterpCenter.i1, i32 immarg 0, i32 immarg 0, i32 %PrimMask) #1
+  %9 = fmul reassoc nnan nsz arcp contract afn float %8, 4.000000e00
+  %10 = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float %9)
+  %11 = fptosi float %10 to i32
+  %cond.freeze1 = freeze i32 %11
+  switch i32 %cond.freeze1, label %14 [
+  i32 0, label %12
+  i32 1, label %13
+  ]
+
+12:                                               ; preds = %6
+  br label %14
+
+13:                                               ; preds = %6
+  br label %14
+
+14:                                               ; preds = %6, %12, %13, %.entry
+  %samplePos.1.i0 = phi float [ 0.000000e00, %.entry ], [ 0.000000e00, %6 ], [ 2.000000e00, %13 ], [ 1.000000e00, %12 ]
+  ret float %samplePos.1.i0
+}
+
+
+
+define float @SimpleTestSwitch(<2 x float> noundef %PerspInterpCenter, i32 inreg noundef %PrimMask) {
+; ALL-LABEL: @SimpleTestSwitch(
+; ALL-NEXT:  .entry:
+; ALL-NEXT:    [[PERSPINTERPCENTER_I1:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER:%.*]], i64 1
+; ALL-NEXT:    [[PERSPINTERPCENTER_I0:%.*]] = extractelement <2 x float> [[PERSPINTERPCENTER]], i64 0
+; ALL-NEXT:    [[TMP0:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK:%.*]])
+; ALL-NEXT:    [[TMP1:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP0]], float [[PERSPINTERPCENTER_I1]], i32 immarg 1, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP2:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP1]], 3.000000e+00
+; ALL-NEXT:    [[TMP3:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP2]])
+; ALL-NEXT:    [[TMP4:%.*]] = fptosi float [[TMP3]] to i32
+; ALL-NEXT:    [[DOTFR:%.*]] = freeze i32 [[TMP4]]
+; ALL-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[DOTFR]], 1
+; ALL-NEXT:    br i1 [[TMP5]], label [[TMP6:%.*]], label [[TMP12:%.*]]
+; ALL:       6:
+; ALL-NEXT:    [[TMP7:%.*]] = call float @llvm.amdgcn.interp.p1(float [[PERSPINTERPCENTER_I0]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP8:%.*]] = call float @llvm.amdgcn.interp.p2(float [[TMP7]], float [[PERSPINTERPCENTER_I1]], i32 immarg 0, i32 immarg 0, i32 [[PRIMMASK]])
+; ALL-NEXT:    [[TMP9:%.*]] = fmul reassoc nnan nsz arcp contract afn float [[TMP8]], 4.000000e+00
+; ALL-NEXT:    [[TMP10:%.*]] = call reassoc nnan nsz arcp contract afn float @llvm.floor.f32(float [[TMP9]])
+; ALL-NEXT:    [[TMP11:%.*]] = fptosi float [[TMP10]] to i32
+; ALL-NEXT:    [[COND_FREEZE1:%.*]] = freeze i32 [[TMP11]]
+; ALL-NEXT:    [[SWITCH_SELECTCMP:%.*]] = icmp eq i32 [[COND_FREEZE1]], 0
+; ALL-NEXT:    [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
+; ALL-NEXT:    [[SWITCH_SELECTCMP2:%.*]] = icmp eq i32 [[COND_FREEZE1]], 2
+; ALL-NEXT:    [[SWITCH_SELECTCMP3:%.*]] = icmp eq i32 [[COND_FREEZE1]], 3
+; ALL-NEXT:    [[SWITCHACIM_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP3]], float 4.000000e+00, float 0.000000e+00
+; ALL-NEXT:    [[SWITCHACIM_SELECT4:%.*]] = select i1 [[SWITCH_SELECTCMP2]], float 3.000000e+00, float [[SWITCHACIM_SELECT]]
+; ALL-NEXT:    [[SWITCHACIM_SELECT5:%.*]] = select i1 [[SWITCH_SELECTCMP1]], float 2.000000e+00, float [[SWITCHACIM_SELECT4]]
+; ALL-NEXT:    [[SWITCHACIM_SELECT6:%.*]] = select i1 [[SWITCH_SELECTCMP]], float 1.000000e+00, float [[SWITCHACIM_SELECT5]]
+; ALL-NEXT:    br label [[TMP12]]
+; ALL:       12:
+; ALL-NEXT:    [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ [[SWITCHACIM_SELECT6]], [[TMP6]] ]
+; ALL-NEXT:    ret float [[SAMPLEPOS_1_I0]]
+;
+.entry:
+  %PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
+  %PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
+  %0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask)...
[truncated]

@jayfoad jayfoad requested a review from nikic February 26, 2024 10:31
@jayfoad
Copy link
Contributor

jayfoad commented Feb 26, 2024

Needs a better title e.g. "[SimplifyCFG] Convert switch to cmp/select sequence"

@jayfoad
Copy link
Contributor

jayfoad commented Feb 26, 2024

The general idea of doing this in SimplifyCFG seems reasonable to me. However I am not too familiar with the existing code in SimplifyCFG.

Comment on lines 64 to 67
static cl::opt<bool> UserSwitchToSelect(
"switch-to-select", cl::Hidden, cl::init(false),
cl::desc("Convert switches into icmp + select (default = false)"));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need new cl::opts for this? Can just be a pass parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for testing purposes, I guess in one point we will instantiate a pass with enable by default...

@@ -5916,7 +5916,7 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
return false;

// Only one value per case is permitted.
if (Results.size() > 1)
if (Results.size() > 3) // How many PHI instructions are hendled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: handled

@@ -5916,7 +5916,7 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
return false;

// Only one value per case is permitted.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment needs update?

@@ -0,0 +1,312 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be a non-AMDGPU test case since this is a general optimization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do it that way.

if (IsSizeOkay && ResultVector.size() > 2) {
Value *FinalSelect =
createSelectChain(Condition, DefaultResult, ResultVector, 0, Builder);
if (FinalSelect)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createSelectChain always returns a non-nullptr value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
bool IsSwitchToSelect = false;
if (Options.ConvertSwitchToSelect)
IsSwitchToSelect = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsSwitchToSelect = Options.ConvertSwitchToSelect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

; ALL-NEXT: [[SWITCH_SELECTCMP1:%.*]] = icmp eq i32 [[COND_FREEZE1]], 1
; ALL-NEXT: [[SWITCH_SELECTCMP2:%.*]] = icmp eq i32 [[COND_FREEZE1]], 2
; ALL-NEXT: [[SWITCH_SELECTCMP3:%.*]] = icmp eq i32 [[COND_FREEZE1]], 3
; ALL-NEXT: [[SWITCHACIM_SELECT:%.*]] = select i1 [[SWITCH_SELECTCMP3]], float 4.000000e+00, float 0.000000e+00
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the name Switchacim intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have forgotten to update tests, it was just for debugging.

; ALL-NEXT: [[TMP20:%.*]] = fneg reassoc nnan nsz arcp contract afn float [[DOTI0]]
; ALL-NEXT: br label [[TMP21]]
; ALL: 21:
; ALL-NEXT: [[SAMPLEPOS_1_I0:%.*]] = phi float [ 0.000000e+00, [[DOTENTRY:%.*]] ], [ 0.000000e+00, [[TMP6]] ], [ [[TMP20]], [[TMP19]] ], [ 1.000000e+00, [[TMP17]] ], [ [[DOTI0]], [[TMP16]] ], [ 1.000000e+00, [[TMP15]] ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is not working on partially constant phi nodes, right? It would be a great addition since we'd be able to generate selects out of code like shown in this test, and I guess this is something that will often appear in practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, it should be working with non-constant PHI Nodes, and with multiple PHI nodes.

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the title? I have no idea what it means as-is

.entry:
%PerspInterpCenter.i1 = extractelement <2 x float> %PerspInterpCenter, i64 1
%PerspInterpCenter.i0 = extractelement <2 x float> %PerspInterpCenter, i64 0
%0 = call float @llvm.amdgcn.interp.p1(float %PerspInterpCenter.i0, i32 immarg 1, i32 immarg 0, i32 %PrimMask) #1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use named values in tests

@Acim-Maravic Acim-Maravic changed the title Imprecise switch case [SimplifyCFG] Convert switch to cmp/select sequence Mar 26, 2024
@Acim-Maravic
Copy link
Contributor Author

Fixed support for handling of multiple PHI nodes. Reworked implementation to try to carve out cases with constant values instead of giving up if not all cases are constant.

Copy link
Member

@XChy XChy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, it's hard for me to understand the details of your code. Could you please add more comments to explain? And I think it's not always profitable to convert a big switch into a long select sequence. We need to add a threshold for it.

Pred = CaseDest;
CaseDest = I.getSuccessor(0);
} else {
// samo jedna instrukcija se hendluje po bloku
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can this condition happen?

// Find the default result value.
SmallVector<std::pair<PHINode *, Value *>, 1> DefaultResults;
BasicBlock *DefaultDest = SI->getDefaultDest();
getCaseResultsWithoutConstants(SI, SI->getDefaultDest(), &CommonDest,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we care about the return value here?

// The first field contains the value that the switch produces when a certain
// case group is selected, and the second field is a vector containing the
// cases composing the case group.
using SwitchCaseResultVectorTy2 =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont really understand this name. Do you mean ResultToCasesTy?

Comment on lines +6218 to +6234
if (I.isTerminator()) {
// If the terminator is a simple branch, continue to the next block.
if (I.getNumSuccessors() != 1 || I.isSpecialTerminator())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (I.isTerminator()) {
// If the terminator is a simple branch, continue to the next block.
if (I.getNumSuccessors() != 1 || I.isSpecialTerminator())
if (BranchInst* BI = dyn_cast<BranchInst>(&I)) {
// If the terminator is a simple branch, continue to the next block.
if (!BI->isUnconditional())

return: ; preds = %sw.bb1, %sw.bb2, %sw.bb3, %sw.bb4, %sw, %.entry
%samplePos.1.i0 = phi float [ 0.000000e+00, %.entry ], [ %20, %sw.bb4 ], [ -1.000000e+00, %sw.bb3 ], [ %.i0, %sw.bb2 ], [ 1.000000e+00, %sw.bb1 ], [ 4.000000e+00, %sw ]
ret float %samplePos.1.i0
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Negative tests? For example,

  • No common dest for cases.
  • The switch jumps direct into the dest block.
  • Loop with switch.

SmallVector<std::pair<Value *, SmallVector<Value *, 4>>, 2>;

using PHINodeToCaseEntryValueMapTy = std::map<
llvm::PHINode *,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need llvm::


return SelectedValue;

} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No else after return

if (Index != -1) {
Value *Val = PHI->getIncomingValue(Index);
if (isa<Constant>(Val)) {
PHI->setIncomingValue(Index, SelectValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks for repeated predecessors

@@ -552,6 +552,7 @@ FUNCTION_PASS_WITH_PARAMS(
"no-forward-switch-cond;forward-switch-cond;no-switch-range-to-icmp;"
"switch-range-to-icmp;no-switch-to-lookup;switch-to-lookup;no-keep-loops;"
"keep-loops;no-hoist-common-insts;hoist-common-insts;no-sink-common-insts;"
"switch-to-select"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing ; at the end?

}
}
UniqueResults.push_back(
std::make_pair(Result, SmallVector<Value *, 4>(1, CaseVal)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use {} as initializer list?

bool IsDefault = false) {
BasicBlock *Pred = SI->getParent();
int NumOfInsts = 0;
// Check if there is only one instruction per block, excepet from default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: excepet

if (Instruction *I = dyn_cast<Instruction>(User))
if (I->getParent() == CaseDest)
continue;
if (PHINode *Phi = dyn_cast<PHINode>(User))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could have some parentheses { and }, and you can pull the nested if into the top if via ;

Res.push_back(std::make_pair(&PHI, Val));
}

return Res.size() > 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!Res.empty()

SwitchCaseResultVectorTy2 UniqueResults;
PHINodeToCaseEntryValueMapTy Map;

if (!initializeuniqueCasesWithoutConstants(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!initializeuniqueCasesWithoutConstants(
if (!initializeUniqueCasesWithoutConstants(

return true;
};

if (isUnableToOptimize(Map))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the lambda required? There's only a single use of it

Value *value = ValuePair.second;
if (constantInt == CaseVal)
if (!isa<Constant>(value))
IsSafeToRemove = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the loop break then?

Copy link

github-actions bot commented May 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants