158 changes: 88 additions & 70 deletions llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GCStrategy.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
Expand Down Expand Up @@ -125,6 +126,9 @@ static cl::opt<bool> RematDerivedAtUses("rs4gc-remat-derived-at-uses",
/// constant physical memory: llvm.invariant.start.
static void stripNonValidData(Module &M);

// Find the GC strategy for a function, or null if it doesn't have one.
static std::unique_ptr<GCStrategy> findGCStrategy(Function &F);

static bool shouldRewriteStatepointsIn(Function &F);

PreservedAnalyses RewriteStatepointsForGC::run(Module &M,
Expand Down Expand Up @@ -311,61 +315,60 @@ static ArrayRef<Use> GetDeoptBundleOperands(const CallBase *Call) {

/// Compute the live-in set for every basic block in the function
static void computeLiveInValues(DominatorTree &DT, Function &F,
GCPtrLivenessData &Data);
GCPtrLivenessData &Data, GCStrategy *GC);

/// Given results from the dataflow liveness computation, find the set of live
/// Values at a particular instruction.
static void findLiveSetAtInst(Instruction *inst, GCPtrLivenessData &Data,
StatepointLiveSetTy &out);
StatepointLiveSetTy &out, GCStrategy *GC);

// TODO: Once we can get to the GCStrategy, this becomes
// std::optional<bool> isGCManagedPointer(const Type *Ty) const override {
static bool isGCPointerType(Type *T, GCStrategy *GC) {
assert(GC && "GC Strategy for isGCPointerType cannot be null");

static bool isGCPointerType(Type *T) {
if (auto *PT = dyn_cast<PointerType>(T))
// For the sake of this example GC, we arbitrarily pick addrspace(1) as our
// GC managed heap. We know that a pointer into this heap needs to be
// updated and that no other pointer does.
return PT->getAddressSpace() == 1;
return false;
if (!isa<PointerType>(T))
return false;

// conservative - same as StatepointLowering
return GC->isGCManagedPointer(T).value_or(true);
}

// Return true if this type is one which a) is a gc pointer or contains a GC
// pointer and b) is of a type this code expects to encounter as a live value.
// (The insertion code will assert that a type which matches (a) and not (b)
// is not encountered.)
static bool isHandledGCPointerType(Type *T) {
static bool isHandledGCPointerType(Type *T, GCStrategy *GC) {
// We fully support gc pointers
if (isGCPointerType(T))
if (isGCPointerType(T, GC))
return true;
// We partially support vectors of gc pointers. The code will assert if it
// can't handle something.
if (auto VT = dyn_cast<VectorType>(T))
if (isGCPointerType(VT->getElementType()))
if (isGCPointerType(VT->getElementType(), GC))
return true;
return false;
}

#ifndef NDEBUG
/// Returns true if this type contains a gc pointer whether we know how to
/// handle that type or not.
static bool containsGCPtrType(Type *Ty) {
if (isGCPointerType(Ty))
static bool containsGCPtrType(Type *Ty, GCStrategy *GC) {
if (isGCPointerType(Ty, GC))
return true;
if (VectorType *VT = dyn_cast<VectorType>(Ty))
return isGCPointerType(VT->getScalarType());
return isGCPointerType(VT->getScalarType(), GC);
if (ArrayType *AT = dyn_cast<ArrayType>(Ty))
return containsGCPtrType(AT->getElementType());
return containsGCPtrType(AT->getElementType(), GC);
if (StructType *ST = dyn_cast<StructType>(Ty))
return llvm::any_of(ST->elements(), containsGCPtrType);
return llvm::any_of(ST->elements(),
[GC](Type *Ty) { return containsGCPtrType(Ty, GC); });
return false;
}

// Returns true if this is a type which a) is a gc pointer or contains a GC
// pointer and b) is of a type which the code doesn't expect (i.e. first class
// aggregates). Used to trip assertions.
static bool isUnhandledGCPointerType(Type *Ty) {
return containsGCPtrType(Ty) && !isHandledGCPointerType(Ty);
static bool isUnhandledGCPointerType(Type *Ty, GCStrategy *GC) {
return containsGCPtrType(Ty, GC) && !isHandledGCPointerType(Ty, GC);
}
#endif

Expand All @@ -382,9 +385,9 @@ static std::string suffixed_name_or(Value *V, StringRef Suffix,
// live. Values used by that instruction are considered live.
static void analyzeParsePointLiveness(
DominatorTree &DT, GCPtrLivenessData &OriginalLivenessData, CallBase *Call,
PartiallyConstructedSafepointRecord &Result) {
PartiallyConstructedSafepointRecord &Result, GCStrategy *GC) {
StatepointLiveSetTy LiveSet;
findLiveSetAtInst(Call, OriginalLivenessData, LiveSet);
findLiveSetAtInst(Call, OriginalLivenessData, LiveSet, GC);

if (PrintLiveSet) {
dbgs() << "Live Variables:\n";
Expand Down Expand Up @@ -1385,20 +1388,21 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache,
static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData,
CallBase *Call,
PartiallyConstructedSafepointRecord &result,
PointerToBaseTy &PointerToBase);
PointerToBaseTy &PointerToBase,
GCStrategy *GC);

static void recomputeLiveInValues(
Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate,
MutableArrayRef<struct PartiallyConstructedSafepointRecord> records,
PointerToBaseTy &PointerToBase) {
PointerToBaseTy &PointerToBase, GCStrategy *GC) {
// TODO-PERF: reuse the original liveness, then simply run the dataflow
// again. The old values are still live and will help it stabilize quickly.
GCPtrLivenessData RevisedLivenessData;
computeLiveInValues(DT, F, RevisedLivenessData);
computeLiveInValues(DT, F, RevisedLivenessData, GC);
for (size_t i = 0; i < records.size(); i++) {
struct PartiallyConstructedSafepointRecord &info = records[i];
recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info,
PointerToBase);
recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info, PointerToBase,
GC);
}
}

Expand Down Expand Up @@ -1522,7 +1526,7 @@ static AttributeList legalizeCallAttributes(LLVMContext &Ctx,
static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
ArrayRef<Value *> BasePtrs,
Instruction *StatepointToken,
IRBuilder<> &Builder) {
IRBuilder<> &Builder, GCStrategy *GC) {
if (LiveVariables.empty())
return;

Expand All @@ -1542,8 +1546,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
// towards a single unified pointer type anyways, we can just cast everything
// to an i8* of the right address space. A bitcast is added later to convert
// gc_relocate to the actual value's type.
auto getGCRelocateDecl = [&] (Type *Ty) {
assert(isHandledGCPointerType(Ty));
auto getGCRelocateDecl = [&](Type *Ty) {
assert(isHandledGCPointerType(Ty, GC));
auto AS = Ty->getScalarType()->getPointerAddressSpace();
Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS);
if (auto *VT = dyn_cast<VectorType>(Ty))
Expand Down Expand Up @@ -1668,7 +1672,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
const SmallVectorImpl<Value *> &LiveVariables,
PartiallyConstructedSafepointRecord &Result,
std::vector<DeferredReplacement> &Replacements,
const PointerToBaseTy &PointerToBase) {
const PointerToBaseTy &PointerToBase,
GCStrategy *GC) {
assert(BasePtrs.size() == LiveVariables.size());

// Then go ahead and use the builder do actually do the inserts. We insert
Expand Down Expand Up @@ -1901,7 +1906,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst();
Result.UnwindToken = ExceptionalToken;

CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder);
CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder, GC);

// Generate gc relocates and returns for normal block
BasicBlock *NormalDest = II->getNormalDest();
Expand Down Expand Up @@ -1947,7 +1952,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */
Result.StatepointToken = Token;

// Second, create a gc.relocate for every live variable
CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder);
CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder, GC);
}

// Replace an existing gc.statepoint with a new one and a set of gc.relocates
Expand All @@ -1959,7 +1964,7 @@ static void
makeStatepointExplicit(DominatorTree &DT, CallBase *Call,
PartiallyConstructedSafepointRecord &Result,
std::vector<DeferredReplacement> &Replacements,
const PointerToBaseTy &PointerToBase) {
const PointerToBaseTy &PointerToBase, GCStrategy *GC) {
const auto &LiveSet = Result.LiveSet;

// Convert to vector for efficient cross referencing.
Expand All @@ -1976,7 +1981,7 @@ makeStatepointExplicit(DominatorTree &DT, CallBase *Call,

// Do the actual rewriting and delete the old statepoint
makeStatepointExplicitImpl(Call, BaseVec, LiveVec, Result, Replacements,
PointerToBase);
PointerToBase, GC);
}

// Helper function for the relocationViaAlloca.
Expand Down Expand Up @@ -2277,12 +2282,13 @@ static void insertUseHolderAfter(CallBase *Call, const ArrayRef<Value *> Values,

static void findLiveReferences(
Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate,
MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) {
MutableArrayRef<struct PartiallyConstructedSafepointRecord> records,
GCStrategy *GC) {
GCPtrLivenessData OriginalLivenessData;
computeLiveInValues(DT, F, OriginalLivenessData);
computeLiveInValues(DT, F, OriginalLivenessData, GC);
for (size_t i = 0; i < records.size(); i++) {
struct PartiallyConstructedSafepointRecord &info = records[i];
analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info);
analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info, GC);
}
}

Expand Down Expand Up @@ -2684,6 +2690,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
SmallVectorImpl<CallBase *> &ToUpdate,
DefiningValueMapTy &DVCache,
IsKnownBaseMapTy &KnownBases) {
std::unique_ptr<GCStrategy> GC = findGCStrategy(F);

#ifndef NDEBUG
// Validate the input
std::set<CallBase *> Uniqued;
Expand Down Expand Up @@ -2718,9 +2726,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
SmallVector<Value *, 64> DeoptValues;

for (Value *Arg : GetDeoptBundleOperands(Call)) {
assert(!isUnhandledGCPointerType(Arg->getType()) &&
assert(!isUnhandledGCPointerType(Arg->getType(), GC.get()) &&
"support for FCA unimplemented");
if (isHandledGCPointerType(Arg->getType()))
if (isHandledGCPointerType(Arg->getType(), GC.get()))
DeoptValues.push_back(Arg);
}

Expand All @@ -2731,7 +2739,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,

// A) Identify all gc pointers which are statically live at the given call
// site.
findLiveReferences(F, DT, ToUpdate, Records);
findLiveReferences(F, DT, ToUpdate, Records, GC.get());

/// Global mapping from live pointers to a base-defining-value.
PointerToBaseTy PointerToBase;
Expand Down Expand Up @@ -2782,7 +2790,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// By selecting base pointers, we've effectively inserted new uses. Thus, we
// need to rerun liveness. We may *also* have inserted new defs, but that's
// not the key issue.
recomputeLiveInValues(F, DT, ToUpdate, Records, PointerToBase);
recomputeLiveInValues(F, DT, ToUpdate, Records, PointerToBase, GC.get());

if (PrintBasePointers) {
errs() << "Base Pairs: (w/Relocation)\n";
Expand Down Expand Up @@ -2842,7 +2850,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
// the old statepoint calls as we go.)
for (size_t i = 0; i < Records.size(); i++)
makeStatepointExplicit(DT, ToUpdate[i], Records[i], Replacements,
PointerToBase);
PointerToBase, GC.get());

ToUpdate.clear(); // prevent accident use of invalid calls.

Expand Down Expand Up @@ -2899,7 +2907,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,
#ifndef NDEBUG
// Validation check
for (auto *Ptr : Live)
assert(isHandledGCPointerType(Ptr->getType()) &&
assert(isHandledGCPointerType(Ptr->getType(), GC.get()) &&
"must be a gc pointer type");
#endif

Expand Down Expand Up @@ -3026,18 +3034,26 @@ static void stripNonValidDataFromBody(Function &F) {
}
}

/// Looks up the GC strategy for a given function, returning null if the
/// function doesn't have a GC tag. The strategy is stored in the cache.
static std::unique_ptr<GCStrategy> findGCStrategy(Function &F) {
if (!F.hasGC())
return nullptr;

return getGCStrategy(F.getGC());
}

/// Returns true if this function should be rewritten by this pass. The main
/// point of this function is as an extension point for custom logic.
static bool shouldRewriteStatepointsIn(Function &F) {
// TODO: This should check the GCStrategy
if (F.hasGC()) {
const auto &FunctionGCName = F.getGC();
const StringRef StatepointExampleName("statepoint-example");
const StringRef CoreCLRName("coreclr");
return (StatepointExampleName == FunctionGCName) ||
(CoreCLRName == FunctionGCName);
} else
if (!F.hasGC())
return false;

std::unique_ptr<GCStrategy> Strategy = findGCStrategy(F);

assert(Strategy && "GC strategy is required by function, but was not found");

return Strategy->useRS4GC();
}

static void stripNonValidData(Module &M) {
Expand Down Expand Up @@ -3216,7 +3232,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,
/// the live-out set of the basic block
static void computeLiveInValues(BasicBlock::reverse_iterator Begin,
BasicBlock::reverse_iterator End,
SetVector<Value *> &LiveTmp) {
SetVector<Value *> &LiveTmp, GCStrategy *GC) {
for (auto &I : make_range(Begin, End)) {
// KILL/Def - Remove this definition from LiveIn
LiveTmp.remove(&I);
Expand All @@ -3228,9 +3244,9 @@ static void computeLiveInValues(BasicBlock::reverse_iterator Begin,

// USE - Add to the LiveIn set for this instruction
for (Value *V : I.operands()) {
assert(!isUnhandledGCPointerType(V->getType()) &&
assert(!isUnhandledGCPointerType(V->getType(), GC) &&
"support for FCA unimplemented");
if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) {
if (isHandledGCPointerType(V->getType(), GC) && !isa<Constant>(V)) {
// The choice to exclude all things constant here is slightly subtle.
// There are two independent reasons:
// - We assume that things which are constant (from LLVM's definition)
Expand All @@ -3247,26 +3263,27 @@ static void computeLiveInValues(BasicBlock::reverse_iterator Begin,
}
}

static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) {
static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp,
GCStrategy *GC) {
for (BasicBlock *Succ : successors(BB)) {
for (auto &I : *Succ) {
PHINode *PN = dyn_cast<PHINode>(&I);
if (!PN)
break;

Value *V = PN->getIncomingValueForBlock(BB);
assert(!isUnhandledGCPointerType(V->getType()) &&
assert(!isUnhandledGCPointerType(V->getType(), GC) &&
"support for FCA unimplemented");
if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V))
if (isHandledGCPointerType(V->getType(), GC) && !isa<Constant>(V))
LiveTmp.insert(V);
}
}
}

static SetVector<Value *> computeKillSet(BasicBlock *BB) {
static SetVector<Value *> computeKillSet(BasicBlock *BB, GCStrategy *GC) {
SetVector<Value *> KillSet;
for (Instruction &I : *BB)
if (isHandledGCPointerType(I.getType()))
if (isHandledGCPointerType(I.getType(), GC))
KillSet.insert(&I);
return KillSet;
}
Expand Down Expand Up @@ -3301,22 +3318,22 @@ static void checkBasicSSA(DominatorTree &DT, GCPtrLivenessData &Data,
#endif

static void computeLiveInValues(DominatorTree &DT, Function &F,
GCPtrLivenessData &Data) {
GCPtrLivenessData &Data, GCStrategy *GC) {
SmallSetVector<BasicBlock *, 32> Worklist;

// Seed the liveness for each individual block
for (BasicBlock &BB : F) {
Data.KillSet[&BB] = computeKillSet(&BB);
Data.KillSet[&BB] = computeKillSet(&BB, GC);
Data.LiveSet[&BB].clear();
computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB]);
computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB], GC);

#ifndef NDEBUG
for (Value *Kill : Data.KillSet[&BB])
assert(!Data.LiveSet[&BB].count(Kill) && "live set contains kill");
#endif

Data.LiveOut[&BB] = SetVector<Value *>();
computeLiveOutSeed(&BB, Data.LiveOut[&BB]);
computeLiveOutSeed(&BB, Data.LiveOut[&BB], GC);
Data.LiveIn[&BB] = Data.LiveSet[&BB];
Data.LiveIn[&BB].set_union(Data.LiveOut[&BB]);
Data.LiveIn[&BB].set_subtract(Data.KillSet[&BB]);
Expand Down Expand Up @@ -3368,7 +3385,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F,
}

static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data,
StatepointLiveSetTy &Out) {
StatepointLiveSetTy &Out, GCStrategy *GC) {
BasicBlock *BB = Inst->getParent();

// Note: The copy is intentional and required
Expand All @@ -3379,18 +3396,19 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data,
// call result is not live (normal), nor are it's arguments
// (unless they're used again later). This adjustment is
// specifically what we need to relocate
computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(),
LiveOut);
computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(), LiveOut,
GC);
LiveOut.remove(Inst);
Out.insert(LiveOut.begin(), LiveOut.end());
}

static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData,
CallBase *Call,
PartiallyConstructedSafepointRecord &Info,
PointerToBaseTy &PointerToBase) {
PointerToBaseTy &PointerToBase,
GCStrategy *GC) {
StatepointLiveSetTy Updated;
findLiveSetAtInst(Call, RevisedLivenessData, Updated);
findLiveSetAtInst(Call, RevisedLivenessData, Updated, GC);

// We may have base pointers which are now live that weren't before. We need
// to update the PointerToBase structure to reflect this.
Expand Down