Skip to content

Commit 9bd4b36

Browse files
committed
MIR2Vec embedding
1 parent 6a43c34 commit 9bd4b36

File tree

10 files changed

+863
-27
lines changed

10 files changed

+863
-27
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,21 @@ class LLVMContext;
5151
class MIR2VecVocabLegacyAnalysis;
5252
class TargetInstrInfo;
5353

54+
enum class MIR2VecKind { Symbolic };
55+
5456
namespace mir2vec {
57+
58+
// Forward declarations
59+
class MIREmbedder;
60+
class SymbolicMIREmbedder;
61+
5562
extern llvm::cl::OptionCategory MIR2VecCategory;
5663
extern cl::opt<float> OpcWeight;
5764

5865
using Embedding = ir2vec::Embedding;
66+
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
67+
using MachineBlockEmbeddingsMap =
68+
DenseMap<const MachineBasicBlock *, Embedding>;
5969

6070
/// Class for storing and accessing the MIR2Vec vocabulary.
6171
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -132,6 +142,79 @@ class MIRVocabulary {
132142
assert(isValid() && "Invalid vocabulary");
133143
return Storage.size();
134144
}
145+
146+
/// Create a dummy vocabulary for testing purposes.
147+
static MIRVocabulary createDummyVocabForTest(const TargetInstrInfo &TII,
148+
unsigned Dim = 1);
149+
};
150+
151+
/// Base class for MIR embedders
152+
class MIREmbedder {
153+
protected:
154+
const MachineFunction &MF;
155+
const MIRVocabulary &Vocab;
156+
157+
/// Dimension of the embeddings; Captured from the vocabulary
158+
const unsigned Dimension;
159+
160+
/// Weight for opcode embeddings
161+
const float OpcWeight;
162+
163+
// Utility maps - these are used to store the vector representations of
164+
// instructions, basic blocks and functions.
165+
mutable Embedding MFuncVector;
166+
mutable MachineBlockEmbeddingsMap MBBVecMap;
167+
mutable MachineInstEmbeddingsMap MInstVecMap;
168+
169+
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab);
170+
171+
/// Function to compute embeddings. It generates embeddings for all
172+
/// the instructions and basic blocks in the function F.
173+
void computeEmbeddings() const;
174+
175+
/// Function to compute the embedding for a given basic block.
176+
/// Specific to the kind of embeddings being computed.
177+
virtual void computeEmbeddings(const MachineBasicBlock &MBB) const = 0;
178+
179+
public:
180+
virtual ~MIREmbedder() = default;
181+
182+
/// Factory method to create an Embedder object of the specified kind
183+
/// Returns nullptr if the requested kind is not supported.
184+
static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
185+
const MachineFunction &MF,
186+
const MIRVocabulary &Vocab);
187+
188+
/// Returns a map containing machine instructions and the corresponding
189+
/// embeddings for the machine function MF if it has been computed. If not, it
190+
/// computes the embeddings for MF and returns the map.
191+
const MachineInstEmbeddingsMap &getMInstVecMap() const;
192+
193+
/// Returns a map containing machine basic block and the corresponding
194+
/// embeddings for the machine function MF if it has been computed. If not, it
195+
/// computes the embeddings for MF and returns the map.
196+
const MachineBlockEmbeddingsMap &getMBBVecMap() const;
197+
198+
/// Returns the embedding for a given machine basic block in the machine
199+
/// function MF if it has been computed. If not, it computes the embedding for
200+
/// MBB and returns it.
201+
const Embedding &getMBBVector(const MachineBasicBlock &MBB) const;
202+
203+
/// Computes and returns the embedding for the current machine function.
204+
const Embedding &getMFunctionVector() const;
205+
};
206+
207+
/// Class for computing Symbolic embeddings
208+
/// Symbolic embeddings are constructed based on the entity-level
209+
/// representations obtained from the MIR Vocabulary.
210+
class SymbolicMIREmbedder : public MIREmbedder {
211+
private:
212+
void computeEmbeddings(const MachineBasicBlock &MBB) const override;
213+
214+
public:
215+
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
216+
static std::unique_ptr<SymbolicMIREmbedder>
217+
create(const MachineFunction &MF, const MIRVocabulary &Vocab);
135218
};
136219

137220
} // namespace mir2vec
@@ -181,6 +264,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
181264
}
182265
};
183266

267+
/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
268+
/// and instructions
269+
class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
270+
raw_ostream &OS;
271+
272+
public:
273+
static char ID;
274+
explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
275+
: MachineFunctionPass(ID), OS(OS) {}
276+
277+
bool runOnMachineFunction(MachineFunction &MF) override;
278+
void getAnalysisUsage(AnalysisUsage &AU) const override {
279+
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
280+
AU.setPreservesAll();
281+
MachineFunctionPass::getAnalysisUsage(AU);
282+
}
283+
284+
StringRef getPassName() const override {
285+
return "MIR2Vec Embedder Printer Pass";
286+
}
287+
};
288+
289+
/// Create a machine pass that prints MIR2Vec embeddings
290+
MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
291+
184292
} // namespace llvm
185293

186294
#endif // LLVM_CODEGEN_MIR2VEC_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
9393
LLVM_ABI MachineFunctionPass *
9494
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
9595

96+
/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
97+
/// machine functions, basic blocks and instructions.
98+
LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
99+
96100
/// StackFramePrinter pass - This pass prints out the machine function's
97101
/// stack frame to the given stream as a debugging tool.
98102
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ LLVM_ABI void
222222
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
223223
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
224224
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
225+
LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
225226
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
226227
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
227228
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);

llvm/lib/CodeGen/CodeGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
9898
initializeMachineUniformityAnalysisPassPass(Registry);
9999
initializeMIR2VecVocabLegacyAnalysisPass(Registry);
100100
initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
101+
initializeMIR2VecPrinterLegacyPassPass(Registry);
101102
initializeMachineUniformityInfoPrinterPassPass(Registry);
102103
initializeMachineVerifierLegacyPassPass(Registry);
103104
initializeObjCARCContractLegacyPassPass(Registry);

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 193 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@ static cl::opt<std::string>
4141
cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4242
cl::desc("Weight for machine opcode embeddings"),
4343
cl::cat(MIR2VecCategory));
44+
cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
45+
"mir2vec-kind", cl::Optional,
46+
cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
47+
"Generate symbolic embeddings for MIR")),
48+
cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
49+
cl::cat(MIR2VecCategory));
50+
4451
} // namespace mir2vec
4552
} // namespace llvm
4653

4754
//===----------------------------------------------------------------------===//
48-
// Vocabulary Implementation
55+
// Vocabulary
4956
//===----------------------------------------------------------------------===//
5057

5158
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -190,6 +197,29 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
190197
<< " unique base opcodes\n");
191198
}
192199

200+
MIRVocabulary MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
201+
unsigned Dim) {
202+
assert(Dim > 0 && "Dimension must be greater than zero");
203+
204+
float DummyVal = 0.1f;
205+
206+
// Create a temporary vocabulary instance to build canonical mapping
207+
MIRVocabulary TempVocab({}, &TII);
208+
TempVocab.buildCanonicalOpcodeMapping();
209+
210+
// Create dummy embeddings for all canonical opcode names
211+
VocabMap DummyVocabMap;
212+
for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
213+
// Create dummy embedding filled with DummyVal
214+
Embedding DummyEmbedding(Dim, DummyVal);
215+
DummyVocabMap[COpcodeName] = DummyEmbedding;
216+
DummyVal += 0.1f;
217+
}
218+
219+
// Create and return vocabulary with dummy embeddings
220+
return MIRVocabulary(std::move(DummyVocabMap), &TII);
221+
}
222+
193223
//===----------------------------------------------------------------------===//
194224
// MIR2VecVocabLegacyAnalysis Implementation
195225
//===----------------------------------------------------------------------===//
@@ -267,7 +297,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
267297
}
268298

269299
//===----------------------------------------------------------------------===//
270-
// Printer Passes Implementation
300+
// MIREmbedder and its subclasses
301+
//===----------------------------------------------------------------------===//
302+
303+
MIREmbedder::MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
304+
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
305+
OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
306+
307+
std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
308+
const MachineFunction &MF,
309+
const MIRVocabulary &Vocab) {
310+
switch (Mode) {
311+
case MIR2VecKind::Symbolic:
312+
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
313+
}
314+
return nullptr;
315+
}
316+
317+
const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap() const {
318+
if (MInstVecMap.empty())
319+
computeEmbeddings();
320+
return MInstVecMap;
321+
}
322+
323+
const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap() const {
324+
if (MBBVecMap.empty())
325+
computeEmbeddings();
326+
return MBBVecMap;
327+
}
328+
329+
const Embedding &MIREmbedder::getMBBVector(const MachineBasicBlock &BB) const {
330+
auto It = MBBVecMap.find(&BB);
331+
if (It != MBBVecMap.end())
332+
return It->second;
333+
computeEmbeddings(BB);
334+
return MBBVecMap[&BB];
335+
}
336+
337+
const Embedding &MIREmbedder::getMFunctionVector() const {
338+
// Currently, we always (re)compute the embeddings for the function.
339+
// This is cheaper than caching the vector.
340+
computeEmbeddings();
341+
return MFuncVector;
342+
}
343+
344+
void MIREmbedder::computeEmbeddings() const {
345+
// Reset function vector to zero before recomputing
346+
MFuncVector = Embedding(Dimension, 0.0);
347+
348+
// Consider all machine basic blocks in the function
349+
for (const auto &MBB : MF) {
350+
computeEmbeddings(MBB);
351+
MFuncVector += MBBVecMap[&MBB];
352+
}
353+
}
354+
355+
SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
356+
const MIRVocabulary &Vocab)
357+
: MIREmbedder(MF, Vocab) {}
358+
359+
std::unique_ptr<SymbolicMIREmbedder>
360+
SymbolicMIREmbedder::create(const MachineFunction &MF,
361+
const MIRVocabulary &Vocab) {
362+
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
363+
}
364+
365+
void SymbolicMIREmbedder::computeEmbeddings(
366+
const MachineBasicBlock &MBB) const {
367+
Embedding MBBVector(Dimension, 0);
368+
369+
// Get instruction info for opcode name resolution
370+
const auto &Subtarget = MF.getSubtarget();
371+
const auto *TII = Subtarget.getInstrInfo();
372+
if (!TII) {
373+
MF.getFunction().getContext().emitError(
374+
"MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
375+
return;
376+
}
377+
378+
// Process each machine instruction in the basic block
379+
for (const auto &MI : MBB) {
380+
// Skip debug instructions and other metadata
381+
if (MI.isDebugInstr())
382+
continue;
383+
384+
// Todo: Add operand/argument contributions
385+
386+
// Store the instruction embedding
387+
auto InstVector = Vocab[MI.getOpcode()];
388+
MInstVecMap[&MI] = InstVector;
389+
MBBVector += InstVector;
390+
}
391+
392+
// Store the basic block embedding
393+
MBBVecMap[&MBB] = MBBVector;
394+
}
395+
396+
//===----------------------------------------------------------------------===//
397+
// Printer Passes
271398
//===----------------------------------------------------------------------===//
272399

273400
char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -304,3 +431,67 @@ MachineFunctionPass *
304431
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
305432
return new MIR2VecVocabPrinterLegacyPass(OS);
306433
}
434+
435+
char MIR2VecPrinterLegacyPass::ID = 0;
436+
INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
437+
"MIR2Vec Embedder Printer Pass", false, true)
438+
INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
439+
INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
440+
INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
441+
"MIR2Vec Embedder Printer Pass", false, true)
442+
443+
bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
444+
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
445+
auto MIRVocab = Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
446+
447+
if (!MIRVocab.isValid()) {
448+
OS << "MIR2Vec Embedder Printer: Invalid vocabulary for function "
449+
<< MF.getName() << "\n";
450+
return false;
451+
}
452+
453+
auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
454+
if (!Emb) {
455+
OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
456+
<< "\n";
457+
return false;
458+
}
459+
460+
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
461+
OS << "Machine Function vector: ";
462+
Emb->getMFunctionVector().print(OS);
463+
464+
OS << "Machine basic block vectors:\n";
465+
const auto &MBBMap = Emb->getMBBVecMap();
466+
for (const MachineBasicBlock &MBB : MF) {
467+
auto It = MBBMap.find(&MBB);
468+
if (It != MBBMap.end()) {
469+
OS << "Machine basic block: " << MBB.getFullName() << ":\n";
470+
It->second.print(OS);
471+
}
472+
}
473+
474+
OS << "Machine instruction vectors:\n";
475+
const auto &MInstMap = Emb->getMInstVecMap();
476+
for (const MachineBasicBlock &MBB : MF) {
477+
for (const MachineInstr &MI : MBB) {
478+
// Skip debug instructions as they are not
479+
// embedded
480+
if (MI.isDebugInstr())
481+
continue;
482+
483+
auto It = MInstMap.find(&MI);
484+
if (It != MInstMap.end()) {
485+
OS << "Machine instruction: ";
486+
MI.print(OS);
487+
It->second.print(OS);
488+
}
489+
}
490+
}
491+
492+
return false;
493+
}
494+
495+
MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
496+
return new MIR2VecPrinterLegacyPass(OS);
497+
}

0 commit comments

Comments
 (0)