Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/docs/MLGO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.

.. code-block:: c++

const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
ir2vec::Embedding FuncVector = Emb->getFunctionVector();

Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
Basic Blocks, and Functions. Appropriate getters are provided to access the
Expand Down
61 changes: 34 additions & 27 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,21 +533,20 @@ class Embedder {
/// in the IR instructions to generate the vector representation.
const float OpcWeight, TypeWeight, ArgWeight;

// Utility maps - these are used to store the vector representations of
// instructions, basic blocks and functions.
mutable Embedding FuncVector;
mutable BBEmbeddingsMap BBVecMap;
mutable InstEmbeddingsMap InstVecMap;

LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(ir2vec::OpcWeight), TypeWeight(ir2vec::TypeWeight),
ArgWeight(ir2vec::ArgWeight) {}

/// Function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F.
void computeEmbeddings() const;
/// Function to compute embeddings.
Embedding computeEmbeddings() const;

/// Function to compute the embedding for a given basic block.
Embedding computeEmbeddings(const BasicBlock &BB) const;

/// Function to compute the embedding for a given instruction.
/// Specific to the kind of embeddings being computed.
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
virtual Embedding computeEmbeddings(const Instruction &I) const = 0;

public:
virtual ~Embedder() = default;
Expand All @@ -556,31 +555,35 @@ class Embedder {
LLVM_ABI static std::unique_ptr<Embedder>
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);

/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
/// for the function and returns the map.
LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;

/// Returns a map containing basic block and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
/// for the function and returns the map.
LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
/// Computes and returns the embedding for a given instruction in the function
/// F
LLVM_ABI Embedding getInstVector(const Instruction &I) const {
return computeEmbeddings(I);
}

/// Returns the embedding for a given basic block in the function F if it has
/// been computed. If not, it computes the embedding for the basic block and
/// returns it.
LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
/// Computes and returns the embedding for a given basic block in the function
/// F
LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const {
return computeEmbeddings(BB);
}

/// Computes and returns the embedding for the current function.
LLVM_ABI const Embedding &getFunctionVector() const;
LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }

/// Invalidate embeddings if cached. The embeddings may not be relevant
/// anymore when the IR changes due to transformations. In such cases, the
/// cached embeddings should be invalidated to ensure
/// correctness/recomputation. This is a no-op for SymbolicEmbedder but
/// removes all the cached entries in FlowAwareEmbedder.
virtual void invalidateEmbeddings() { return; }
};

/// Class for computing the Symbolic embeddings of IR2Vec.
/// Symbolic embeddings are constructed based on the entity-level
/// representations obtained from the Vocabulary.
class LLVM_ABI SymbolicEmbedder : public Embedder {
private:
void computeEmbeddings(const BasicBlock &BB) const override;
Embedding computeEmbeddings(const Instruction &I) const override;

public:
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
Expand All @@ -592,11 +595,15 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
/// embeddings, and additionally capture the flow information in the IR.
class LLVM_ABI FlowAwareEmbedder : public Embedder {
private:
void computeEmbeddings(const BasicBlock &BB) const override;
// FlowAware embeddings would benefit from caching instruction embeddings as
// they are reused while computing the embeddings of other instructions.
mutable InstEmbeddingsMap InstVecMap;
Embedding computeEmbeddings(const Instruction &I) const override;

public:
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
: Embedder(F, Vocab) {}
void invalidateEmbeddings() override { InstVecMap.clear(); }
};

} // namespace ir2vec
Expand Down
180 changes: 71 additions & 109 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
// Embedder and its subclasses
//===----------------------------------------------------------------------===//

Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
FuncVector(Embedding(Dimension)) {}

std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
switch (Mode) {
Expand All @@ -169,110 +164,85 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
return nullptr;
}

const InstEmbeddingsMap &Embedder::getInstVecMap() const {
if (InstVecMap.empty())
computeEmbeddings();
return InstVecMap;
}

const BBEmbeddingsMap &Embedder::getBBVecMap() const {
if (BBVecMap.empty())
computeEmbeddings();
return BBVecMap;
}

const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
auto It = BBVecMap.find(&BB);
if (It != BBVecMap.end())
return It->second;
computeEmbeddings(BB);
return BBVecMap[&BB];
}
Embedding Embedder::computeEmbeddings() const {
Embedding FuncVector(Dimension, 0.0);

const Embedding &Embedder::getFunctionVector() const {
// Currently, we always (re)compute the embeddings for the function.
// This is cheaper than caching the vector.
computeEmbeddings();
return FuncVector;
}

void Embedder::computeEmbeddings() const {
if (F.isDeclaration())
return;

FuncVector = Embedding(Dimension, 0.0);
return FuncVector;

// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F)) {
computeEmbeddings(*BB);
FuncVector += BBVecMap[BB];
}
for (const BasicBlock *BB : depth_first(&F))
FuncVector += computeEmbeddings(*BB);
return FuncVector;
}

void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);

