Skip to content

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Oct 6, 2025

Implement MIR2Vec embedder for generating vector representations of Machine IR instructions, basic blocks, and functions. This patch introduces changes necessary to embed machine opcodes. Machine operands would be handled incrementally in the upcoming patches.

Copy link

github-actions bot commented Oct 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@svkeerthy svkeerthy changed the title MIR2Vec embedding [MIR2Vec] Add embedder for machine instructions Oct 6, 2025
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from b7a8ff7 to dc7d70a Compare October 6, 2025 21:30
@svkeerthy svkeerthy marked this pull request as ready for review October 6, 2025 21:31
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from dc7d70a to 5accfaf Compare October 6, 2025 22:38
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from 0ebc739 to 5778be1 Compare October 6, 2025 22:38
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from 5accfaf to 7bdb214 Compare October 6, 2025 22:39
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from 5778be1 to ffc59e7 Compare October 6, 2025 22:39
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from 7bdb214 to 699ab74 Compare October 6, 2025 23:29
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from ffc59e7 to 8a80321 Compare October 6, 2025 23:29
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from 699ab74 to e28efcc Compare October 6, 2025 23:35
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from 8a80321 to 4864150 Compare October 6, 2025 23:35
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from e28efcc to 6fa7a86 Compare October 7, 2025 20:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from 4864150 to 671c686 Compare October 7, 2025 20:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from 6fa7a86 to c58a415 Compare October 7, 2025 20:10
@svkeerthy svkeerthy requested a review from mtrofin October 7, 2025 20:12
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from 671c686 to 6a43c34 Compare October 7, 2025 20:48
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from c58a415 to 9bd4b36 Compare October 7, 2025 20:48
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch 2 times, most recently from 8e371bd to 4d87a59 Compare October 7, 2025 22:29
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from 9bd4b36 to fa57365 Compare October 7, 2025 22:29
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-02-mirvocabulary_changes branch from 4d87a59 to 0afa19d Compare October 7, 2025 23:07
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from fa57365 to f70bce3 Compare October 7, 2025 23:07
// Create dummy embedding filled with DummyVal
Embedding DummyEmbedding(Dim, DummyVal);
DummyVocabMap[COpcodeName] = DummyEmbedding;
DummyVal += 0.1f;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could also, at the end, just DummyVal *= TempVocab.UnoqueBaseOpcodeNames.size()?

}

const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap() const {
if (MInstVecMap.empty())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the caching necessary - didn't we find on the IR side that caching actually hurts performance?

Base automatically changed from users/svkeerthy/10-02-mirvocabulary_changes to main October 7, 2025 23:44
@svkeerthy svkeerthy force-pushed the users/svkeerthy/10-06-mir2vec_embedding branch from f70bce3 to c1747fb Compare October 7, 2025 23:46
@llvmbot llvmbot added the mlgo label Oct 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 7, 2025

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

Implement MIR2Vec embedder for generating vector representations of Machine IR instructions, basic blocks, and functions. This patch introduces changes necessary to embed machine opcodes. Machine operands would be handled incrementally in the upcoming patches.


Patch is 39.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162161.diff

10 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/MIR2Vec.h (+108)
  • (modified) llvm/include/llvm/CodeGen/Passes.h (+4)
  • (modified) llvm/include/llvm/InitializePasses.h (+1)
  • (modified) llvm/lib/CodeGen/CodeGen.cpp (+1)
  • (modified) llvm/lib/CodeGen/MIR2Vec.cpp (+193-2)
  • (added) llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json (+22)
  • (added) llvm/test/CodeGen/MIR2Vec/if-else.mir (+144)
  • (added) llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir (+76)
  • (modified) llvm/tools/llc/llc.cpp (+15)
  • (modified) llvm/unittests/CodeGen/MIR2VecTest.cpp (+299-25)
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..ebafe4ccddff3 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -51,11 +51,21 @@ class LLVMContext;
 class MIR2VecVocabLegacyAnalysis;
 class TargetInstrInfo;
 
+enum class MIR2VecKind { Symbolic };
+
 namespace mir2vec {
+
+// Forward declarations
+class MIREmbedder;
+class SymbolicMIREmbedder;
+
 extern llvm::cl::OptionCategory MIR2VecCategory;
 extern cl::opt<float> OpcWeight;
 
 using Embedding = ir2vec::Embedding;
+using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
+using MachineBlockEmbeddingsMap =
+    DenseMap<const MachineBasicBlock *, Embedding>;
 
 /// Class for storing and accessing the MIR2Vec vocabulary.
 /// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -132,6 +142,79 @@ class MIRVocabulary {
     assert(isValid() && "Invalid vocabulary");
     return Storage.size();
   }
+
+  /// Create a dummy vocabulary for testing purposes.
+  static MIRVocabulary createDummyVocabForTest(const TargetInstrInfo &TII,
+                                               unsigned Dim = 1);
+};
+
+/// Base class for MIR embedders
+class MIREmbedder {
+protected:
+  const MachineFunction &MF;
+  const MIRVocabulary &Vocab;
+
+  /// Dimension of the embeddings; Captured from the vocabulary
+  const unsigned Dimension;
+
+  /// Weight for opcode embeddings
+  const float OpcWeight;
+
+  // Utility maps - these are used to store the vector representations of
+  // instructions, basic blocks and functions.
+  mutable Embedding MFuncVector;
+  mutable MachineBlockEmbeddingsMap MBBVecMap;
+  mutable MachineInstEmbeddingsMap MInstVecMap;
+
+  MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab);
+
+  /// Function to compute embeddings. It generates embeddings for all
+  /// the instructions and basic blocks in the function F.
+  void computeEmbeddings() const;
+
+  /// Function to compute the embedding for a given basic block.
+  /// Specific to the kind of embeddings being computed.
+  virtual void computeEmbeddings(const MachineBasicBlock &MBB) const = 0;
+
+public:
+  virtual ~MIREmbedder() = default;
+
+  /// Factory method to create an Embedder object of the specified kind
+  /// Returns nullptr if the requested kind is not supported.
+  static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
+                                             const MachineFunction &MF,
+                                             const MIRVocabulary &Vocab);
+
+  /// Returns a map containing machine instructions and the corresponding
+  /// embeddings for the machine function MF if it has been computed. If not, it
+  /// computes the embeddings for MF and returns the map.
+  const MachineInstEmbeddingsMap &getMInstVecMap() const;
+
+  /// Returns a map containing machine basic block and the corresponding
+  /// embeddings for the machine function MF if it has been computed. If not, it
+  /// computes the embeddings for MF and returns the map.
+  const MachineBlockEmbeddingsMap &getMBBVecMap() const;
+
+  /// Returns the embedding for a given machine basic block in the machine
+  /// function MF if it has been computed. If not, it computes the embedding for
+  /// MBB and returns it.
+  const Embedding &getMBBVector(const MachineBasicBlock &MBB) const;
+
+  /// Computes and returns the embedding for the current machine function.
+  const Embedding &getMFunctionVector() const;
+};
+
+/// Class for computing Symbolic embeddings
+/// Symbolic embeddings are constructed based on the entity-level
+/// representations obtained from the MIR Vocabulary.
+class SymbolicMIREmbedder : public MIREmbedder {
+private:
+  void computeEmbeddings(const MachineBasicBlock &MBB) const override;
+
+public:
+  SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
+  static std::unique_ptr<SymbolicMIREmbedder>
+  create(const MachineFunction &MF, const MIRVocabulary &Vocab);
 };
 
 } // namespace mir2vec
@@ -181,6 +264,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
   }
 };
 
+/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
+/// and instructions
+class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
+  raw_ostream &OS;
+
+public:
+  static char ID;
+  explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
+      : MachineFunctionPass(ID), OS(OS) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MIR2VecVocabLegacyAnalysis>();
+    AU.setPreservesAll();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  StringRef getPassName() const override {
+    return "MIR2Vec Embedder Printer Pass";
+  }
+};
+
+/// Create a machine pass that prints MIR2Vec embeddings
+MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
 } // namespace llvm
 
 #endif // LLVM_CODEGEN_MIR2VEC_H
\ No newline at end of file
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 272b4acf950c5..7fae550d8d170 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
 LLVM_ABI MachineFunctionPass *
 createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
 
+/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
+/// machine functions, basic blocks and instructions.
+LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
+
 /// StackFramePrinter pass - This pass prints out the machine function's
 /// stack frame to the given stream as a debugging tool.
 LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index cd774e7888e64..d507ba267d791 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -222,6 +222,7 @@ LLVM_ABI void
 initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
 LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
+LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
 LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp
index c438eaeb29d1e..9795a0b707fd3 100644
--- a/llvm/lib/CodeGen/CodeGen.cpp
+++ b/llvm/lib/CodeGen/CodeGen.cpp
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
   initializeMachineUniformityAnalysisPassPass(Registry);
   initializeMIR2VecVocabLegacyAnalysisPass(Registry);
   initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
+  initializeMIR2VecPrinterLegacyPassPass(Registry);
   initializeMachineUniformityInfoPrinterPassPass(Registry);
   initializeMachineVerifierLegacyPassPass(Registry);
   initializeObjCARCContractLegacyPassPass(Registry);
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..ffd076331c7d2 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -41,11 +41,18 @@ static cl::opt<std::string>
 cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
                          cl::desc("Weight for machine opcode embeddings"),
                          cl::cat(MIR2VecCategory));
+cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
+    "mir2vec-kind", cl::Optional,
+    cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
+                          "Generate symbolic embeddings for MIR")),
+    cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
+    cl::cat(MIR2VecCategory));
+
 } // namespace mir2vec
 } // namespace llvm
 
 //===----------------------------------------------------------------------===//
-// Vocabulary Implementation
+// Vocabulary
 //===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -190,6 +197,29 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
                     << " unique base opcodes\n");
 }
 
+MIRVocabulary MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
+                                                     unsigned Dim) {
+  assert(Dim > 0 && "Dimension must be greater than zero");
+
+  float DummyVal = 0.1f;
+
+  // Create a temporary vocabulary instance to build canonical mapping
+  MIRVocabulary TempVocab({}, &TII);
+  TempVocab.buildCanonicalOpcodeMapping();
+
+  // Create dummy embeddings for all canonical opcode names
+  VocabMap DummyVocabMap;
+  for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
+    // Create dummy embedding filled with DummyVal
+    Embedding DummyEmbedding(Dim, DummyVal);
+    DummyVocabMap[COpcodeName] = DummyEmbedding;
+    DummyVal += 0.1f;
+  }
+
+  // Create and return vocabulary with dummy embeddings
+  return MIRVocabulary(std::move(DummyVocabMap), &TII);
+}
+
 //===----------------------------------------------------------------------===//
 // MIR2VecVocabLegacyAnalysis Implementation
 //===----------------------------------------------------------------------===//
@@ -267,7 +297,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printer Passes Implementation
+// MIREmbedder and its subclasses
+//===----------------------------------------------------------------------===//
+
+MIREmbedder::MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
+    : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
+      OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
+
+std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
+                                                 const MachineFunction &MF,
+                                                 const MIRVocabulary &Vocab) {
+  switch (Mode) {
+  case MIR2VecKind::Symbolic:
+    return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+  }
+  return nullptr;
+}
+
+const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap() const {
+  if (MInstVecMap.empty())
+    computeEmbeddings();
+  return MInstVecMap;
+}
+
+const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap() const {
+  if (MBBVecMap.empty())
+    computeEmbeddings();
+  return MBBVecMap;
+}
+
+const Embedding &MIREmbedder::getMBBVector(const MachineBasicBlock &BB) const {
+  auto It = MBBVecMap.find(&BB);
+  if (It != MBBVecMap.end())
+    return It->second;
+  computeEmbeddings(BB);
+  return MBBVecMap[&BB];
+}
+
+const Embedding &MIREmbedder::getMFunctionVector() const {
+  // Currently, we always (re)compute the embeddings for the function.
+  // This is cheaper than caching the vector.
+  computeEmbeddings();
+  return MFuncVector;
+}
+
+void MIREmbedder::computeEmbeddings() const {
+  // Reset function vector to zero before recomputing
+  MFuncVector = Embedding(Dimension, 0.0);
+
+  // Consider all machine basic blocks in the function
+  for (const auto &MBB : MF) {
+    computeEmbeddings(MBB);
+    MFuncVector += MBBVecMap[&MBB];
+  }
+}
+
+SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
+                                         const MIRVocabulary &Vocab)
+    : MIREmbedder(MF, Vocab) {}
+
+std::unique_ptr<SymbolicMIREmbedder>
+SymbolicMIREmbedder::create(const MachineFunction &MF,
+                            const MIRVocabulary &Vocab) {
+  return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+}
+
+void SymbolicMIREmbedder::computeEmbeddings(
+    const MachineBasicBlock &MBB) const {
+  Embedding MBBVector(Dimension, 0);
+
+  // Get instruction info for opcode name resolution
+  const auto &Subtarget = MF.getSubtarget();
+  const auto *TII = Subtarget.getInstrInfo();
+  if (!TII) {
+    MF.getFunction().getContext().emitError(
+        "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
+    return;
+  }
+
+  // Process each machine instruction in the basic block
+  for (const auto &MI : MBB) {
+    // Skip debug instructions and other metadata
+    if (MI.isDebugInstr())
+      continue;
+
+    // Todo: Add operand/argument contributions
+
+    // Store the instruction embedding
+    auto InstVector = Vocab[MI.getOpcode()];
+    MInstVecMap[&MI] = InstVector;
+    MBBVector += InstVector;
+  }
+
+  // Store the basic block embedding
+  MBBVecMap[&MBB] = MBBVector;
+}
+
+//===----------------------------------------------------------------------===//
+// Printer Passes
 //===----------------------------------------------------------------------===//
 
 char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -304,3 +431,67 @@ MachineFunctionPass *
 llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
   return new MIR2VecVocabPrinterLegacyPass(OS);
 }
+
+char MIR2VecPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
+                      "MIR2Vec Embedder Printer Pass", false, true)
+INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
+INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
+INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
+                    "MIR2Vec Embedder Printer Pass", false, true)
+
+bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
+  auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
+  auto MIRVocab = Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
+
+  if (!MIRVocab.isValid()) {
+    OS << "MIR2Vec Embedder Printer: Invalid vocabulary for function "
+       << MF.getName() << "\n";
+    return false;
+  }
+
+  auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
+  if (!Emb) {
+    OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
+       << "\n";
+    return false;
+  }
+
+  OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+  OS << "Machine Function vector: ";
+  Emb->getMFunctionVector().print(OS);
+
+  OS << "Machine basic block vectors:\n";
+  const auto &MBBMap = Emb->getMBBVecMap();
+  for (const MachineBasicBlock &MBB : MF) {
+    auto It = MBBMap.find(&MBB);
+    if (It != MBBMap.end()) {
+      OS << "Machine basic block: " << MBB.getFullName() << ":\n";
+      It->second.print(OS);
+    }
+  }
+
+  OS << "Machine instruction vectors:\n";
+  const auto &MInstMap = Emb->getMInstVecMap();
+  for (const MachineBasicBlock &MBB : MF) {
+    for (const MachineInstr &MI : MBB) {
+      // Skip debug instructions as they are not
+      // embedded
+      if (MI.isDebugInstr())
+        continue;
+
+      auto It = MInstMap.find(&MI);
+      if (It != MInstMap.end()) {
+        OS << "Machine instruction: ";
+        MI.print(OS);
+        It->second.print(OS);
+      }
+    }
+  }
+
+  return false;
+}
+
+MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
+  return new MIR2VecPrinterLegacyPass(OS);
+}
diff --git a/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
new file mode 100644
index 0000000000000..5de715bf80917
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json
@@ -0,0 +1,22 @@
+{
+  "entities": {
+    "KILL": [0.1, 0.2, 0.3],
+    "MOV": [0.4, 0.5, 0.6],
+    "LEA": [0.7, 0.8, 0.9],
+    "RET": [1.0, 1.1, 1.2],
+    "ADD": [1.3, 1.4, 1.5],
+    "SUB": [1.6, 1.7, 1.8],
+    "IMUL": [1.9, 2.0, 2.1],
+    "AND": [2.2, 2.3, 2.4],
+    "OR": [2.5, 2.6, 2.7],
+    "XOR": [2.8, 2.9, 3.0],
+    "CMP": [3.1, 3.2, 3.3],
+    "TEST": [3.4, 3.5, 3.6],
+    "JMP": [3.7, 3.8, 3.9],
+    "CALL": [4.0, 4.1, 4.2],
+    "PUSH": [4.3, 4.4, 4.5],
+    "POP": [4.6, 4.7, 4.8],
+    "NOP": [4.9, 5.0, 5.1],
+    "COPY": [5.2, 5.3, 5.4]
+  }
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/MIR2Vec/if-else.mir b/llvm/test/CodeGen/MIR2Vec/if-else.mir
new file mode 100644
index 0000000000000..2accf476f7c4d
--- /dev/null
+++ b/llvm/test/CodeGen/MIR2Vec/if-else.mir
@@ -0,0 +1,144 @@
+# REQUIRES: x86_64-linux
+# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s
+
+--- |
+  target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+  
+  define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) {
+  entry:
+    %retval = alloca i32, align 4
+    %a.addr = alloca i32, align 4
+    %b.addr = alloca i32, align 4
+    store i32 %a, ptr %a.addr, align 4
+    store i32 %b, ptr %b.addr, align 4
+    %0 = load i32, ptr %a.addr, align 4
+    %1 = load i32, ptr %b.addr, align 4
+    %cmp = icmp sgt i32 %0, %1
+    br i1 %cmp, label %if.then, label %if.else
+  
+  if.then:                                          ; preds = %entry
+    %2 = load i32, ptr %b.addr, align 4
+    store i32 %2, ptr %retval, align 4
+    br label %return
+  
+  if.else:                                          ; preds = %entry
+    %3 = load i32, ptr %a.addr, align 4
+    store i32 %3, ptr %retval, align 4
+    br label %return
+  
+  return:                                           ; preds = %if.else, %if.then
+    %4 = load i32, ptr %retval, align 4
+    ret i32 %4
+  }
+...
+---
+name:            abc
+alignment:       16
+exposesReturnsTwice: false
+legalized:       false
+regBankSelected: false
+selected:        false
+failedISel:      false
+tracksRegLiveness: true
+hasWinCFI:       false
+noPhis:          false
+isSSA:           true
+noVRegs:         false
+hasFakeUses:     false
+callsEHReturn:   false
+callsUnwindInit: false
+hasEHContTarget: false
+hasEHScopes:     false
+hasEHFunclets:   false
+isOutlined:      false
+debugInstrRef:   true
+failsVerification: false
+tracksDebugUserValues: false
+registers:
+  - { id: 0, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 1, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 2, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 3, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 4, class: gr32, preferred-register: '', flags: [  ] }
+  - { id: 5, class: gr32, preferred-register: '', flags: [  ] }
+liveins:
+  - { reg: '$edi', virtual-reg: '%0' }
+  - { reg: '$esi', virtual-reg: '%1' }
+frameInfo:
+  isFrameAddressTaken: false
+  isReturnAddressTaken: false
+  hasStackMap:     false
+  hasPatchPoint:   false
+  stackSize:       0
+  offsetAdjustment: 0
+  maxAlignment:    4
+  adjustsStack:    false
+  hasCalls:        false
+  stackProtector:  ''
+  functionContext: ''
+  maxCallFrameSize: 4294967295
+  cvBytesOfCalleeSavedRegisters: 0
+  hasOpaqueSPAdjustment: false
+  hasVAStart:      false
+  hasMustTailInVarArgFunc: false
+  hasTailCall:     false
+  isCalleeSavedInfoValid: false
+  localFrameSize:  0
+fixedStack:      []
+stack:
+  - { id: 0, name: retval, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+  - { id: 1, name: a.addr, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+  - { id: 2, name: b.addr, type: default, offset: 0, size: 4, alignment: 4, 
+      stack-id: default, callee-saved-register: '', callee-saved-restored: true, 
+      debug-info-variable: '', debug-info-expression: '', debug-info-location: '' }
+entry_values:    []
+callSites:       []
+debugValueSubstitutions: []
+constants:       []
+machineFunctionInfo:
+  amxProgModel:    None
+body:             |
+  bb.0.entry:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+    liveins: $edi, $esi
+  
+    %1:gr32 = COPY $esi
+    %0:gr32 = COPY $edi
+    MOV32mr %stack.1.a.addr, 1, $noreg, 0, $noreg, %0 :: (store (s32) into %ir.a.addr)
+    MOV32mr %stack.2.b.addr, 1, $noreg, 0, $noreg, %1 :: (store (s32) into %ir.b.addr)
+    %2:gr32 = SUB32rr %0, %1, implicit-def $eflags
+    JCC_1 %bb.2, 14, impli...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants