diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index 7b1b5d9aee15d..f6b0571f5dac6 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -52,11 +52,21 @@ class LLVMContext; class MIR2VecVocabLegacyAnalysis; class TargetInstrInfo; +enum class MIR2VecKind { Symbolic }; + namespace mir2vec { + +// Forward declarations +class MIREmbedder; +class SymbolicMIREmbedder; + extern llvm::cl::OptionCategory MIR2VecCategory; extern cl::opt OpcWeight; using Embedding = ir2vec::Embedding; +using MachineInstEmbeddingsMap = DenseMap; +using MachineBlockEmbeddingsMap = + DenseMap; /// Class for storing and accessing the MIR2Vec vocabulary. /// The MIRVocabulary class manages seed embeddings for LLVM Machine IR @@ -107,19 +117,91 @@ class MIRVocabulary { const_iterator end() const { return Storage.end(); } - /// Total number of entries in the vocabulary - size_t getCanonicalSize() const { return Storage.size(); } - MIRVocabulary() = delete; /// Factory method to create MIRVocabulary from vocabulary map static Expected create(VocabMap &&Entries, const TargetInstrInfo &TII); + /// Create a dummy vocabulary for testing purposes. + static Expected + createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1); + + /// Total number of entries in the vocabulary + size_t getCanonicalSize() const { return Storage.size(); } + private: MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII); }; +/// Base class for MIR embedders +class MIREmbedder { +protected: + const MachineFunction &MF; + const MIRVocabulary &Vocab; + + /// Dimension of the embeddings; Captured from the vocabulary + const unsigned Dimension; + + /// Weight for opcode embeddings + const float OpcWeight; + + MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab) + : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()), + OpcWeight(mir2vec::OpcWeight) {} + + /// Function to compute embeddings. + Embedding computeEmbeddings() const; + + /// Function to compute the embedding for a given machine basic block. + Embedding computeEmbeddings(const MachineBasicBlock &MBB) const; + + /// Function to compute the embedding for a given machine instruction. + /// Specific to the kind of embeddings being computed. + virtual Embedding computeEmbeddings(const MachineInstr &MI) const = 0; + +public: + virtual ~MIREmbedder() = default; + + /// Factory method to create an Embedder object of the specified kind + /// Returns nullptr if the requested kind is not supported. + static std::unique_ptr create(MIR2VecKind Mode, + const MachineFunction &MF, + const MIRVocabulary &Vocab); + + /// Computes and returns the embedding for a given machine instruction MI in + /// the machine function MF. + Embedding getMInstVector(const MachineInstr &MI) const { + return computeEmbeddings(MI); + } + + /// Computes and returns the embedding for a given machine basic block in the + /// machine function MF. + Embedding getMBBVector(const MachineBasicBlock &MBB) const { + return computeEmbeddings(MBB); + } + + /// Computes and returns the embedding for the current machine function. + Embedding getMFunctionVector() const { + // Currently, we always (re)compute the embeddings for the function. This is + // cheaper than caching the vector. + return computeEmbeddings(); + } +}; + +/// Class for computing Symbolic embeddings +/// Symbolic embeddings are constructed based on the entity-level +/// representations obtained from the MIR Vocabulary. +class SymbolicMIREmbedder : public MIREmbedder { +private: + Embedding computeEmbeddings(const MachineInstr &MI) const override; + +public: + SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab); + static std::unique_ptr + create(const MachineFunction &MF, const MIRVocabulary &Vocab); +}; + } // namespace mir2vec /// Pass to analyze and populate MIR2Vec vocabulary from a module @@ -166,6 +248,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass { } }; +/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks, +/// and instructions +class MIR2VecPrinterLegacyPass : public MachineFunctionPass { + raw_ostream &OS; + +public: + static char ID; + explicit MIR2VecPrinterLegacyPass(raw_ostream &OS) + : MachineFunctionPass(ID), OS(OS) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesAll(); + MachineFunctionPass::getAnalysisUsage(AU); + } + + StringRef getPassName() const override { + return "MIR2Vec Embedder Printer Pass"; + } +}; + +/// Create a machine pass that prints MIR2Vec embeddings +MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS); + } // namespace llvm #endif // LLVM_CODEGEN_MIR2VEC_H \ No newline at end of file diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h index 272b4acf950c5..7fae550d8d170 100644 --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS, LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS); +/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for +/// machine functions, basic blocks and instructions. +LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS); + /// StackFramePrinter pass - This pass prints out the machine function's /// stack frame to the given stream as a debugging tool. LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass(); diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index cd774e7888e64..d507ba267d791 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -222,6 +222,7 @@ LLVM_ABI void initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &); LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &); LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &); +LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &); LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &); LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &); LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &); diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp index c438eaeb29d1e..9795a0b707fd3 100644 --- a/llvm/lib/CodeGen/CodeGen.cpp +++ b/llvm/lib/CodeGen/CodeGen.cpp @@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) { initializeMachineUniformityAnalysisPassPass(Registry); initializeMIR2VecVocabLegacyAnalysisPass(Registry); initializeMIR2VecVocabPrinterLegacyPassPass(Registry); + initializeMIR2VecPrinterLegacyPassPass(Registry); initializeMachineUniformityInfoPrinterPassPass(Registry); initializeMachineVerifierLegacyPassPass(Registry); initializeObjCARCContractLegacyPassPass(Registry); diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index e85976547a2c2..2df14a75bf623 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/CodeGen/MIR2Vec.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/IR/Module.h" @@ -41,11 +42,18 @@ static cl::opt cl::opt OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), cl::desc("Weight for machine opcode embeddings"), cl::cat(MIR2VecCategory)); +cl::opt MIR2VecEmbeddingKind( + "mir2vec-kind", cl::Optional, + cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", + "Generate symbolic embeddings for MIR")), + cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), + cl::cat(MIR2VecCategory)); + } // namespace mir2vec } // namespace llvm //===----------------------------------------------------------------------===// -// Vocabulary Implementation +// Vocabulary //===----------------------------------------------------------------------===// MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, @@ -191,6 +199,30 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() { << " unique base opcodes\n"); } +Expected +MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII, + unsigned Dim) { + assert(Dim > 0 && "Dimension must be greater than zero"); + + float DummyVal = 0.1f; + + // Create a temporary vocabulary instance to build canonical mapping + MIRVocabulary TempVocab({}, TII); + TempVocab.buildCanonicalOpcodeMapping(); + + // Create dummy embeddings for all canonical opcode names + VocabMap DummyVocabMap; + for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) { + // Create dummy embedding filled with DummyVal + Embedding DummyEmbedding(Dim, DummyVal); + DummyVocabMap[COpcodeName] = DummyEmbedding; + DummyVal += 0.1f; + } + + // Create and return vocabulary with dummy embeddings + return MIRVocabulary::create(std::move(DummyVocabMap), TII); +} + //===----------------------------------------------------------------------===// // MIR2VecVocabLegacyAnalysis Implementation //===----------------------------------------------------------------------===// @@ -261,7 +293,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { } //===----------------------------------------------------------------------===// -// Printer Passes Implementation +// MIREmbedder and its subclasses +//===----------------------------------------------------------------------===// + +std::unique_ptr MIREmbedder::create(MIR2VecKind Mode, + const MachineFunction &MF, + const MIRVocabulary &Vocab) { + switch (Mode) { + case MIR2VecKind::Symbolic: + return std::make_unique(MF, Vocab); + } + return nullptr; +} + +Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const { + Embedding MBBVector(Dimension, 0); + + // Get instruction info for opcode name resolution + const auto &Subtarget = MF.getSubtarget(); + const auto *TII = Subtarget.getInstrInfo(); + if (!TII) { + MF.getFunction().getContext().emitError( + "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings"); + return MBBVector; + } + + // Process each machine instruction in the basic block + for (const auto &MI : MBB) { + // Skip debug instructions and other metadata + if (MI.isDebugInstr()) + continue; + MBBVector += computeEmbeddings(MI); + } + + return MBBVector; +} + +Embedding MIREmbedder::computeEmbeddings() const { + Embedding MFuncVector(Dimension, 0); + + // Consider all reachable machine basic blocks in the function + for (const auto *MBB : depth_first(&MF)) + MFuncVector += computeEmbeddings(*MBB); + return MFuncVector; +} + +SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF, + const MIRVocabulary &Vocab) + : MIREmbedder(MF, Vocab) {} + +std::unique_ptr +SymbolicMIREmbedder::create(const MachineFunction &MF, + const MIRVocabulary &Vocab) { + return std::make_unique(MF, Vocab); +} + +Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const { + // Skip debug instructions and other metadata + if (MI.isDebugInstr()) + return Embedding(Dimension, 0); + + // Todo: Add operand/argument contributions + + return Vocab[MI.getOpcode()]; +} + +//===----------------------------------------------------------------------===// +// Printer Passes //===----------------------------------------------------------------------===// char MIR2VecVocabPrinterLegacyPass::ID = 0; @@ -300,3 +398,56 @@ MachineFunctionPass * llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) { return new MIR2VecVocabPrinterLegacyPass(OS); } + +char MIR2VecPrinterLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec", + "MIR2Vec Embedder Printer Pass", false, true) +INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec", + "MIR2Vec Embedder Printer Pass", false, true) + +bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { + auto &Analysis = getAnalysis(); + auto VocabOrErr = + Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent()); + assert(VocabOrErr && "Failed to get MIR2Vec vocabulary"); + auto &MIRVocab = *VocabOrErr; + + auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab); + if (!Emb) { + OS << "Error creating MIR2Vec embeddings for function " << MF.getName() + << "\n"; + return false; + } + + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; + OS << "Machine Function vector: "; + Emb->getMFunctionVector().print(OS); + + OS << "Machine basic block vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + OS << "Machine basic block: " << MBB.getFullName() << ":\n"; + Emb->getMBBVector(MBB).print(OS); + } + + OS << "Machine instruction vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + // Skip debug instructions as they are not + // embedded + if (MI.isDebugInstr()) + continue; + + OS << "Machine instruction: "; + MI.print(OS); + Emb->getMInstVector(MI).print(OS); + } + } + + return false; +} + +MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) { + return new MIR2VecPrinterLegacyPass(OS); +} diff --git a/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json new file mode 100644 index 0000000000000..5de715bf80917 --- /dev/null +++ b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json @@ -0,0 +1,22 @@ +{ + "entities": { + "KILL": [0.1, 0.2, 0.3], + "MOV": [0.4, 0.5, 0.6], + "LEA": [0.7, 0.8, 0.9], + "RET": [1.0, 1.1, 1.2], + "ADD": [1.3, 1.4, 1.5], + "SUB": [1.6, 1.7, 1.8], + "IMUL": [1.9, 2.0, 2.1], + "AND": [2.2, 2.3, 2.4], + "OR": [2.5, 2.6, 2.7], + "XOR": [2.8, 2.9, 3.0], + "CMP": [3.1, 3.2, 3.3], + "TEST": [3.4, 3.5, 3.6], + "JMP": [3.7, 3.8, 3.9], + "CALL": [4.0, 4.1, 4.2], + "PUSH": [4.3, 4.4, 4.5], + "POP": [4.6, 4.7, 4.8], + "NOP": [4.9, 5.0, 5.1], + "COPY": [5.2, 5.3, 5.4] + } +} \ No newline at end of file diff --git a/llvm/test/CodeGen/MIR2Vec/if-else.mir b/llvm/test/CodeGen/MIR2Vec/if-else.mir new file mode 100644 index 0000000000000..2accf476f7c4d --- /dev/null +++ b/llvm/test/CodeGen/MIR2Vec/if-else.mir @@ -0,0 +1,144 @@ +# REQUIRES: x86_64-linux +# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s + +--- | + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + + define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) { + entry: + %retval = alloca i32, align 4 + %a.addr = alloca i32, align 4 + %b.addr = alloca i32, align 4 + store i32 %a, ptr %a.addr, align 4 + store i32 %b, ptr %b.addr, align 4 + %0 = load i32, ptr %a.addr, align 4 + %1 = load i32, ptr %b.addr, align 4 + %cmp = icmp sgt i32 %0, %1 + br i1 %cmp, label %if.then, label %if.else + + if.then: ; preds = %entry + %2 = load i32, ptr %b.addr, align 4 + store i32 %2, ptr %retval, align 4 + br label %return + + if.else: ; preds = %entry + %3 = load i32, ptr %a.addr, align 4 + store i32 %3, ptr %retval, align 4 + br label %return + + return: ; preds = %if.else, %if.then + %4 = load i32, ptr %retval, align 4 + ret i32 %4 + } +... +--- +name: abc +alignment: 16 +exposesReturnsTwice: false +legalized: false +regBankSelected: false +selected: false +failedISel: false +tracksRegLiveness: true +hasWinCFI: false +noPhis: false +isSSA: true +noVRegs: false +hasFakeUses: false +callsEHReturn: false +callsUnwindInit: false +hasEHContTarget: false +hasEHScopes: false +hasEHFunclets: false +isOutlined: false +debugInstrRef: true +failsVerification: false +tracksDebugUserValues: false +registers: + - { id: 0, class: gr32, preferred-register: '', flags: [ ] } + - { id: 1, class: gr32, preferred-register: '', flags: [ ] } + - { id: 2, class: gr32, preferred-register: '', flags: [ ] } + - { id: 3, class: gr32, preferred-register: '', flags: [ ] } + - { id: 4, class: gr32, preferred-register: '', flags: [ ] } + - { id: 5, class: gr32, preferred-register: '', flags: [ ] } +liveins: + - { reg: '$edi', virtual-reg: '%0' } + - { reg: '$esi', virtual-reg: '%1' } +frameInfo: + isFrameAddressTaken: false + isReturnAddressTaken: false + hasStackMap: false + hasPatchPoint: false + stackSize: 0 + offsetAdjustment: 0 + maxAlignment: 4 + adjustsStack: false + hasCalls: false + stackProtector: '' + functionContext: '' + maxCallFrameSize: 4294967295 + cvBytesOfCalleeSavedRegisters: 0 + hasOpaqueSPAdjustment: false + hasVAStart: false + hasMustTailInVarArgFunc: false + hasTailCall: false + isCalleeSavedInfoValid: false + localFrameSize: 0 +fixedStack: [] +stack: + - { id: 0, name: retval, type: default, offset: 0, size: 4, alignment: 4, + stack-id: default, callee-saved-register: '', callee-saved-restored: true, + debug-info-variable: '', debug-info-expression: '', debug-info-location: '' } + - { id: 1, name: a.addr, type: default, offset: 0, size: 4, alignment: 4, + stack-id: default, callee-saved-register: '', callee-saved-restored: true, + debug-info-variable: '', debug-info-expression: '', debug-info-location: '' } + - { id: 2, name: b.addr, type: default, offset: 0, size: 4, alignment: 4, + stack-id: default, callee-saved-register: '', callee-saved-restored: true, + debug-info-variable: '', debug-info-expression: '', debug-info-location: '' } +entry_values: [] +callSites: [] +debugValueSubstitutions: [] +constants: [] +machineFunctionInfo: + amxProgModel: None +body: | + bb.0.entry: + successors: %bb.1(0x40000000), %bb.2(0x40000000) + liveins: $edi, $esi + + %1:gr32 = COPY $esi + %0:gr32 = COPY $edi + MOV32mr %stack.1.a.addr, 1, $noreg, 0, $noreg, %0 :: (store (s32) into %ir.a.addr) + MOV32mr %stack.2.b.addr, 1, $noreg, 0, $noreg, %1 :: (store (s32) into %ir.b.addr) + %2:gr32 = SUB32rr %0, %1, implicit-def $eflags + JCC_1 %bb.2, 14, implicit $eflags + JMP_1 %bb.1 + + bb.1.if.then: + successors: %bb.3(0x80000000) + + %4:gr32 = MOV32rm %stack.2.b.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.b.addr) + MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %4 :: (store (s32) into %ir.retval) + JMP_1 %bb.3 + + bb.2.if.else: + successors: %bb.3(0x80000000) + + %3:gr32 = MOV32rm %stack.1.a.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.a.addr) + MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %3 :: (store (s32) into %ir.retval) + + bb.3.return: + %5:gr32 = MOV32rm %stack.0.retval, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.retval) + $eax = COPY %5 + RET 0, $eax +... + +# CHECK: Machine basic block vectors: +# CHECK-NEXT: Machine basic block: abc:entry: +# CHECK-NEXT: [ 16.50 17.10 17.70 ] +# CHECK-NEXT: Machine basic block: abc:if.then: +# CHECK-NEXT: [ 4.50 4.80 5.10 ] +# CHECK-NEXT: Machine basic block: abc:if.else: +# CHECK-NEXT: [ 0.80 1.00 1.20 ] +# CHECK-NEXT: Machine basic block: abc:return: +# CHECK-NEXT: [ 6.60 6.90 7.20 ] \ No newline at end of file diff --git a/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir new file mode 100644 index 0000000000000..44240affb2206 --- /dev/null +++ b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir @@ -0,0 +1,76 @@ +# REQUIRES: x86_64-linux +# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s + +--- | + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + + define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) { + entry: + %sum = add nsw i32 %a, %b + %result = mul nsw i32 %sum, 2 + ret i32 %result + } + + define dso_local void @simple_function() { + entry: + ret void + } +... +--- +name: add_function +alignment: 16 +tracksRegLiveness: true +registers: + - { id: 0, class: gr32 } + - { id: 1, class: gr32 } + - { id: 2, class: gr32 } + - { id: 3, class: gr32 } +liveins: + - { reg: '$edi', virtual-reg: '%0' } + - { reg: '$esi', virtual-reg: '%1' } +body: | + bb.0.entry: + liveins: $edi, $esi + + %1:gr32 = COPY $esi + %0:gr32 = COPY $edi + %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags + %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags + $eax = COPY %3 + RET 0, $eax + +--- +name: simple_function +alignment: 16 +tracksRegLiveness: true +body: | + bb.0.entry: + RET 0 + +# CHECK: MIR2Vec embeddings for machine function add_function: +# CHECK: Function vector: [ 19.20 19.80 20.40 ] +# CHECK-NEXT: Machine basic block vectors: +# CHECK-NEXT: Machine basic block: add_function:entry: +# CHECK-NEXT: [ 19.20 19.80 20.40 ] +# CHECK-NEXT: Machine instruction vectors: +# CHECK-NEXT: Machine instruction: %1:gr32 = COPY $esi +# CHECK-NEXT: [ 5.20 5.30 5.40 ] +# CHECK-NEXT: Machine instruction: %0:gr32 = COPY $edi +# CHECK-NEXT: [ 5.20 5.30 5.40 ] +# CHECK-NEXT: Machine instruction: %2:gr32 = nsw ADD32rr %0:gr32(tied-def 0), %1:gr32, implicit-def dead $eflags +# CHECK-NEXT: [ 1.30 1.40 1.50 ] +# CHECK-NEXT: Machine instruction: %3:gr32 = ADD32rr %2:gr32(tied-def 0), %2:gr32, implicit-def dead $eflags +# CHECK-NEXT: [ 1.30 1.40 1.50 ] +# CHECK-NEXT: Machine instruction: $eax = COPY %3:gr32 +# CHECK-NEXT: [ 5.20 5.30 5.40 ] +# CHECK-NEXT: Machine instruction: RET 0, $eax +# CHECK-NEXT: [ 1.00 1.10 1.20 ] + +# CHECK: MIR2Vec embeddings for machine function simple_function: +# CHECK-NEXT:Function vector: [ 1.00 1.10 1.20 ] +# CHECK-NEXT: Machine basic block vectors: +# CHECK-NEXT: Machine basic block: simple_function:entry: +# CHECK-NEXT: [ 1.00 1.10 1.20 ] +# CHECK-NEXT: Machine instruction vectors: +# CHECK-NEXT: Machine instruction: RET 0 +# CHECK-NEXT: [ 1.00 1.10 1.20 ] \ No newline at end of file diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp index f04b256e2e6c9..f4441ccb896b1 100644 --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -172,6 +172,11 @@ static cl::opt cl::desc("Print MIR2Vec vocabulary contents"), cl::init(false)); +static cl::opt + PrintMIR2Vec("print-mir2vec", cl::Hidden, + cl::desc("Print MIR2Vec embeddings for functions"), + cl::init(false)); + static cl::list IncludeDirs("I", cl::desc("include search path")); static cl::opt RemarksWithHotness( @@ -776,6 +781,11 @@ static int compileModule(char **argv, LLVMContext &Context) { PM.add(createMIR2VecVocabPrinterLegacyPass(errs())); } + // Add MIR2Vec printer if requested + if (PrintMIR2Vec) { + PM.add(createMIR2VecPrinterLegacyPass(errs())); + } + PM.add(createFreeMachineFunctionPass()); } else { if (Target->addPassesToEmitFile(PM, *OS, DwoOut ? &DwoOut->os() : nullptr, @@ -789,6 +799,11 @@ static int compileModule(char **argv, LLVMContext &Context) { if (PrintMIR2VecVocab) { PM.add(createMIR2VecVocabPrinterLegacyPass(errs())); } + + // Add MIR2Vec printer if requested + if (PrintMIR2Vec) { + PM.add(createMIR2VecPrinterLegacyPass(errs())); + } } Target->getObjFileLowering()->Initialize(MMIWP->getMMI().getContext(), diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 11222b4d02fa3..8cd9d5ac9f6be 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -82,6 +82,9 @@ class MIR2VecVocabTestFixture : public ::testing::Test { return; } + // Set the data layout to match the target machine + M->setDataLayout(TM->createDataLayout()); + // Create a dummy function to get subtarget info FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *F = @@ -96,16 +99,27 @@ class MIR2VecVocabTestFixture : public ::testing::Test { } void TearDown() override { TII = nullptr; } -}; -// Function to find an opcode by name -static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { - for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { - if (TII->getName(Opcode) == Name) - return Opcode; + // Find an opcode by name + int findOpcodeByName(StringRef Name) { + for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { + if (TII->getName(Opcode) == Name) + return Opcode; + } + return -1; // Not found } - return -1; // Not found -} + + // Create a vocabulary with specific opcodes and embeddings + Expected + createTestVocab(std::initializer_list> opcodes, + unsigned dimension = 2) { + assert(TII && "TargetInstrInfo not initialized"); + VocabMap VMap; + for (const auto &[name, value] : opcodes) + VMap[name] = Embedding(dimension, value); + return MIRVocabulary::create(std::move(VMap), *TII); + } +}; TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices @@ -118,10 +132,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VMap; Embedding Val = Embedding(64, 1.0f); - VMap["ADD"] = Val; - auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64); ASSERT_TRUE(static_cast(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); @@ -156,16 +168,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors - int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); + int Add32rrOpcode = findOpcodeByName("ADD32rr"); ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); - int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); + int Sub32rrOpcode = findOpcodeByName("SUB32rr"); ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; EXPECT_TRUE( TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); - int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); + int Mov32rrOpcode = findOpcodeByName("MOV32rr"); ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; EXPECT_TRUE( TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); @@ -178,9 +190,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VMap; - VMap["ADD"] = Embedding(64, 1.0f); - auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64); ASSERT_TRUE(static_cast(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); @@ -189,8 +199,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName); - - EXPECT_EQ(Index1, Index2); EXPECT_EQ(Index2, Index3); // Test across multiple runs @@ -202,11 +210,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - VocabMap VMap; - VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 - VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - - auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; @@ -243,4 +247,247 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { } } +// Fixture for embedding related tests +class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture { +protected: + std::unique_ptr MMI; + MachineFunction *MF = nullptr; + + void SetUp() override { + MIR2VecVocabTestFixture::SetUp(); + + // Create a dummy function for MachineFunction + FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); + Function *F = + Function::Create(FT, Function::ExternalLinkage, "test", M.get()); + + MMI = std::make_unique(TM.get()); + MF = &MMI->getOrCreateMachineFunction(*F); + } + + void TearDown() override { MIR2VecVocabTestFixture::TearDown(); } + + // Create a machine instruction + MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) { + const MCInstrDesc &Desc = TII->get(Opcode); + // Create instruction - operands don't affect opcode-based embeddings + MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc); + return MI; + } + + MachineInstr *createMachineInstr(MachineBasicBlock &MBB, + const char *OpcodeName) { + int Opcode = findOpcodeByName(OpcodeName); + if (Opcode == -1) + return nullptr; + return createMachineInstr(MBB, Opcode); + } + + void createMachineInstrs(MachineBasicBlock &MBB, + std::initializer_list Opcodes) { + for (const char *OpcodeName : Opcodes) { + MachineInstr *MI = createMachineInstr(MBB, OpcodeName); + ASSERT_TRUE(MI != nullptr); + } + } +}; + +// Test factory method for creating embedder +TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) { + auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &V = *VocabOrErr; + auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V); + EXPECT_NE(Emb, nullptr); +} + +TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) { + auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &V = *VocabOrErr; + auto Result = MIREmbedder::create(static_cast(-1), *MF, V); + EXPECT_FALSE(static_cast(Result)); +} + +// Test SymbolicMIREmbedder with simple target opcodes +TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) { + // Create a test vocabulary with specific values + auto VocabOrErr = createTestVocab( + { + {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0] + {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0] + {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0] + }, + 4); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + // Create a basic block using fixture's MF + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Use real X86 opcodes that should exist and not be pseudo + auto NoopInst = createMachineInstr(*MBB, "NOOP"); + ASSERT_TRUE(NoopInst != nullptr); + + auto RetInst = createMachineInstr(*MBB, "RET64"); + ASSERT_TRUE(RetInst != nullptr); + + auto TrapInst = createMachineInstr(*MBB, "TRAP"); + ASSERT_TRUE(TrapInst != nullptr); + + // Verify these are not pseudo instructions + ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction"; + ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction"; + ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction"; + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test instruction embeddings + auto NoopEmb = Embedder->getMInstVector(*NoopInst); + auto RetEmb = Embedder->getMInstVector(*RetInst); + auto TrapEmb = Embedder->getMInstVector(*TrapInst); + + // Verify embeddings match expected values (accounting for weight scaling) + float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line + EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight))); + EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight))); + EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight))); + + // Test basic block embedding (should be sum of instruction embeddings) + auto MBBVector = Embedder->getMBBVector(*MBB); + + // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] * + // weight = [6, 6, 6, 6] * weight + Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight); + EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector)); + + // Test function embedding (should equal MBB embedding since we have one MBB) + auto MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector)); +} + +// Test embedder with multiple basic blocks +TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) { + // Create a test vocabulary + auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + // Create two basic blocks using fixture's MF + MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock(); + MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock(); + MF->push_back(MBB1); + MF->push_back(MBB2); + + createMachineInstrs(*MBB1, {"NOOP", "NOOP"}); + createMachineInstr(*MBB2, "TRAP"); + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test basic block embeddings + auto MBB1Vector = Embedder->getMBBVector(*MBB1); + auto MBB2Vector = Embedder->getMBBVector(*MBB2); + + float ExpectedWeight = mir2vec::OpcWeight; + // BB1: NOOP + NOOP = 2 * ([1, 1] * weight) + Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight); + EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector)); + + // BB2: TRAP = [2, 2] * weight + Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight); + EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector)); + + // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight + // Function embedding should be just the first BB embedding as the second BB + // is unreachable + auto MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector)); + + // Add a branch from BB1 to BB2 to make both reachable; now function embedding + // should be MBB1 + MBB2 + MBB1->addSuccessor(MBB2); + auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings + Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector; + EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector)); +} + +// Test embedder with empty basic block +TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) { + + // Create an empty basic block + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Create embedder + auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &V = *VocabOrErr; + auto Embedder = SymbolicMIREmbedder::create(*MF, V); + ASSERT_TRUE(Embedder != nullptr); + + // Test that empty BB has zero embedding + auto MBBVector = Embedder->getMBBVector(*MBB); + Embedding ExpectedBBVector(2, 0.0f); + EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); + + // Function embedding should also be zero + auto MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector)); +} + +// Test embedder with opcodes not in vocabulary +TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) { + // Create a test vocabulary with limited entries + // SUB is intentionally not included + auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + // Create a basic block + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Find opcodes + int AddOpcode = findOpcodeByName("ADD32rr"); + int SubOpcode = findOpcodeByName("SUB32rr"); + + ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found"; + ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found"; + + // Create instructions + MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode); + MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode); + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test instruction embeddings + auto AddVector = Embedder->getMInstVector(*AddInstr); + auto SubVector = Embedder->getMInstVector(*SubInstr); + + float ExpectedWeight = mir2vec::OpcWeight; + // ADD should have the embedding from vocabulary + EXPECT_TRUE( + AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight))); + + // SUB should have zero embedding (not in vocabulary) + EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f))); + + // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0, + // 0.0] = [1.0, 1.0] * weight + const auto &MBBVector = Embedder->getMBBVector(*MBB); + Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight); + EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); +} } // namespace