diff --git a/llvm/docs/CommandGuide/llvm-ir2vec.rst b/llvm/docs/CommandGuide/llvm-ir2vec.rst index fc590a6180316..55fe75d2084b1 100644 --- a/llvm/docs/CommandGuide/llvm-ir2vec.rst +++ b/llvm/docs/CommandGuide/llvm-ir2vec.rst @@ -1,5 +1,5 @@ -llvm-ir2vec - IR2Vec Embedding Generation Tool -============================================== +llvm-ir2vec - IR2Vec and MIR2Vec Embedding Generation Tool +=========================================================== .. program:: llvm-ir2vec @@ -11,9 +11,9 @@ SYNOPSIS DESCRIPTION ----------- -:program:`llvm-ir2vec` is a standalone command-line tool for IR2Vec. It -generates IR2Vec embeddings for LLVM IR and supports triplet generation -for vocabulary training. +:program:`llvm-ir2vec` is a standalone command-line tool for IR2Vec and MIR2Vec. +It generates embeddings for both LLVM IR and Machine IR (MIR) and supports +triplet generation for vocabulary training. The tool provides three main subcommands: @@ -23,23 +23,33 @@ The tool provides three main subcommands: 2. **entities**: Generates entity mapping files (entity2id.txt) for vocabulary training. -3. **embeddings**: Generates IR2Vec embeddings using a trained vocabulary +3. **embeddings**: Generates IR2Vec or MIR2Vec embeddings using a trained vocabulary at different granularity levels (instruction, basic block, or function). +The tool supports two operation modes: + +* **LLVM IR mode** (``--mode=llvm``): Process LLVM IR bitcode files and generate + IR2Vec embeddings +* **Machine IR mode** (``--mode=mir``): Process Machine IR (.mir) files and generate + MIR2Vec embeddings + The tool is designed to facilitate machine learning applications that work with -LLVM IR by converting the IR into numerical representations that can be used by -ML models. The `triplets` subcommand generates numeric IDs directly instead of string -triplets, streamlining the training data preparation workflow. +LLVM IR or Machine IR by converting them into numerical representations that can +be used by ML models. The `triplets` subcommand generates numeric IDs directly +instead of string triplets, streamlining the training data preparation workflow. .. note:: - For information about using IR2Vec programmatically within LLVM passes and - the C++ API, see the `IR2Vec Embeddings `_ + For information about using IR2Vec and MIR2Vec programmatically within LLVM + passes and the C++ API, see the `IR2Vec Embeddings `_ section in the MLGO documentation. OPERATION MODES --------------- +The tool operates in two modes: **LLVM IR mode** and **Machine IR mode**. The mode +is selected using the ``--mode`` option (default: ``llvm``). + Triplet Generation and Entity Mapping Modes are used for preparing vocabulary and training data for knowledge graph embeddings. The Embedding Mode is used for generating embeddings from LLVM IR using a pre-trained vocabulary. @@ -89,18 +99,31 @@ Embedding Generation ~~~~~~~~~~~~~~~~~~~~ With the `embeddings` subcommand, :program:`llvm-ir2vec` uses a pre-trained vocabulary to -generate numerical embeddings for LLVM IR at different levels of granularity. +generate numerical embeddings for LLVM IR or Machine IR at different levels of granularity. + +Example Usage for LLVM IR: + +.. code-block:: bash + + llvm-ir2vec embeddings --mode=llvm --ir2vec-vocab-path=vocab.json --ir2vec-kind=symbolic --level=func input.bc -o embeddings.txt -Example Usage: +Example Usage for Machine IR: .. code-block:: bash - llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json --ir2vec-kind=symbolic --level=func input.bc -o embeddings.txt + llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=vocab.json --level=func input.mir -o embeddings.txt OPTIONS ------- -Global options: +Common options (applicable to both LLVM IR and Machine IR modes): + +.. option:: --mode= + + Specify the operation mode. Valid values are: + + * ``llvm`` - Process LLVM IR bitcode files (default) + * ``mir`` - Process Machine IR (.mir) files .. option:: -o @@ -116,8 +139,8 @@ Subcommand-specific options: .. option:: - The input LLVM IR or bitcode file to process. This positional argument is - required for the `embeddings` subcommand. + The input LLVM IR/bitcode file (.ll/.bc) or Machine IR file (.mir) to process. + This positional argument is required for the `embeddings` subcommand. .. option:: --level= @@ -131,6 +154,8 @@ Subcommand-specific options: Process only the specified function instead of all functions in the module. +**IR2Vec-specific options** (for ``--mode=llvm``): + .. option:: --ir2vec-kind= Specify the kind of IR2Vec embeddings to generate. Valid values are: @@ -143,8 +168,8 @@ Subcommand-specific options: .. option:: --ir2vec-vocab-path= - Specify the path to the vocabulary file (required for embedding generation). - The vocabulary file should be in JSON format and contain the trained + Specify the path to the IR2Vec vocabulary file (required for LLVM IR embedding + generation). The vocabulary file should be in JSON format and contain the trained vocabulary for embedding generation. See `llvm/lib/Analysis/models` for pre-trained vocabulary files. @@ -163,6 +188,35 @@ Subcommand-specific options: Specify the weight for argument embeddings (default: 0.2). This controls the relative importance of operand information in the final embedding. +**MIR2Vec-specific options** (for ``--mode=mir``): + +.. option:: --mir2vec-vocab-path= + + Specify the path to the MIR2Vec vocabulary file (required for Machine IR + embedding generation). The vocabulary file should be in JSON format and + contain the trained vocabulary for embedding generation. + +.. option:: --mir2vec-kind= + + Specify the kind of MIR2Vec embeddings to generate. Valid values are: + + * ``symbolic`` - Generate symbolic embeddings (default) + +.. option:: --mir2vec-opc-weight= + + Specify the weight for machine opcode embeddings (default: 1.0). This controls + the relative importance of machine instruction opcodes in the final embedding. + +.. option:: --mir2vec-common-operand-weight= + + Specify the weight for common operand embeddings (default: 1.0). This controls + the relative importance of common operand types in the final embedding. + +.. option:: --mir2vec-reg-operand-weight= + + Specify the weight for register operand embeddings (default: 1.0). This controls + the relative importance of register operands in the final embedding. + **triplets** subcommand: @@ -240,3 +294,6 @@ SEE ALSO For more information about the IR2Vec algorithm and approach, see: `IR2Vec: LLVM IR Based Scalable Program Embeddings `_. + +For more information about the MIR2Vec algorithm and approach, see: +`RL4ReAl: Reinforcement Learning for Register Allocation `_. diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index 953e590a6d64f..f47d9abb042d8 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -7,9 +7,20 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file defines the MIR2Vec vocabulary -/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder -/// interface for generating Machine IR embeddings, and related utilities. +/// This file defines the MIR2Vec framework for generating Machine IR +/// embeddings. +/// +/// Architecture Overview: +/// ---------------------- +/// 1. MIR2VecVocabProvider - Core vocabulary loading logic (no PM dependency) +/// - Can be used standalone or wrapped by the pass manager +/// - Requires MachineModuleInfo with parsed machine functions +/// +/// 2. MIR2VecVocabLegacyAnalysis - Pass manager wrapper (ImmutablePass) +/// - Integrated and used by llc -print-mir2vec +/// +/// 3. MIREmbedder - Generates embeddings from vocabulary +/// - SymbolicMIREmbedder: MIR2Vec embedding implementation /// /// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the /// LLVM Machine IR as embeddings which can be used as input to machine learning @@ -306,26 +317,58 @@ class SymbolicMIREmbedder : public MIREmbedder { } // namespace mir2vec +/// MIR2Vec vocabulary provider used by pass managers and standalone tools. +/// This class encapsulates the core vocabulary loading logic and can be used +/// independently of the pass manager infrastructure. For pass-based usage, +/// see MIR2VecVocabLegacyAnalysis. +/// +/// Note: This provider pattern makes new PM migration straightforward when +/// needed. A new PM analysis wrapper can be added that delegates to this +/// provider, similar to how MIR2VecVocabLegacyAnalysis currently wraps it. +class MIR2VecVocabProvider { + using VocabMap = std::map; + +public: + MIR2VecVocabProvider(const MachineModuleInfo &MMI) : MMI(MMI) {} + + Expected getVocabulary(const Module &M); + +private: + Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab, + VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap); + const MachineModuleInfo &MMI; +}; + /// Pass to analyze and populate MIR2Vec vocabulary from a module class MIR2VecVocabLegacyAnalysis : public ImmutablePass { using VocabVector = std::vector; using VocabMap = std::map; - std::optional Vocab; StringRef getPassName() const override; - Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab, - VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap); protected: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.setPreservesAll(); } + std::unique_ptr Provider; public: static char ID; MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {} - Expected getMIR2VecVocabulary(const Module &M); + + Expected getMIR2VecVocabulary(const Module &M) { + MachineModuleInfo &MMI = + getAnalysis().getMMI(); + if (!Provider) + Provider = std::make_unique(MMI); + return Provider->getVocabulary(M); + } + + MIR2VecVocabProvider &getProvider() { + assert(Provider && "Provider not initialized"); + return *Provider; + } }; /// This pass prints the embeddings in the MIR2Vec vocabulary diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 716221101af9f..69c1e28e55e3b 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -412,24 +412,39 @@ Expected MIRVocabulary::createDummyVocabForTest( } //===----------------------------------------------------------------------===// -// MIR2VecVocabLegacyAnalysis Implementation +// MIR2VecVocabProvider and MIR2VecVocabLegacyAnalysis //===----------------------------------------------------------------------===// -char MIR2VecVocabLegacyAnalysis::ID = 0; -INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", - "MIR2Vec Vocabulary Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) -INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", - "MIR2Vec Vocabulary Analysis", false, true) +Expected +MIR2VecVocabProvider::getVocabulary(const Module &M) { + VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap; -StringRef MIR2VecVocabLegacyAnalysis::getPassName() const { - return "MIR2Vec Vocabulary Analysis"; + if (Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap, + VirtRegVocabMap)) + return std::move(Err); + + for (const auto &F : M) { + if (F.isDeclaration()) + continue; + + if (auto *MF = MMI.getMachineFunction(F)) { + auto &Subtarget = MF->getSubtarget(); + if (const auto *TII = Subtarget.getInstrInfo()) + if (const auto *TRI = Subtarget.getRegisterInfo()) + return mir2vec::MIRVocabulary::create( + std::move(OpcVocab), std::move(CommonOperandVocab), + std::move(PhyRegVocabMap), std::move(VirtRegVocabMap), *TII, *TRI, + MF->getRegInfo()); + } + } + return createStringError(errc::invalid_argument, + "No machine functions found in module"); } -Error MIR2VecVocabLegacyAnalysis::readVocabulary(VocabMap &OpcodeVocab, - VocabMap &CommonOperandVocab, - VocabMap &PhyRegVocabMap, - VocabMap &VirtRegVocabMap) { +Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab, + VocabMap &CommonOperandVocab, + VocabMap &PhyRegVocabMap, + VocabMap &VirtRegVocabMap) { if (VocabFile.empty()) return createStringError( errc::invalid_argument, @@ -478,49 +493,15 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary(VocabMap &OpcodeVocab, return Error::success(); } -Expected -MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { - if (Vocab.has_value()) - return std::move(Vocab.value()); - - VocabMap OpcMap, CommonOperandMap, PhyRegMap, VirtRegMap; - if (Error Err = - readVocabulary(OpcMap, CommonOperandMap, PhyRegMap, VirtRegMap)) - return std::move(Err); - - // Get machine module info to access machine functions and target info - MachineModuleInfo &MMI = getAnalysis().getMMI(); - - // Find first available machine function to get target instruction info - for (const auto &F : M) { - if (F.isDeclaration()) - continue; - - if (auto *MF = MMI.getMachineFunction(F)) { - auto &Subtarget = MF->getSubtarget(); - const TargetInstrInfo *TII = Subtarget.getInstrInfo(); - if (!TII) { - return createStringError(errc::invalid_argument, - "No TargetInstrInfo available; cannot create " - "MIR2Vec vocabulary"); - } - - const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); - if (!TRI) { - return createStringError(errc::invalid_argument, - "No TargetRegisterInfo available; cannot " - "create MIR2Vec vocabulary"); - } - - return mir2vec::MIRVocabulary::create( - std::move(OpcMap), std::move(CommonOperandMap), std::move(PhyRegMap), - std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo()); - } - } +char MIR2VecVocabLegacyAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", + "MIR2Vec Vocabulary Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", + "MIR2Vec Vocabulary Analysis", false, true) - // No machine functions available - return error - return createStringError(errc::invalid_argument, - "No machine functions found in module"); +StringRef MIR2VecVocabLegacyAnalysis::getPassName() const { + return "MIR2Vec Vocabulary Analysis"; } //===----------------------------------------------------------------------===// diff --git a/llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir b/llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir new file mode 100644 index 0000000000000..e5f78bfd2090e --- /dev/null +++ b/llvm/test/tools/llvm-ir2vec/embeddings-symbolic.mir @@ -0,0 +1,92 @@ +# REQUIRES: x86_64-linux +# RUN: llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT +# RUN: llvm-ir2vec embeddings --mode=mir --level=func --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL +# RUN: llvm-ir2vec embeddings --mode=mir --level=func --function=add_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ADD +# RUN: not llvm-ir2vec embeddings --mode=mir --level=func --function=missing_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-MISSING +# RUN: llvm-ir2vec embeddings --mode=mir --level=bb --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL +# RUN: llvm-ir2vec embeddings --mode=mir --level=inst --function=add_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL + +--- | + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + target triple = "x86_64-unknown-linux-gnu" + + define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) { + entry: + %sum = add nsw i32 %a, %b + %result = mul nsw i32 %sum, 2 + ret i32 %result + } + + define dso_local void @simple_function() { + entry: + ret void + } +... +--- +name: add_function +alignment: 16 +tracksRegLiveness: true +registers: + - { id: 0, class: gr32 } + - { id: 1, class: gr32 } + - { id: 2, class: gr32 } + - { id: 3, class: gr32 } +liveins: + - { reg: '$edi', virtual-reg: '%0' } + - { reg: '$esi', virtual-reg: '%1' } +body: | + bb.0.entry: + liveins: $edi, $esi + + %1:gr32 = COPY $esi + %0:gr32 = COPY $edi + %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags + %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags + $eax = COPY %3 + RET 0, $eax + +--- +name: simple_function +alignment: 16 +tracksRegLiveness: true +body: | + bb.0.entry: + RET 0 + +# CHECK-DEFAULT: MIR2Vec embeddings for machine function add_function: +# CHECK-DEFAULT-NEXT: Function vector: [ 26.50 27.10 27.70 ] +# CHECK-DEFAULT: MIR2Vec embeddings for machine function simple_function: +# CHECK-DEFAULT-NEXT: Function vector: [ 1.10 1.20 1.30 ] + +# CHECK-FUNC-LEVEL: MIR2Vec embeddings for machine function add_function: +# CHECK-FUNC-LEVEL-NEXT: Function vector: [ 26.50 27.10 27.70 ] +# CHECK-FUNC-LEVEL: MIR2Vec embeddings for machine function simple_function: +# CHECK-FUNC-LEVEL-NEXT: Function vector: [ 1.10 1.20 1.30 ] + +# CHECK-FUNC-LEVEL-ADD: MIR2Vec embeddings for machine function add_function: +# CHECK-FUNC-LEVEL-ADD-NEXT: Function vector: [ 26.50 27.10 27.70 ] +# CHECK-FUNC-LEVEL-ADD-NOT: simple_function + +# CHECK-FUNC-MISSING: Error: Function 'missing_function' not found + +# CHECK-BB-LEVEL: MIR2Vec embeddings for machine function add_function: +# CHECK-BB-LEVEL-NEXT: Basic block vectors: +# CHECK-BB-LEVEL-NEXT: MBB entry: [ 26.50 27.10 27.70 ] +# CHECK-BB-LEVEL: MIR2Vec embeddings for machine function simple_function: +# CHECK-BB-LEVEL-NEXT: Basic block vectors: +# CHECK-BB-LEVEL-NEXT: MBB entry: [ 1.10 1.20 1.30 ] + +# CHECK-INST-LEVEL: MIR2Vec embeddings for machine function add_function: +# CHECK-INST-LEVEL-NEXT: Instruction vectors: +# CHECK-INST-LEVEL: %1:gr32 = COPY $esi +# CHECK-INST-LEVEL-NEXT: -> [ 6.00 6.10 6.20 ] +# CHECK-INST-LEVEL-NEXT: %0:gr32 = COPY $edi +# CHECK-INST-LEVEL-NEXT: -> [ 6.00 6.10 6.20 ] +# CHECK-INST-LEVEL: %2:gr32 = nsw ADD32rr +# CHECK-INST-LEVEL: -> [ 3.70 3.80 3.90 ] +# CHECK-INST-LEVEL: %3:gr32 = ADD32rr +# CHECK-INST-LEVEL: -> [ 3.70 3.80 3.90 ] +# CHECK-INST-LEVEL: $eax = COPY %3:gr32 +# CHECK-INST-LEVEL-NEXT: -> [ 6.00 6.10 6.20 ] +# CHECK-INST-LEVEL: RET 0, $eax +# CHECK-INST-LEVEL-NEXT: -> [ 1.10 1.20 1.30 ] diff --git a/llvm/test/tools/llvm-ir2vec/error-handling.mir b/llvm/test/tools/llvm-ir2vec/error-handling.mir new file mode 100644 index 0000000000000..154078c18d647 --- /dev/null +++ b/llvm/test/tools/llvm-ir2vec/error-handling.mir @@ -0,0 +1,41 @@ +# REQUIRES: x86_64-linux +# Test error handling and input validation for llvm-ir2vec tool in MIR mode + +# RUN: not llvm-ir2vec embeddings --mode=mir %s 2>&1 | FileCheck %s -check-prefix=CHECK-NO-VOCAB +# RUN: not llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=%S/nonexistent-vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-VOCAB-NOT-FOUND +# RUN: not llvm-ir2vec embeddings --mode=mir --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-INVALID-VOCAB +# RUN: not llvm-ir2vec embeddings --mode=mir --function=nonexistent_function --mir2vec-vocab-path=%S/../../CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-NOT-FOUND + +--- | + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + target triple = "x86_64-unknown-linux-gnu" + + define dso_local noundef i32 @test_function(i32 noundef %a) { + entry: + ret i32 %a + } +... +--- +name: test_function +alignment: 16 +tracksRegLiveness: true +registers: + - { id: 0, class: gr32 } +liveins: + - { reg: '$edi', virtual-reg: '%0' } +body: | + bb.0.entry: + liveins: $edi + + %0:gr32 = COPY $edi + $eax = COPY %0 + RET 0, $eax + +# CHECK-NO-VOCAB: Error: Failed to load MIR2Vec vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path + +# CHECK-VOCAB-NOT-FOUND: Error: Failed to load MIR2Vec vocabulary +# CHECK-VOCAB-NOT-FOUND: No such file or directory + +# CHECK-INVALID-VOCAB: Error: Failed to load MIR2Vec vocabulary - Missing 'Opcodes' section in vocabulary file + +# CHECK-FUNC-NOT-FOUND: Error: Function 'nonexistent_function' not found diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt index a4cf9690e86b5..e680144452136 100644 --- a/llvm/tools/llvm-ir2vec/CMakeLists.txt +++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt @@ -1,10 +1,24 @@ set(LLVM_LINK_COMPONENTS + # Core LLVM components for IR processing Analysis Core IRReader Support + + # Machine IR components (for -mode=mir) + CodeGen + MIRParser + + # Target initialization (required for MIR parsing) + AllTargetsAsmParsers + AllTargetsCodeGens + AllTargetsDescs + AllTargetsInfos ) add_llvm_tool(llvm-ir2vec llvm-ir2vec.cpp + + DEPENDS + intrinsics_gen ) diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 1031932116c1e..c41cf20539c0d 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -1,4 +1,4 @@ -//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===// +//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,9 +7,13 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file implements the IR2Vec embedding generation tool. +/// This file implements the IR2Vec and MIR2Vec embedding generation tool. /// -/// This tool provides three main subcommands: +/// This tool supports two modes: +/// - LLVM IR mode (-mode=llvm): Process LLVM IR +/// - Machine IR mode (-mode=mir): Process Machine IR +/// +/// Available subcommands: /// /// 1. Triplet Generation (triplets): /// Generates numeric triplets (head, tail, relation) for vocabulary @@ -23,16 +27,24 @@ /// Usage: llvm-ir2vec entities input.bc -o entity2id.txt /// /// 3. Embedding Generation (embeddings): -/// Generates IR2Vec embeddings using a trained vocabulary. -/// Usage: llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json -/// --ir2vec-kind= --level= input.bc -o embeddings.txt -/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware +/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary. +/// +/// For LLVM IR: +/// llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json +/// --ir2vec-kind= --level= input.bc -o embeddings.txt +/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware +/// +/// For Machine IR: +/// llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json +/// --level= input.mir -o embeddings.txt +/// /// Levels: --level=inst (instructions), --level=bb (basic blocks), -/// --level=func (functions) (See IR2Vec.cpp for more embedding generation -/// options) +/// --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding +/// generation options) /// //===----------------------------------------------------------------------===// +#include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/IR2Vec.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -50,10 +62,36 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/CodeGen/MIR2Vec.h" +#include "llvm/CodeGen/MIRParser/MIRParser.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/WithColor.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Host.h" + #define DEBUG_TYPE "ir2vec" namespace llvm { -namespace ir2vec { + +// Common option category for options shared between IR2Vec and MIR2Vec +static cl::OptionCategory CommonCategory("Common Options", + "Options applicable to both IR2Vec " + "and MIR2Vec modes"); + +enum IRKind { + LLVMIR = 0, ///< LLVM IR + MIR ///< Machine IR +}; + +static cl::opt + IRMode("mode", cl::desc("Tool operation mode:"), + cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"), + clEnumValN(MIR, "mir", "Process Machine IR")), + cl::init(LLVMIR), cl::cat(CommonCategory)); // Subcommands static cl::SubCommand @@ -70,18 +108,18 @@ static cl::opt InputFilename(cl::Positional, cl::desc(""), cl::init("-"), cl::sub(TripletsSubCmd), - cl::sub(EmbeddingsSubCmd), cl::cat(ir2vec::IR2VecCategory)); + cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); static cl::opt OutputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-"), - cl::cat(ir2vec::IR2VecCategory)); + cl::cat(CommonCategory)); // Embedding-specific options static cl::opt FunctionName("function", cl::desc("Process specific function only"), cl::value_desc("name"), cl::Optional, cl::init(""), - cl::sub(EmbeddingsSubCmd), cl::cat(ir2vec::IR2VecCategory)); + cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); enum EmbeddingLevel { InstructionLevel, // Generate instruction-level embeddings @@ -98,9 +136,9 @@ static cl::opt clEnumValN(FunctionLevel, "func", "Generate function-level embeddings")), cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd), - cl::cat(ir2vec::IR2VecCategory)); + cl::cat(CommonCategory)); -namespace { +namespace ir2vec { /// Relation types for triplet generation enum RelationType { @@ -300,20 +338,116 @@ Error processModule(Module &M, raw_ostream &OS) { } return Error::success(); } -} // namespace } // namespace ir2vec + +namespace mir2vec { + +/// Helper class for MIR2Vec embedding generation +class MIR2VecTool { +private: + MachineModuleInfo &MMI; + std::unique_ptr Vocab; + +public: + explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {} + + /// Initialize the MIR2Vec vocabulary + bool initializeVocabulary(const Module &M) { + MIR2VecVocabProvider Provider(MMI); + auto VocabOrErr = Provider.getVocabulary(M); + if (!VocabOrErr) { + errs() << "Error: Failed to load MIR2Vec vocabulary - " + << toString(VocabOrErr.takeError()) << "\n"; + return false; + } + Vocab = std::make_unique(std::move(*VocabOrErr)); + return true; + } + + /// Generate embeddings for all machine functions in the module + void generateEmbeddings(const Module &M, raw_ostream &OS) const { + if (!Vocab) { + OS << "Error: Vocabulary not initialized.\n"; + return; + } + + for (const Function &F : M) { + if (F.isDeclaration()) + continue; + + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) { + errs() << "Warning: No MachineFunction for " << F.getName() << "\n"; + continue; + } + + generateEmbeddings(*MF, OS); + } + } + + /// Generate embeddings for a specific machine function + void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const { + if (!Vocab) { + OS << "Error: Vocabulary not initialized.\n"; + return; + } + + auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab); + if (!Emb) { + errs() << "Error: Failed to create embedder for " << MF.getName() << "\n"; + return; + } + + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; + + // Generate embeddings based on the specified level + switch (Level) { + case FunctionLevel: { + OS << "Function vector: "; + Emb->getMFunctionVector().print(OS); + break; + } + case BasicBlockLevel: { + OS << "Basic block vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + OS << "MBB " << MBB.getName() << ": "; + Emb->getMBBVector(MBB).print(OS); + } + break; + } + case InstructionLevel: { + OS << "Instruction vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + OS << MI << " -> "; + Emb->getMInstVector(MI).print(OS); + } + } + break; + } + } + } + + const MIRVocabulary *getVocabulary() const { return Vocab.get(); } +}; + +} // namespace mir2vec + } // namespace llvm int main(int argc, char **argv) { using namespace llvm; using namespace llvm::ir2vec; + using namespace llvm::mir2vec; InitLLVM X(argc, argv); - cl::HideUnrelatedOptions(ir2vec::IR2VecCategory); + // Show Common, IR2Vec and MIR2Vec option categories + cl::HideUnrelatedOptions(ArrayRef{ + &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory}); cl::ParseCommandLineOptions( argc, argv, - "IR2Vec - Embedding Generation Tool\n" - "Generates embeddings for a given LLVM IR and " + "IR2Vec/MIR2Vec - Embedding Generation Tool\n" + "Generates embeddings for a given LLVM IR or MIR and " "supports triplet generation for vocabulary " "training and embedding generation.\n\n" "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more " @@ -326,26 +460,110 @@ int main(int argc, char **argv) { return 1; } - if (EntitiesSubCmd) { - // Just dump entity mappings without processing any IR - IR2VecTool::generateEntityMappings(OS); + if (IRMode == IRKind::LLVMIR) { + if (EntitiesSubCmd) { + // Just dump entity mappings without processing any IR + IR2VecTool::generateEntityMappings(OS); + return 0; + } + + // Parse the input LLVM IR file or stdin + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr M = parseIRFile(InputFilename, Err, Context); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + + if (Error Err = processModule(*M, OS)) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { + errs() << "Error: " << EIB.message() << "\n"; + }); + return 1; + } return 0; } + if (IRMode == IRKind::MIR) { + // Initialize targets for Machine IR processing + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmParsers(); + InitializeAllAsmPrinters(); + static codegen::RegisterCodeGenFlags CGF; + + // Parse MIR input file + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr TM; + + auto MIR = createMIRParserFromFile(InputFilename, Err, Context); + if (!MIR) { + Err.print(argv[0], WithColor::error(errs(), argv[0])); + return 1; + } - // Parse the input LLVM IR file or stdin - SMDiagnostic Err; - LLVMContext Context; - std::unique_ptr M = parseIRFile(InputFilename, Err, Context); - if (!M) { - Err.print(argv[0], errs()); - return 1; - } + auto SetDataLayout = [&](StringRef DataLayoutTargetTriple, + StringRef OldDLStr) -> std::optional { + std::string IRTargetTriple = DataLayoutTargetTriple.str(); + Triple TheTriple = Triple(IRTargetTriple); + if (TheTriple.getTriple().empty()) + TheTriple.setTriple(sys::getDefaultTargetTriple()); + auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str()); + if (!TMOrErr) { + Err.print(argv[0], WithColor::error(errs(), argv[0])); + exit(1); + } + TM = std::move(*TMOrErr); + return TM->createDataLayout().getStringRepresentation(); + }; + + std::unique_ptr M = MIR->parseIRModule(SetDataLayout); + if (!M) { + Err.print(argv[0], WithColor::error(errs(), argv[0])); + return 1; + } - if (Error Err = processModule(*M, OS)) { - handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { - errs() << "Error: " << EIB.message() << "\n"; - }); - return 1; + // Parse machine functions + auto MMI = std::make_unique(TM.get()); + if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) { + Err.print(argv[0], WithColor::error(errs(), argv[0])); + return 1; + } + + // Create MIR2Vec tool and initialize vocabulary + MIR2VecTool Tool(*MMI); + if (!Tool.initializeVocabulary(*M)) + return 1; + + LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n" + << "Vocabulary dimension: " + << Tool.getVocabulary()->getDimension() << "\n" + << "Vocabulary size: " + << Tool.getVocabulary()->getCanonicalSize() << "\n"); + + // Generate embeddings based on subcommand + if (!FunctionName.empty()) { + // Process single function + Function *F = M->getFunction(FunctionName); + if (!F) { + errs() << "Error: Function '" << FunctionName << "' not found\n"; + return 1; + } + + MachineFunction *MF = MMI->getMachineFunction(*F); + if (!MF) { + errs() << "Error: No MachineFunction for " << FunctionName << "\n"; + return 1; + } + + Tool.generateEmbeddings(*MF, OS); + } else { + // Process all functions + Tool.generateEmbeddings(*M, OS); + } + + return 0; } return 0;