diff --git a/llvm/include/llvm/Transforms/Scalar/SROA.h b/llvm/include/llvm/Transforms/Scalar/SROA.h index 85b75c4d4640ee..26348da2202116 100644 --- a/llvm/include/llvm/Transforms/Scalar/SROA.h +++ b/llvm/include/llvm/Transforms/Scalar/SROA.h @@ -27,6 +27,7 @@ namespace llvm { class AllocaInst; class LoadInst; +class StoreInst; class AssumptionCache; class DominatorTree; class DomTreeUpdater; @@ -46,7 +47,7 @@ class Partition; class SROALegacyPass; class SelectHandSpeculativity { - unsigned char Storage = 0; + unsigned char Storage = 0; // None are speculatable by default. using TrueVal = Bitfield::Element; // Low 0'th bit. using FalseVal = Bitfield::Element; // Low 1'th bit. public: @@ -64,7 +65,10 @@ static_assert(sizeof(SelectHandSpeculativity) == sizeof(unsigned char)); using PossiblySpeculatableLoad = PointerIntPair; -using PossiblySpeculatableLoads = SmallVector; +using UnspeculatableStore = StoreInst *; +using RewriteableMemOp = + std::variant; +using RewriteableMemOps = SmallVector; } // end namespace sroa @@ -130,8 +134,7 @@ class SROAPass : public PassInfoMixin { /// A worklist of select instructions to rewrite prior to promoting /// allocas. - SmallMapVector - SelectsToRewrite; + SmallMapVector SelectsToRewrite; /// Select instructions that use an alloca and are subsequently loaded can be /// rewritten to load both input pointers and then select between the result, @@ -149,7 +152,7 @@ class SROAPass : public PassInfoMixin { /// or if we are allowed to perform CFG modifications. /// If found an intervening bitcast with a single use of the load, /// allow the promotion. - static std::optional + static std::optional isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG); public: diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index cf988e02beb214..96afc1af7beb81 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -108,6 +108,9 @@ STATISTIC(NumPromoted, "Number of allocas promoted to SSA values"); STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion"); STATISTIC(NumLoadsPredicated, "Number of loads rewritten into predicated loads to allow promotion"); +STATISTIC( + NumStoresPredicated, + "Number of stores rewritten into predicated loads to allow promotion"); STATISTIC(NumDeleted, "Number of instructions deleted"); STATISTIC(NumVectorized, "Number of vectorized aggregates"); @@ -1353,17 +1356,25 @@ isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) { return Spec; } -std::optional +std::optional SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { - PossiblySpeculatableLoads Loads; + RewriteableMemOps Ops; for (User *U : SI.users()) { - LoadInst *LI; - BitCastInst *BC = dyn_cast(U); - if (BC && BC->hasOneUse()) - LI = dyn_cast(*BC->user_begin()); - else - LI = dyn_cast(U); + if (auto *BC = dyn_cast(U); BC && BC->hasOneUse()) + U = *BC->user_begin(); + + if (auto *Store = dyn_cast(U)) { + // Note that atomic stores can be transformed; atomic semantics do not + // have any meaning for a local alloca. Stores are not speculatable, + // however, so if we can't turn it into a predicated store, we are done. + if (Store->isVolatile() || PreserveCFG) + return {}; // Give up on this `select`. + Ops.emplace_back(Store); + continue; + } + + auto *LI = dyn_cast(U); // Note that atomic loads can be transformed; // atomic semantics do not have any meaning for a local alloca. @@ -1371,13 +1382,12 @@ SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { return {}; // Give up on this `select`. PossiblySpeculatableLoad Load(LI); - if (!LI->isSimple()) { // If the `load` is not simple, we can't speculatively execute it, // but we could handle this via a CFG modification. But can we? if (PreserveCFG) return {}; // Give up on this `select`. - Loads.emplace_back(Load); + Ops.emplace_back(Load); continue; } @@ -1387,10 +1397,10 @@ SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { return {}; // Give up on this `select`. Load.setInt(Spec); - Loads.emplace_back(Load); + Ops.emplace_back(Load); } - return Loads; + return Ops; } static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI, @@ -1430,18 +1440,20 @@ static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI, LI.replaceAllUsesWith(V); } -static void rewriteLoadOfSelect(SelectInst &SI, LoadInst &LI, - sroa::SelectHandSpeculativity Spec, - DomTreeUpdater &DTU) { - LLVM_DEBUG(dbgs() << " original load: " << SI << "\n"); - BasicBlock *Head = LI.getParent(); +template +static void rewriteMemOpOfSelect(SelectInst &SI, T &I, + sroa::SelectHandSpeculativity Spec, + DomTreeUpdater &DTU) { + assert((isa(I) || isa(I)) && "Only for load and store!"); + LLVM_DEBUG(dbgs() << " original mem op: " << I << "\n"); + BasicBlock *Head = I.getParent(); Instruction *ThenTerm = nullptr; Instruction *ElseTerm = nullptr; if (Spec.areNoneSpeculatable()) - SplitBlockAndInsertIfThenElse(SI.getCondition(), &LI, &ThenTerm, &ElseTerm, + SplitBlockAndInsertIfThenElse(SI.getCondition(), &I, &ThenTerm, &ElseTerm, SI.getMetadata(LLVMContext::MD_prof), &DTU); else { - SplitBlockAndInsertIfThen(SI.getCondition(), &LI, /*Unreachable=*/false, + SplitBlockAndInsertIfThen(SI.getCondition(), &I, /*Unreachable=*/false, SI.getMetadata(LLVMContext::MD_prof), &DTU, /*LI=*/nullptr, /*ThenBlock=*/nullptr); if (Spec.isSpeculatable(/*isTrueVal=*/true)) @@ -1449,46 +1461,75 @@ static void rewriteLoadOfSelect(SelectInst &SI, LoadInst &LI, } auto *HeadBI = cast(Head->getTerminator()); Spec = {}; // Do not use `Spec` beyond this point. - BasicBlock *Tail = LI.getParent(); + BasicBlock *Tail = I.getParent(); Tail->setName(Head->getName() + ".cont"); - auto *PN = PHINode::Create(LI.getType(), 2, "", &LI); + PHINode *PN; + if (isa(I)) + PN = PHINode::Create(I.getType(), 2, "", &I); for (BasicBlock *SuccBB : successors(Head)) { bool IsThen = SuccBB == HeadBI->getSuccessor(0); int SuccIdx = IsThen ? 0 : 1; - auto *NewLoadBB = SuccBB == Tail ? Head : SuccBB; - if (NewLoadBB != Head) { - NewLoadBB->setName(Head->getName() + (IsThen ? ".then" : ".else")); - ++NumLoadsPredicated; + auto *NewMemOpBB = SuccBB == Tail ? Head : SuccBB; + if (NewMemOpBB != Head) { + NewMemOpBB->setName(Head->getName() + (IsThen ? ".then" : ".else")); + if (isa(I)) + ++NumLoadsPredicated; + else + ++NumStoresPredicated; } else ++NumLoadsSpeculated; - auto *CondLoad = cast(LI.clone()); - CondLoad->insertBefore(NewLoadBB->getTerminator()); - CondLoad->setOperand(0, SI.getOperand(1 + SuccIdx)); - CondLoad->setName(LI.getName() + (IsThen ? ".then" : ".else") + ".val"); - PN->addIncoming(CondLoad, NewLoadBB); + auto &CondMemOp = cast(*I.clone()); + CondMemOp.insertBefore(NewMemOpBB->getTerminator()); + CondMemOp.setOperand(I.getPointerOperandIndex(), + SI.getOperand(1 + SuccIdx)); + if (isa(I)) { + CondMemOp.setName(I.getName() + (IsThen ? ".then" : ".else") + ".val"); + PN->addIncoming(&CondMemOp, NewMemOpBB); + } else + LLVM_DEBUG(dbgs() << " to: " << CondMemOp << "\n"); } - PN->takeName(&LI); - LLVM_DEBUG(dbgs() << " to: " << *PN << "\n"); - LI.replaceAllUsesWith(PN); + if (isa(I)) { + PN->takeName(&I); + LLVM_DEBUG(dbgs() << " to: " << *PN << "\n"); + I.replaceAllUsesWith(PN); + } +} + +static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I, + sroa::SelectHandSpeculativity Spec, + DomTreeUpdater &DTU) { + if (auto *LI = dyn_cast(&I)) + rewriteMemOpOfSelect(SelInst, *LI, Spec, DTU); + else if (auto *SI = dyn_cast(&I)) + rewriteMemOpOfSelect(SelInst, *SI, Spec, DTU); + else + llvm_unreachable_internal("Only for load and store."); } -static bool rewriteSelectInstLoads(SelectInst &SI, - const sroa::PossiblySpeculatableLoads &Loads, - IRBuilderTy &IRB, DomTreeUpdater *DTU) { +static bool rewriteSelectInstMemOps(SelectInst &SI, + const sroa::RewriteableMemOps &Ops, + IRBuilderTy &IRB, DomTreeUpdater *DTU) { bool CFGChanged = false; LLVM_DEBUG(dbgs() << " original select: " << SI << "\n"); - for (const PossiblySpeculatableLoad &Load : Loads) { - LoadInst *LI = Load.getPointer(); - sroa::SelectHandSpeculativity Spec = Load.getInt(); + for (const RewriteableMemOp &Op : Ops) { + sroa::SelectHandSpeculativity Spec; + Instruction *I; + if (auto *const *US = std::get_if(&Op)) { + I = *US; + } else { + auto PSL = std::get(Op); + I = PSL.getPointer(); + Spec = PSL.getInt(); + } if (Spec.areAllSpeculatable()) { - speculateSelectInstLoads(SI, *LI, IRB); + speculateSelectInstLoads(SI, cast(*I), IRB); } else { assert(DTU && "Should not get here when not allowed to modify the CFG!"); - rewriteLoadOfSelect(SI, *LI, Spec, *DTU); + rewriteMemOpOfSelect(SI, *I, Spec, *DTU); CFGChanged = true; } - LI->eraseFromParent(); + I->eraseFromParent(); } for (User *U : make_early_inc_range(SI.users())) @@ -4490,20 +4531,20 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, break; } - SmallVector, 2> + SmallVector, 2> NewSelectsToRewrite; NewSelectsToRewrite.reserve(SelectUsers.size()); for (SelectInst *Sel : SelectUsers) { - std::optional Loads = + std::optional Ops = isSafeSelectToSpeculate(*Sel, PreserveCFG); - if (!Loads) { + if (!Ops) { Promotable = false; PHIUsers.clear(); SelectUsers.clear(); NewSelectsToRewrite.clear(); break; } - NewSelectsToRewrite.emplace_back(std::make_pair(Sel, *Loads)); + NewSelectsToRewrite.emplace_back(std::make_pair(Sel, *Ops)); } if (Promotable) { @@ -4809,7 +4850,7 @@ SROAPass::runOnAlloca(AllocaInst &AI) { while (!RemainingSelectsToRewrite.empty()) { const auto [K, V] = RemainingSelectsToRewrite.pop_back_val(); CFGChanged |= - rewriteSelectInstLoads(*K, V, IRB, PreserveCFG ? nullptr : DTU); + rewriteSelectInstMemOps(*K, V, IRB, PreserveCFG ? nullptr : DTU); } return {Changed, CFGChanged}; diff --git a/llvm/test/Transforms/SROA/select-store.ll b/llvm/test/Transforms/SROA/select-store.ll index 830b0cf67924b5..33090c1d3b248c 100644 --- a/llvm/test/Transforms/SROA/select-store.ll +++ b/llvm/test/Transforms/SROA/select-store.ll @@ -6,15 +6,28 @@ declare i8 @gen.i8() declare ptr @gen.ptr() define i8 @store(i8 %init, i1 %cond, ptr dereferenceable(4) %escape) { -; CHECK-LABEL: @store( -; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP:%.*]] = alloca i8, align 4 -; CHECK-NEXT: store i8 [[INIT:%.*]], ptr [[TMP]], align 4 -; CHECK-NEXT: [[REINIT:%.*]] = call i8 @gen.i8() -; CHECK-NEXT: [[ADDR:%.*]] = select i1 [[COND:%.*]], ptr [[TMP]], ptr [[ESCAPE:%.*]] -; CHECK-NEXT: store i8 [[REINIT]], ptr [[ADDR]], align 4 -; CHECK-NEXT: [[TMP_0_RES:%.*]] = load i8, ptr [[TMP]], align 4 -; CHECK-NEXT: ret i8 [[TMP_0_RES]] +; CHECK-PRESERVE-CFG-LABEL: @store( +; CHECK-PRESERVE-CFG-NEXT: entry: +; CHECK-PRESERVE-CFG-NEXT: [[TMP:%.*]] = alloca i8, align 4 +; CHECK-PRESERVE-CFG-NEXT: store i8 [[INIT:%.*]], ptr [[TMP]], align 4 +; CHECK-PRESERVE-CFG-NEXT: [[REINIT:%.*]] = call i8 @gen.i8() +; CHECK-PRESERVE-CFG-NEXT: [[ADDR:%.*]] = select i1 [[COND:%.*]], ptr [[TMP]], ptr [[ESCAPE:%.*]] +; CHECK-PRESERVE-CFG-NEXT: store i8 [[REINIT]], ptr [[ADDR]], align 4 +; CHECK-PRESERVE-CFG-NEXT: [[TMP_0_RES:%.*]] = load i8, ptr [[TMP]], align 4 +; CHECK-PRESERVE-CFG-NEXT: ret i8 [[TMP_0_RES]] +; +; CHECK-MODIFY-CFG-LABEL: @store( +; CHECK-MODIFY-CFG-NEXT: entry: +; CHECK-MODIFY-CFG-NEXT: [[REINIT:%.*]] = call i8 @gen.i8() +; CHECK-MODIFY-CFG-NEXT: br i1 [[COND:%.*]], label [[ENTRY_THEN:%.*]], label [[ENTRY_ELSE:%.*]] +; CHECK-MODIFY-CFG: entry.then: +; CHECK-MODIFY-CFG-NEXT: br label [[ENTRY_CONT:%.*]] +; CHECK-MODIFY-CFG: entry.else: +; CHECK-MODIFY-CFG-NEXT: store i8 [[REINIT]], ptr [[ESCAPE:%.*]], align 4 +; CHECK-MODIFY-CFG-NEXT: br label [[ENTRY_CONT]] +; CHECK-MODIFY-CFG: entry.cont: +; CHECK-MODIFY-CFG-NEXT: [[TMP_0:%.*]] = phi i8 [ [[REINIT]], [[ENTRY_THEN]] ], [ [[INIT:%.*]], [[ENTRY_ELSE]] ] +; CHECK-MODIFY-CFG-NEXT: ret i8 [[TMP_0]] ; entry: %tmp = alloca i8, align 4 @@ -48,15 +61,28 @@ entry: } define i8 @store_atomic_unord(i8 %init, i1 %cond, ptr dereferenceable(4) %escape) { -; CHECK-LABEL: @store_atomic_unord( -; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP:%.*]] = alloca i8, align 4 -; CHECK-NEXT: store i8 [[INIT:%.*]], ptr [[TMP]], align 4 -; CHECK-NEXT: [[REINIT:%.*]] = call i8 @gen.i8() -; CHECK-NEXT: [[ADDR:%.*]] = select i1 [[COND:%.*]], ptr [[TMP]], ptr [[ESCAPE:%.*]] -; CHECK-NEXT: store atomic i8 [[REINIT]], ptr [[ADDR]] unordered, align 4 -; CHECK-NEXT: [[TMP_0_RES:%.*]] = load i8, ptr [[TMP]], align 4 -; CHECK-NEXT: ret i8 [[TMP_0_RES]] +; CHECK-PRESERVE-CFG-LABEL: @store_atomic_unord( +; CHECK-PRESERVE-CFG-NEXT: entry: +; CHECK-PRESERVE-CFG-NEXT: [[TMP:%.*]] = alloca i8, align 4 +; CHECK-PRESERVE-CFG-NEXT: store i8 [[INIT:%.*]], ptr [[TMP]], align 4 +; CHECK-PRESERVE-CFG-NEXT: [[REINIT:%.*]] = call i8 @gen.i8() +; CHECK-PRESERVE-CFG-NEXT: [[ADDR:%.*]] = select i1 [[COND:%.*]], ptr [[TMP]], ptr [[ESCAPE:%.*]] +; CHECK-PRESERVE-CFG-NEXT: store atomic i8 [[REINIT]], ptr [[ADDR]] unordered, align 4 +; CHECK-PRESERVE-CFG-NEXT: [[TMP_0_RES:%.*]] = load i8, ptr [[TMP]], align 4 +; CHECK-PRESERVE-CFG-NEXT: ret i8 [[TMP_0_RES]] +; +; CHECK-MODIFY-CFG-LABEL: @store_atomic_unord( +; CHECK-MODIFY-CFG-NEXT: entry: +; CHECK-MODIFY-CFG-NEXT: [[REINIT:%.*]] = call i8 @gen.i8() +; CHECK-MODIFY-CFG-NEXT: br i1 [[COND:%.*]], label [[ENTRY_THEN:%.*]], label [[ENTRY_ELSE:%.*]] +; CHECK-MODIFY-CFG: entry.then: +; CHECK-MODIFY-CFG-NEXT: br label [[ENTRY_CONT:%.*]] +; CHECK-MODIFY-CFG: entry.else: +; CHECK-MODIFY-CFG-NEXT: store atomic i8 [[REINIT]], ptr [[ESCAPE:%.*]] unordered, align 4 +; CHECK-MODIFY-CFG-NEXT: br label [[ENTRY_CONT]] +; CHECK-MODIFY-CFG: entry.cont: +; CHECK-MODIFY-CFG-NEXT: [[TMP_0:%.*]] = phi i8 [ [[REINIT]], [[ENTRY_THEN]] ], [ [[INIT:%.*]], [[ENTRY_ELSE]] ] +; CHECK-MODIFY-CFG-NEXT: ret i8 [[TMP_0]] ; entry: %tmp = alloca i8, align 4 @@ -88,6 +114,3 @@ entry: %res = load i8, ptr %tmp, align 4 ret i8 %res } -;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: -; CHECK-MODIFY-CFG: {{.*}} -; CHECK-PRESERVE-CFG: {{.*}}