Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions llvm/include/llvm/Transforms/Utils/PredicateInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class PredicateBase {

/// Fetch condition in the form of PredicateConstraint, if possible.
LLVM_ABI std::optional<PredicateConstraint> getConstraint() const;
/// Fetch condition in the form of a ConstantRange, if possible.
LLVM_ABI std::optional<ConstantRange> getRangeConstraint() const;

protected:
PredicateBase(PredicateType PT, Value *Op, Value *Condition)
Expand Down Expand Up @@ -157,18 +159,22 @@ class PredicateBranch : public PredicateWithEdge {

class PredicateSwitch : public PredicateWithEdge {
public:
Value *CaseValue;
// This is the switch instruction.
SwitchInst *Switch;
using CaseValuesVec = SmallVector<ConstantInt *, 2>;
CaseValuesVec CaseValues;
bool IsDefault;
PredicateSwitch(Value *Op, BasicBlock *SwitchBB, BasicBlock *TargetBB,
Value *CaseValue, SwitchInst *SI)
ArrayRef<ConstantInt *> CaseValues, SwitchInst *SI,
bool IsDefault)
: PredicateWithEdge(PT_Switch, Op, SwitchBB, TargetBB,
SI->getCondition()),
CaseValue(CaseValue), Switch(SI) {}
CaseValues(CaseValues), IsDefault(IsDefault) {}
PredicateSwitch(Value *Op, BasicBlock *SwitchBB, BasicBlock *TargetBB,
CaseValuesVec &&CaseValues, SwitchInst *SI, bool IsDefault)
: PredicateWithEdge(PT_Switch, Op, SwitchBB, TargetBB,
SI->getCondition()),
CaseValues(CaseValues), IsDefault(IsDefault) {}
PredicateSwitch() = delete;
static bool classof(const PredicateBase *PB) {
return PB->Type == PT_Switch;
}
static bool classof(const PredicateBase *PB) { return PB->Type == PT_Switch; }
};

/// Encapsulates PredicateInfo, including all data associated with memory
Expand Down
123 changes: 107 additions & 16 deletions llvm/lib/Transforms/Utils/PredicateInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/IR/AssemblyAnnotationWriter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
Expand Down Expand Up @@ -442,6 +444,7 @@ void PredicateInfoBuilder::processBranch(
}
}
}

// Process a block terminating switch, and place relevant operations to be
// renamed into OpsToRename.
void PredicateInfoBuilder::processSwitch(
Expand All @@ -450,21 +453,41 @@ void PredicateInfoBuilder::processSwitch(
Value *Op = SI->getCondition();
if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse())
return;
using CaseValuesVec = PredicateSwitch::CaseValuesVec;

BasicBlock *DefaultDest = SI->getDefaultDest();
// Remember all cases for PT_Switch related to the default dest.
CaseValuesVec AllCases;
AllCases.reserve(SI->getNumCases());

// Remember how many outgoing edges there are to every successor.
SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges;
for (BasicBlock *TargetBlock : successors(BranchBB))
++SwitchEdges[TargetBlock];
// For each successor, remember all its related case values.
SmallDenseMap<BasicBlock *, CaseValuesVec, 16> SwitchEdges;

// Now propagate info for each case value
for (auto C : SI->cases()) {
BasicBlock *TargetBlock = C.getCaseSuccessor();
if (SwitchEdges.lookup(TargetBlock) == 1) {
PredicateSwitch *PS = new (Allocator) PredicateSwitch(
Op, SI->getParent(), TargetBlock, C.getCaseValue(), SI);
addInfoFor(OpsToRename, Op, PS);
}
/// TODO: Replace this if with an assertion if we can guarantee that
/// this function must be called after SimplifyCFG, as a canonical switch
/// should not have case dest being the default dest.
if (TargetBlock == DefaultDest)
continue;
// Only collect real case values
ConstantInt *CaseValue = C.getCaseValue();
AllCases.push_back(CaseValue);
SwitchEdges[TargetBlock].push_back(CaseValue);
}

// Now propagate info for each case successor
for (auto *CaseSucc : SwitchEdges.keys()) {
auto &CaseValues = SwitchEdges.at(CaseSucc);
PredicateSwitch *PS = new (Allocator) PredicateSwitch(
Op, SI->getParent(), CaseSucc, std::move(CaseValues), SI, false);
addInfoFor(OpsToRename, Op, PS);
}

// Finally, propagate info for the default case
PredicateSwitch *PS = new (Allocator) PredicateSwitch(
Op, SI->getParent(), DefaultDest, std::move(AllCases), SI, true);
addInfoFor(OpsToRename, Op, PS);
}

