202 changes: 157 additions & 45 deletions llvm/lib/Transforms/IPO/MergeFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/IR/ValueMap.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand All @@ -130,14 +131,50 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck(

namespace {

/// GlobalNumberState assigns an integer to each global value in the program,
/// which is used by the comparison routine to order references to globals. This
/// state must be preserved throughout the pass, because Functions and other
/// globals need to maintain their relative order. Globals are assigned a number
/// when they are first visited. This order is deterministic, and so the
/// assigned numbers are as well. When two functions are merged, neither number
/// is updated. If the symbols are weak, this would be incorrect. If they are
/// strong, then one will be replaced at all references to the other, and so
/// direct callsites will now see one or the other symbol, and no update is
/// necessary. Note that if we were guaranteed unique names, we could just
/// compare those, but this would not work for stripped bitcodes or for those
/// few symbols without a name.
class GlobalNumberState {
struct Config : ValueMapConfig<GlobalValue*> {
enum { FollowRAUW = false };
};
// Each GlobalValue is mapped to an identifier. The Config ensures when RAUW
// occurs, the mapping does not change. Tracking changes is unnecessary, and
// also problematic for weak symbols (which may be overwritten).
typedef ValueMap<GlobalValue *, uint64_t, Config> ValueNumberMap;
ValueNumberMap GlobalNumbers;
// The next unused serial number to assign to a global.
uint64_t NextNumber;
public:
GlobalNumberState() : GlobalNumbers(), NextNumber(0) {}
uint64_t getNumber(GlobalValue* Global) {
ValueNumberMap::iterator MapIter;
bool Inserted;
std::tie(MapIter, Inserted) = GlobalNumbers.insert({Global, NextNumber});
if (Inserted)
NextNumber++;
return MapIter->second;
}
};

/// FunctionComparator - Compares two functions to determine whether or not
/// they will generate machine code with the same behaviour. DataLayout is
/// used if available. The comparator always fails conservatively (erring on the
/// side of claiming that two functions are different).
class FunctionComparator {
public:
FunctionComparator(const Function *F1, const Function *F2)
: FnL(F1), FnR(F2) {}
FunctionComparator(const Function *F1, const Function *F2,
GlobalNumberState* GN)
: FnL(F1), FnR(F2), GlobalNumbers(GN) {}

/// Test whether the two functions have equivalent behaviour.
int compare();
Expand All @@ -148,7 +185,7 @@ class FunctionComparator {

private:
/// Test whether two basic blocks have equivalent behaviour.
int compare(const BasicBlock *BBL, const BasicBlock *BBR);
int cmpBasicBlocks(const BasicBlock *BBL, const BasicBlock *BBR);

/// Constants comparison.
/// Its analog to lexicographical comparison between hypothetical numbers
Expand Down Expand Up @@ -254,6 +291,10 @@ class FunctionComparator {
/// If these properties are equal - compare their contents.
int cmpConstants(const Constant *L, const Constant *R);

/// Compares two global values by number. Uses the GlobalNumbersState to
/// identify the same gobals across function calls.
int cmpGlobalValues(GlobalValue *L, GlobalValue *R);

/// Assign or look up previously assigned numbers for the two values, and
/// return whether the numbers are equal. Numbers are assigned in the order
/// visited.
Expand Down Expand Up @@ -333,8 +374,9 @@ class FunctionComparator {
///
/// 1. If types are of different kind (different type IDs).
/// Return result of type IDs comparison, treating them as numbers.
/// 2. If types are vectors or integers, compare Type* values as numbers.
/// 3. Types has same ID, so check whether they belongs to the next group:
/// 2. If types are integers, check that they have the same width. If they
/// are vectors, check that they have the same count and subtype.
/// 3. Types have the same ID, so check whether they are one of:
/// * Void
/// * Float
/// * Double
Expand All @@ -343,8 +385,7 @@ class FunctionComparator {
/// * PPC_FP128
/// * Label
/// * Metadata
/// If so - return 0, yes - we can treat these types as equal only because
/// their IDs are same.
/// We can treat these types as equal whenever their IDs are same.
/// 4. If Left and Right are pointers, return result of address space
/// comparison (numbers comparison). We can treat pointer types of same
/// address space as equal.
Expand All @@ -359,7 +400,8 @@ class FunctionComparator {

int cmpAPInts(const APInt &L, const APInt &R) const;
int cmpAPFloats(const APFloat &L, const APFloat &R) const;
int cmpStrings(StringRef L, StringRef R) const;
int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
int cmpMem(StringRef L, StringRef R) const;
int cmpAttrs(const AttributeSet L, const AttributeSet R) const;

// The two functions undergoing comparison.
Expand Down Expand Up @@ -399,33 +441,28 @@ class FunctionComparator {
/// could be operands from further BBs we didn't scan yet.
/// So it's impossible to use dominance properties in general.
DenseMap<const Value*, int> sn_mapL, sn_mapR;

// The global state we will use
GlobalNumberState* GlobalNumbers;
};

class FunctionNode {
mutable AssertingVH<Function> F;
FunctionComparator::FunctionHash Hash;

public:
// Note the hash is recalculated potentially multiple times, but it is cheap.
FunctionNode(Function *F) : F(F), Hash(FunctionComparator::functionHash(*F)){}
FunctionNode(Function *F)
: F(F), Hash(FunctionComparator::functionHash(*F)) {}
Function *getFunc() const { return F; }
FunctionComparator::FunctionHash getHash() const { return Hash; }

/// Replace the reference to the function F by the function G, assuming their
/// implementations are equal.
void replaceBy(Function *G) const {
assert(!(*this < FunctionNode(G)) && !(FunctionNode(G) < *this) &&
"The two functions must be equal");

F = G;
}

void release() { F = 0; }
bool operator<(const FunctionNode &RHS) const {
// Order first by hashes, then full function comparison.
if (Hash != RHS.Hash)
return Hash < RHS.Hash;
return (FunctionComparator(F, RHS.getFunc()).compare()) == -1;
}
};
}

Expand All @@ -444,13 +481,17 @@ int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const {
}

int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const {
if (int Res = cmpNumbers((uint64_t)&L.getSemantics(),
(uint64_t)&R.getSemantics()))
// TODO: This correctly handles all existing fltSemantics, because they all
// have different precisions. This isn't very robust, however, if new types
// with different exponent ranges are introduced.
const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics();
if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL),
APFloat::semanticsPrecision(SR)))
return Res;
return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt());
}

int FunctionComparator::cmpStrings(StringRef L, StringRef R) const {
int FunctionComparator::cmpMem(StringRef L, StringRef R) const {
// Prevent heavy comparison, compare sizes first.
if (int Res = cmpNumbers(L.size(), R.size()))
return Res;
Expand Down Expand Up @@ -556,9 +597,25 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
if (!L->isNullValue() && R->isNullValue())
return -1;

auto GlobalValueL = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(L));
auto GlobalValueR = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(R));
if (GlobalValueL && GlobalValueR) {
return cmpGlobalValues(GlobalValueL, GlobalValueR);
}

if (int Res = cmpNumbers(L->getValueID(), R->getValueID()))
return Res;

if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) {
const auto *SeqR = dyn_cast<ConstantDataSequential>(R);
// This handles ConstantDataArray and ConstantDataVector. Note that we
// compare the two raw data arrays, which might differ depending on the host
// endianness. This isn't a problem though, because the endiness of a module
// will affect the order of the constants, but this order is the same
// for a given input module and host platform.
return cmpMem(SeqL->getRawDataValues(), SeqR->getRawDataValues());
}

switch (L->getValueID()) {
case Value::UndefValueVal: return TypesRes;
case Value::ConstantIntVal: {
Expand Down Expand Up @@ -627,12 +684,21 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
}
return 0;
}
case Value::FunctionVal:
case Value::GlobalVariableVal:
case Value::GlobalAliasVal:
default: // Unknown constant, cast L and R pointers to numbers and compare.
case Value::BlockAddressVal: {
// FIXME: This still uses a pointer comparison. It isn't clear how to remove
// this. This only affects programs which take BlockAddresses and store them
// as constants, which is limited to interepreters, etc.
return cmpNumbers((uint64_t)L, (uint64_t)R);
}
default: // Unknown constant, abort.
DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n");
llvm_unreachable("Constant ValueID not recognized.");
return -1;
}
}

