Skip to content

Commit

Permalink
[AggressiveInstCombine] Add phi nodes support to TruncInstCombine
Browse files Browse the repository at this point in the history
Expand `TruncInstCombine` to handle loops by adding `phi` nodes
to expression graph.

Reviewed by: RKSimon, lebedev.ri

(recommit of fixed f84d732, reverted by 8ad6d5e after sanitizer breakage)

Differential Revision: https://reviews.llvm.org/D109817
  • Loading branch information
anton-afanasyev committed Feb 25, 2022
1 parent c2f501f commit 0dd8401
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 55 deletions.
Expand Up @@ -23,14 +23,14 @@
using namespace llvm;

//===----------------------------------------------------------------------===//
// TruncInstCombine - looks for expression dags dominated by trunc instructions
// and for each eligible dag, it will create a reduced bit-width expression and
// replace the old expression with this new one and remove the old one.
// Eligible expression dag is such that:
// TruncInstCombine - looks for expression graphs dominated by trunc
// instructions and for each eligible graph, it will create a reduced bit-width
// expression and replace the old expression with this new one and remove the
// old one. Eligible expression graph is such that:
// 1. Contains only supported instructions.
// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
// 3. Can be evaluated into type with reduced legal bit-width (or Trunc type).
// 4. All instructions in the dag must not have users outside the dag.
// 4. All instructions in the graph must not have users outside the graph.
// Only exception is for {ZExt, SExt}Inst with operand type equal to the
// new reduced type chosen in (3).
//
Expand Down Expand Up @@ -63,7 +63,7 @@ class TruncInstCombine {
/// Current processed TruncInst instruction.
TruncInst *CurrentTruncInst = nullptr;

/// Information per each instruction in the expression dag.
/// Information per each instruction in the expression graph.
struct Info {
/// Number of LSBs that are needed to generate a valid expression.
unsigned ValidBitWidth = 0;
Expand All @@ -72,10 +72,10 @@ class TruncInstCombine {
/// The reduced value generated to replace the old instruction.
Value *NewValue = nullptr;
};
/// An ordered map representing expression dag post-dominated by current
/// processed TruncInst. It maps each instruction in the dag to its Info
/// An ordered map representing expression graph post-dominated by current
/// processed TruncInst. It maps each instruction in the graph to its Info
/// structure. The map is ordered such that each instruction appears before
/// all other instructions in the dag that uses it.
/// all other instructions in the graph that uses it.
MapVector<Instruction *, Info> InstInfoMap;

public:
Expand All @@ -87,11 +87,11 @@ class TruncInstCombine {
bool run(Function &F);

private:
/// Build expression dag dominated by the /p CurrentTruncInst and append it to
/// the InstInfoMap container.
/// Build expression graph dominated by the /p CurrentTruncInst and append it
/// to the InstInfoMap container.
///
/// \return true only if succeed to generate an eligible sub expression dag.
bool buildTruncExpressionDag();
/// \return true only if succeed to generate an eligible sub expression graph.
bool buildTruncExpressionGraph();

/// Calculate the minimal allowed bit-width of the chain ending with the
/// currently visited truncate's operand.
Expand All @@ -100,12 +100,12 @@ class TruncInstCombine {
/// truncate's operand can be shrunk to.
unsigned getMinBitWidth();

/// Build an expression dag dominated by the current processed TruncInst and
/// Build an expression graph dominated by the current processed TruncInst and
/// Check if it is eligible to be reduced to a smaller type.
///
/// \return the scalar version of the new type to be used for the reduced
/// expression dag, or nullptr if the expression dag is not eligible
/// to be reduced.
/// expression graph, or nullptr if the expression graph is not
/// eligible to be reduced.
Type *getBestTruncatedType();

KnownBits computeKnownBits(const Value *V) const {
Expand All @@ -128,12 +128,12 @@ class TruncInstCombine {
/// \return the new reduced value.
Value *getReducedOperand(Value *V, Type *SclTy);

/// Create a new expression dag using the reduced /p SclTy type and replace
/// the old expression dag with it. Also erase all instructions in the old
/// dag, except those that are still needed outside the dag.
/// Create a new expression graph using the reduced /p SclTy type and replace
/// the old expression graph with it. Also erase all instructions in the old
/// graph, except those that are still needed outside the graph.
///
/// \param SclTy scalar version of new type to reduce expression dag into.
void ReduceExpressionDag(Type *SclTy);
/// \param SclTy scalar version of new type to reduce expression graph into.
void ReduceExpressionGraph(Type *SclTy);
};
} // end namespace llvm.

Expand Down
87 changes: 64 additions & 23 deletions llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
Expand Up @@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
//
// TruncInstCombine - looks for expression dags post-dominated by TruncInst and
// for each eligible dag, it will create a reduced bit-width expression, replace
// the old expression with this new one and remove the old expression.
// Eligible expression dag is such that:
// TruncInstCombine - looks for expression graphs post-dominated by TruncInst
// and for each eligible graph, it will create a reduced bit-width expression,
// replace the old expression with this new one and remove the old expression.
// Eligible expression graph is such that:
// 1. Contains only supported instructions.
// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
// 3. Can be evaluated into type with reduced legal bit-width.
// 4. All instructions in the dag must not have users outside the dag.
// 4. All instructions in the graph must not have users outside the graph.
// The only exception is for {ZExt, SExt}Inst with operand type equal to
// the new reduced type evaluated in (3).
//
Expand All @@ -39,14 +39,13 @@ using namespace llvm;

#define DEBUG_TYPE "aggressive-instcombine"

STATISTIC(
NumDAGsReduced,
"Number of truncations eliminated by reducing bit width of expression DAG");
STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "
"width of expression graph");
STATISTIC(NumInstrsReduced,
"Number of instructions whose bit width was reduced");

/// Given an instruction and a container, it fills all the relevant operands of
/// that instruction, with respect to the Trunc expression dag optimizaton.
/// that instruction, with respect to the Trunc expression graph optimizaton.
static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
unsigned Opc = I->getOpcode();
switch (Opc) {
Expand Down Expand Up @@ -78,15 +77,19 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
Ops.push_back(I->getOperand(1));
Ops.push_back(I->getOperand(2));
break;
case Instruction::PHI:
for (Value *V : cast<PHINode>(I)->incoming_values())
Ops.push_back(V);
break;
default:
llvm_unreachable("Unreachable!");
}
}

bool TruncInstCombine::buildTruncExpressionDag() {
bool TruncInstCombine::buildTruncExpressionGraph() {
SmallVector<Value *, 8> Worklist;
SmallVector<Instruction *, 8> Stack;
// Clear old expression dag.
// Clear old instructions info.
InstInfoMap.clear();

Worklist.push_back(CurrentTruncInst->getOperand(0));
Expand Down Expand Up @@ -150,11 +153,19 @@ bool TruncInstCombine::buildTruncExpressionDag() {
append_range(Worklist, Operands);
break;
}
case Instruction::PHI: {
SmallVector<Value *, 2> Operands;
getRelevantOperands(I, Operands);
// Add only operands not in Stack to prevent cycle
for (auto *Op : Operands)
if (all_of(Stack, [Op](Value *V) { return Op != V; }))
Worklist.push_back(Op);
break;
}
default:
// TODO: Can handle more cases here:
// 1. shufflevector
// 2. sdiv, srem
// 3. phi node(and loop handling)
// ...
return false;
}
Expand Down Expand Up @@ -254,7 +265,7 @@ unsigned TruncInstCombine::getMinBitWidth() {
}

Type *TruncInstCombine::getBestTruncatedType() {
if (!buildTruncExpressionDag())
if (!buildTruncExpressionGraph())
return nullptr;

// We don't want to duplicate instructions, which isn't profitable. Thus, we
Expand Down Expand Up @@ -367,8 +378,10 @@ Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
return Entry.NewValue;
}

void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {
NumInstrsReduced += InstInfoMap.size();
// Pairs of old and new phi-nodes
SmallVector<std::pair<PHINode *, PHINode *>, 2> OldNewPHINodes;
for (auto &Itr : InstInfoMap) { // Forward
Instruction *I = Itr.first;
TruncInstCombine::Info &NodeInfo = Itr.second;
Expand Down Expand Up @@ -451,6 +464,12 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
Res = Builder.CreateSelect(Op0, LHS, RHS);
break;
}
case Instruction::PHI: {
Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands());
OldNewPHINodes.push_back(
std::make_pair(cast<PHINode>(I), cast<PHINode>(Res)));
break;
}
default:
llvm_unreachable("Unhandled instruction");
}
Expand All @@ -460,6 +479,14 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
ResI->takeName(I);
}

for (auto &Node : OldNewPHINodes) {
PHINode *OldPN = Node.first;
PHINode *NewPN = Node.second;
for (auto Incoming : zip(OldPN->incoming_values(), OldPN->blocks()))
NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),
std::get<1>(Incoming));
}

Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);
Type *DstTy = CurrentTruncInst->getType();
if (Res->getType() != DstTy) {
Expand All @@ -470,17 +497,31 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
}
CurrentTruncInst->replaceAllUsesWith(Res);