// We consider only the non-debug and non-pseudo instructions
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands())
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
BBVecMap[&BB] = BBVector;
}

void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
for (const auto &I : BB.instructionsWithoutDebug())
BBVector += computeEmbeddings(I);
return BBVector;
}

Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
// Currently, we always (re)compute the embeddings for symbolic embedder.
// This is cheaper than caching the vectors.
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands())
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
return InstVector;
}

Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
// If we have already computed the embedding for this instruction, return it
auto It = InstVecMap.find(&I);
if (It != InstVecMap.end())
return It->second;

// We consider only the non-debug and non-pseudo instructions
for (const auto &I : BB.instructionsWithoutDebug()) {
// TODO: Handle call instructions differently.
// For now, we treat them like other instructions
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands()) {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
auto DefIt = InstVecMap.find(DefInst);
// Fixme (#159171): Ideally we should never miss an instruction
// embedding here.
// But when we have cyclic dependencies (e.g., phi
// nodes), we might miss the embedding. In such cases, we fall back to
// using the vocabulary embedding. This can be fixed by iterating to a
// fixed-point, or by using a simple solver for the set of simultaneous
// equations.
// Another case when we might miss an instruction embedding is when
// the operand instruction is in a different basic block that has not
// been processed yet. This can be fixed by processing the basic blocks
// in a topological order.
if (DefIt != InstVecMap.end())
ArgEmb += DefIt->second;
else
ArgEmb += Vocab[*Op];
}
// If the operand is not defined by an instruction, we use the vocabulary
else {
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
<< *Op << "=" << Vocab[*Op][0] << "\n");
// TODO: Handle call instructions differently.
// For now, we treat them like other instructions
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands()) {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
auto DefIt = InstVecMap.find(DefInst);
// Fixme (#159171): Ideally we should never miss an instruction
// embedding here.
// But when we have cyclic dependencies (e.g., phi
// nodes), we might miss the embedding. In such cases, we fall back to
// using the vocabulary embedding. This can be fixed by iterating to a
// fixed-point, or by using a simple solver for the set of simultaneous
// equations.
// Another case when we might miss an instruction embedding is when
// the operand instruction is in a different basic block that has not
// been processed yet. This can be fixed by processing the basic blocks
// in a topological order.
if (DefIt != InstVecMap.end())
ArgEmb += DefIt->second;
else
ArgEmb += Vocab[*Op];
}
}
// Create the instruction vector by combining opcode, type, and arguments
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
// Add compare predicate embedding as an additional operand if applicable
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
// If the operand is not defined by an instruction, we use the
// vocabulary
else {
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
<< *Op << "=" << Vocab[*Op][0] << "\n");
ArgEmb += Vocab[*Op];
}
}
BBVecMap[&BB] = BBVector;
// Create the instruction vector by combining opcode, type, and arguments
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
return InstVector;
}

// ==----------------------------------------------------------------------===//
Expand Down Expand Up @@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
Emb->getFunctionVector().print(OS);

OS << "Basic block vectors:\n";
const auto &BBMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBMap.find(&BB);
if (It != BBMap.end()) {
OS << "Basic block: " << BB.getName() << ":\n";
It->second.print(OS);
}
OS << "Basic block: " << BB.getName() << ":\n";
Emb->getBBVector(BB).print(OS);
}

OS << "Instruction vectors:\n";
const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
auto It = InstMap.find(&I);
if (It != InstMap.end()) {
OS << "Instruction: ";
I.print(OS);
It->second.print(OS);
}
OS << "Instruction: ";
I.print(OS);
Emb->getInstVector(I).print(OS);
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions llvm/test/Analysis/IR2Vec/unreachable.ll
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ return: ; preds = %if.else, %if.then
%4 = load i32, ptr %retval, align 4
ret i32 %4
}

; CHECK: Basic block vectors:
; We'll get individual basic block embeddings for all blocks in the function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would removing caching change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coz, vocb printer iterates over all the basic blocks and invokes getBBVector() now. Either we can change the printer pass to iterate over blocks in depth_first order or change here.

; But unreachable blocks are not counted for computing the function embedding.
; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ]
; CHECK-NEXT: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: unreachable:
; CHECK-NEXT: [ 101.00 103.00 105.00 ]
; CHECK-NEXT: Basic block: return:
; CHECK-NEXT: [ 95.00 97.00 99.00 ]
16 changes: 4 additions & 12 deletions llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,25 +253,17 @@ class IR2VecTool {
break;
}
case BasicBlockLevel: {
const auto &BBVecMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBVecMap.find(&BB);
if (It != BBVecMap.end()) {
OS << BB.getName() << ":";
It->second.print(OS);
}
OS << BB.getName() << ":";
Emb->getBBVector(BB).print(OS);
}
break;
}
case InstructionLevel: {
const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
auto It = InstMap.find(&I);
if (It != InstMap.end()) {
I.print(OS);
It->second.print(OS);
}
I.print(OS);
Emb->getInstVector(I).print(OS);
}
}
break;
Expand Down
Loading
Loading