int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue* R) {
return cmpNumbers(GlobalNumbers->getNumber(L), GlobalNumbers->getNumber(R));
}

/// cmpType - compares two types,
Expand Down Expand Up @@ -660,10 +726,15 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const {
llvm_unreachable("Unknown type!");
// Fall through in Release mode.
case Type::IntegerTyID:
case Type::VectorTyID:
// TyL == TyR would have returned true earlier.
return cmpNumbers((uint64_t)TyL, (uint64_t)TyR);

return cmpNumbers(cast<IntegerType>(TyL)->getBitWidth(),
cast<IntegerType>(TyR)->getBitWidth());
case Type::VectorTyID: {
VectorType *VTyL = cast<VectorType>(TyL), *VTyR = cast<VectorType>(TyR);
if (int Res = cmpNumbers(VTyL->getNumElements(), VTyR->getNumElements()))
return Res;
return cmpTypes(VTyL->getElementType(), VTyR->getElementType());
}
// TyL == TyR would have returned true earlier, because types are uniqued.
case Type::VoidTyID:
case Type::FloatTyID:
case Type::DoubleTyID:
Expand Down Expand Up @@ -895,9 +966,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
GEPR->accumulateConstantOffset(DL, OffsetR))
return cmpAPInts(OffsetL, OffsetR);

if (int Res = cmpNumbers((uint64_t)GEPL->getPointerOperand()->getType(),
(uint64_t)GEPR->getPointerOperand()->getType()))
if (int Res = cmpTypes(GEPL->getPointerOperand()->getType(),
GEPR->getPointerOperand()->getType()))
return Res;

if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands()))
Expand All @@ -911,6 +981,28 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
return 0;
}

int FunctionComparator::cmpInlineAsm(const InlineAsm *L,
const InlineAsm *R) const {
// InlineAsm's are uniqued. If they are the same pointer, obviously they are
// the same, otherwise compare the fields.
if (L == R)
return 0;
if (int Res = cmpTypes(L->getFunctionType(), R->getFunctionType()))
return Res;
if (int Res = cmpMem(L->getAsmString(), R->getAsmString()))
return Res;
if (int Res = cmpMem(L->getConstraintString(), R->getConstraintString()))
return Res;
if (int Res = cmpNumbers(L->hasSideEffects(), R->hasSideEffects()))
return Res;
if (int Res = cmpNumbers(L->isAlignStack(), R->isAlignStack()))
return Res;
if (int Res = cmpNumbers(L->getDialect(), R->getDialect()))
return Res;
llvm_unreachable("InlineAsm blocks were not uniqued.");
return 0;
}

/// Compare two values used by the two functions under pair-wise comparison. If
/// this is the first time the values are seen, they're added to the mapping so
/// that we will detect mismatches on next use.
Expand Down Expand Up @@ -945,7 +1037,7 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) {
const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R);

