Skip to content

Conversation

OutOfCache
Copy link
Contributor

This pass models the relationship between a switch variable
and an associated constant offset (like from a getelementptr
or an integer add). Find a linear function that represents
this relationship and use the function to create a new,
more generalized version of that operation.

Similar to SimplifyIndVar, but instead of the correlation
between a loop iteration and an induction variable,
we look at the switch variable (i.e., the case value) and
the offset.

It can help eliminate the switch entirely in some cases
with the help of dead code elimination and
simplifycfg.

This pass models the relationship between a switch variable
and an associated constant offset (like from a getelementptr
or an integer add). Find a linear function that represents
this relationship and use the function to create a new,
more generalized version of that operation.

Similar to SimplifyIndVar, but instead of the correlation
between a loop iteration and an induction variable,
we look at the switch variable (i.e., the case value) and
the offset.
Find the BB, where most of the cases (and perhaps the
switch BB itself) meet.

We can use that BB to find phi nodes with the case BBs of the
switch statement later on. If we can generalize the instruction
with the constant offset, and that instruction is also an
incoming value of the phi node, we can replace the incoming
values with a generalized version of the instruction.

That can enable other optimizations like dead code elimination
and simplifycfg to further clean up the IR.
Determine if a phi's incoming value comes from a case BB.
If it does, then record the case value and the index in
the phi node alongside the incoming value and collect
them in a vector.
Return the common base value, type of operation and offset type
of the incoming values.

These parameters help with the creation of the generalized
operation later on.

Anything other than geps and integer adds are not supported.
Filter out invalid cases and return a vector of just valid ones.
Valid cases have the same base value as the found base value
and they have a constant offset.

Collect the valid ones in a map of case values to offsets.
Essentially, the map contains points with the case value
as x-value and offset as y-value.

We use that map to determin a linear function through
these points in a later change.
Using the map as collection of points (where the x-value is the
case value and the y-value is the offset), find a linear function
via random sampling.

If we cannot find a function that contains at least half of the
points within 5 tries, abort.

The number 5 was chosen because it gives enough confidence in
most cases, see
https://en.wikipedia.org/wiki/Random_sample_consensus#Parameters

If 80% of the points are on a straight line, then the
probability p of finding the line with two points is
0.8^2.
The chance of not finding the line is (1 - p).
The chance of not finding the line with 5 random
tries is (1 - p)^5, which is ~0.6%.

Use the same seed to be deterministic.
Filter out outlier of the function as we cannot replace them.
Create a getelementptr or an add using the found
function and instruction parameters.

Calculate the result of slope * switchvar + bias and
use the result in the new instruction.
Use the newly created operation instead of the
previously collected incoming values.
@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jessica Del (OutOfCache)

Changes

This pass models the relationship between a switch variable
and an associated constant offset (like from a getelementptr
or an integer add). Find a linear function that represents
this relationship and use the function to create a new,
more generalized version of that operation.

Similar to SimplifyIndVar, but instead of the correlation
between a loop iteration and an induction variable,
we look at the switch variable (i.e., the case value) and
the offset.

It can help eliminate the switch entirely in some cases
with the help of dead code elimination and
simplifycfg.


Patch is 45.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149937.diff

6 Files Affected:

  • (added) llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h (+27)
  • (modified) llvm/lib/Passes/PassBuilder.cpp (+1)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1)
  • (modified) llvm/lib/Transforms/Utils/CMakeLists.txt (+1)
  • (added) llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp (+371)
  • (added) llvm/test/Transforms/Util/simplify-switch-var.ll (+860)
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h b/llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h
new file mode 100644
index 0000000000000..9d98a470a96d6
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h
@@ -0,0 +1,27 @@
+//===-- SimplifySwitchVar.h - Simplify Switch Variables ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// This file defines a pass for switch variable simplification.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SIMPLIFYSWITCHVAR_H
+#define LLVM_TRANSFORMS_SIMPLIFYSWITCHVAR_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class SimplifySwitchVarPass : public PassInfoMixin<SimplifySwitchVarPass> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_TRANSFORMS_SIMPLIFYSWITCHVAR_H
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index e15570c3f600e..4d40906e4a591 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -364,6 +364,7 @@
 #include "llvm/Transforms/Utils/NameAnonGlobals.h"
 #include "llvm/Transforms/Utils/PredicateInfo.h"
 #include "llvm/Transforms/Utils/RelLookupTableConverter.h"
+#include "llvm/Transforms/Utils/SimplifySwitchVar.h"
 #include "llvm/Transforms/Utils/StripGCRelocates.h"
 #include "llvm/Transforms/Utils/StripNonLineTableDebugInfo.h"
 #include "llvm/Transforms/Utils/SymbolRewriter.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index caa78b613b901..c60b6bebac609 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -537,6 +537,7 @@ FUNCTION_PASS("slp-vectorizer", SLPVectorizerPass())
 FUNCTION_PASS("slsr", StraightLineStrengthReducePass())
 FUNCTION_PASS("stack-protector", StackProtectorPass(TM))
 FUNCTION_PASS("strip-gc-relocates", StripGCRelocates())
