-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MIR2Vec] Handle Operands #163281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
svkeerthy
wants to merge
1
commit into
main
Choose a base branch
from
users/svkeerthy/10-13-handle_operands
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,422
−166
Open
[MIR2Vec] Handle Operands #163281
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,8 @@ | |
#include "llvm/CodeGen/MachineFunctionPass.h" | ||
#include "llvm/CodeGen/MachineInstr.h" | ||
#include "llvm/CodeGen/MachineModuleInfo.h" | ||
#include "llvm/CodeGen/MachineOperand.h" | ||
#include "llvm/CodeGen/MachineRegisterInfo.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/Pass.h" | ||
#include "llvm/Support/CommandLine.h" | ||
|
@@ -61,7 +63,7 @@ class MIREmbedder; | |
class SymbolicMIREmbedder; | ||
|
||
extern llvm::cl::OptionCategory MIR2VecCategory; | ||
extern cl::opt<float> OpcWeight; | ||
extern cl::opt<float> OpcWeight, CommonOperandWeight, RegOperandWeight; | ||
|
||
using Embedding = ir2vec::Embedding; | ||
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>; | ||
|
@@ -74,31 +76,114 @@ class MIRVocabulary { | |
friend class llvm::MIR2VecVocabLegacyAnalysis; | ||
using VocabMap = std::map<std::string, ir2vec::Embedding>; | ||
|
||
private: | ||
// Define vocabulary layout - adapted for MIR | ||
// MIRVocabulary Layout: | ||
// +-------------------+-----------------------------------------------------+ | ||
// | Entity Type | Description | | ||
// +-------------------+-----------------------------------------------------+ | ||
// | 1. Opcodes | Target specific opcodes derived from TII, grouped | | ||
// | | by instruction semantics. | | ||
// | 2. Common Operands| All common operand types, except register operands, | | ||
// | | defined by MachineOperand::MachineOperandType enum. | | ||
// | 3. Physical | Register classes defined by the target, specialized | | ||
// | Reg classes | by physical registers. | | ||
// | 4. Virtual | Register classes defined by the target, specialized | | ||
// | Reg classes | by virtual and physical registers. | | ||
// +-------------------+-----------------------------------------------------+ | ||
|
||
/// Layout information for the MIR vocabulary. Defines the starting index | ||
/// and size of each section in the vocabulary. | ||
struct { | ||
size_t OpcodeBase = 0; | ||
size_t OperandBase = 0; | ||
size_t CommonOperandBase = 0; | ||
size_t PhyRegBase = 0; | ||
size_t VirtRegBase = 0; | ||
size_t TotalEntries = 0; | ||
} Layout; | ||
|
||
enum class Section : unsigned { Opcodes = 0, MaxSections }; | ||
enum class Section : unsigned { | ||
Opcodes = 0, | ||
CommonOperands = 1, | ||
PhyRegisters = 2, | ||
VirtRegisters = 3, | ||
MaxSections | ||
}; | ||
|
||
ir2vec::VocabStorage Storage; | ||
mutable std::set<std::string> UniqueBaseOpcodeNames; | ||
std::set<std::string> UniqueBaseOpcodeNames; | ||
SmallVector<std::string, 24> RegisterOperandNames; | ||
|
||
// Some instructions have optional register operands that may be NoRegister. | ||
// We return a zero vector in such cases. | ||
Embedding ZeroEmbedding; | ||
|
||
// We have specialized MO_Register handling in the Register operand section, | ||
// so we don't include it here. Also, no MO_DbgInstrRef for now. | ||
static constexpr StringLiteral CommonOperandNames[] = { | ||
"Immediate", "CImmediate", "FPImmediate", "MBB", | ||
"FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex", | ||
"ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask", | ||
"RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex", | ||
"IntrinsicID", "Predicate", "ShuffleMask"}; | ||
static_assert(std::size(CommonOperandNames) == MachineOperand::MO_Last - 1 && | ||
"Common operand names size changed, update accordingly"); | ||
|
||
const TargetInstrInfo &TII; | ||
void generateStorage(const VocabMap &OpcodeMap); | ||
const TargetRegisterInfo &TRI; | ||
const MachineRegisterInfo &MRI; | ||
|
||
void generateStorage(const VocabMap &OpcodeMap, | ||
const VocabMap &CommonOperandMap, | ||
const VocabMap &PhyRegMap, const VocabMap &VirtRegMap); | ||
void buildCanonicalOpcodeMapping(); | ||
void buildRegisterOperandMapping(); | ||
|
||
/// Get canonical index for a machine opcode | ||
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const; | ||
|
||
/// Get index for a common (non-register) machine operand | ||
unsigned | ||
getCommonOperandIndex(MachineOperand::MachineOperandType OperandType) const; | ||
|
||
/// Get index for a register machine operand | ||
unsigned getRegisterOperandIndex(Register Reg) const; | ||
|
||
// Accessors for operand types | ||
const Embedding & | ||
operator[](MachineOperand::MachineOperandType OperandType) const { | ||
unsigned LocalIndex = getCommonOperandIndex(OperandType); | ||
return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex]; | ||
} | ||
|
||
const Embedding &operator[](Register Reg) const { | ||
// Reg is sometimes NoRegister (0) for optional operands. We return a zero | ||
// vector in this case. | ||
if (!Reg.isValid()) | ||
return ZeroEmbedding; | ||
// TODO: Implement proper stack slot handling for MIR2Vec embeddings. | ||
// Stack slots represent frame indices and should have their own | ||
// embedding strategy rather than defaulting to register class 0. | ||
// Consider: 1) Separate vocabulary section for stack slots | ||
// 2) Stack slot size/alignment based embeddings | ||
// 3) Frame index based categorization | ||
if (Reg.isStack()) | ||
return ZeroEmbedding; | ||
|
||
unsigned LocalIndex = getRegisterOperandIndex(Reg); | ||
auto SectionID = | ||
Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters; | ||
return Storage[static_cast<unsigned>(SectionID)][LocalIndex]; | ||
} | ||
|
||
public: | ||
/// Static method for extracting base opcode names (public for testing) | ||
static std::string extractBaseOpcodeName(StringRef InstrName); | ||
|
||
/// Get canonical index for base name (public for testing) | ||
/// Get indices from opcode or operand names. These are public for testing. | ||
/// String based lookups are inefficient and should be avoided in general. | ||
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const; | ||
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const; | ||
unsigned getCanonicalIndexForRegisterClass(StringRef RegName, | ||
bool IsPhysical = true) const; | ||
|
||
/// Get the string key for a vocabulary entry at the given position | ||
std::string getStringKey(unsigned Pos) const; | ||
|
@@ -111,6 +196,14 @@ class MIRVocabulary { | |
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex]; | ||
} | ||
|
||
const Embedding &operator[](MachineOperand Operand) const { | ||
auto OperandType = Operand.getType(); | ||
if (OperandType == MachineOperand::MO_Register) | ||
return operator[](Operand.getReg()); | ||
else | ||
return operator[](OperandType); | ||
} | ||
|
||
// Iterator access | ||
using const_iterator = ir2vec::VocabStorage::const_iterator; | ||
const_iterator begin() const { return Storage.begin(); } | ||
|
@@ -120,18 +213,25 @@ class MIRVocabulary { | |
MIRVocabulary() = delete; | ||
|
||
/// Factory method to create MIRVocabulary from vocabulary map | ||
static Expected<MIRVocabulary> create(VocabMap &&Entries, | ||
const TargetInstrInfo &TII); | ||
static Expected<MIRVocabulary> | ||
create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, | ||
VocabMap &&VirtRegMap, const TargetInstrInfo &TII, | ||
const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI); | ||
|
||
/// Create a dummy vocabulary for testing purposes. | ||
static Expected<MIRVocabulary> | ||
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1); | ||
createDummyVocabForTest(const TargetInstrInfo &TII, | ||
const TargetRegisterInfo &TRI, | ||
const MachineRegisterInfo &MRI, unsigned Dim = 1); | ||
|
||
/// Total number of entries in the vocabulary | ||
size_t getCanonicalSize() const { return Storage.size(); } | ||
|
||
private: | ||
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII); | ||
MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, | ||
VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, | ||
const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, | ||
const MachineRegisterInfo &MRI); | ||
}; | ||
|
||
/// Base class for MIR embedders | ||
|
@@ -144,11 +244,13 @@ class MIREmbedder { | |
const unsigned Dimension; | ||
|
||
/// Weight for opcode embeddings | ||
const float OpcWeight; | ||
const float OpcWeight, CommonOperandWeight, RegOperandWeight; | ||
|
||
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab) | ||
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()), | ||
OpcWeight(mir2vec::OpcWeight) {} | ||
OpcWeight(mir2vec::OpcWeight), | ||
CommonOperandWeight(mir2vec::CommonOperandWeight), | ||
RegOperandWeight(mir2vec::RegOperandWeight) {} | ||
|
||
/// Function to compute embeddings. | ||
Embedding computeEmbeddings() const; | ||
|
@@ -208,11 +310,11 @@ class SymbolicMIREmbedder : public MIREmbedder { | |
class MIR2VecVocabLegacyAnalysis : public ImmutablePass { | ||
using VocabVector = std::vector<mir2vec::Embedding>; | ||
using VocabMap = std::map<std::string, mir2vec::Embedding>; | ||
VocabMap StrVocabMap; | ||
VocabVector Vocab; | ||
std::optional<mir2vec::MIRVocabulary> Vocab; | ||
|
||
StringRef getPassName() const override; | ||
Error readVocabulary(); | ||
Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab, | ||
VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap); | ||
|
||
protected: | ||
void getAnalysisUsage(AnalysisUsage &AU) const override { | ||
|
@@ -275,4 +377,4 @@ MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS); | |
|
||
} // namespace llvm | ||
|
||
#endif // LLVM_CODEGEN_MIR2VEC_H | ||
#endif // LLVM_CODEGEN_MIR2VEC_H | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Re: line +381] spurious change, or fixing an existing "no end of file newline" case? See this comment inline on Graphite. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixing existing case |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.