-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Transforms/Util] Add SimplifySwitchVar pass #149937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pass models the relationship between a switch variable and an associated constant offset (like from a getelementptr or an integer add). Find a linear function that represents this relationship and use the function to create a new, more generalized version of that operation. Similar to SimplifyIndVar, but instead of the correlation between a loop iteration and an induction variable, we look at the switch variable (i.e., the case value) and the offset.
Find the BB, where most of the cases (and perhaps the switch BB itself) meet. We can use that BB to find phi nodes with the case BBs of the switch statement later on. If we can generalize the instruction with the constant offset, and that instruction is also an incoming value of the phi node, we can replace the incoming values with a generalized version of the instruction. That can enable other optimizations like dead code elimination and simplifycfg to further clean up the IR.
Determine if a phi's incoming value comes from a case BB. If it does, then record the case value and the index in the phi node alongside the incoming value and collect them in a vector.
Return the common base value, type of operation and offset type of the incoming values. These parameters help with the creation of the generalized operation later on. Anything other than geps and integer adds are not supported.
Filter out invalid cases and return a vector of just valid ones. Valid cases have the same base value as the found base value and they have a constant offset. Collect the valid ones in a map of case values to offsets. Essentially, the map contains points with the case value as x-value and offset as y-value. We use that map to determin a linear function through these points in a later change.
Using the map as collection of points (where the x-value is the case value and the y-value is the offset), find a linear function via random sampling. If we cannot find a function that contains at least half of the points within 5 tries, abort. The number 5 was chosen because it gives enough confidence in most cases, see https://en.wikipedia.org/wiki/Random_sample_consensus#Parameters If 80% of the points are on a straight line, then the probability p of finding the line with two points is 0.8^2. The chance of not finding the line is (1 - p). The chance of not finding the line with 5 random tries is (1 - p)^5, which is ~0.6%. Use the same seed to be deterministic.
Filter out outlier of the function as we cannot replace them.
Create a getelementptr or an add using the found function and instruction parameters. Calculate the result of slope * switchvar + bias and use the result in the new instruction.
Use the newly created operation instead of the previously collected incoming values.
@llvm/pr-subscribers-llvm-transforms Author: Jessica Del (OutOfCache) ChangesThis pass models the relationship between a switch variable Similar to It can help eliminate the switch entirely in some cases Patch is 45.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149937.diff 6 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h b/llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h
new file mode 100644
index 0000000000000..9d98a470a96d6
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Utils/SimplifySwitchVar.h
@@ -0,0 +1,27 @@
+//===-- SimplifySwitchVar.h - Simplify Switch Variables ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// This file defines a pass for switch variable simplification.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SIMPLIFYSWITCHVAR_H
+#define LLVM_TRANSFORMS_SIMPLIFYSWITCHVAR_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class SimplifySwitchVarPass : public PassInfoMixin<SimplifySwitchVarPass> {
+public:
+ PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_TRANSFORMS_SIMPLIFYSWITCHVAR_H
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index e15570c3f600e..4d40906e4a591 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -364,6 +364,7 @@
#include "llvm/Transforms/Utils/NameAnonGlobals.h"
#include "llvm/Transforms/Utils/PredicateInfo.h"
#include "llvm/Transforms/Utils/RelLookupTableConverter.h"
+#include "llvm/Transforms/Utils/SimplifySwitchVar.h"
#include "llvm/Transforms/Utils/StripGCRelocates.h"
#include "llvm/Transforms/Utils/StripNonLineTableDebugInfo.h"
#include "llvm/Transforms/Utils/SymbolRewriter.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index caa78b613b901..c60b6bebac609 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -537,6 +537,7 @@ FUNCTION_PASS("slp-vectorizer", SLPVectorizerPass())
FUNCTION_PASS("slsr", StraightLineStrengthReducePass())
FUNCTION_PASS("stack-protector", StackProtectorPass(TM))
FUNCTION_PASS("strip-gc-relocates", StripGCRelocates())
+FUNCTION_PASS("simplify-switch-var", SimplifySwitchVarPass())
FUNCTION_PASS("tailcallelim", TailCallElimPass())
FUNCTION_PASS("transform-warning", WarnMissedTransformationsPass())
FUNCTION_PASS("trigger-crash-function", TriggerCrashFunctionPass())
diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt
index f7e66eca6aeb3..1138256c6851f 100644
--- a/llvm/lib/Transforms/Utils/CMakeLists.txt
+++ b/llvm/lib/Transforms/Utils/CMakeLists.txt
@@ -84,6 +84,7 @@ add_llvm_component_library(LLVMTransformUtils
SizeOpts.cpp
SplitModule.cpp
StripNonLineTableDebugInfo.cpp
+ SimplifySwitchVar.cpp
SymbolRewriter.cpp
UnifyFunctionExitNodes.cpp
UnifyLoopExits.cpp
diff --git a/llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp b/llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp
new file mode 100644
index 0000000000000..e30ef6c12b20f
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/SimplifySwitchVar.cpp
@@ -0,0 +1,371 @@
+//===-- SimplifySwitchVar.cpp - Switch Variable simplification ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This file implements switch variable simplification. It looks for a
+/// linear relationship between the case value of a switch and the constant
+/// offset of an operation. Knowing this relationship, we can simplify
+/// multiple individual operations into a single, more generic one, which
+/// can help with further optimizations.
+///
+/// It is similar to SimplifyIndVar, but instead of looking at an
+/// induction variable and modeling its scalar evolution over
+/// multiple iterations, it analyzes the switch variable and
+/// models how it affects constant offsets.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/SimplifySwitchVar.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include <random>
+
+using namespace llvm;
+using namespace PatternMatch;
+
+/// Return the BB, where (most of) the cases meet.
+/// In that BB are phi nodes, that contain the case BBs.
+static BasicBlock *findMostCommonSuccessor(SwitchInst *Switch) {
+ uint64_t Max = 0;
+ BasicBlock *MostCommonSuccessor = nullptr;
+
+ for (auto &Case : Switch->cases()) {
+ auto *CaseBB = Case.getCaseSuccessor();
+ auto GetNumPredecessors = [](BasicBlock *BB) -> uint64_t {
+ return std::distance(predecessors(BB).begin(), predecessors(BB).end());
+ };
+
+ auto Length = GetNumPredecessors(CaseBB);
+
+ if (Length > Max) {
+ Max = Length;
+ MostCommonSuccessor = CaseBB;
+ }
+
+ for (auto *Successor : successors(CaseBB)) {
+ auto Length = GetNumPredecessors(Successor);
+ if (Length > Max) {
+ Max = Length;
+ MostCommonSuccessor = Successor;
+ }
+ }
+ }
+
+ return MostCommonSuccessor;
+}
+
+namespace {
+struct PhiCase {
+ int PhiIndex;
+ Value *IncomingValue;
+ ConstantInt *CaseValue;
+};
+} // namespace
+
+/// Collect the incoming value, index and associated case value from a phi node.
+/// Ignores incoming values, which do not come from a case BB from the switch.
+static SmallVector<PhiCase> collectPhiCases(PHINode &Phi, SwitchInst *Switch) {
+ SmallVector<PhiCase> PhiInputs;
+
+ for (auto *IncomingBlock : Phi.blocks()) {
+ auto *CaseVal = Switch->findCaseDest(IncomingBlock);
+ if (!CaseVal)
+ continue;
+
+ auto PhiIdx = Phi.getBasicBlockIndex(IncomingBlock);
+ PhiInputs.push_back({PhiIdx, Phi.getIncomingValue(PhiIdx), CaseVal});
+ }
+ return PhiInputs;
+}
+
+namespace {
+enum SupportedOp {
+ GetElementPtr,
+ IntegerAdd,
+ Unsupported,
+};
+
+struct NewInstParameters {
+ SupportedOp Op;
+ Value *BaseValue;
+ Type *OffsetTy;
+};
+} // namespace
+
+/// Find the common Base Value, Operation type and Index Type of the found phi
+/// incoming values.
+static NewInstParameters findInstParameters(SmallVector<PhiCase> &PhiCases) {
+ auto Op = SupportedOp::Unsupported;
+ DenseMap<Value *, uint64_t> BaseAddressCounts;
+ Type *OffsetTy = nullptr;
+
+ for (auto &Case : PhiCases) {
+ auto *GEP = dyn_cast<GetElementPtrInst>(Case.IncomingValue);
+ bool IsAdd =
+ match(Case.IncomingValue, m_Add(m_Value(), m_AnyIntegralConstant()));
+
+ if (GEP) {
+ Op = SupportedOp::GetElementPtr;
+ BaseAddressCounts[GEP->getPointerOperand()] += 1;
+ OffsetTy = GEP->getOperand(GEP->getNumIndices())->getType();
+ continue;
+ }
+
+ if (IsAdd) {
+ Op = SupportedOp::IntegerAdd;
+ auto *Add = cast<Instruction>(Case.IncomingValue);
+ BaseAddressCounts[Add->getOperand(Add->getNumOperands() - 2)] += 1;
+ OffsetTy = Add->getOperand(Add->getNumOperands() - 1)->getType();
+ continue;
+ }
+
+ BaseAddressCounts[Case.IncomingValue] += 1;
+ }
+
+ unsigned Max = 0;
+ Value *BaseValue;
+ for (auto &Base : BaseAddressCounts) {
+ if (Base.second > Max) {
+ BaseValue = Base.first;
+ Max = Base.second;
+ }
+ }
+
+ return {Op, BaseValue, OffsetTy};
+}
+
+/// Collect valid cases.
+/// A case is valid if it uses the same base value (or is the base value like a
+/// pointer from an alloca) and it has a constant offset.
+static SmallVector<PhiCase>
+collectValidCases(SmallVector<PhiCase> &PhiCases,
+ NewInstParameters NewInstParameters,
+ DenseMap<int64_t, int64_t> &CaseOffsetMap) {
+ SmallVector<PhiCase> FilteredCases;
+ auto *BaseValue = NewInstParameters.BaseValue;
+ auto CurrentOp = NewInstParameters.Op;
+
+ switch (CurrentOp) {
+ case SupportedOp::GetElementPtr: {
+ for (auto &Case : PhiCases) {
+ auto *GEP = dyn_cast<GetElementPtrInst>(Case.IncomingValue);
+
+ if (!GEP) {
+ if (Case.IncomingValue != BaseValue) {
+ continue;
+ }
+ CaseOffsetMap[Case.CaseValue->getSExtValue()] = 0;
+ FilteredCases.push_back(Case);
+ continue;
+ }
+
+ if (GEP->getPointerOperand() != BaseValue) {
+ continue;
+ }
+
+ auto &DL = GEP->getParent()->getDataLayout();
+ APInt Offset(DL.getTypeSizeInBits(GEP->getPointerOperandType()), 0);
+ if (!GEP->accumulateConstantOffset(GEP->getDataLayout(), Offset)) {
+ continue;
+ }
+ CaseOffsetMap[Case.CaseValue->getSExtValue()] = Offset.getSExtValue();
+ FilteredCases.push_back(Case);
+ }
+ break;
+ }
+ case SupportedOp::IntegerAdd: {
+ for (auto &Case : PhiCases) {
+ bool IsAdd =
+ match(Case.IncomingValue, m_Add(m_Value(), m_AnyIntegralConstant()));
+
+ if (!IsAdd) {
+ if (Case.IncomingValue != BaseValue) {
+ continue;
+ }
+ CaseOffsetMap[Case.CaseValue->getSExtValue()] = 0;
+ FilteredCases.push_back(Case);
+ continue;
+ }
+
+ auto *AddInst = dyn_cast<Instruction>(Case.IncomingValue);
+ if (AddInst->getOperand(0) != BaseValue) {
+ continue;
+ }
+ auto *Offset = cast<ConstantInt>(AddInst->getOperand(1));
+ if (!Offset) {
+ continue;
+ }
+
+ CaseOffsetMap[Case.CaseValue->getSExtValue()] = Offset->getSExtValue();
+ FilteredCases.push_back(Case);
+ continue;
+ }
+ break;
+ }
+ case SupportedOp::Unsupported: {
+ llvm_unreachable("Unsupported Operation for SimplifySwitchVar.");
+ }
+ }
+ return FilteredCases;
+}
+
+namespace {
+struct FuncParams {
+ int64_t Slope;
+ int64_t Bias;
+};
+} // namespace
+
+using RandomEngine = std::minstd_rand;
+using RandomDistribution = std::uniform_int_distribution<int>;
+/// Find the linear function that models the switch variable progression.
+/// Uses random sampling to find the best fit, even if outliers are present.
+/// Abort if there are too many outliers (> 50%)
+static std::optional<FuncParams>
+findLinearFunction(DenseMap<int64_t, int64_t> &Cases,
+ SmallVector<PhiCase> &PhiCases) {
+ RandomEngine Rand(0xdeadbeef);
+ RandomDistribution RandDist(0, Cases.size() - 1);
+
+ // Repeat the process at most 5 times, because if at least 80% of the points
+ // lie on the line, then we will find the line with 99.4% probability within 5
+ // tries. See https://en.wikipedia.org/wiki/Random_sample_consensus#Parameters
+ for (int I = 0; I < 5; ++I) {
+ auto Index0 = RandDist(Rand);
+ auto Index1 = RandDist(Rand);
+ while (Index0 == Index1) {
+ Index1 = RandDist(Rand);
+ }
+
+ auto X0 = PhiCases[Index0].CaseValue->getSExtValue();
+ auto X1 = PhiCases[Index1].CaseValue->getSExtValue();
+ auto Y0 = Cases[X0];
+ auto Y1 = Cases[X1];
+
+ int64_t Slope = (Y1 - Y0) / (X1 - X0);
+ int64_t Bias = (Y0 - (Slope * X0));
+
+ auto Count = llvm::count_if(Cases, [Bias, Slope](auto Case) {
+ return Slope * Case.first + Bias == Case.second;
+ });
+
+ float InlierRatio = (float)Count / (float)Cases.size();
+ if (InlierRatio > 0.5f) {
+ return std::optional<FuncParams>({Slope, Bias});
+ }
+ }
+
+ return std::nullopt;
+}
+
+/// Remove outlier cases, where the offset is not exactly a point on the
+/// calculated function.
+static SmallVector<PhiCase>
+removeOutliers(FuncParams F, SmallVector<PhiCase> PhiCases,
+ DenseMap<int64_t, int64_t> &CaseValueMap) {
+ SmallVector<PhiCase> FilteredCases;
+ for (auto &Case : PhiCases) {
+ auto CaseValue = Case.CaseValue->getSExtValue();
+ auto CaseOffset = CaseValueMap[CaseValue];
+ auto Result = F.Slope * CaseValue + F.Bias;
+
+ if (Result == CaseOffset)
+ FilteredCases.push_back(Case);
+ }
+ return FilteredCases;
+}
+
+/// Create the new generalized value that models all the found cases with the
+/// calculated function.
+/// Uses the slope and bias to modify the switch variable
+/// and uses the resulting value as argument to the new instruction.
+static Value *findNewValue(NewInstParameters InstParameters, FuncParams F,
+ IRBuilder<> &Builder, SwitchInst *Switch) {
+ auto *BaseValue = InstParameters.BaseValue;
+ auto Op = InstParameters.Op;
+ auto *OffsetTy = InstParameters.OffsetTy;
+
+ auto *SwitchVar = Switch->getCondition();
+ auto SwitchBitWidth = SwitchVar->getType()->getIntegerBitWidth();
+
+ Builder.SetInsertPoint(Switch);
+
+ auto *NewIdx =
+ Builder.CreateMul(SwitchVar, Builder.getIntN(SwitchBitWidth, F.Slope));
+ NewIdx = Builder.CreateAdd(NewIdx, Builder.getIntN(SwitchBitWidth, F.Bias));
+
+ switch (Op) {
+ case SupportedOp::GetElementPtr: {
+ return Builder.CreateGEP(Builder.getInt8Ty(), BaseValue,
+ Builder.CreateSExtOrTrunc(NewIdx, OffsetTy));
+ };
+ case SupportedOp::IntegerAdd: {
+ return Builder.CreateAdd(BaseValue,
+ Builder.CreateSExtOrTrunc(NewIdx, OffsetTy));
+ }
+ case SupportedOp::Unsupported: {
+ llvm_unreachable("Unsupported Operation for SimplifySwitchVar.");
+ return nullptr;
+ }
+ }
+}
+
+PreservedAnalyses SimplifySwitchVarPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ bool Changed = false;
+ IRBuilder<> Builder(F.getContext());
+ BasicBlock *MostCommonSuccessor;
+ // collect switch insts
+ for (auto &BB : F) {
+ if (auto *Switch = dyn_cast<SwitchInst>(BB.getTerminator())) {
+ // get the most common successor for the phi nodes
+ MostCommonSuccessor = findMostCommonSuccessor(Switch);
+
+ for (auto &Phi : MostCommonSuccessor->phis()) {
+ // filter out the phis, whose incoming blocks do not come from the
+ // switch
+ if (none_of(Phi.blocks(), [&Switch](BasicBlock *BB) {
+ return Switch->findCaseDest(BB) != nullptr;
+ }))
+ continue;
+ SmallVector<PhiCase> PhiCases = collectPhiCases(Phi, Switch);
+
+ auto InstParameters = findInstParameters(PhiCases);
+ if (InstParameters.Op == SupportedOp::Unsupported)
+ continue;
+
+ DenseMap<int64_t, int64_t> CaseOffsetMap;
+ PhiCases = collectValidCases(PhiCases, InstParameters, CaseOffsetMap);
+ if (CaseOffsetMap.size() < 2)
+ continue;
+
+ auto FuncParams = findLinearFunction(CaseOffsetMap, PhiCases);
+ if (!FuncParams.has_value())
+ continue;
+
+ auto F = FuncParams.value();
+
+ PhiCases = removeOutliers(F, PhiCases, CaseOffsetMap);
+
+ auto *NewValue = findNewValue(InstParameters, F, Builder, Switch);
+
+ if (!NewValue)
+ continue;
+
+ for (auto &Case : PhiCases) {
+ Phi.setIncomingValue(Case.PhiIndex, NewValue);
+ }
+ Changed = true;
+ }
+ }
+ }
+
+ if (!Changed)
+ return PreservedAnalyses::all();
+
+ return PreservedAnalyses::allInSet<CFGAnalyses>();
+}
diff --git a/llvm/test/Transforms/Util/simplify-switch-var.ll b/llvm/test/Transforms/Util/simplify-switch-var.ll
new file mode 100644
index 0000000000000..1ea944ac24607
--- /dev/null
+++ b/llvm/test/Transforms/Util/simplify-switch-var.ll
@@ -0,0 +1,860 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes="simplify-switch-var,adce,instcombine,simplifycfg<switch-range-to-icmp>" %s 2>&1 < %s | FileCheck %s
+
+;;;; ------------------------- getelementptr -------------------------
+;;;; -------------------------- valid cases --------------------------
+define i8 @gep_switch_consecutive_case_values(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @gep_switch_consecutive_case_values(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT: [[_ENTRY:.*:]]
+; CHECK-NEXT: [[GEP_DEFAULT:%.*]] = getelementptr i8, ptr [[PTR]], i64 36
+; CHECK-NEXT: [[TMP0:%.*]] = shl i32 [[INDEX]], 4
+; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[TMP0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP2]], i64 4
+; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i32 [[INDEX]], 2
+; CHECK-NEXT: [[SPEC_SELECT:%.*]] = select i1 [[SWITCH]], ptr [[TMP3]], ptr [[GEP_DEFAULT]]
+; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[SPEC_SELECT]], align 1
+; CHECK-NEXT: ret i8 [[LOAD]]
+;
+.entry:
+ %gep.0 = getelementptr i8, ptr %ptr, i64 4
+ %gep.1 = getelementptr i8, ptr %ptr, i64 20
+ %gep.default = getelementptr i8, ptr %ptr, i64 36
+ switch i32 %index, label %default [
+ i32 0, label %case.0
+ i32 1, label %case.1
+ ]
+
+case.1:
+ br label %default
+
+case.0:
+ br label %default
+
+default:
+ %.sink = phi ptr [ %gep.0, %case.0 ], [ %gep.1, %case.1 ], [ %gep.default, %.entry ]
+ %load = load i8, ptr %.sink, align 1
+ ret i8 %load
+}
+
+define i8 @gep_switch_nonconsecutive_case_values(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @gep_switch_nonconsecutive_case_values(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT: [[_ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = shl i32 [[INDEX]], 2
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP2]]
+; CHECK-NEXT: switch i32 [[INDEX]], label %[[DEFAULT:.*]] [
+; CHECK-NEXT: i32 4, label %[[CASE_0:.*]]
+; CHECK-NEXT: i32 16, label %[[CASE_0]]
+; CHECK-NEXT: ]
+; CHECK: [[CASE_0]]:
+; CHECK-NEXT: br label %[[DEFAULT]]
+; CHECK: [[DEFAULT]]:
+; CHECK-NEXT: [[DOTSINK:%.*]] = phi ptr [ [[TMP3]], %[[CASE_0]] ], [ [[PTR]], [[DOTENTRY:%.*]] ]
+; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[DOTSINK]], align 1
+; CHECK-NEXT: ret i8 [[LOAD]]
+;
+.entry:
+ %gep.0 = getelementptr i8, ptr %ptr, i64 20
+ %gep.1 = getelementptr i8, ptr %ptr, i64 68
+ %gep.default = getelementptr i8, ptr %ptr, i64 0
+ switch i32 %index, label %default [
+ i32 4, label %case.0
+ i32 16, label %case.1
+ ]
+
+case.1:
+ br label %default
+
+case.0:
+ br label %default
+
+default:
+ %.sink = phi ptr [ %gep.0, %case.0 ], [ %gep.1, %case.1 ], [ %gep.default, %.entry ]
+ %load = load i8, ptr %.sink, align 1
+ ret i8 %load
+}
+
+define i8 @negative_slope(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @negative_slope(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT: [[_ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = shl i32 [[INDEX]], 3
+; CHECK-NEXT: [[TMP1:%.*]] = sub i32 160, [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP2]]
+; CHECK-NEXT: switch i32 [[INDEX]], label %[[DEFAULT:.*]] [
+; CHECK-NEXT: i32 4, label %[[CASE_0:.*]]
+; CHECK-NEXT: i32 16, label %[[CASE_0]]
+; CHECK-NEXT: ]
+; CHECK: [[CASE_0]]:
+; CHECK-NEXT: br label %[[DEFAULT]]
+; CHECK: [[DEFAULT]]:
+; CHECK-NEXT: [[DOTSINK:%.*]] = phi ptr [ [[TMP3]], %[[CASE_0]] ], [ [[PTR]], [[DOTENTRY:%.*]] ]
+; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[DOTSINK]], align 1
+; CHECK-NEXT: ret i8 [[LOAD]]
+;
+.entry:
+ %gep.0 = getelementptr i8, ptr %ptr, i64 128
+ %gep.1 = getelementptr i8, ptr %ptr, i64 32
+ %gep.default = getelementptr i8, ptr %ptr, i64 0
+ switch i32 %index, label %default [
+ i32 4, label %case.0
+ i32 16, label %case.1
+ ]
+
+case.1:
+ br label %default
+
+case.0:
+ br label %default
+
+default:
+ %.sink = phi ptr [ %gep.0, %case.0 ], [ %gep.1, %case.1 ], [ %gep.default, %.entry ]
+ %load = load i8, ptr %.sink, align 1
+ ret i8 %load
+}
+
+define i8 @gep_i32_sourceelementtype(ptr %ptr, i32 %index) {
+; CHECK-LABEL: define i8 @gep_i32_sourceelementtype(
+; CHECK-SAME: ptr [[PTR:%.*]], i32 [[INDEX:%.*]]) {
+; CHECK-NEXT: [[_ENTRY:.*:]]
+; CHECK-NEXT: [[GEP_DEFAULT:%.*]] = getelementptr i8, ptr [[PTR]], i64 256
+; CHECK-NEXT: [[TMP0:%.*]] = shl i32 [[INDEX]], 7
+; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[TMP0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[TMP1]]
+; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i32 [[INDEX]], 2
+; CHECK-NEXT: [[SPEC_SEL...
[truncated]
|
|
||
for (auto &Case : PhiCases) { | ||
auto *GEP = dyn_cast<GetElementPtrInst>(Case.IncomingValue); | ||
bool IsAdd = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This match can be moved to after the if (GEP)
check, because the result will be unused if it's looking at an GEP.
auto GetNumPredecessors = [](BasicBlock *BB) -> uint64_t { | ||
return std::distance(predecessors(BB).begin(), predecessors(BB).end()); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pred_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At a high level, I don't think that this is the correct approach to the problem. We should be sinking those GEPs such that there is one GEP with a phi operand, and then existing switch simplification will take care of the rest.
So I think we missed that a similar switch simplification already exists because there are fairly straightforward cases that are currently not simplified which are relevant for what we're targeting here. It may still be the case that the "correct" answer is:
I'm a little worried however that the relative tradeoff of when these transforms make sense is quite different across architectures. Do you have any thoughts about that? Here's a case that currently isn't simplified:
This one is pretty straightforward. The following is less straightforward:
(The calls stand in for arbitrary code.) In order to handle this case, we have to extend the lifetime of Going at it from the other end, do you think there's a reason why the existing instruction sinking prefers not to sink those longer sequences of instructions? That was really my main concern going into this. It's generally quite difficult to do any of this in a way that makes all targets happy. That's really the biggest reason why this ended up as a separate pass. |
This case is already simplified -- this is one of the most basic cases. The relevant code is behind
I think we'd transform that case without the entry edge. |
You are right, this would be simplified, on some targets like x86: ; opt -S -passes="simplifycfg<switch-to-lookup>" -mtriple=x86_64-unknown-unknown
define i32 @test(i32 %x) {
entry:
%0 = icmp ult i32 %x, 4
%switch.idx.mult = mul nsw i32 %x, 3
%spec.select = select i1 %0, i32 %switch.idx.mult, i32 12
ret i32 %spec.select
} To be honest, I was not aware that llvm did this simplification already, thanks for the pointer!
True, there the For this simplification to apply to the cases we are targeting, we need to
|
If we can use scalar loads it's probably not that bad, but requires benchmarking |
Can't comment on the AMDGPU question, but yes to the rest: The transform of switches to linear expressions (and other representations) is currently tightly bound to lookup table generation, but it really shouldn't be. While I'd expect some of the implementation to be shared, TTI disabling lookup tables, or use of no-jump-tables really shouldn't prevent the linear expression fold. Another benefit of separating these is that we can run parts of this earlier, pre-inlining (while lookup tables should always be generated late). I'd also love to see further generalization of that code to pick more complex mapping functions, ideally it would not be defeated by needing an extra zext somewhere. But just getting the existing SimplifyCFG code to work for AMDGPU seems like significant win independently of the more complex cases involving GEP sinking etc.
I think #128171 tries to implement the necessary GEP sinking support, but I haven't looked in detail. The general approach isn't really suitable for handling outliers. Are these important to your motivating cases? |
I agree now that I know this exists, thanks again for the help! How can we best achieve that? Currently, we first check if the decide if Then, in If, and only if we support creating a LUT, do we collect the cases and associated values. Then we create a SwitchLookupTable. In its constructor we decide which kind of "LUT" we want to create And finally in A nice solution would be something that completely separates the "true" LUT from the other kinds, because However, in all of these kinds we need information about the case values and the "looked up" values, We could keep the general structure more-or-less, but do the checks for LUTs much later in Is there a better way to achieve this? |
That approach sounds reasonable to me. |
I opened this PR to first refactor SimplifyCFG: #155602 A later PR will do the actual check for target support for LUT. |
…up` (#155602) This PR is the first part to solve the issue in #149937. The end goal is enabling more switch optimizations on targets that do not support lookup tables. SimplifyCFG has the ability to replace switches with either a few simple calculations, a single value, or a lookup table. However, it only considers these options if the target supports lookup tables, even if the final result is not a LUT, but a few simple instructions like muls, adds and shifts. To enable more targets to use these other kinds of optimization, this PR restructures the code in `switchToLookup`. Previously, code was generated even before choosing what kind of replacement to do. However, we need to know if we actually want to create a true LUT or not before generating anything. Then we can check for target support only if any LUT would be created. This PR moves the code so it first determines the replacement kind and then generates the instructions. A later PR will insert the target support check after determining the kind of replacement. If the result is not a LUT, then even targets without LUT support can replace the switch with something else.
This pass models the relationship between a switch variable
and an associated constant offset (like from a getelementptr
or an integer add). Find a linear function that represents
this relationship and use the function to create a new,
more generalized version of that operation.
Similar to
SimplifyIndVar
, but instead of the correlationbetween a loop iteration and an induction variable,
we look at the switch variable (i.e., the case value) and
the offset.
It can help eliminate the switch entirely in some cases
with the help of dead code elimination and
simplifycfg
.