Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions llvm/include/llvm/CodeGen/MIR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,21 @@ class LLVMContext;
class MIR2VecVocabLegacyAnalysis;
class TargetInstrInfo;

enum class MIR2VecKind { Symbolic };

namespace mir2vec {

// Forward declarations
class MIREmbedder;
class SymbolicMIREmbedder;

extern llvm::cl::OptionCategory MIR2VecCategory;
extern cl::opt<float> OpcWeight;

using Embedding = ir2vec::Embedding;
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
using MachineBlockEmbeddingsMap =
DenseMap<const MachineBasicBlock *, Embedding>;

/// Class for storing and accessing the MIR2Vec vocabulary.
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
Expand Down Expand Up @@ -107,19 +117,91 @@ class MIRVocabulary {

const_iterator end() const { return Storage.end(); }

/// Total number of entries in the vocabulary
size_t getCanonicalSize() const { return Storage.size(); }

MIRVocabulary() = delete;

/// Factory method to create MIRVocabulary from vocabulary map
static Expected<MIRVocabulary> create(VocabMap &&Entries,
const TargetInstrInfo &TII);

/// Create a dummy vocabulary for testing purposes.
static Expected<MIRVocabulary>
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);

/// Total number of entries in the vocabulary
size_t getCanonicalSize() const { return Storage.size(); }

private:
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
};

/// Base class for MIR embedders
class MIREmbedder {
protected:
const MachineFunction &MF;
const MIRVocabulary &Vocab;

/// Dimension of the embeddings; Captured from the vocabulary
const unsigned Dimension;

/// Weight for opcode embeddings
const float OpcWeight;

MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(mir2vec::OpcWeight) {}

/// Function to compute embeddings.
Embedding computeEmbeddings() const;

/// Function to compute the embedding for a given machine basic block.
Embedding computeEmbeddings(const MachineBasicBlock &MBB) const;

/// Function to compute the embedding for a given machine instruction.
/// Specific to the kind of embeddings being computed.
virtual Embedding computeEmbeddings(const MachineInstr &MI) const = 0;

public:
virtual ~MIREmbedder() = default;

/// Factory method to create an Embedder object of the specified kind
/// Returns nullptr if the requested kind is not supported.
static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
const MachineFunction &MF,
const MIRVocabulary &Vocab);

/// Computes and returns the embedding for a given machine instruction MI in
/// the machine function MF.
Embedding getMInstVector(const MachineInstr &MI) const {
return computeEmbeddings(MI);
}

/// Computes and returns the embedding for a given machine basic block in the
/// machine function MF.
Embedding getMBBVector(const MachineBasicBlock &MBB) const {
return computeEmbeddings(MBB);
}

/// Computes and returns the embedding for the current machine function.
Embedding getMFunctionVector() const {
// Currently, we always (re)compute the embeddings for the function. This is
// cheaper than caching the vector.
return computeEmbeddings();
}
};

/// Class for computing Symbolic embeddings
/// Symbolic embeddings are constructed based on the entity-level
/// representations obtained from the MIR Vocabulary.
class SymbolicMIREmbedder : public MIREmbedder {
private:
Embedding computeEmbeddings(const MachineInstr &MI) const override;

public:
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
static std::unique_ptr<SymbolicMIREmbedder>
create(const MachineFunction &MF, const MIRVocabulary &Vocab);
};

} // namespace mir2vec

/// Pass to analyze and populate MIR2Vec vocabulary from a module
Expand Down Expand Up @@ -166,6 +248,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
}
};

/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
/// and instructions
class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
raw_ostream &OS;

public:
static char ID;
explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
: MachineFunctionPass(ID), OS(OS) {}

bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}

StringRef getPassName() const override {
return "MIR2Vec Embedder Printer Pass";
}
};

/// Create a machine pass that prints MIR2Vec embeddings
MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);

} // namespace llvm

#endif // LLVM_CODEGEN_MIR2VEC_H
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
LLVM_ABI MachineFunctionPass *
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);

/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
/// machine functions, basic blocks and instructions.
LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);

/// StackFramePrinter pass - This pass prints out the machine function's
/// stack frame to the given stream as a debugging tool.
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ LLVM_ABI void
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
initializeMachineUniformityAnalysisPassPass(Registry);
initializeMIR2VecVocabLegacyAnalysisPass(Registry);
initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
initializeMIR2VecPrinterLegacyPassPass(Registry);
initializeMachineUniformityInfoPrinterPassPass(Registry);
initializeMachineVerifierLegacyPassPass(Registry);
initializeObjCARCContractLegacyPassPass(Registry);
Expand Down
155 changes: 153 additions & 2 deletions llvm/lib/CodeGen/MIR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -41,11 +42,18 @@ static cl::opt<std::string>
cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
cl::desc("Weight for machine opcode embeddings"),
cl::cat(MIR2VecCategory));
cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
"mir2vec-kind", cl::Optional,
cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
"Generate symbolic embeddings for MIR")),
cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
cl::cat(MIR2VecCategory));

} // namespace mir2vec
} // namespace llvm

//===----------------------------------------------------------------------===//
// Vocabulary Implementation
// Vocabulary
//===----------------------------------------------------------------------===//

MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
Expand Down Expand Up @@ -191,6 +199,30 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
<< " unique base opcodes\n");
}

Expected<MIRVocabulary>
MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
unsigned Dim) {
assert(Dim > 0 && "Dimension must be greater than zero");

float DummyVal = 0.1f;

// Create a temporary vocabulary instance to build canonical mapping
MIRVocabulary TempVocab({}, TII);
TempVocab.buildCanonicalOpcodeMapping();

// Create dummy embeddings for all canonical opcode names
VocabMap DummyVocabMap;
for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
// Create dummy embedding filled with DummyVal
Embedding DummyEmbedding(Dim, DummyVal);
DummyVocabMap[COpcodeName] = DummyEmbedding;
DummyVal += 0.1f;
}

// Create and return vocabulary with dummy embeddings
return MIRVocabulary::create(std::move(DummyVocabMap), TII);
}

//===----------------------------------------------------------------------===//
// MIR2VecVocabLegacyAnalysis Implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -261,7 +293,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
}

//===----------------------------------------------------------------------===//
// Printer Passes Implementation
// MIREmbedder and its subclasses
//===----------------------------------------------------------------------===//

std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
const MachineFunction &MF,
const MIRVocabulary &Vocab) {
switch (Mode) {
case MIR2VecKind::Symbolic:
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
}
return nullptr;
}

Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const {
Embedding MBBVector(Dimension, 0);

// Get instruction info for opcode name resolution
const auto &Subtarget = MF.getSubtarget();
const auto *TII = Subtarget.getInstrInfo();
if (!TII) {
MF.getFunction().getContext().emitError(
"MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
return MBBVector;
}

// Process each machine instruction in the basic block
for (const auto &MI : MBB) {
// Skip debug instructions and other metadata
if (MI.isDebugInstr())
continue;
MBBVector += computeEmbeddings(MI);
}

return MBBVector;
}

Embedding MIREmbedder::computeEmbeddings() const {
Embedding MFuncVector(Dimension, 0);

// Consider all reachable machine basic blocks in the function
for (const auto *MBB : depth_first(&MF))
MFuncVector += computeEmbeddings(*MBB);
return MFuncVector;
}

SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
const MIRVocabulary &Vocab)
: MIREmbedder(MF, Vocab) {}

std::unique_ptr<SymbolicMIREmbedder>
SymbolicMIREmbedder::create(const MachineFunction &MF,
const MIRVocabulary &Vocab) {
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
}

Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const {
// Skip debug instructions and other metadata
if (MI.isDebugInstr())
return Embedding(Dimension, 0);

// Todo: Add operand/argument contributions

return Vocab[MI.getOpcode()];
}

//===----------------------------------------------------------------------===//
// Printer Passes
//===----------------------------------------------------------------------===//

char MIR2VecVocabPrinterLegacyPass::ID = 0;
Expand Down Expand Up @@ -300,3 +398,56 @@ MachineFunctionPass *
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
return new MIR2VecVocabPrinterLegacyPass(OS);
}

char MIR2VecPrinterLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
"MIR2Vec Embedder Printer Pass", false, true)
INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
"MIR2Vec Embedder Printer Pass", false, true)

bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
auto VocabOrErr =
Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
auto &MIRVocab = *VocabOrErr;

auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
if (!Emb) {
OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
<< "\n";
return false;
}

OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
OS << "Machine Function vector: ";
Emb->getMFunctionVector().print(OS);

OS << "Machine basic block vectors:\n";
for (const MachineBasicBlock &MBB : MF) {
OS << "Machine basic block: " << MBB.getFullName() << ":\n";
Emb->getMBBVector(MBB).print(OS);
}

OS << "Machine instruction vectors:\n";
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions as they are not
// embedded
if (MI.isDebugInstr())
continue;

OS << "Machine instruction: ";
MI.print(OS);
Emb->getMInstVector(MI).print(OS);
}
}

return false;
}

MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
return new MIR2VecPrinterLegacyPass(OS);
}
22 changes: 22 additions & 0 deletions llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"entities": {
"KILL": [0.1, 0.2, 0.3],
"MOV": [0.4, 0.5, 0.6],
"LEA": [0.7, 0.8, 0.9],
"RET": [1.0, 1.1, 1.2],
"ADD": [1.3, 1.4, 1.5],
"SUB": [1.6, 1.7, 1.8],
"IMUL": [1.9, 2.0, 2.1],
"AND": [2.2, 2.3, 2.4],
"OR": [2.5, 2.6, 2.7],
"XOR": [2.8, 2.9, 3.0],
"CMP": [3.1, 3.2, 3.3],
"TEST": [3.4, 3.5, 3.6],
"JMP": [3.7, 3.8, 3.9],
"CALL": [4.0, 4.1, 4.2],
"PUSH": [4.3, 4.4, 4.5],
"POP": [4.6, 4.7, 4.8],
"NOP": [4.9, 5.0, 5.1],
"COPY": [5.2, 5.3, 5.4]
}
}
Loading
Loading