+FUNCTION_PASS("simplify-switch-var", SimplifySwitchVarPass())
 FUNCTION_PASS("tailcallelim", TailCallElimPass())
 FUNCTION_PASS("transform-warning", WarnMissedTransformationsPass())
 FUNCTION_PASS("trigger-crash-function", TriggerCrashFunctionPass())
diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt
index f7e66eca6aeb3..1138256c6851f 100644
--- a/llvm/lib/Transforms/Utils/CMakeLists.txt
+++ b/llvm/lib/Transforms/Utils/CMakeLists.txt
@@ -84,6 +84,7 @@ add_llvm_component_library(LLVMTransformUtils
   SizeOpts.cpp
   SplitModule.cpp
   StripNonLineTableDebugInfo.cpp
+  SimplifySwitchVar.cpp
   SymbolRewriter.cpp
   UnifyFunctionExitNodes.cpp
   UnifyLoopExits.cpp
diff --git a/llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp b/llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp
new file mode 100644
index 0000000000000..e30ef6c12b20f
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp
@@ -0,0 +1,371 @@
+//===-- SimplifySwitchVar.cpp - Switch Variable simplification ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This file implements switch variable simplification. It looks for a
+/// linear relationship between the case value of a switch and the constant
+/// offset of an operation. Knowing this relationship, we can simplify
+/// multiple individual operations into a single, more generic one, which
+/// can help with further optimizations.
+///
+/// It is similar to SimplifyIndVar, but instead of looking at an
+/// induction variable and modeling its scalar evolution over
+/// multiple iterations, it analyzes the switch variable and
+/// models how it affects constant offsets.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/SimplifySwitchVar.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include <random>
+
+using namespace llvm;
+using namespace PatternMatch;
+
+/// Return the BB, where (most of) the cases meet.
+/// In that BB are phi nodes, that contain the case BBs.
+static BasicBlock *findMostCommonSuccessor(SwitchInst *Switch) {
+  uint64_t Max = 0;
+  BasicBlock *MostCommonSuccessor = nullptr;
+
+  for (auto &Case : Switch->cases()) {
+    auto *CaseBB = Case.getCaseSuccessor();
+    auto GetNumPredecessors = [](BasicBlock *BB) -> uint64_t {
+      return std::distance(predecessors(BB).begin(), predecessors(BB).end());
+    };
+
+    auto Length = GetNumPredecessors(CaseBB);
+
+    if (Length > Max) {
+      Max = Length;
+      MostCommonSuccessor = CaseBB;
+    }
+
+    for (auto *Successor : successors(CaseBB)) {
+      auto Length = GetNumPredecessors(Successor);
+      if (Length > Max) {
+        Max = Length;
+        MostCommonSuccessor = Successor;
+      }
+    }
+  }
+
+  return MostCommonSuccessor;
+}
+
+namespace {
+struct PhiCase {
+  int PhiIndex;
+  Value *IncomingValue;
+  ConstantInt *CaseValue;
+};
+} // namespace
+
+/// Collect the incoming value, index and associated case value from a phi node.
+/// Ignores incoming values, which do not come from a case BB from the switch.
+static SmallVector<PhiCase> collectPhiCases(PHINode &Phi, SwitchInst *Switch) {
+  SmallVector<PhiCase> PhiInputs;
+
+  for (auto *IncomingBlock : Phi.blocks()) {
+    auto *CaseVal = Switch->findCaseDest(IncomingBlock);
+    if (!CaseVal)
+      continue;
+
+    auto PhiIdx = Phi.getBasicBlockIndex(IncomingBlock);
+    PhiInputs.push_back({PhiIdx, Phi.getIncomingValue(PhiIdx), CaseVal});
+  }
+  return PhiInputs;
+}
+
+namespace {
+enum SupportedOp {
+  GetElementPtr,
+  IntegerAdd,
+  Unsupported,
+};
+
+struct NewInstParameters {
+  SupportedOp Op;
+  Value *BaseValue;
+  Type *OffsetTy;
+};
+} // namespace
+
+/// Find the common Base Value, Operation type and Index Type of the found phi
+/// incoming values.
+static NewInstParameters findInstParameters(SmallVector<PhiCase> &PhiCases) {
+  auto Op = SupportedOp::Unsupported;
+  DenseMap<Value *, uint64_t> BaseAddressCounts;
+  Type *OffsetTy = nullptr;
+
+  for (auto &Case : PhiCases) {
+    auto *GEP = dyn_cast<GetElementPtrInst>(Case.IncomingValue);
+    bool IsAdd =
+        match(Case.IncomingValue, m_Add(m_Value(), m_AnyIntegralConstant()));
+
+    if (GEP) {
+      Op = SupportedOp::GetElementPtr;
+      BaseAddressCounts[GEP->getPointerOperand()] += 1;
+      OffsetTy = GEP->getOperand(GEP->getNumIndices())->getType();
+      continue;
+    }
+
+    if (IsAdd) {
+      Op = SupportedOp::IntegerAdd;
+      auto *Add = cast<Instruction>(Case.IncomingValue);
+      BaseAddressCounts[Add->getOperand(Add->getNumOperands() - 2)] += 1;
+      OffsetTy = Add->getOperand(Add->getNumOperands() - 1)->getType();
+      continue;
+    }
+
+    BaseAddressCounts[Case.IncomingValue] += 1;
+  }
+
+  unsigned Max = 0;
+  Value *BaseValue;
+  for (auto &Base : BaseAddressCounts) {
+    if (Base.second > Max) {
+      BaseValue = Base.first;
+      Max = Base.second;
+    }
+  }
+
+  return {Op, BaseValue, OffsetTy};
+}
+
+/// Collect valid cases.
+/// A case is valid if it uses the same base value (or is the base value like a
+/// pointer from an alloca) and it has a constant offset.
+static SmallVector<PhiCase>
+collectValidCases(SmallVector<PhiCase> &PhiCases,
+                  NewInstParameters NewInstParameters,
+                  DenseMap<int64_t, int64_t> &CaseOffsetMap) {
+  SmallVector<PhiCase> FilteredCases;
+  auto *BaseValue = NewInstParameters.BaseValue;
+  auto CurrentOp = NewInstParameters.Op;
+
+  switch (CurrentOp) {
+  case SupportedOp::GetElementPtr: {
+    for (auto &Case : PhiCases) {
+      auto *GEP = dyn_cast<GetElementPtrInst>(Case.IncomingValue);
+
+      if (!GEP) {
+        if (Case.IncomingValue != BaseValue) {
+          continue;
+        }
+        CaseOffsetMap[Case.CaseValue->getSExtValue()] = 0;
+        FilteredCases.push_back(Case);
+        continue;
+      }
+
+      if (GEP->getPointerOperand() != BaseValue) {
+        continue;
+      }
+
+      auto &DL = GEP->getParent()->getDataLayout();
+      APInt Offset(DL.getTypeSizeInBits(GEP->getPointerOperandType()), 0);
+      if (!GEP->accumulateConstantOffset(GEP->getDataLayout(), Offset)) {
+        continue;
+      }
+      CaseOffsetMap[Case.CaseValue->getSExtValue()] = Offset.getSExtValue();
+      FilteredCases.push_back(Case);
+    }
+    break;
+  }
+  case SupportedOp::IntegerAdd: {
+    for (auto &Case : PhiCases) {
+      bool IsAdd =
+          match(Case.IncomingValue, m_Add(m_Value(), m_AnyIntegralConstant()));
+
+      if (!IsAdd) {
+        if (Case.IncomingValue != BaseValue) {
+          continue;
+        }
+        CaseOffsetMap[Case.CaseValue->getSExtValue()] = 0;
+        FilteredCases.push_back(Case);
+        continue;
+      }
+
+      auto *AddInst = dyn_cast<Instruction>(Case.IncomingValue);
+      if (AddInst->getOperand(0) != BaseValue) {
+        continue;
+      }
+      auto *Offset = cast<ConstantInt>(AddInst->getOperand(1));
+      if (!Offset) {
+        continue;
+      }
+
+      CaseOffsetMap[Case.CaseValue->getSExtValue()] = Offset->getSExtValue();
+      FilteredCases.push_back(Case);
+      continue;
+    }
+    break;
+  }
+  case SupportedOp::Unsupported: {
+    llvm_unreachable("Unsupported Operation for SimplifySwitchVar.");
+  }
+  }
+  return FilteredCases;
+}
+
+namespace {
+struct FuncParams {
+  int64_t Slope;
+  int64_t Bias;
+};
+} // namespace
+
+using RandomEngine = std::minstd_rand;
+using RandomDistribution = std::uniform_int_distribution<int>;
+/// Find the linear function that models the switch variable progression.
+/// Uses random sampling to find the best fit, even if outliers are present.
+/// Abort if there are too many outliers (> 50%)
+static std::optional<FuncParams>
+findLinearFunction(DenseMap<int64_t, int64_t> &Cases,
+                   SmallVector<PhiCase> &PhiCases) {
+  RandomEngine Rand(0xdeadbeef);
+  RandomDistribution RandDist(0, Cases.size() - 1);
+
+  // Repeat the process at most 5 times, because if at least 80% of the points
+  // lie on the line, then we will find the line with 99.4% probability within 5
+  // tries. See https://en.wikipedia.org/wiki/Random_sample_consensus#Parameters
+  for (int I = 0; I < 5; ++I) {
+    auto Index0 = RandDist(Rand);
+    auto Index1 = RandDist(Rand);
+    while (Index0 == Index1) {
+      Index1 = RandDist(Rand);
+    }
+
+    auto X0 = PhiCases[Index0].CaseValue->getSExtValue();
+    auto X1 = PhiCases[Index1].CaseValue->getSExtValue();
+    auto Y0 = Cases[X0];
+    auto Y1 = Cases[X1];
+
+    int64_t Slope = (Y1 - Y0) / (X1 - X0);
+    int64_t Bias = (Y0 - (Slope * X0));
+
+    auto Count = llvm::count_if(Cases, [Bias, Slope](auto Case) {
+      return Slope * Case.first + Bias == Case.second;
+    });
+
+    float InlierRatio = (float)Count / (float)Cases.size();
+    if (InlierRatio > 0.5f) {
+      return std::optional<FuncParams>({Slope, Bias});
+    }
+  }
+
+  return std::nullopt;
+}
+
+/// Remove outlier cases, where the offset is not exactly a point on the
+/// calculated function.
+static SmallVector<PhiCase>
+removeOutliers(FuncParams F, SmallVector<PhiCase> PhiCases,
+               DenseMap<int64_t, int64_t> &CaseValueMap) {
+  SmallVector<PhiCase> FilteredCases;
+  for (auto &Case : PhiCases) {
+    auto CaseValue = Case.CaseValue->getSExtValue();
+    auto CaseOffset = CaseValueMap[CaseValue];
+    auto Result = F.Slope * CaseValue + F.Bias;
+
+    if (Result == CaseOffset)
+      FilteredCases.push_back(Case);
+  }
+  return FilteredCases;
+}
+
+/// Create the new generalized value that models all the found cases with the
+/// calculated function.
+/// Uses the slope and bias to modify the switch variable
+/// and uses the resulting value as argument to the new instruction.
+static Value *findNewValue(NewInstParameters InstParameters, FuncParams F,
+                           IRBuilder<> &Builder, SwitchInst *Switch) {
+  auto *BaseValue = InstParameters.BaseValue;
+  auto Op = InstParameters.Op;
+  auto *OffsetTy = InstParameters.OffsetTy;
+
+  auto *SwitchVar = Switch->getCondition();
+  auto SwitchBitWidth = SwitchVar->getType()->getIntegerBitWidth();
+
+  Builder.SetInsertPoint(Switch);
+
+  auto *NewIdx =
+      Builder.CreateMul(SwitchVar, Builder.getIntN(SwitchBitWidth, F.Slope));
+  NewIdx = Builder.CreateAdd(NewIdx, Builder.getIntN(SwitchBitWidth, F.Bias));
+
+  switch (Op) {
+  case SupportedOp::GetElementPtr: {
+    return Builder.CreateGEP(Builder.getInt8Ty(), BaseValue,
+                             Builder.CreateSExtOrTrunc(NewIdx, OffsetTy));
+  };
+  case SupportedOp::IntegerAdd: {
+    return Builder.CreateAdd(BaseValue,
+                             Builder.CreateSExtOrTrunc(NewIdx, OffsetTy));
+  }
+  case SupportedOp::Unsupported: {
+    llvm_unreachable("Unsupported Operation for SimplifySwitchVar.");
+    return nullptr;
+  }
+  }
+}
+
+PreservedAnalyses SimplifySwitchVarPass::run(Function &F,
+                                             FunctionAnalysisManager &AM) {
+  bool Changed = false;
+  IRBuilder<> Builder(F.getContext());
+  BasicBlock *MostCommonSuccessor;
+  // collect switch insts
+  for (auto &BB : F) {
+    if (auto *Switch = dyn_cast<SwitchInst>(BB.getTerminator())) {
+      // get the most common successor for the phi nodes
+      MostCommonSuccessor = findMostCommonSuccessor(Switch);
+
+      for (auto &Phi : MostCommonSuccessor->phis()) {
+        // filter out the phis, whose incoming blocks do not come from the
+        // switch
+        if (none_of(Phi.blocks(), [&Switch](BasicBlock *BB) {
+              return Switch->findCaseDest(BB) != nullptr;
+            }))
+          continue;
+        SmallVector<PhiCase> PhiCases = collectPhiCases(Phi, Switch);
+
+        auto InstParameters = findInstParameters(PhiCases);
+        if (InstParameters.Op == SupportedOp::Unsupported)
+          continue;
+
+        DenseMap<int64_t, int64_t> CaseOffsetMap;
+        PhiCases = collectValidCases(PhiCases, InstParameters, CaseOffsetMap);
+        if (CaseOffsetMap.size() < 2)
+          continue;
+
+        auto FuncParams = findLinearFunction(CaseOffsetMap, PhiCases);
+        if (!FuncParams.has_value())
+          continue;
+
+        auto F = FuncParams.value();
+
+        PhiCases = removeOutliers(F, PhiCases, CaseOffsetMap);
+
+        auto *NewValue = findNewValue(InstParameters, F, Builder, Switch);
+
+        if (!NewValue)
+          continue;
+
+        for (auto &Case : PhiCases) {
+          Phi.setIncomingValue(Case.PhiIndex, NewValue);
+        }
+        Changed = true;
+      }
+    }
+  }
+
+  if (!Changed)
+    return PreservedAnalyses::all();
+
+  return PreservedAnalyses::allInSet<CFGAnalyses>();
+}
diff --git a/llvm/test/Transforms/Util/simplify-switch-var.ll b/llvm/test/Transforms/Util/simplify-switch-var.ll
new file mode 100644
index 0000000000000..1ea944ac24607
--- /dev/null
+++ b/llvm/test/Transforms/Util/simplify-switch-var.ll
@@ -0,0 +1,860 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes="simplify-switch-var,adce,instcombine,simplifycfg<switch-range-to-icmp>" %s 2>&1 < %s | FileCheck %s
+
+;;;; ------------------------- getelementptr -------------------------
+;;;; -------------------------- valid cases --------------------------
+define i8 @gep_switch_consecutive_case_values(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @gep_switch_consecutive_case_values(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT:  [[_ENTRY:.*:]]
+; CHECK-NEXT:    [[GEP_DEFAULT:%.*]] = getelementptr i8, ptr [[PTR]], i64 36
+; CHECK-NEXT:    [[TMP0:%.*]] = shl i32 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP2]], i64 4
+; CHECK-NEXT:    [[SWITCH:%.*]] = icmp ult i32 [[INDEX]], 2
+; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[SWITCH]], ptr [[TMP3]], ptr [[GEP_DEFAULT]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[SPEC_SELECT]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+.entry:
+  %gep.0 = getelementptr i8, ptr %ptr, i64 4
+  %gep.1 = getelementptr i8, ptr %ptr, i64 20
+  %gep.default = getelementptr i8, ptr %ptr, i64 36
+  switch i32 %index, label %default [
+  i32 0, label %case.0
+  i32 1, label %case.1
+  ]
+
+case.1:
+  br label %default
+
+case.0:
+  br label %default
+
+default:
+  %.sink = phi ptr [ %gep.0, %case.0 ], [ %gep.1, %case.1 ], [ %gep.default, %.entry ]
+  %load = load i8, ptr %.sink, align 1
+  ret i8 %load
+}
+
+define i8 @gep_switch_nonconsecutive_case_values(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @gep_switch_nonconsecutive_case_values(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT:  [[_ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = shl i32 [[INDEX]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[TMP0]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP2]]
+; CHECK-NEXT:    switch i32 [[INDEX]], label %[[DEFAULT:.*]] [
+; CHECK-NEXT:      i32 4, label %[[CASE_0:.*]]
+; CHECK-NEXT:      i32 16, label %[[CASE_0]]
+; CHECK-NEXT:    ]
+; CHECK:       [[CASE_0]]:
+; CHECK-NEXT:    br label %[[DEFAULT]]
+; CHECK:       [[DEFAULT]]:
+; CHECK-NEXT:    [[DOTSINK:%.*]] = phi ptr [ [[TMP3]], %[[CASE_0]] ], [ [[PTR]], [[DOTENTRY:%.*]] ]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[DOTSINK]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+.entry:
+  %gep.0 = getelementptr i8, ptr %ptr, i64 20
+  %gep.1 = getelementptr i8, ptr %ptr, i64 68
+  %gep.default = getelementptr i8, ptr %ptr, i64 0
+  switch i32 %index, label %default [
+  i32 4, label %case.0
+  i32 16, label %case.1
+  ]
+
+case.1:
+  br label %default
+
+case.0:
+  br label %default
+
+default:
+  %.sink = phi ptr [ %gep.0, %case.0 ], [ %gep.1, %case.1 ], [ %gep.default, %.entry ]
+  %load = load i8, ptr %.sink, align 1
+  ret i8 %load
+}
+
+define i8 @negative_slope(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @negative_slope(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT:  [[_ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = shl i32 [[INDEX]], 3
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i32 160, [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP2]]
+; CHECK-NEXT:    switch i32 [[INDEX]], label %[[DEFAULT:.*]] [
+; CHECK-NEXT:      i32 4, label %[[CASE_0:.*]]
+; CHECK-NEXT:      i32 16, label %[[CASE_0]]
+; CHECK-NEXT:    ]
+; CHECK:       [[CASE_0]]:
+; CHECK-NEXT:    br label %[[DEFAULT]]
+; CHECK:       [[DEFAULT]]:
+; CHECK-NEXT:    [[DOTSINK:%.*]] = phi ptr [ [[TMP3]], %[[CASE_0]] ], [ [[PTR]], [[DOTENTRY:%.*]] ]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr [[DOTSINK]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+.entry:
+  %gep.0 = getelementptr i8, ptr %ptr, i64 128
+  %gep.1 = getelementptr i8, ptr %ptr, i64 32
+  %gep.default = getelementptr i8, ptr %ptr, i64 0
+  switch i32 %index, label %default [
+  i32 4, label %case.0
+  i32 16, label %case.1
+  ]
+
+case.1:
+  br label %default
+
+case.0:
+  br label %default
+
+default:
+  %.sink = phi ptr [ %gep.0, %case.0 ], [ %gep.1, %case.1 ], [ %gep.default, %.entry ]
+  %load = load i8, ptr %.sink, align 1
+  ret i8 %load
+}
+
+define i8 @gep_i32_sourceelementtype(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @gep_i32_sourceelementtype(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT:  [[_ENTRY:.*:]]
+; CHECK-NEXT:    [[GEP_DEFAULT:%.*]] = getelementptr i8, ptr [[PTR]], i64 256
+; CHECK-NEXT:    [[TMP0:%.*]] = shl i32 [[INDEX]], 7
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP1]]
+; CHECK-NEXT:    [[SWITCH:%.*]] = icmp ult i32 [[INDEX]], 2
+; CHECK-NEXT:    [[SPEC_SEL...
[truncated]

@OutOfCache OutOfCache marked this pull request as ready for review July 25, 2025 13:14

for (auto &Case : PhiCases) {
auto *GEP = dyn_cast<GetElementPtrInst>(Case.IncomingValue);
bool IsAdd =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This match can be moved to after the if (GEP) check, because the result will be unused if it's looking at an GEP.

Comment on lines +38 to +40
auto GetNumPredecessors = [](BasicBlock *BB) -> uint64_t {
return std::distance(predecessors(BB).begin(), predecessors(BB).end());
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pred_size?

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a high level, I don't think that this is the correct approach to the problem. We should be sinking those GEPs such that there is one GEP with a phi operand, and then existing switch simplification will take care of the rest.

@nhaehnle
Copy link
Collaborator

At a high level, I don't think that this is the correct approach to the problem. We should be sinking those GEPs such that there is one GEP with a phi operand, and then existing switch simplification will take care of the rest.

So I think we missed that a similar switch simplification already exists because there are fairly straightforward cases that are currently not simplified which are relevant for what we're targeting here.

It may still be the case that the "correct" answer is:

  • Generalize the existing simplification so that it can handle these cases
  • Strengthen the sinking of getelementptrs

I'm a little worried however that the relative tradeoff of when these transforms make sense is quite different across architectures. Do you have any thoughts about that?

Here's a case that currently isn't simplified:

define i32 @test(i32 %x) {
entry:
  switch i32 %x, label %end [
    i32 0, label %case0
    i32 1, label %case1
    i32 2, label %case2
    i32 3, label %case3
  ]

case0:
  br label %end
case1:
  br label %end
case2:
  br label %end
case3:
  br label %end

end:
  %idx = phi i32 [ 0, %case0 ], [ 3, %case1 ], [ 6, %case2 ], [ 9, %case3 ], [ 12, %entry ]
  ret i32 %idx
}

This one is pretty straightforward.

The following is less straightforward:

declare void @foo1()
declare void @foo2()
declare void @foo3()
declare void @foo4()

define i32 @test(i32 %x) {
entry:
  switch i32 %x, label %end [
    i32 0, label %case0
    i32 1, label %case1
    i32 2, label %case2
    i32 3, label %case3
  ]

case0:
  call void @foo1()
  br label %end
case1:
  call void @foo2()
  br label %end
case2:
  call void @foo3()
  br label %end
case3:
  call void @foo4()
  br label %end

end:
  %idx = phi i32 [ 0, %case0 ], [ 1, %case1 ], [ 2, %case2 ], [ 3, %case3 ], [ 4, %entry ]
  ret i32 %idx
}

(The calls stand in for arbitrary code.)

In order to handle this case, we have to extend the lifetime of %idx or %x. This is very typically worth it in the case we're targeting (on AMDGPU, which has lots of registers), but may be less of a slam dunk on other architectures.

Going at it from the other end, do you think there's a reason why the existing instruction sinking prefers not to sink those longer sequences of instructions? That was really my main concern going into this.

It's generally quite difficult to do any of this in a way that makes all targets happy. That's really the biggest reason why this ended up as a separate pass.

@nikic
Copy link
Contributor

nikic commented Jul 25, 2025

At a high level, I don't think that this is the correct approach to the problem. We should be sinking those GEPs such that there is one GEP with a phi operand, and then existing switch simplification will take care of the rest.

So I think we missed that a similar switch simplification already exists because there are fairly straightforward cases that are currently not simplified which are relevant for what we're targeting here.

It may still be the case that the "correct" answer is:

* Generalize the existing simplification so that it can handle these cases

* Strengthen the sinking of `getelementptr`s

I'm a little worried however that the relative tradeoff of when these transforms make sense is quite different across architectures. Do you have any thoughts about that?

Here's a case that currently isn't simplified:

define i32 @test(i32 %x) {
entry:
  switch i32 %x, label %end [
    i32 0, label %case0
    i32 1, label %case1
    i32 2, label %case2
    i32 3, label %case3
  ]

case0:
  br label %end
case1:
  br label %end
case2:
  br label %end
case3:
  br label %end

end:
  %idx = phi i32 [ 0, %case0 ], [ 3, %case1 ], [ 6, %case2 ], [ 9, %case3 ], [ 12, %entry ]
  ret i32 %idx
}

This one is pretty straightforward.

This case is already simplified -- this is one of the most basic cases. The relevant code is behind fitsInLegalInteger, so you'll have to specify a triple or n data layout to get it to actually work.

The following is less straightforward:

declare void @foo1()
declare void @foo2()
declare void @foo3()
declare void @foo4()

define i32 @test(i32 %x) {
entry:
  switch i32 %x, label %end [
    i32 0, label %case0
    i32 1, label %case1
    i32 2, label %case2
    i32 3, label %case3
  ]

case0:
  call void @foo1()
  br label %end
case1:
  call void @foo2()
  br label %end
case2:
  call void @foo3()
  br label %end
case3:
  call void @foo4()
  br label %end

end:
  %idx = phi i32 [ 0, %case0 ], [ 1, %case1 ], [ 2, %case2 ], [ 3, %case3 ], [ 4, %entry ]
  ret i32 %idx
}

(The calls stand in for arbitrary code.)

In order to handle this case, we have to extend the lifetime of %idx or %x. This is very typically worth it in the case we're targeting (on AMDGPU, which has lots of registers), but may be less of a slam dunk on other architectures.

Going at it from the other end, do you think there's a reason why the existing instruction sinking prefers not to sink those longer sequences of instructions? That was really my main concern going into this.

It's generally quite difficult to do any of this in a way that makes all targets happy. That's really the biggest reason why this ended up as a separate pass.

I think we'd transform that case without the entry edge.

@OutOfCache
Copy link
Contributor Author

Here's a case that currently isn't simplified:

define i32 @test(i32 %x) {
entry:
  switch i32 %x, label %end [
    i32 0, label %case0
    i32 1, label %case1
    i32 2, label %case2
    i32 3, label %case3
  ]

case0:
  br label %end
case1:
  br label %end
case2:
  br label %end
case3:
  br label %end

end:
  %idx = phi i32 [ 0, %case0 ], [ 3, %case1 ], [ 6, %case2 ], [ 9, %case3 ], [ 12, %entry ]
  ret i32 %idx
}

This one is pretty straightforward.

This case is already simplified -- this is one of the most basic cases. The relevant code is behind fitsInLegalInteger, so you'll have to specify a triple or n data layout to get it to actually work.

You are right, this would be simplified, on some targets like x86:

; opt -S  -passes="simplifycfg<switch-to-lookup>" -mtriple=x86_64-unknown-unknown
define i32 @test(i32 %x) {
entry:
  %0 = icmp ult i32 %x, 4
  %switch.idx.mult = mul nsw i32 %x, 3
  %spec.select = select i1 %0, i32 %switch.idx.mult, i32 12
  ret i32 %spec.select
}

To be honest, I was not aware that llvm did this simplification already, thanks for the pointer!
However, it only happens if we simplify the switch into a lookup table in SimplifyCFG. Correct me if I am wrong, but I think we can't do lookup tables efficiently on AMDGPU due to high memory latencies. Therefore, we would need to separate the linear transformation logic from the LUT creation. Then we also need to think about what to do if we have outliers, but this is less important for the case we are looking at right now.

The following is less straightforward:

declare void @foo1()
declare void @foo2()
declare void @foo3()
declare void @foo4()

define i32 @test(i32 %x) {
entry:
  switch i32 %x, label %end [
    i32 0, label %case0
    i32 1, label %case1
    i32 2, label %case2
    i32 3, label %case3
  ]

case0:
  call void @foo1()
  br label %end
case1:
  call void @foo2()
  br label %end
case2:
  call void @foo3()
  br label %end
case3:
  call void @foo4()
  br label %end

end:
  %idx = phi i32 [ 0, %case0 ], [ 1, %case1 ], [ 2, %case2 ], [ 3, %case3 ], [ 4, %entry ]
  ret i32 %idx
}

I think we'd transform that case without the entry edge.

True, there the phi is replaced with ret i32 %x, thanks to a transformation in InstCombinePHI. However, it does not support the linear transformation like the first example, and also does not handle default cases or outliers.

For this simplification to apply to the cases we are targeting, we need to

  1. improve the sinking of getelementptrs, so that there is only one gep in the final block and the phi only contains the offset. How should we sink if we have outliers and not every predecessor has a gep? Or if one of the geps is further away and used in multiple switches as base pointer, but is not duplicated into the relevant basic blocks?
  2. recognize base pointers, even if they don't come from a getelementptr with 0 offset, so we can sink them
  3. use the same logic in SimplifyCFG to find the linear transformations of the offsets, but without the LUT creation

@arsenm
Copy link
Contributor

arsenm commented Jul 30, 2025

Correct me if I am wrong, but I think we can't do lookup tables efficiently on AMDGPU due to high memory latencies.

If we can use scalar loads it's probably not that bad, but requires benchmarking

@nikic
Copy link
Contributor

nikic commented Jul 30, 2025

However, it only happens if we simplify the switch into a lookup table in SimplifyCFG. Correct me if I am wrong, but I think we can't do lookup tables efficiently on AMDGPU due to high memory latencies. Therefore, we would need to separate the linear transformation logic from the LUT creation. Then we also need to think about what to do if we have outliers, but this is less important for the case we are looking at right now.

Can't comment on the AMDGPU question, but yes to the rest: The transform of switches to linear expressions (and other representations) is currently tightly bound to lookup table generation, but it really shouldn't be. While I'd expect some of the implementation to be shared, TTI disabling lookup tables, or use of no-jump-tables really shouldn't prevent the linear expression fold. Another benefit of separating these is that we can run parts of this earlier, pre-inlining (while lookup tables should always be generated late).

I'd also love to see further generalization of that code to pick more complex mapping functions, ideally it would not be defeated by needing an extra zext somewhere.

But just getting the existing SimplifyCFG code to work for AMDGPU seems like significant win independently of the more complex cases involving GEP sinking etc.

For this simplification to apply to the cases we are targeting, we need to

  1. improve the sinking of getelementptrs, so that there is only one gep in the final block and the phi only contains the offset. How should we sink if we have outliers and not every predecessor has a gep? Or if one of the geps is further away and used in multiple switches as base pointer, but is not duplicated into the relevant basic blocks?

  2. recognize base pointers, even if they don't come from a getelementptr with 0 offset, so we can sink them

  3. use the same logic in SimplifyCFG to find the linear transformations of the offsets, but without the LUT creation

I think #128171 tries to implement the necessary GEP sinking support, but I haven't looked in detail.

The general approach isn't really suitable for handling outliers. Are these important to your motivating cases?

@OutOfCache
Copy link
Contributor Author

Can't comment on the AMDGPU question, but yes to the rest: The transform of switches to linear expressions (and other representations) is currently tightly bound to lookup table generation, but it really shouldn't be. While I'd expect some of the implementation to be shared, TTI disabling lookup tables, or use of no-jump-tables really shouldn't prevent the linear expression fold. Another benefit of separating these is that we can run parts of this earlier, pre-inlining (while lookup tables should always be generated late).

I agree now that I know this exists, thanks again for the help! How can we best achieve that?

Currently, we first check if the decide if Options.ConvertSwitchToLookupTable is set
before we consider creating a LUT. This delays the LUT creation until later.

Then, in switchToLookupTable, there is a check for the TTI support first and then an early exit.

If, and only if we support creating a LUT, do we collect the cases and associated values.

Then we create a SwitchLookupTable. In its constructor we decide which kind of "LUT" we want to create
(although only ArrayKind seems like a regular LUT while the others are optimized special cases).

And finally in buildLookup there is the actual creation of whatever kind of LUT we decided on.

A nice solution would be something that completely separates the "true" LUT from the other kinds, because
we can handle them differently (we can generate them earlier, we can support them on targets that don't allow LUTs).

However, in all of these kinds we need information about the case values and the "looked up" values,
which is currently done in switchToLookupTable, while the information about the kind of LUT is in the
SwitchLookupTable class and used later on.

We could keep the general structure more-or-less, but do the checks for LUTs much later in switchToLookupTable, i.e., collect the case-value mappings first, then create a SwitchLookupTable to determine what kind to generate.
Then we could check if we really want to create an ArrayKind LUT, and only then do the checks (whether Options.ConvertSwitchToLookupTable is set, TTI support and no-jump-tables attribute).
If we don't, we can generate code for the other kinds already.

Is there a better way to achieve this?

@nikic
Copy link
Contributor

nikic commented Aug 6, 2025

We could keep the general structure more-or-less, but do the checks for LUTs much later in switchToLookupTable, i.e., collect the case-value mappings first, then create a SwitchLookupTable to determine what kind to generate. Then we could check if we really want to create an ArrayKind LUT, and only then do the checks (whether Options.ConvertSwitchToLookupTable is set, TTI support and no-jump-tables attribute). If we don't, we can generate code for the other kinds already.

Is there a better way to achieve this?

That approach sounds reasonable to me.

@OutOfCache
Copy link
Contributor Author

I opened this PR to first refactor SimplifyCFG: #155602 A later PR will do the actual check for target support for LUT.

@OutOfCache OutOfCache closed this Aug 27, 2025
OutOfCache added a commit that referenced this pull request Sep 1, 2025
…up` (#155602)

This PR is the first part to solve the issue in #149937.

The end goal is enabling more switch optimizations on targets that do
not support lookup tables.

SimplifyCFG has the ability to replace switches with either a few simple
calculations, a single value, or a lookup table.
However, it only considers these options if the target supports lookup
tables, even if the final result is not a LUT, but a few simple
instructions like muls, adds and shifts.

To enable more targets to use these other kinds of optimization, this PR
restructures the code in `switchToLookup`.
Previously, code was generated even before choosing what kind of
replacement to do. However, we need to know if we actually want to
create a true LUT or not before generating anything. Then we can check
for target support only if any LUT would be created.

This PR moves the code so it first determines the replacement kind and
then generates the instructions.

A later PR will insert the target support check after determining the
kind of replacement. If the result is not a LUT, then even targets
without LUT support can replace the switch with something else.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants