@@ -41,11 +41,18 @@ static cl::opt<std::string>
41
41
cl::opt<float > OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
42
42
cl::desc(" Weight for machine opcode embeddings" ),
43
43
cl::cat(MIR2VecCategory));
44
+ cl::opt<MIR2VecKind> MIR2VecEmbeddingKind (
45
+ " mir2vec-kind" , cl::Optional,
46
+ cl::values (clEnumValN(MIR2VecKind::Symbolic, " symbolic" ,
47
+ " Generate symbolic embeddings for MIR" )),
48
+ cl::init(MIR2VecKind::Symbolic), cl::desc(" MIR2Vec embedding kind" ),
49
+ cl::cat(MIR2VecCategory));
50
+
44
51
} // namespace mir2vec
45
52
} // namespace llvm
46
53
47
54
// ===----------------------------------------------------------------------===//
48
- // Vocabulary Implementation
55
+ // Vocabulary
49
56
// ===----------------------------------------------------------------------===//
50
57
51
58
MIRVocabulary::MIRVocabulary (VocabMap &&OpcodeEntries,
@@ -190,6 +197,29 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
190
197
<< " unique base opcodes\n " );
191
198
}
192
199
200
+ MIRVocabulary MIRVocabulary::createDummyVocabForTest (const TargetInstrInfo &TII,
201
+ unsigned Dim) {
202
+ assert (Dim > 0 && " Dimension must be greater than zero" );
203
+
204
+ float DummyVal = 0 .1f ;
205
+
206
+ // Create a temporary vocabulary instance to build canonical mapping
207
+ MIRVocabulary TempVocab ({}, &TII);
208
+ TempVocab.buildCanonicalOpcodeMapping ();
209
+
210
+ // Create dummy embeddings for all canonical opcode names
211
+ VocabMap DummyVocabMap;
212
+ for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames ) {
213
+ // Create dummy embedding filled with DummyVal
214
+ Embedding DummyEmbedding (Dim, DummyVal);
215
+ DummyVocabMap[COpcodeName] = DummyEmbedding;
216
+ DummyVal += 0 .1f ;
217
+ }
218
+
219
+ // Create and return vocabulary with dummy embeddings
220
+ return MIRVocabulary (std::move (DummyVocabMap), &TII);
221
+ }
222
+
193
223
// ===----------------------------------------------------------------------===//
194
224
// MIR2VecVocabLegacyAnalysis Implementation
195
225
// ===----------------------------------------------------------------------===//
@@ -267,7 +297,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
267
297
}
268
298
269
299
// ===----------------------------------------------------------------------===//
270
- // Printer Passes Implementation
300
+ // MIREmbedder and its subclasses
301
+ // ===----------------------------------------------------------------------===//
302
+
303
+ MIREmbedder::MIREmbedder (const MachineFunction &MF, const MIRVocabulary &Vocab)
304
+ : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
305
+ OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
306
+
307
+ std::unique_ptr<MIREmbedder> MIREmbedder::create (MIR2VecKind Mode,
308
+ const MachineFunction &MF,
309
+ const MIRVocabulary &Vocab) {
310
+ switch (Mode) {
311
+ case MIR2VecKind::Symbolic:
312
+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
313
+ }
314
+ return nullptr ;
315
+ }
316
+
317
+ const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap () const {
318
+ if (MInstVecMap.empty ())
319
+ computeEmbeddings ();
320
+ return MInstVecMap;
321
+ }
322
+
323
+ const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap () const {
324
+ if (MBBVecMap.empty ())
325
+ computeEmbeddings ();
326
+ return MBBVecMap;
327
+ }
328
+
329
+ const Embedding &MIREmbedder::getMBBVector (const MachineBasicBlock &BB) const {
330
+ auto It = MBBVecMap.find (&BB);
331
+ if (It != MBBVecMap.end ())
332
+ return It->second ;
333
+ computeEmbeddings (BB);
334
+ return MBBVecMap[&BB];
335
+ }
336
+
337
+ const Embedding &MIREmbedder::getMFunctionVector () const {
338
+ // Currently, we always (re)compute the embeddings for the function.
339
+ // This is cheaper than caching the vector.
340
+ computeEmbeddings ();
341
+ return MFuncVector;
342
+ }
343
+
344
+ void MIREmbedder::computeEmbeddings () const {
345
+ // Reset function vector to zero before recomputing
346
+ MFuncVector = Embedding (Dimension, 0.0 );
347
+
348
+ // Consider all machine basic blocks in the function
349
+ for (const auto &MBB : MF) {
350
+ computeEmbeddings (MBB);
351
+ MFuncVector += MBBVecMap[&MBB];
352
+ }
353
+ }
354
+
355
+ SymbolicMIREmbedder::SymbolicMIREmbedder (const MachineFunction &MF,
356
+ const MIRVocabulary &Vocab)
357
+ : MIREmbedder(MF, Vocab) {}
358
+
359
+ std::unique_ptr<SymbolicMIREmbedder>
360
+ SymbolicMIREmbedder::create (const MachineFunction &MF,
361
+ const MIRVocabulary &Vocab) {
362
+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
363
+ }
364
+
365
+ void SymbolicMIREmbedder::computeEmbeddings (
366
+ const MachineBasicBlock &MBB) const {
367
+ Embedding MBBVector (Dimension, 0 );
368
+
369
+ // Get instruction info for opcode name resolution
370
+ const auto &Subtarget = MF.getSubtarget ();
371
+ const auto *TII = Subtarget.getInstrInfo ();
372
+ if (!TII) {
373
+ MF.getFunction ().getContext ().emitError (
374
+ " MIR2Vec: No TargetInstrInfo available; cannot compute embeddings" );
375
+ return ;
376
+ }
377
+
378
+ // Process each machine instruction in the basic block
379
+ for (const auto &MI : MBB) {
380
+ // Skip debug instructions and other metadata
381
+ if (MI.isDebugInstr ())
382
+ continue ;
383
+
384
+ // Todo: Add operand/argument contributions
385
+
386
+ // Store the instruction embedding
387
+ auto InstVector = Vocab[MI.getOpcode ()];
388
+ MInstVecMap[&MI] = InstVector;
389
+ MBBVector += InstVector;
390
+ }
391
+
392
+ // Store the basic block embedding
393
+ MBBVecMap[&MBB] = MBBVector;
394
+ }
395
+
396
+ // ===----------------------------------------------------------------------===//
397
+ // Printer Passes
271
398
// ===----------------------------------------------------------------------===//
272
399
273
400
char MIR2VecVocabPrinterLegacyPass::ID = 0 ;
@@ -304,3 +431,67 @@ MachineFunctionPass *
304
431
llvm::createMIR2VecVocabPrinterLegacyPass (raw_ostream &OS) {
305
432
return new MIR2VecVocabPrinterLegacyPass (OS);
306
433
}
434
+
435
+ char MIR2VecPrinterLegacyPass::ID = 0 ;
436
+ INITIALIZE_PASS_BEGIN (MIR2VecPrinterLegacyPass, " print-mir2vec" ,
437
+ " MIR2Vec Embedder Printer Pass" , false , true )
438
+ INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
439
+ INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
440
+ INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, " print-mir2vec" ,
441
+ " MIR2Vec Embedder Printer Pass" , false , true )
442
+
443
+ bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
444
+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
445
+ auto MIRVocab = Analysis.getMIR2VecVocabulary (*MF.getFunction ().getParent ());
446
+
447
+ if (!MIRVocab.isValid ()) {
448
+ OS << " MIR2Vec Embedder Printer: Invalid vocabulary for function "
449
+ << MF.getName () << " \n " ;
450
+ return false ;
451
+ }
452
+
453
+ auto Emb = mir2vec::MIREmbedder::create (MIR2VecEmbeddingKind, MF, MIRVocab);
454
+ if (!Emb) {
455
+ OS << " Error creating MIR2Vec embeddings for function " << MF.getName ()
456
+ << " \n " ;
457
+ return false ;
458
+ }
459
+
460
+ OS << " MIR2Vec embeddings for machine function " << MF.getName () << " :\n " ;
461
+ OS << " Machine Function vector: " ;
462
+ Emb->getMFunctionVector ().print (OS);
463
+
464
+ OS << " Machine basic block vectors:\n " ;
465
+ const auto &MBBMap = Emb->getMBBVecMap ();
466
+ for (const MachineBasicBlock &MBB : MF) {
467
+ auto It = MBBMap.find (&MBB);
468
+ if (It != MBBMap.end ()) {
469
+ OS << " Machine basic block: " << MBB.getFullName () << " :\n " ;
470
+ It->second .print (OS);
471
+ }
472
+ }
473
+
474
+ OS << " Machine instruction vectors:\n " ;
475
+ const auto &MInstMap = Emb->getMInstVecMap ();
476
+ for (const MachineBasicBlock &MBB : MF) {
477
+ for (const MachineInstr &MI : MBB) {
478
+ // Skip debug instructions as they are not
479
+ // embedded
480
+ if (MI.isDebugInstr ())
481
+ continue ;
482
+
483
+ auto It = MInstMap.find (&MI);
484
+ if (It != MInstMap.end ()) {
485
+ OS << " Machine instruction: " ;
486
+ MI.print (OS);
487
+ It->second .print (OS);
488
+ }
489
+ }
490
+ }
491
+
492
+ return false ;
493
+ }
494
+
495
+ MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass (raw_ostream &OS) {
496
+ return new MIR2VecPrinterLegacyPass (OS);
497
+ }
0 commit comments