diff --git a/llvm/include/llvm/Analysis/StackLifetime.h b/llvm/include/llvm/Analysis/StackLifetime.h index 577d9ba177ea2d..40c82681369611 100644 --- a/llvm/include/llvm/Analysis/StackLifetime.h +++ b/llvm/include/llvm/Analysis/StackLifetime.h @@ -13,6 +13,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/raw_ostream.h" @@ -78,8 +79,16 @@ class StackLifetime { bool test(unsigned Idx) const { return Bits.test(Idx); } }; + // Controls what is "alive" if control flow may reach the instruction + // with a different liveness of the alloca. + enum class LivenessType { + May, // May be alive on some path. + Must, // Must be alive on every path. + }; + private: const Function &F; + LivenessType Type; /// Maps active slots (per bit) for each basic block. using LivenessMap = DenseMap; @@ -124,7 +133,8 @@ class StackLifetime { void calculateLiveIntervals(); public: - StackLifetime(const Function &F, ArrayRef Allocas); + StackLifetime(const Function &F, ArrayRef Allocas, + LivenessType Type); void run(); std::vector getMarkers() const; @@ -168,10 +178,12 @@ inline raw_ostream &operator<<(raw_ostream &OS, /// Printer pass for testing. class StackLifetimePrinterPass : public PassInfoMixin { + StackLifetime::LivenessType Type; raw_ostream &OS; public: - explicit StackLifetimePrinterPass(raw_ostream &OS) : OS(OS) {} + StackLifetimePrinterPass(raw_ostream &OS, StackLifetime::LivenessType Type) + : Type(Type), OS(OS) {} PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; diff --git a/llvm/lib/Analysis/StackLifetime.cpp b/llvm/lib/Analysis/StackLifetime.cpp index efdfcf5c6c5cdb..9e4df6473edeae 100644 --- a/llvm/lib/Analysis/StackLifetime.cpp +++ b/llvm/lib/Analysis/StackLifetime.cpp @@ -166,7 +166,17 @@ void StackLifetime::calculateLocalLiveness() { // If a predecessor is unreachable, ignore it. if (I == BlockLiveness.end()) continue; - LocalLiveIn |= I->second.LiveOut; + switch (Type) { + case LivenessType::May: + LocalLiveIn |= I->second.LiveOut; + break; + case LivenessType::Must: + if (LocalLiveIn.empty()) + LocalLiveIn = I->second.LiveOut; + else + LocalLiveIn &= I->second.LiveOut; + break; + } } // Compute LiveOut by subtracting out lifetimes that end in this @@ -272,8 +282,9 @@ LLVM_DUMP_METHOD void StackLifetime::dumpLiveRanges() const { #endif StackLifetime::StackLifetime(const Function &F, - ArrayRef Allocas) - : F(F), Allocas(Allocas), NumAllocas(Allocas.size()) { + ArrayRef Allocas, + LivenessType Type) + : F(F), Type(Type), Allocas(Allocas), NumAllocas(Allocas.size()) { LLVM_DEBUG(dumpAllocas()); for (unsigned I = 0; I < NumAllocas; ++I) @@ -351,7 +362,7 @@ PreservedAnalyses StackLifetimePrinterPass::run(Function &F, for (auto &I : instructions(F)) if (const AllocaInst *AI = dyn_cast(&I)) Allocas.push_back(AI); - StackLifetime SL(F, Allocas); + StackLifetime SL(F, Allocas, Type); SL.run(); SL.print(OS); return PreservedAnalyses::all(); diff --git a/llvm/lib/CodeGen/SafeStack.cpp b/llvm/lib/CodeGen/SafeStack.cpp index 1481894186e419..55478c232dd708 100644 --- a/llvm/lib/CodeGen/SafeStack.cpp +++ b/llvm/lib/CodeGen/SafeStack.cpp @@ -497,7 +497,7 @@ Value *SafeStack::moveStaticAllocasToUnsafeStack( DIBuilder DIB(*F.getParent()); - StackLifetime SSC(F, StaticAllocas); + StackLifetime SSC(F, StaticAllocas, StackLifetime::LivenessType::May); static const StackLifetime::LiveRange NoColoringRange(1, true); if (ClColoring) SSC.run(); diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 49866f35e7fdf3..d7f2a64e94d8d8 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -1851,6 +1851,26 @@ Expected parseGVNOptions(StringRef Params) { return Result; } +Expected +parseStackLifetimeOptions(StringRef Params) { + StackLifetime::LivenessType Result = StackLifetime::LivenessType::May; + while (!Params.empty()) { + StringRef ParamName; + std::tie(ParamName, Params) = Params.split(';'); + + if (ParamName == "may") { + Result = StackLifetime::LivenessType::May; + } else if (ParamName == "must") { + Result = StackLifetime::LivenessType::Must; + } else { + return make_error( + formatv("invalid StackLifetime parameter '{0}' ", ParamName).str(), + inconvertibleErrorCode()); + } + } + return Result; +} + } // namespace /// Tests whether a pass name starts with a valid prefix for a default pipeline diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index a38c651c673cb7..1edb0fe5c23060 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -239,7 +239,6 @@ FUNCTION_PASS("print", PhiValuesPrinterPass(dbgs())) FUNCTION_PASS("print", RegionInfoPrinterPass(dbgs())) FUNCTION_PASS("print", ScalarEvolutionPrinterPass(dbgs())) FUNCTION_PASS("print", StackSafetyPrinterPass(dbgs())) -FUNCTION_PASS("print", StackLifetimePrinterPass(dbgs())) FUNCTION_PASS("reassociate", ReassociatePass()) FUNCTION_PASS("scalarizer", ScalarizerPass()) FUNCTION_PASS("sccp", SCCPPass()) @@ -302,6 +301,11 @@ FUNCTION_PASS_WITH_PARAMS("gvn", return GVN(Opts); }, parseGVNOptions) +FUNCTION_PASS_WITH_PARAMS("print", + [](StackLifetime::LivenessType Type) { + return StackLifetimePrinterPass(dbgs(), Type); + }, + parseStackLifetimeOptions) #undef FUNCTION_PASS_WITH_PARAMS #ifndef LOOP_ANALYSIS diff --git a/llvm/test/Analysis/StackSafetyAnalysis/lifetime.ll b/llvm/test/Analysis/StackSafetyAnalysis/lifetime.ll index 61d0dc2286f606..d2ac1abdce6f1e 100644 --- a/llvm/test/Analysis/StackSafetyAnalysis/lifetime.ll +++ b/llvm/test/Analysis/StackSafetyAnalysis/lifetime.ll @@ -1,4 +1,5 @@ -; RUN: opt -passes='print' -disable-output %s 2>&1 | FileCheck %s +; RUN: opt -passes='print' -disable-output %s 2>&1 | FileCheck %s --check-prefixes=CHECK,MAY +; RUN: opt -passes='print' -disable-output %s 2>&1 | FileCheck %s --check-prefixes=CHECK,MUST define void @f() { ; CHECK-LABEL: define void @f() @@ -710,7 +711,8 @@ entry: l2: ; preds = %l2, %entry ; CHECK: l2: -; CHECK-NEXT: Alive: +; MAY-NEXT: Alive: +; MUST-NEXT: Alive: <> call void @capture8(i8* %x) call void @llvm.lifetime.end.p0i8(i64 4, i8* %x) ; CHECK: call void @llvm.lifetime.end.p0i8(i64 4, i8* %x) @@ -758,6 +760,55 @@ l2: ; preds = %l2, %entry ; CHECK-NEXT: Alive: } +define void @if_must(i1 %a) { +; CHECK-LABEL: define void @if_must +entry: +; CHECK: entry: +; CHECK-NEXT: Alive: <> + %x = alloca i8, align 4 + %y = alloca i8, align 4 + + br i1 %a, label %if.then, label %if.else +; CHECK: br i1 %a +; CHECK-NEXT: Alive: <> + +if.then: +; CHECK: if.then: +; CHECK-NEXT: Alive: <> + call void @llvm.lifetime.start.p0i8(i64 4, i8* %y) +; CHECK: call void @llvm.lifetime.start.p0i8(i64 4, i8* %y) +; CHECK-NEXT: Alive: + + br label %if.end +; CHECK: br label %if.end +; CHECK-NEXT: Alive: + +if.else: +; CHECK: if.else: +; CHECK-NEXT: Alive: <> + call void @llvm.lifetime.start.p0i8(i64 4, i8* %y) +; CHECK: call void @llvm.lifetime.start.p0i8(i64 4, i8* %y) +; CHECK-NEXT: Alive: + + call void @llvm.lifetime.start.p0i8(i64 4, i8* %x) +; CHECK: call void @llvm.lifetime.start.p0i8(i64 4, i8* %x) +; CHECK-NEXT: Alive: + + br label %if.end +; CHECK: br label %if.end +; CHECK-NEXT: Alive: + +if.end: +; CHECK: if.end: +; MAY-NEXT: Alive: +; MUST-NEXT: Alive: + +ret void +; CHECK: ret void +; MAY-NEXT: Alive: +; MUST-NEXT: Alive: +} + declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) declare void @capture8(i8*)