// Erase old expression dag, which was replaced by the reduced expression dag.
// We iterate backward, which means we visit the instruction before we visit
// any of its operands, this way, when we get to the operand, we already
// removed the instructions (from the expression dag) that uses it.
// Erase old expression graph, which was replaced by the reduced expression
// graph.
CurrentTruncInst->eraseFromParent();
// First, erase old phi-nodes and its uses
for (auto &Node : OldNewPHINodes) {
PHINode *OldPN = Node.first;
OldPN->replaceAllUsesWith(PoisonValue::get(OldPN->getType()));
OldPN->eraseFromParent();
}
// Now we have expression graph turned into dag.
// We iterate backward, which means we visit the instruction before we
// visit any of its operands, this way, when we get to the operand, we already
// removed the instructions (from the expression dag) that uses it.
for (auto &I : llvm::reverse(InstInfoMap)) {
// Skip phi-nodes since they were erased before
if (isa<PHINode>(I.first))
continue;
// We still need to check that the instruction has no users before we erase
// it, because {SExt, ZExt}Inst Instruction might have other users that was
// not reduced, in such case, we need to keep that instruction.
if (I.first->use_empty())
I.first->eraseFromParent();
else
assert((isa<SExtInst>(I.first) || isa<ZExtInst>(I.first)) &&
"Only {SExt, ZExt}Inst might have unreduced users");
}
}