if (InlineAsmL && InlineAsmR)
return cmpNumbers((uint64_t)L, (uint64_t)R);
return cmpInlineAsm(InlineAsmL, InlineAsmR);
if (InlineAsmL)
return 1;
if (InlineAsmR)
Expand All @@ -957,7 +1049,8 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) {
return cmpNumbers(LeftSN.first->second, RightSN.first->second);
}
// Test whether two basic blocks have equivalent behaviour.
int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) {
int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL,
const BasicBlock *BBR) {
BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end();
BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end();

Expand Down Expand Up @@ -1020,15 +1113,15 @@ int FunctionComparator::compare() {
return Res;

if (FnL->hasGC()) {
if (int Res = cmpNumbers((uint64_t)FnL->getGC(), (uint64_t)FnR->getGC()))
if (int Res = cmpMem(FnL->getGC(), FnR->getGC()))
return Res;
}

if (int Res = cmpNumbers(FnL->hasSection(), FnR->hasSection()))
return Res;

if (FnL->hasSection()) {
if (int Res = cmpStrings(FnL->getSection(), FnR->getSection()))
if (int Res = cmpMem(FnL->getSection(), FnR->getSection()))
return Res;
}

Expand Down Expand Up @@ -1074,7 +1167,7 @@ int FunctionComparator::compare() {
if (int Res = cmpValues(BBL, BBR))
return Res;

if (int Res = compare(BBL, BBR))
if (int Res = cmpBasicBlocks(BBL, BBR))
return Res;

const TerminatorInst *TermL = BBL->getTerminator();
Expand Down Expand Up @@ -1129,7 +1222,7 @@ FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) {
SmallVector<const BasicBlock *, 8> BBs;
SmallSet<const BasicBlock *, 16> VisitedBBs;

// Walk the blocks in the same order as FunctionComparator::compare(),
// Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(),
// accumulating the hash of the function "structure." (BB and opcode sequence)
BBs.push_back(&F.getEntryBlock());
VisitedBBs.insert(BBs[0]);
Expand Down Expand Up @@ -1163,14 +1256,31 @@ class MergeFunctions : public ModulePass {
public:
static char ID;
MergeFunctions()
: ModulePass(ID), HasGlobalAliases(false) {
: ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)),
HasGlobalAliases(false) {
initializeMergeFunctionsPass(*PassRegistry::getPassRegistry());
}

bool runOnModule(Module &M) override;

private:
typedef std::set<FunctionNode> FnTreeType;
// The function comparison operator is provided here so that FunctionNodes do
// not need to become larger with another pointer.
class FunctionNodeCmp {
GlobalNumberState* GlobalNumbers;
public:
FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {}
bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const {
// Order first by hashes, then full function comparison.
if (LHS.getHash() != RHS.getHash())
return LHS.getHash() < RHS.getHash();
FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers);
return FCmp.compare() == -1;
}
};
typedef std::set<FunctionNode, FunctionNodeCmp> FnTreeType;

GlobalNumberState GlobalNumbers;

/// A work queue of functions that may have been modified and should be
/// analyzed again.
Expand Down Expand Up @@ -1245,8 +1355,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
for (std::vector<WeakVH>::iterator J = I; J != E && j < Max; ++J, ++j) {
Function *F1 = cast<Function>(*I);
Function *F2 = cast<Function>(*J);
int Res1 = FunctionComparator(F1, F2).compare();
int Res2 = FunctionComparator(F2, F1).compare();
int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare();
int Res2 = FunctionComparator(F2, F1, &GlobalNumbers).compare();

// If F1 <= F2, then F2 >= F1, otherwise report failure.
if (Res1 != -Res2) {
Expand All @@ -1267,8 +1377,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
continue;

Function *F3 = cast<Function>(*K);
int Res3 = FunctionComparator(F1, F3).compare();
int Res4 = FunctionComparator(F2, F3).compare();
int Res3 = FunctionComparator(F1, F3, &GlobalNumbers).compare();
int Res4 = FunctionComparator(F2, F3, &GlobalNumbers).compare();

bool Transitive = true;

Expand Down Expand Up @@ -1556,6 +1666,8 @@ void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
(!F->mayBeOverridden() && !G->mayBeOverridden())) &&
"Only change functions if both are strong or both are weak");
(void)F;
assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
"The two functions must be equal");

IterToF->replaceBy(G);
}
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/Transforms/MergeFunc/constant-entire-value.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: opt -S -mergefunc < %s | FileCheck %s

; RUN: opt -S -mergefunc < %s | FileCheck -check-prefix=NOPLUS %s

; This makes sure that zeros in constants don't cause problems with string based
; memory comparisons
define internal i32 @sum(i32 %x, i32 %y) {
; CHECK-LABEL: @sum
%sum = add i32 %x, %y
%1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 2 ], 2
%sum2 = add i32 %sum, %1
%sum3 = add i32 %sum2, %y
ret i32 %sum3
}

define internal i32 @add(i32 %x, i32 %y) {
; CHECK-LABEL: @add
%sum = add i32 %x, %y
%1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 1 ], 2
%sum2 = add i32 %sum, %1
%sum3 = add i32 %sum2, %y
ret i32 %sum3
}

define internal i32 @plus(i32 %x, i32 %y) {
; NOPLUS-NOT: @plus
%sum = add i32 %x, %y
%1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 5 ], 2
%sum2 = add i32 %sum, %1
%sum3 = add i32 %sum2, %y
ret i32 %sum3
}

define internal i32 @next(i32 %x, i32 %y) {
; CHECK-LABEL: @next
%sum = add i32 %x, %y
%1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 5 ], 2
%sum2 = add i32 %sum, %1
%sum3 = add i32 %sum2, %y
ret i32 %sum3
}