diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 0efa6f424a278a..468c9b824f61b2 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -387,13 +387,8 @@ static void analyzeParsePointLiveness( Result.LiveSet = LiveSet; } -// Returns true is V is a knownBaseResult. static bool isKnownBaseResult(Value *V); -// Returns true if V is a BaseResult that already exists in the IR, i.e. it is -// not created by the findBasePointers algorithm. -static bool isOriginalBaseResult(Value *V); - namespace { /// A single base defining value - An immediate base defining value for an @@ -638,20 +633,15 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) { return Def; } -/// This value is a base pointer that is not generated by RS4GC, i.e. it already -/// exists in the code. -static bool isOriginalBaseResult(Value *V) { - // no recursion possible - return !isa(V) && !isa(V) && - !isa(V) && !isa(V) && - !isa(V); -} - /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, /// is it known to be a base pointer? Or do we need to continue searching. static bool isKnownBaseResult(Value *V) { - if (isOriginalBaseResult(V)) + if (!isa(V) && !isa(V) && + !isa(V) && !isa(V) && + !isa(V)) { + // no recursion possible return true; + } if (isa(V) && cast(V)->getMetadata("is_base_value")) { // This is a previously inserted base phi or select. We know @@ -663,12 +653,6 @@ static bool isKnownBaseResult(Value *V) { return false; } -// Returns true if First and Second values are both scalar or both vector. -static bool areBothVectorOrScalar(Value *First, Value *Second) { - return isa(First->getType()) == - isa(Second->getType()); -} - namespace { /// Models the state of a single base defining value in the findBasePointer @@ -778,7 +762,7 @@ static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) { static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { Value *Def = findBaseOrBDV(I, Cache); - if (isKnownBaseResult(Def) && areBothVectorOrScalar(Def, I)) + if (isKnownBaseResult(Def)) return Def; // Here's the rough algorithm: @@ -826,16 +810,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { States.insert({Def, BDVState()}); while (!Worklist.empty()) { Value *Current = Worklist.pop_back_val(); - assert(!isOriginalBaseResult(Current) && "why did it get added?"); + assert(!isKnownBaseResult(Current) && "why did it get added?"); auto visitIncomingValue = [&](Value *InVal) { Value *Base = findBaseOrBDV(InVal, Cache); - if (isKnownBaseResult(Base) && areBothVectorOrScalar(Base, InVal)) + if (isKnownBaseResult(Base)) // Known bases won't need new instructions introduced and can be - // ignored safely. However, this can only be done when InVal and Base - // are both scalar or both vector. Otherwise, we need to find a - // correct BDV for InVal, by creating an entry in the lattice - // (States). + // ignored safely return; assert(isExpectedBDVType(Base) && "the only non-base values " "we see should be base defining values"); @@ -872,10 +853,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Return a phi state for a base defining value. We'll generate a new // base state for known bases and expect to find a cached state otherwise. - auto GetStateForBDV = [&](Value *BaseValue, Value *Input) { - if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input)) - return BDVState(BaseValue); - auto I = States.find(BaseValue); + auto getStateForBDV = [&](Value *baseValue) { + if (isKnownBaseResult(baseValue)) + return BDVState(baseValue); + auto I = States.find(baseValue); assert(I != States.end() && "lookup failed!"); return I->second; }; @@ -892,18 +873,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // much faster. for (auto Pair : States) { Value *BDV = Pair.first; - // Only values that do not have known bases or those that have differing - // type (scalar versus vector) from a possible known base should be in the - // lattice. - assert((!isKnownBaseResult(BDV) || - !areBothVectorOrScalar(BDV, Pair.second.getBaseValue())) && - "why did it get added?"); + assert(!isKnownBaseResult(BDV) && "why did it get added?"); // Given an input value for the current instruction, return a BDVState // instance which represents the BDV of that value. auto getStateForInput = [&](Value *V) mutable { Value *BDV = findBaseOrBDV(V, Cache); - return GetStateForBDV(BDV, V); + return getStateForBDV(BDV); }; BDVState NewState; @@ -950,41 +926,41 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { } #endif - // Handle all instructions that have a vector BDV, but the instruction itself - // is of scalar type. + // Handle extractelement instructions and their uses. for (auto Pair : States) { Instruction *I = cast(Pair.first); BDVState State = Pair.second; - auto *BaseValue = State.getBaseValue(); - // Only values that do not have known bases or those that have differing - // type (scalar versus vector) from a possible known base should be in the - // lattice. - assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, BaseValue)) && - "why did it get added?"); + assert(!isKnownBaseResult(I) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); - if (!State.isBase() || !isa(BaseValue->getType())) - continue; // extractelement instructions are a bit special in that we may need to // insert an extract even when we know an exact base for the instruction. // The problem is that we need to convert from a vector base to a scalar // base for the particular indice we're interested in. - if (isa(I)) { - auto *EE = cast(I); - // TODO: In many cases, the new instruction is just EE itself. We should - // exploit this, but can't do it here since it would break the invariant - // about the BDV not being known to be a base. - auto *BaseInst = ExtractElementInst::Create( - State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); - BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); - States[I] = BDVState(BDVState::Base, BaseInst); - } else if (!isa(I->getType())) { - // We need to handle cases that have a vector base but the instruction is - // a scalar type (these could be phis or selects or any instruction that - // are of scalar type, but the base can be a vector type). We - // conservatively set this as conflict. Setting the base value for these - // conflicts is handled in the next loop which traverses States. - States[I] = BDVState(BDVState::Conflict); + if (!State.isBase() || !isa(I) || + !isa(State.getBaseValue()->getType())) + continue; + auto *EE = cast(I); + // TODO: In many cases, the new instruction is just EE itself. We should + // exploit this, but can't do it here since it would break the invariant + // about the BDV not being known to be a base. + auto *BaseInst = ExtractElementInst::Create( + State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); + BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); + States[I] = BDVState(BDVState::Base, BaseInst); + + // We need to handle uses of the extractelement that have the same vector + // base as well but the use is a scalar type. Since we cannot reuse the + // same BaseInst above (may not satisfy property that base pointer should + // always dominate derived pointer), we conservatively set this as conflict. + // Setting the base value for these conflicts is handled in the next loop + // which traverses States. + for (User *U : I->users()) { + auto *UseI = dyn_cast(U); + if (!UseI || !States.count(UseI)) + continue; + if (!isa(UseI->getType()) && States[UseI] == State) + States[UseI] = BDVState(BDVState::Conflict); } } @@ -993,11 +969,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { for (auto Pair : States) { Instruction *I = cast(Pair.first); BDVState State = Pair.second; - // Only values that do not have known bases or those that have differing - // type (scalar versus vector) from a possible known base should be in the - // lattice. - assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, State.getBaseValue())) && - "why did it get added?"); + assert(!isKnownBaseResult(I) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); // Since we're joining a vector and scalar base, they can never be the @@ -1058,7 +1030,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { Value *BDV = findBaseOrBDV(Input, Cache); Value *Base = nullptr; - if (isKnownBaseResult(BDV) && areBothVectorOrScalar(BDV, Input)) { + if (isKnownBaseResult(BDV)) { Base = BDV; } else { // Either conflict or base. @@ -1079,12 +1051,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { Instruction *BDV = cast(Pair.first); BDVState State = Pair.second; - // Only values that do not have known bases or those that have differing - // type (scalar versus vector) from a possible known base should be in the - // lattice. - assert((!isKnownBaseResult(BDV) || - !areBothVectorOrScalar(BDV, State.getBaseValue())) && - "why did it get added?"); + assert(!isKnownBaseResult(BDV) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); if (!State.isConflict()) continue; @@ -1174,11 +1141,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { auto *BDV = Pair.first; Value *Base = Pair.second.getBaseValue(); assert(BDV && Base); - // Only values that do not have known bases or those that have differing - // type (scalar versus vector) from a possible known base should be in the - // lattice. - assert((!isKnownBaseResult(BDV) || !areBothVectorOrScalar(BDV, Base)) && - "why did it get added?"); + assert(!isKnownBaseResult(BDV) && "why did it get added?"); LLVM_DEBUG( dbgs() << "Updating base value cache" diff --git a/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll b/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll index a4290ef53f1b2d..34af81cd7337e6 100644 --- a/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll +++ b/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll @@ -192,75 +192,5 @@ latch: ; preds = %bb25, %bb7 br label %header } -; Uses of extractelement that are of scalar type should not have the BDV -; incorrectly identified as a vector type. -define void @widget() gc "statepoint-example" { -; CHECK-LABEL: @widget( -; CHECK-NEXT: bb6: -; CHECK-NEXT: [[BASE_EE:%.*]] = extractelement <2 x i8 addrspace(1)*> zeroinitializer, i32 1, !is_base_value !0 -; CHECK-NEXT: [[TMP:%.*]] = extractelement <2 x i8 addrspace(1)*> undef, i32 1 -; CHECK-NEXT: br i1 undef, label [[BB7:%.*]], label [[BB9:%.*]] -; CHECK: bb7: -; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12 -; CHECK-NEXT: br label [[BB11:%.*]] -; CHECK: bb9: -; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12 -; CHECK-NEXT: br i1 undef, label [[BB11]], label [[BB15:%.*]] -; CHECK: bb11: -; CHECK-NEXT: [[TMP12_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB7]] ], [ [[BASE_EE]], [[BB9]] ], !is_base_value !0 -; CHECK-NEXT: [[TMP12:%.*]] = phi i8 addrspace(1)* [ [[TMP8]], [[BB7]] ], [ [[TMP10]], [[BB9]] ] -; CHECK-NEXT: [[STATEPOINT_TOKEN:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP12_BASE]], i8 addrspace(1)* [[TMP12]]) -; CHECK-NEXT: [[TMP12_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 8) -; CHECK-NEXT: [[TMP12_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 9) -; CHECK-NEXT: br label [[BB15]] -; CHECK: bb15: -; CHECK-NEXT: [[TMP16_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB9]] ], [ [[TMP12_BASE_RELOCATED]], [[BB11]] ], !is_base_value !0 -; CHECK-NEXT: [[TMP16:%.*]] = phi i8 addrspace(1)* [ [[TMP10]], [[BB9]] ], [ [[TMP12_RELOCATED]], [[BB11]] ] -; CHECK-NEXT: br i1 undef, label [[BB17:%.*]], label [[BB20:%.*]] -; CHECK: bb17: -; CHECK-NEXT: [[STATEPOINT_TOKEN1:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP16_BASE]], i8 addrspace(1)* [[TMP16]]) -; CHECK-NEXT: [[TMP16_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 8) -; CHECK-NEXT: [[TMP16_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 9) -; CHECK-NEXT: br label [[BB20]] -; CHECK: bb20: -; CHECK-NEXT: [[DOT05:%.*]] = phi i8 addrspace(1)* [ [[TMP16_BASE_RELOCATED]], [[BB17]] ], [ [[TMP16_BASE]], [[BB15]] ] -; CHECK-NEXT: [[DOT0:%.*]] = phi i8 addrspace(1)* [ [[TMP16_RELOCATED]], [[BB17]] ], [ [[TMP16]], [[BB15]] ] -; CHECK-NEXT: [[STATEPOINT_TOKEN2:%.*]] = call token (i64, i32, void (i8 addrspace(1)*)*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidp1i8f(i64 2882400000, i32 0, void (i8 addrspace(1)*)* @foo, i32 1, i32 0, i8 addrspace(1)* [[DOT0]], i32 0, i32 0, i8 addrspace(1)* [[DOT05]], i8 addrspace(1)* [[DOT0]]) -; CHECK-NEXT: [[TMP16_BASE_RELOCATED3:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 8) -; CHECK-NEXT: [[TMP16_RELOCATED4:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 9) -; CHECK-NEXT: ret void -; -bb6: ; preds = %bb3 - %tmp = extractelement <2 x i8 addrspace(1)*> undef, i32 1 - br i1 undef, label %bb7, label %bb9 - -bb7: ; preds = %bb6 - %tmp8 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12 - br label %bb11 - -bb9: ; preds = %bb6, %bb6 - %tmp10 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12 - br i1 undef, label %bb11, label %bb15 - -bb11: ; preds = %bb9, %bb7 - %tmp12 = phi i8 addrspace(1)* [ %tmp8, %bb7 ], [ %tmp10, %bb9 ] - call void @snork() [ "deopt"(i32 undef) ] - br label %bb15 - -bb15: ; preds = %bb11, %bb9, %bb9 - %tmp16 = phi i8 addrspace(1)* [ %tmp10, %bb9 ], [ %tmp12, %bb11 ] - br i1 undef, label %bb17, label %bb20 - -bb17: ; preds = %bb15 - call void @snork() [ "deopt"(i32 undef) ] - br label %bb20 - -bb20: ; preds = %bb17, %bb15, %bb15 - call void @foo(i8 addrspace(1)* %tmp16) - ret void -} - -declare void @snork() -declare void @foo(i8 addrspace(1)*) declare void @spam() declare <2 x i8 addrspace(1)*> @baz()