Skip to content

Commit

Permalink
Reassociate: add global reassociation algorithm (#6598)
Browse files Browse the repository at this point in the history
This PR pulls the upstream change, Reassociate: add global reassociation
algorithm
(llvm/llvm-project@b8a330c),
into DXC with miminal changes.

For the code below:
  foo = (a * b) * c
  bar = (a * d) * c

As the upstream change states, it can identify the a*c is a common
factor and redundant.

This is part 1 of the fix for #6593.

(cherry picked from commit 6f9c107)
  • Loading branch information
lizhengxing committed May 21, 2024
1 parent fd7e54b commit 26ba5bc
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 2 deletions.
124 changes: 122 additions & 2 deletions lib/Transforms/Scalar/Reassociate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
Expand All @@ -37,6 +37,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
using namespace llvm;
Expand Down Expand Up @@ -161,6 +162,13 @@ namespace {
DenseMap<BasicBlock*, unsigned> RankMap;
DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
SetVector<AssertingVH<Instruction> > RedoInsts;

// Arbitrary, but prevents quadratic behavior.
static const unsigned GlobalReassociateLimit = 10;
static const unsigned NumBinaryOps =
Instruction::BinaryOpsEnd - Instruction::BinaryOpsBegin;
DenseMap<std::pair<Value *, Value *>, unsigned> PairMap[NumBinaryOps];

bool MadeChange;
public:
static char ID; // Pass identification, replacement for typeid
Expand Down Expand Up @@ -196,6 +204,7 @@ namespace {
void EraseInst(Instruction *I);
void OptimizeInst(Instruction *I);
Instruction *canonicalizeNegConstExpr(Instruction *I);
void BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT);
};
}

Expand Down Expand Up @@ -2234,18 +2243,127 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) {
return;
}

if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) {
// Find the pair with the highest count in the pairmap and move it to the
// back of the list so that it can later be CSE'd.
// example:
// a*b*c*d*e
// if c*e is the most "popular" pair, we can express this as
// (((c*e)*d)*b)*a
unsigned Max = 1;
unsigned BestRank = 0;
std::pair<unsigned, unsigned> BestPair;
unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin;
for (unsigned i = 0; i < Ops.size() - 1; ++i)
for (unsigned j = i + 1; j < Ops.size(); ++j) {
unsigned Score = 0;
Value *Op0 = Ops[i].Op;
Value *Op1 = Ops[j].Op;
if (std::less<Value *>()(Op1, Op0))
std::swap(Op0, Op1);
auto it = PairMap[Idx].find({Op0, Op1});
if (it != PairMap[Idx].end())
Score += it->second;

unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank);
if (Score > Max || (Score == Max && MaxRank < BestRank)) {
BestPair = {i, j};
Max = Score;
BestRank = MaxRank;
}
}
if (Max > 1) {
auto Op0 = Ops[BestPair.first];
auto Op1 = Ops[BestPair.second];
Ops.erase(&Ops[BestPair.second]);
Ops.erase(&Ops[BestPair.first]);
Ops.push_back(Op0);
Ops.push_back(Op1);
}
}
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
RewriteExprTree(I, Ops);
}

void Reassociate::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
// Make a "pairmap" of how often each operand pair occurs.
for (BasicBlock *BI : RPOT) {
for (Instruction &I : *BI) {
if (!I.isAssociative())
continue;

// Ignore nodes that aren't at the root of trees.
if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode())
continue;

// Collect all operands in a single reassociable expression.
// Since Reassociate has already been run once, we can assume things
// are already canonical according to Reassociation's regime.
SmallVector<Value *, 8> Worklist = {I.getOperand(0), I.getOperand(1)};
SmallVector<Value *, 8> Ops;
while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) {
Value *Op = Worklist.pop_back_val();
Instruction *OpI = dyn_cast<Instruction>(Op);
if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) {
Ops.push_back(Op);
continue;
}
// Be paranoid about self-referencing expressions in unreachable code.
if (OpI->getOperand(0) != OpI)
Worklist.push_back(OpI->getOperand(0));
if (OpI->getOperand(1) != OpI)
Worklist.push_back(OpI->getOperand(1));
}
// Skip extremely long expressions.
if (Ops.size() > GlobalReassociateLimit)
continue;

// Add all pairwise combinations of operands to the pair map.
unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin;
SmallSet<std::pair<Value *, Value *>, 32> Visited;
for (unsigned i = 0; i < Ops.size() - 1; ++i) {
for (unsigned j = i + 1; j < Ops.size(); ++j) {
// Canonicalize operand orderings.
Value *Op0 = Ops[i];
Value *Op1 = Ops[j];
if (std::less<Value *>()(Op1, Op0))
std::swap(Op0, Op1);
if (!Visited.insert({Op0, Op1}).second)
continue;
auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1});
if (!res.second)
++res.first->second;
}
}
}
}
}

bool Reassociate::runOnFunction(Function &F) {
if (skipOptnoneFunction(F))
return false;

// Calculate the rank map for F
BuildRankMap(F);

// Build the pair map before running reassociate.
// Technically this would be more accurate if we did it after one round
// of reassociation, but in practice it doesn't seem to help much on
// real-world code, so don't waste the compile time running reassociate
// twice.
// If a user wants, they could expicitly run reassociate twice in their
// pass pipeline for further potential gains.
// It might also be possible to update the pair map during runtime, but the
// overhead of that may be large if there's many reassociable chains.
// TODO: RPOT
// Get the functions basic blocks in Reverse Post Order. This order is used by
// BuildRankMap to pre calculate ranks correctly. It also excludes dead basic
// blocks (it has been seen that the analysis in this pass could hang when
// analysing dead basic blocks).
ReversePostOrderTraversal<Function *> RPOT(&F);
BuildPairMap(RPOT);

MadeChange = false;
for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
// Optimize every instruction in the basic block.
Expand All @@ -2268,9 +2386,11 @@ bool Reassociate::runOnFunction(Function &F) {
}
}

// We are done with the rank map.
// We are done with the rank map and pair map.
RankMap.clear();
ValueRankMap.clear();
for (auto &Entry : PairMap)
Entry.clear();

return MadeChange;
}
15 changes: 15 additions & 0 deletions test/Transforms/Reassociate/basictest.ll
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,18 @@ define i32 @test15(i32 %X1, i32 %X2, i32 %X3) {
; CHECK-LABEL: @test15
; CHECK: and i1 %A, %B
}

; CHECK-LABEL: @test17
; CHECK: %[[A:.*]] = mul i32 %X4, %X3
; CHECK-NEXT: %[[C:.*]] = mul i32 %[[A]], %X1
; CHECK-NEXT: %[[D:.*]] = mul i32 %[[A]], %X2
; CHECK-NEXT: %[[E:.*]] = xor i32 %[[C]], %[[D]]
; CHECK-NEXT: ret i32 %[[E]]
define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
%A = mul i32 %X3, %X1
%B = mul i32 %X3, %X2
%C = mul i32 %A, %X4
%D = mul i32 %B, %X4
%E = xor i32 %C, %D
ret i32 %E
}

0 comments on commit 26ba5bc

Please sign in to comment.