diff --git a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp index c694b71da649b..6b74e56d6b3e9 100644 --- a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp +++ b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp @@ -18,10 +18,8 @@ #include "BPF.h" #include "BPFCORE.h" #include "BPFTargetMachine.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -43,14 +41,12 @@ class BPFCheckAndAdjustIR final : public ModulePass { public: static char ID; BPFCheckAndAdjustIR() : ModulePass(ID) {} - virtual void getAnalysisUsage(AnalysisUsage &AU) const override; private: void checkIR(Module &M); bool adjustIR(Module &M); bool removePassThroughBuiltin(Module &M); bool removeCompareBuiltin(Module &M); - bool sinkMinMax(Module &M); }; } // End anonymous namespace @@ -165,208 +161,9 @@ bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) { return Changed; } -struct MinMaxSinkInfo { - ICmpInst *ICmp; - Value *Other; - ICmpInst::Predicate Predicate; - CallInst *MinMax; - ZExtInst *ZExt; - SExtInst *SExt; - - MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) - : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), - ZExt(nullptr), SExt(nullptr) {} -}; - -static bool sinkMinMaxInBB(BasicBlock &BB, - const std::function &Filter) { - // Check if V is: - // (fn %a %b) or (ext (fn %a %b)) - // Where: - // ext := sext | zext - // fn := smin | umin | smax | umax - auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) { - if (auto *ZExt = dyn_cast(V)) { - V = ZExt->getOperand(0); - Info.ZExt = ZExt; - } else if (auto *SExt = dyn_cast(V)) { - V = SExt->getOperand(0); - Info.SExt = SExt; - } - - auto *Call = dyn_cast(V); - if (!Call) - return false; - - auto *Called = dyn_cast(Call->getCalledOperand()); - if (!Called) - return false; - - switch (Called->getIntrinsicID()) { - case Intrinsic::smin: - case Intrinsic::umin: - case Intrinsic::smax: - case Intrinsic::umax: - break; - default: - return false; - } - - if (!Filter(Call)) - return false; - - Info.MinMax = Call; - - return true; - }; - - auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, - MinMaxSinkInfo &Info) { - if (Info.SExt) { - if (Info.SExt->getType() == V->getType()) - return V; - return Builder.CreateSExt(V, Info.SExt->getType()); - } - if (Info.ZExt) { - if (Info.ZExt->getType() == V->getType()) - return V; - return Builder.CreateZExt(V, Info.ZExt->getType()); - } - return V; - }; - - bool Changed = false; - SmallVector SinkList; - - // Check BB for instructions like: - // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a) - // - // Where: - // fn := min | max | (sext (min ...)) | (sext (max ...)) - // - // Put such instructions to SinkList. - for (Instruction &I : BB) { - ICmpInst *ICmp = dyn_cast(&I); - if (!ICmp) - continue; - if (!ICmp->isRelational()) - continue; - MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), - ICmpInst::getSwappedPredicate(ICmp->getPredicate())); - MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); - bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); - bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); - if (!(FirstMinMax ^ SecondMinMax)) - continue; - SinkList.push_back(FirstMinMax ? First : Second); - } - - // Iterate SinkList and replace each (icmp ...) with corresponding - // `x < a && x < b` or similar expression. - for (auto &Info : SinkList) { - ICmpInst *ICmp = Info.ICmp; - CallInst *MinMax = Info.MinMax; - Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); - ICmpInst::Predicate P = Info.Predicate; - if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && - IID != Intrinsic::smax) - continue; - - IRBuilder<> Builder(ICmp); - Value *X = Info.Other; - Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); - Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); - bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; - bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; - bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); - bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); - assert(IsMin ^ IsMax); - assert(IsLess ^ IsGreater); - - Value *Replacement; - Value *LHS = Builder.CreateICmp(P, X, A); - Value *RHS = Builder.CreateICmp(P, X, B); - if ((IsLess && IsMin) || (IsGreater && IsMax)) - // x < min(a, b) -> x < a && x < b - // x > max(a, b) -> x > a && x > b - Replacement = Builder.CreateLogicalAnd(LHS, RHS); - else - // x > min(a, b) -> x > a || x > b - // x < max(a, b) -> x < a || x < b - Replacement = Builder.CreateLogicalOr(LHS, RHS); - - ICmp->replaceAllUsesWith(Replacement); - - Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax}; - for (Instruction *I : ToRemove) - if (I && I->use_empty()) { - I->dropAllReferences(); - I->removeFromParent(); - } - - Changed = true; - } - - return Changed; -} - -// Do the following transformation: -// -// x < min(a, b) -> x < a && x < b -// x > min(a, b) -> x > a || x > b -// x < max(a, b) -> x < a || x < b -// x > max(a, b) -> x > a && x > b -// -// Such patterns are introduced by LICM.cpp:hoistMinMax() -// transformation and might lead to BPF verification failures for -// older kernels. -// -// To minimize "collateral" changes only do it for icmp + min/max -// calls when icmp is inside a loop and min/max is outside of that -// loop. -// -// Verification failure happens when: -// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1; -// - verifier can recognize RHS as a constant scalar in some context; -// - verifier can't recognize RHS1 as a constant scalar in the same -// context; -// -// The "constant scalar" is not a compile time constant, but a register -// that holds a scalar value known to verifier at some point in time -// during abstract interpretation. -// -// See also: -// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/ -bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { - bool Changed = false; - - for (Function &F : M) { - if (F.isDeclaration()) - continue; - - LoopInfo &LI = getAnalysis(F).getLoopInfo(); - for (Loop *L : LI) - for (BasicBlock *BB : L->blocks()) { - // Filter out instructions coming from the same loop - Loop *BBLoop = LI.getLoopFor(BB); - auto OtherLoopFilter = [&](Instruction *I) { - return LI.getLoopFor(I->getParent()) != BBLoop; - }; - Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter); - } - } - - return Changed; -} - -void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired(); -} - bool BPFCheckAndAdjustIR::adjustIR(Module &M) { bool Changed = removePassThroughBuiltin(M); Changed = removeCompareBuiltin(M) || Changed; - Changed = sinkMinMax(M) || Changed; return Changed; } diff --git a/llvm/test/CodeGen/BPF/sink-min-max.ll b/llvm/test/CodeGen/BPF/sink-min-max.ll deleted file mode 100644 index 5ee080839985d..0000000000000 --- a/llvm/test/CodeGen/BPF/sink-min-max.ll +++ /dev/null @@ -1,258 +0,0 @@ -; RUN: opt --bpf-check-and-opt-ir -S -mtriple=bpf-pc-linux %s | FileCheck %s - -; Test plan: -; @test1: x < umin(i64 a, i64 b) -; @test2: x < umax(i64 a, i64 b) -; @test3: x >= umin(i64 a, i64 b) -; @test4: x >= umax(i64 a, i64 b) -; @test5: umin(i64 a, i64 b) >= x -; @test6: x < smin(i64 a, i64 b) -; @test7: x < umin(i32 a, i32 b) -; @test8: x < zext i64 umin(i32 a, i32 b) -; @test9: x < sext i64 umin(i32 a, i32 b) -; @test10: check that umin belonging to the same loop is not touched -; @test11: check that nested loops are processed - -define i32 @test1(i64 %a, i64 %b, i64 %x) { -entry: - %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) - br label %loop -loop: - %cmp = icmp ult i64 %x, %min - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test1 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %0 = icmp ult i64 %x, %a -; CHECK-NEXT: %1 = icmp ult i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test2(i64 %a, i64 %b, i64 %x) { -entry: - %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b) - br label %loop -loop: - %cmp = icmp ult i64 %x, %max - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test2 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %0 = icmp ult i64 %x, %a -; CHECK-NEXT: %1 = icmp ult i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1 -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test3(i64 %a, i64 %b, i64 %x) { -entry: - %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) - br label %loop -loop: - %cmp = icmp uge i64 %x, %min - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test3 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %0 = icmp uge i64 %x, %a -; CHECK-NEXT: %1 = icmp uge i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1 -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test4(i64 %a, i64 %b, i64 %x) { -entry: - %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b) - br label %loop -loop: - %cmp = icmp uge i64 %x, %max - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test4 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %0 = icmp uge i64 %x, %a -; CHECK-NEXT: %1 = icmp uge i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test5(i64 %a, i64 %b, i64 %x) { -entry: - %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) - br label %loop -loop: - %cmp = icmp uge i64 %min, %x - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test5 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK: %0 = icmp ule i64 %x, %a -; CHECK-NEXT: %1 = icmp ule i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test6(i64 %a, i64 %b, i64 %x) { -entry: - %min = tail call i64 @llvm.smin.i64(i64 %a, i64 %b) - br label %loop -loop: - %cmp = icmp slt i64 %x, %min - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test6 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK: %0 = icmp slt i64 %x, %a -; CHECK-NEXT: %1 = icmp slt i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test7(i32 %a, i32 %b, i32 %x) { -entry: - %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b) - br label %loop -loop: - %cmp = icmp ult i32 %x, %min - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test7 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK: %0 = icmp ult i32 %x, %a -; CHECK-NEXT: %1 = icmp ult i32 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false -; CHECK-NEXT: br i1 %2, label %loop, label %ret - -define i32 @test8(i32 %a, i32 %b, i64 %x) { -entry: - %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b) - br label %loop -loop: - %ext = zext i32 %min to i64 - %cmp = icmp ult i64 %x, %ext - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test8 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %0 = zext i32 %a to i64 -; CHECK-NEXT: %1 = zext i32 %b to i64 -; CHECK-NEXT: %2 = icmp ult i64 %x, %0 -; CHECK-NEXT: %3 = icmp ult i64 %x, %1 -; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false -; CHECK-NEXT: br i1 %4, label %loop, label %ret - -define i32 @test9(i32 %a, i32 %b, i64 %x) { -entry: - %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b) - br label %loop -loop: - %ext = sext i32 %min to i64 - %cmp = icmp ult i64 %x, %ext - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test9 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %0 = sext i32 %a to i64 -; CHECK-NEXT: %1 = sext i32 %b to i64 -; CHECK-NEXT: %2 = icmp ult i64 %x, %0 -; CHECK-NEXT: %3 = icmp ult i64 %x, %1 -; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false -; CHECK-NEXT: br i1 %4, label %loop, label %ret - -; umin within the loop body is unchanged -define i32 @test10(i64 %a, i64 %b, i64 %x) { -entry: - br label %loop -loop: - %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) - %cmp = icmp ult i64 %x, %min - br i1 %cmp, label %loop, label %ret -ret: ret i32 0 -} - -; CHECK: @test10 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) -; CHECK-NEXT: %cmp = icmp ult i64 %x, %min -; CHECK-NEXT: br i1 %cmp, label %loop, label %ret - -; umin from outer loop body is processed -define i32 @test11(i64 %a, i64 %b, i64 %x) { -entry: - br label %loop - -loop: - %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) - br label %nested.loop -nested.loop: - %cmp = icmp ult i64 %x, %min - br i1 %cmp, label %nested.loop, label %loop - -ret: ret i32 0 -} - -; CHECK: @test11 -; CHECK-NEXT: entry: -; CHECK-NEXT: br label %loop -; CHECK-EMPTY: -; CHECK-NEXT: loop: -; CHECK-NEXT: br label %nested.loop -; CHECK-EMPTY: -; CHECK-NEXT: nested.loop: -; CHECK-NEXT: %0 = icmp ult i64 %x, %a -; CHECK-NEXT: %1 = icmp ult i64 %x, %b -; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false -; CHECK-NEXT: br i1 %2, label %nested.loop, label %loop - -declare i64 @llvm.umin.i64(i64, i64) -declare i64 @llvm.smin.i64(i64, i64) -declare i64 @llvm.umax.i64(i64, i64) -declare i64 @llvm.smax.i64(i64, i64) - -declare i32 @llvm.umin.i32(i32, i32) -declare i32 @llvm.smin.i32(i32, i32) -declare i32 @llvm.umax.i32(i32, i32) -declare i32 @llvm.smax.i32(i32, i32)