// Build predicate info for our function
Expand Down Expand Up @@ -500,8 +523,8 @@ void PredicateInfoBuilder::buildPredicateInfo() {
// Given the renaming stack, make all the operands currently on the stack real
// by inserting them into the IR. Return the last operation's value.
Value *PredicateInfoBuilder::materializeStack(unsigned int &Counter,
ValueDFSStack &RenameStack,
Value *OrigOp) {
ValueDFSStack &RenameStack,
Value *OrigOp) {
// Find the first thing we have to materialize
auto RevIter = RenameStack.rbegin();
for (; RevIter != RenameStack.rend(); ++RevIter)
Expand Down Expand Up @@ -601,7 +624,8 @@ void PredicateInfoBuilder::renameUses(SmallVectorImpl<Value *> &OpsToRename) {
// block, and handle it specially. We know that it goes last, and only
// dominate phi uses.
auto BlockEdge = getBlockEdge(PossibleCopy);
if (!BlockEdge.second->getSinglePredecessor()) {
// We use unique predecessor to identify the mult-cases dest in switch
if (!BlockEdge.second->getUniquePredecessor()) {
VD.LocalNum = LN_Last;
auto *DomNode = DT.getNode(BlockEdge.first);
if (DomNode) {
Expand Down Expand Up @@ -759,8 +783,63 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
// TODO: Make this an assertion once RenamedOp is fully accurate.
return std::nullopt;
}
const auto &PS = *cast<PredicateSwitch>(this);
unsigned NumCases = PS.CaseValues.size();
assert(NumCases != 0 && "PT_Switch with no cases is invalid");
// PT_Switch with >1 cases is too complex to derive a PredicateConstraint.
if (NumCases > 1)
return std::nullopt;
// If we have a single case, we can derive a predicate constraint.
return {
{PS.IsDefault ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ, PS.CaseValues[0]}};
}
llvm_unreachable("Unknown predicate type");
}

return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
std::optional<ConstantRange> PredicateBase::getRangeConstraint() const {
switch (Type) {
case PT_Assume:
case PT_Branch: {
// For PT_Assume/PT_Branch, we derive the condition constant range from
// its predicate constraint.
const std::optional<PredicateConstraint> &Constraint = getConstraint();
if (!Constraint)
return std::nullopt;
CmpInst::Predicate Pred = Constraint->Predicate;
Value *OtherOp = Constraint->OtherOp;
const APInt *IntOp;
// If the other operand is not a constant integer, we can't derive a
// constant range.
if (!match(OtherOp, m_APInt(IntOp)))
return std::nullopt;
return {ConstantRange::makeExactICmpRegion(Pred, *IntOp)};
}
case PT_Switch:
// For PT_Switch, we directly derive the constant range from its case
// values.
if (Condition != RenamedOp) {
// TODO: Make this an assertion once RenamedOp is fully accurate.
return std::nullopt;
}

const auto &PS = *cast<PredicateSwitch>(this);
assert(!PS.CaseValues.empty() && "SwitchInfo with no cases is invalid");

unsigned BitWidth = PS.Condition->getType()->getScalarSizeInBits();

// For case values, CR = emptyset ∪ {case1, case2,..., caseN}
// For default, CR = fullset ∩ ~{case1} ∩ ~{case2} ∩ ... ∩ ~{caseN}
bool IsDefault = PS.IsDefault;
ConstantRange CR = IsDefault ? ConstantRange::getFull(BitWidth)
: ConstantRange::getEmpty(BitWidth);
for (ConstantInt *Case : PS.CaseValues) {
assert(Case && "CaseValue in switch should not be null");
CR = IsDefault
? CR.intersectWith(ConstantRange(Case->getValue()).inverse())
: CR.unionWith(Case->getValue());
}

return {CR};
}
llvm_unreachable("Unknown predicate type");
}
Expand Down Expand Up @@ -818,8 +897,20 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
PB->To->printAsOperand(OS);
OS << "]";
} else if (const auto *PS = dyn_cast<PredicateSwitch>(PI)) {
OS << "; switch predicate info { CaseValue: " << *PS->CaseValue
<< " Edge: [";
OS << "; switch predicate info { ";
if (PS->IsDefault) {
OS << "Case: default";
} else if (PS->CaseValues.size() == 1) {
OS << "CaseValue: " << *PS->CaseValues[0];
} else {
auto CaseValues =
llvm::map_range(PS->CaseValues, [](ConstantInt *Case) {
return std::to_string(Case->getSExtValue());
});
OS << "CaseValues: " << *PS->Condition->getType() << " [ "
<< join(CaseValues, ", ") << " ]";
}
OS << " Edge: [";
PS->From->printAsOperand(OS);
OS << ",";
PS->To->printAsOperand(OS);
Expand Down
83 changes: 54 additions & 29 deletions llvm/lib/Transforms/Utils/SCCPSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2021,33 +2021,13 @@ void SCCPInstVisitor::handleCallArguments(CallBase &CB) {

void SCCPInstVisitor::handlePredicate(Instruction *I, Value *CopyOf,
const PredicateBase *PI) {
const std::optional<ConstantRange> &RangeConstraint =
PI->getRangeConstraint();
ValueLatticeElement CopyOfVal = getValueState(CopyOf);
const std::optional<PredicateConstraint> &Constraint = PI->getConstraint();
if (!Constraint) {
mergeInValue(ValueState[I], I, CopyOfVal);
return;
}

CmpInst::Predicate Pred = Constraint->Predicate;
Value *OtherOp = Constraint->OtherOp;

// Wait until OtherOp is resolved.
if (getValueState(OtherOp).isUnknown()) {
addAdditionalUser(OtherOp, I);
return;
}

ValueLatticeElement CondVal = getValueState(OtherOp);
ValueLatticeElement &IV = ValueState[I];
if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) {
auto ImposedCR =
ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType()));

// Get the range imposed by the condition.
if (CondVal.isConstantRange())
ImposedCR = ConstantRange::makeAllowedICmpRegion(
Pred, CondVal.getConstantRange());

auto MergeInValueWithImposedCR = [this, I, CopyOfVal,
CopyOf](ValueLatticeElement &IV,
ConstantRange ImposedCR) {
// Combine range info for the original value with the new range from the
// condition.
auto CopyOfCR = CopyOfVal.asConstantRange(CopyOf->getType(),
Expand All @@ -2067,18 +2047,63 @@ void SCCPInstVisitor::handlePredicate(Instruction *I, Value *CopyOf,
// unless we have conditions that are always true/false (e.g. icmp ule
// i32, %a, i32_max). For the latter overdefined/empty range will be
// inferred, but the branch will get folded accordingly anyways.
addAdditionalUser(OtherOp, I);
mergeInValue(
IV, I, ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef*/ false));
};

if (RangeConstraint) {
// If we can derive a constant range directly from the predicate info,
// simply merge it into the lattice value.
// In such case, the relevant operands must be constants, and thus we do not
// need addAdditionalUser for such operands.
MergeInValueWithImposedCR(ValueState[I], *RangeConstraint);
return;
}

// If we can't simply get the constant range directly from the predicate info,
// then fallback to PredicateConstraint and let SCCPSolver resolve the
// possible Imposed CR.

const std::optional<PredicateConstraint> &Constraint = PI->getConstraint();
if (!Constraint) {
mergeInValue(ValueState[I], I, CopyOfVal);
return;
}

CmpInst::Predicate Pred = Constraint->Predicate;
Value *OtherOp = Constraint->OtherOp;

// Wait until OtherOp is resolved.
if (getValueState(OtherOp).isUnknown()) {
addAdditionalUser(OtherOp, I);
return;
} else if (Pred == CmpInst::ICMP_EQ &&
(CondVal.isConstant() || CondVal.isNotConstant())) {
}

ValueLatticeElement CondVal = getValueState(OtherOp);
ValueLatticeElement &IV = ValueState[I];
if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) {
// Get the range imposed by the condition.
auto ImposedCR =
CondVal.isConstantRange()
? ConstantRange::makeAllowedICmpRegion(Pred,
CondVal.getConstantRange())
: ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType()));

addAdditionalUser(OtherOp, I);
MergeInValueWithImposedCR(IV, ImposedCR);
return;
}

if (Pred == CmpInst::ICMP_EQ &&
(CondVal.isConstant() || CondVal.isNotConstant())) {
// For non-integer values or integer constant expressions, only
// propagate equal constants or not-constants.
addAdditionalUser(OtherOp, I);
mergeInValue(IV, I, CondVal);
return;
} else if (Pred == CmpInst::ICMP_NE && CondVal.isConstant()) {
}

if (Pred == CmpInst::ICMP_NE && CondVal.isConstant()) {
// Propagate inequalities.
addAdditionalUser(OtherOp, I);
mergeInValue(IV, I, ValueLatticeElement::getNot(CondVal.getConstant()));
Expand Down
Loading