Expand All @@ -498,18 +539,18 @@ bool TruncInstCombine::run(Function &F) {
}

// Process all TruncInst in the Worklist, for each instruction:
// 1. Check if it dominates an eligible expression dag to be reduced.
// 2. Create a reduced expression dag and replace the old one with it.
// 1. Check if it dominates an eligible expression graph to be reduced.
// 2. Create a reduced expression graph and replace the old one with it.
while (!Worklist.empty()) {
CurrentTruncInst = Worklist.pop_back_val();

if (Type *NewDstSclTy = getBestTruncatedType()) {
LLVM_DEBUG(
dbgs() << "ICE: TruncInstCombine reducing type of expression dag "
dbgs() << "ICE: TruncInstCombine reducing type of expression graph "
"dominated by: "
<< CurrentTruncInst << '\n');
ReduceExpressionDag(NewDstSclTy);
++NumDAGsReduced;
ReduceExpressionGraph(NewDstSclTy);
++NumExprsReduced;
MadeIRChange = true;
}
}
Expand Down
20 changes: 9 additions & 11 deletions llvm/test/Transforms/AggressiveInstCombine/trunc_phi.ll
Expand Up @@ -4,18 +4,17 @@
define i16 @trunc_phi(i8 %x) {
; CHECK-LABEL: @trunc_phi(
; CHECK-NEXT: LoopHeader:
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: Loop:
; CHECK-NEXT: [[ZEXT2:%.*]] = phi i32 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[SHL:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[ZEXT2:%.*]] = phi i16 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[SHL:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[J:%.*]] = phi i32 [ 0, [[LOOPHEADER]] ], [ [[I:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[SHL]] = shl i32 [[ZEXT2]], 1
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16
; CHECK-NEXT: [[SHL]] = shl i16 [[ZEXT2]], 1
; CHECK-NEXT: [[I]] = add i32 [[J]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[I]], 10
; CHECK-NEXT: br i1 [[CMP]], label [[LOOPEND:%.*]], label [[LOOP]]
; CHECK: LoopEnd:
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: ret i16 [[SHL]]
;
LoopHeader:
%zext = zext i8 %x to i32
Expand All @@ -37,22 +36,21 @@ LoopEnd:
define i16 @trunc_phi2(i8 %x, i32 %sw) {
; CHECK-LABEL: @trunc_phi2(
; CHECK-NEXT: LoopHeader:
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
; CHECK-NEXT: switch i32 [[SW:%.*]], label [[LOOPEND:%.*]] [
; CHECK-NEXT: i32 0, label [[LOOP:%.*]]
; CHECK-NEXT: i32 1, label [[LOOP]]
; CHECK-NEXT: ]
; CHECK: Loop:
; CHECK-NEXT: [[ZEXT2:%.*]] = phi i32 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[ZEXT]], [[LOOPHEADER]] ], [ [[SHL:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[ZEXT2:%.*]] = phi i16 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[ZEXT]], [[LOOPHEADER]] ], [ [[SHL:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[J:%.*]] = phi i32 [ 0, [[LOOPHEADER]] ], [ 0, [[LOOPHEADER]] ], [ [[I:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[SHL]] = shl i32 [[ZEXT2]], 1
; CHECK-NEXT: [[SHL]] = shl i16 [[ZEXT2]], 1
; CHECK-NEXT: [[I]] = add i32 [[J]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[I]], 10
; CHECK-NEXT: br i1 [[CMP]], label [[LOOPEND]], label [[LOOP]]
; CHECK: LoopEnd:
; CHECK-NEXT: [[ZEXT3:%.*]] = phi i32 [ [[ZEXT]], [[LOOPHEADER]] ], [ [[ZEXT2]], [[LOOP]] ]
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ZEXT3]] to i16
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: [[ZEXT3:%.*]] = phi i16 [ [[ZEXT]], [[LOOPHEADER]] ], [ [[ZEXT2]], [[LOOP]] ]
; CHECK-NEXT: ret i16 [[ZEXT3]]
;
LoopHeader:
%zext = zext i8 %x to i32
Expand Down

0 comments on commit 0dd8401

Please sign in to comment.