Skip to content

Conversation

@s-barannikov
Copy link
Contributor

@s-barannikov s-barannikov commented Nov 13, 2025

This changes MCRegUnit type from unsigned to enum class : unsigned and inserts necessary casts.
The added MCRegUnitToIndex functor is used with SparseSet, SparseMultiSet and IndexedMap in a few places.

MCRegUnit is opaque to users, so it didn't seem worth making it a full-fledged class like Register.

Static type checking has detected one issue in PrologueEpilogueInserter.cpp, where BitVector created for MCRegister is indexed by both MCRegister and MCRegUnit.

The number of casts could be reduced by using IndexedMap in more places and/or adding a BitVector adaptor, but the number of casts per file is still small and IndexedMap has limitations, so it didn't seem worth the effort.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-llvm-regalloc

@llvm/pr-subscribers-backend-x86

Author: Sergei Barannikov (s-barannikov)

Changes

This changes MCRegUnit type from unsigned to enum class : unsigned and inserts necessary casts.
The added MCRegUnitToIndex functor is used with SparseSet and SparseMultiSet in a couple of places.

MCRegUnit is opaque to users, so it didn't seem worth making it a full-blown class like Register.

Static type checking has detected one issue in PrologueEpilogueInserter.cpp, where BitVector created for MCRegister is indexed by both MCRegister and MCRegUnit.

The number of casts could be reduced by using IndexedMap and/or adding a BitVector adaptor, but the number of casts per file is still small and IndexedMap has limitations, so it didn't seem worth the effort.


Patch is 40.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167943.diff

30 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/LiveIntervalUnion.h (+6-6)
  • (modified) llvm/include/llvm/CodeGen/LiveIntervals.h (+9-6)
  • (modified) llvm/include/llvm/CodeGen/LiveRegMatrix.h (+3-1)
  • (modified) llvm/include/llvm/CodeGen/LiveRegUnits.h (+6-6)
  • (modified) llvm/include/llvm/CodeGen/MachineTraceMetrics.h (+2-2)
  • (modified) llvm/include/llvm/CodeGen/RDFRegisters.h (+5-4)
  • (modified) llvm/include/llvm/CodeGen/ReachingDefAnalysis.h (+5-5)
  • (modified) llvm/include/llvm/CodeGen/Register.h (+7-2)
  • (modified) llvm/include/llvm/CodeGen/RegisterClassInfo.h (+1-1)
  • (modified) llvm/include/llvm/CodeGen/RegisterPressure.h (+3-3)
  • (modified) llvm/include/llvm/CodeGen/ScheduleDAGInstrs.h (+4-2)
  • (modified) llvm/include/llvm/MC/MCRegister.h (+9-1)
  • (modified) llvm/include/llvm/MC/MCRegisterInfo.h (+7-4)
  • (modified) llvm/lib/CodeGen/EarlyIfConversion.cpp (+3-3)
  • (modified) llvm/lib/CodeGen/InterferenceCache.cpp (+3-3)
  • (modified) llvm/lib/CodeGen/LiveIntervals.cpp (+6-4)
  • (modified) llvm/lib/CodeGen/LiveRegMatrix.cpp (+2-2)
  • (modified) llvm/lib/CodeGen/LiveRegUnits.cpp (+2-2)
  • (modified) llvm/lib/CodeGen/MachineCopyPropagation.cpp (+2-2)
  • (modified) llvm/lib/CodeGen/MachineInstrBundle.cpp (+2-2)
  • (modified) llvm/lib/CodeGen/MachineLICM.cpp (+15-12)
  • (modified) llvm/lib/CodeGen/PrologEpilogInserter.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/RDFRegisters.cpp (+17-17)
  • (modified) llvm/lib/CodeGen/ReachingDefAnalysis.cpp (+12-9)
  • (modified) llvm/lib/CodeGen/RegAllocFast.cpp (+7-5)
  • (modified) llvm/lib/CodeGen/RegAllocGreedy.cpp (+3-2)
  • (modified) llvm/lib/CodeGen/RegisterClassInfo.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/TargetRegisterInfo.cpp (+3-3)
  • (modified) llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (+2-1)
  • (modified) llvm/lib/Target/X86/X86FixupBWInsts.cpp (+2-1)
diff --git a/llvm/include/llvm/CodeGen/LiveIntervalUnion.h b/llvm/include/llvm/CodeGen/LiveIntervalUnion.h
index cc0f2a45bb182..240fa114cf179 100644
--- a/llvm/include/llvm/CodeGen/LiveIntervalUnion.h
+++ b/llvm/include/llvm/CodeGen/LiveIntervalUnion.h
@@ -191,14 +191,14 @@ class LiveIntervalUnion {
 
     void clear();
 
-    LiveIntervalUnion& operator[](unsigned idx) {
-      assert(idx <  Size && "idx out of bounds");
-      return LIUs[idx];
+    LiveIntervalUnion &operator[](MCRegUnit Unit) {
+      assert(static_cast<unsigned>(Unit) < Size && "Unit out of bounds");
+      return LIUs[static_cast<unsigned>(Unit)];
     }
 
-    const LiveIntervalUnion& operator[](unsigned Idx) const {
-      assert(Idx < Size && "Idx out of bounds");
-      return LIUs[Idx];
+    const LiveIntervalUnion &operator[](MCRegUnit Unit) const {
+      assert(static_cast<unsigned>(Unit) < Size && "Unit out of bounds");
+      return LIUs[static_cast<unsigned>(Unit)];
     }
   };
 };
diff --git a/llvm/include/llvm/CodeGen/LiveIntervals.h b/llvm/include/llvm/CodeGen/LiveIntervals.h
index 32027766e7093..b618e0b778ae8 100644
--- a/llvm/include/llvm/CodeGen/LiveIntervals.h
+++ b/llvm/include/llvm/CodeGen/LiveIntervals.h
@@ -413,11 +413,12 @@ class LiveIntervals {
   /// Return the live range for register unit \p Unit. It will be computed if
   /// it doesn't exist.
   LiveRange &getRegUnit(MCRegUnit Unit) {
-    LiveRange *LR = RegUnitRanges[Unit];
+    LiveRange *LR = RegUnitRanges[static_cast<unsigned>(Unit)];
     if (!LR) {
       // Compute missing ranges on demand.
       // Use segment set to speed-up initial computation of the live range.
-      RegUnitRanges[Unit] = LR = new LiveRange(UseSegmentSetForPhysRegs);
+      RegUnitRanges[static_cast<unsigned>(Unit)] = LR =
+          new LiveRange(UseSegmentSetForPhysRegs);
       computeRegUnitRange(*LR, Unit);
     }
     return *LR;
@@ -425,17 +426,19 @@ class LiveIntervals {
 
   /// Return the live range for register unit \p Unit if it has already been
   /// computed, or nullptr if it hasn't been computed yet.
-  LiveRange *getCachedRegUnit(MCRegUnit Unit) { return RegUnitRanges[Unit]; }
+  LiveRange *getCachedRegUnit(MCRegUnit Unit) {
+    return RegUnitRanges[static_cast<unsigned>(Unit)];
+  }
 
   const LiveRange *getCachedRegUnit(MCRegUnit Unit) const {
-    return RegUnitRanges[Unit];
+    return RegUnitRanges[static_cast<unsigned>(Unit)];
   }
 
   /// Remove computed live range for register unit \p Unit. Subsequent uses
   /// should rely on on-demand recomputation.
   void removeRegUnit(MCRegUnit Unit) {
-    delete RegUnitRanges[Unit];
-    RegUnitRanges[Unit] = nullptr;
+    delete RegUnitRanges[static_cast<unsigned>(Unit)];
+    RegUnitRanges[static_cast<unsigned>(Unit)] = nullptr;
   }
 
   /// Remove associated live ranges for the register units associated with \p
diff --git a/llvm/include/llvm/CodeGen/LiveRegMatrix.h b/llvm/include/llvm/CodeGen/LiveRegMatrix.h
index 0bc243271bb73..35add577d071a 100644
--- a/llvm/include/llvm/CodeGen/LiveRegMatrix.h
+++ b/llvm/include/llvm/CodeGen/LiveRegMatrix.h
@@ -165,7 +165,9 @@ class LiveRegMatrix {
 
   /// Directly access the live interval unions per regunit.
   /// This returns an array indexed by the regunit number.
-  LiveIntervalUnion *getLiveUnions() { return &Matrix[0]; }
+  LiveIntervalUnion *getLiveUnions() {
+    return &Matrix[static_cast<MCRegUnit>(0)];
+  }
 
   Register getOneVReg(unsigned PhysReg) const;
 };
diff --git a/llvm/include/llvm/CodeGen/LiveRegUnits.h b/llvm/include/llvm/CodeGen/LiveRegUnits.h
index 37c31cc6f4ac5..0ff5273929671 100644
--- a/llvm/include/llvm/CodeGen/LiveRegUnits.h
+++ b/llvm/include/llvm/CodeGen/LiveRegUnits.h
@@ -86,23 +86,23 @@ class LiveRegUnits {
   /// Adds register units covered by physical register \p Reg.
   void addReg(MCRegister Reg) {
     for (MCRegUnit Unit : TRI->regunits(Reg))
-      Units.set(Unit);
+      Units.set(static_cast<unsigned>(Unit));
   }
 
   /// Adds register units covered by physical register \p Reg that are
   /// part of the lanemask \p Mask.
   void addRegMasked(MCRegister Reg, LaneBitmask Mask) {
-    for (MCRegUnitMaskIterator Unit(Reg, TRI); Unit.isValid(); ++Unit) {
-      LaneBitmask UnitMask = (*Unit).second;
+    for (MCRegUnitMaskIterator I(Reg, TRI); I.isValid(); ++I) {
+      auto [Unit, UnitMask] = *I;
       if ((UnitMask & Mask).any())
-        Units.set((*Unit).first);
+        Units.set(static_cast<unsigned>(Unit));
     }
   }
 
   /// Removes all register units covered by physical register \p Reg.
   void removeReg(MCRegister Reg) {
     for (MCRegUnit Unit : TRI->regunits(Reg))
-      Units.reset(Unit);
+      Units.reset(static_cast<unsigned>(Unit));
   }
 
   /// Removes register units not preserved by the regmask \p RegMask.
@@ -116,7 +116,7 @@ class LiveRegUnits {
   /// Returns true if no part of physical register \p Reg is live.
   bool available(MCRegister Reg) const {
     for (MCRegUnit Unit : TRI->regunits(Reg)) {
-      if (Units.test(Unit))
+      if (Units.test(static_cast<unsigned>(Unit)))
         return false;
     }
     return true;
diff --git a/llvm/include/llvm/CodeGen/MachineTraceMetrics.h b/llvm/include/llvm/CodeGen/MachineTraceMetrics.h
index d1be0ee3dfff9..b29984cd95a4b 100644
--- a/llvm/include/llvm/CodeGen/MachineTraceMetrics.h
+++ b/llvm/include/llvm/CodeGen/MachineTraceMetrics.h
@@ -78,12 +78,12 @@ struct LiveRegUnit {
   const MachineInstr *MI = nullptr;
   unsigned Op = 0;
 
-  unsigned getSparseSetIndex() const { return RegUnit; }
+  unsigned getSparseSetIndex() const { return static_cast<unsigned>(RegUnit); }
 
   explicit LiveRegUnit(MCRegUnit RU) : RegUnit(RU) {}
 };
 
-using LiveRegUnitSet = SparseSet<LiveRegUnit>;
+using LiveRegUnitSet = SparseSet<LiveRegUnit, MCRegUnit, MCRegUnitToIndex>;
 
 /// Strategies for selecting traces.
 enum class MachineTraceStrategy {
diff --git a/llvm/include/llvm/CodeGen/RDFRegisters.h b/llvm/include/llvm/CodeGen/RDFRegisters.h
index 4c15bf534d55f..4f742bd663e0a 100644
--- a/llvm/include/llvm/CodeGen/RDFRegisters.h
+++ b/llvm/include/llvm/CodeGen/RDFRegisters.h
@@ -153,8 +153,9 @@ struct PhysicalRegisterInfo {
   // Returns the set of aliased physical registers.
   std::set<RegisterId> getAliasSet(RegisterId Reg) const;
 
-  RegisterRef getRefForUnit(uint32_t U) const {
-    return RegisterRef(UnitInfos[U].Reg, UnitInfos[U].Mask);
+  RegisterRef getRefForUnit(MCRegUnit U) const {
+    return RegisterRef(UnitInfos[static_cast<unsigned>(U)].Reg,
+                       UnitInfos[static_cast<unsigned>(U)].Mask);
   }
 
   const BitVector &getMaskUnits(RegisterId MaskId) const {
@@ -163,8 +164,8 @@ struct PhysicalRegisterInfo {
 
   std::set<RegisterId> getUnits(RegisterRef RR) const;
 
-  const BitVector &getUnitAliases(uint32_t U) const {
-    return AliasInfos[U].Regs;
+  const BitVector &getUnitAliases(MCRegUnit U) const {
+    return AliasInfos[static_cast<unsigned>(U)].Regs;
   }
 
   RegisterRef mapTo(RegisterRef RR, unsigned R) const;
diff --git a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
index 2893e5ce6647e..863c3b39229b9 100644
--- a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
+++ b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
@@ -78,17 +78,17 @@ class MBBReachingDefsInfo {
   }
 
   void append(unsigned MBBNumber, MCRegUnit Unit, int Def) {
-    AllReachingDefs[MBBNumber][Unit].push_back(Def);
+    AllReachingDefs[MBBNumber][static_cast<unsigned>(Unit)].push_back(Def);
   }
 
   void prepend(unsigned MBBNumber, MCRegUnit Unit, int Def) {
-    auto &Defs = AllReachingDefs[MBBNumber][Unit];
+    auto &Defs = AllReachingDefs[MBBNumber][static_cast<unsigned>(Unit)];
     Defs.insert(Defs.begin(), Def);
   }
 
   void replaceFront(unsigned MBBNumber, MCRegUnit Unit, int Def) {
-    assert(!AllReachingDefs[MBBNumber][Unit].empty());
-    *AllReachingDefs[MBBNumber][Unit].begin() = Def;
+    assert(!AllReachingDefs[MBBNumber][static_cast<unsigned>(Unit)].empty());
+    *AllReachingDefs[MBBNumber][static_cast<unsigned>(Unit)].begin() = Def;
   }
 
   void clear() { AllReachingDefs.clear(); }
@@ -97,7 +97,7 @@ class MBBReachingDefsInfo {
     if (AllReachingDefs[MBBNumber].empty())
       // Block IDs are not necessarily dense.
       return ArrayRef<ReachingDef>();
-    return AllReachingDefs[MBBNumber][Unit];
+    return AllReachingDefs[MBBNumber][static_cast<unsigned>(Unit)];
   }
 
 private:
diff --git a/llvm/include/llvm/CodeGen/Register.h b/llvm/include/llvm/CodeGen/Register.h
index 5e1e12942a019..f375af5808d1c 100644
--- a/llvm/include/llvm/CodeGen/Register.h
+++ b/llvm/include/llvm/CodeGen/Register.h
@@ -182,20 +182,25 @@ class VirtRegOrUnit {
   unsigned VRegOrUnit;
 
 public:
-  constexpr explicit VirtRegOrUnit(MCRegUnit Unit) : VRegOrUnit(Unit) {
+  constexpr explicit VirtRegOrUnit(MCRegUnit Unit)
+      : VRegOrUnit(static_cast<unsigned>(Unit)) {
     assert(!Register::isVirtualRegister(VRegOrUnit));
   }
+
   constexpr explicit VirtRegOrUnit(Register Reg) : VRegOrUnit(Reg.id()) {
     assert(Reg.isVirtual());
   }
 
+  // Catches implicit conversions to Register.
+  template <typename T> explicit VirtRegOrUnit(T) = delete;
+
   constexpr bool isVirtualReg() const {
     return Register::isVirtualRegister(VRegOrUnit);
   }
 
   constexpr MCRegUnit asMCRegUnit() const {
     assert(!isVirtualReg() && "Not a register unit");
-    return VRegOrUnit;
+    return static_cast<MCRegUnit>(VRegOrUnit);
   }
 
   constexpr Register asVirtualReg() const {
diff --git a/llvm/include/llvm/CodeGen/RegisterClassInfo.h b/llvm/include/llvm/CodeGen/RegisterClassInfo.h
index 078ae80915fed..124c7aff8c76d 100644
--- a/llvm/include/llvm/CodeGen/RegisterClassInfo.h
+++ b/llvm/include/llvm/CodeGen/RegisterClassInfo.h
@@ -123,7 +123,7 @@ class RegisterClassInfo {
   MCRegister getLastCalleeSavedAlias(MCRegister PhysReg) const {
     MCRegister CSR;
     for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
-      CSR = CalleeSavedAliases[Unit];
+      CSR = CalleeSavedAliases[static_cast<unsigned>(Unit)];
       if (CSR)
         break;
     }
diff --git a/llvm/include/llvm/CodeGen/RegisterPressure.h b/llvm/include/llvm/CodeGen/RegisterPressure.h
index 20a7e4fa2e9de..7485be6dcb351 100644
--- a/llvm/include/llvm/CodeGen/RegisterPressure.h
+++ b/llvm/include/llvm/CodeGen/RegisterPressure.h
@@ -282,14 +282,14 @@ class LiveRegSet {
   unsigned getSparseIndexFromVirtRegOrUnit(VirtRegOrUnit VRegOrUnit) const {
     if (VRegOrUnit.isVirtualReg())
       return VRegOrUnit.asVirtualReg().virtRegIndex() + NumRegUnits;
-    assert(VRegOrUnit.asMCRegUnit() < NumRegUnits);
-    return VRegOrUnit.asMCRegUnit();
+    assert(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()) < NumRegUnits);
+    return static_cast<unsigned>(VRegOrUnit.asMCRegUnit());
   }
 
   VirtRegOrUnit getVirtRegOrUnitFromSparseIndex(unsigned SparseIndex) const {
     if (SparseIndex >= NumRegUnits)
       return VirtRegOrUnit(Register::index2VirtReg(SparseIndex - NumRegUnits));
-    return VirtRegOrUnit(SparseIndex);
+    return VirtRegOrUnit(static_cast<MCRegUnit>(SparseIndex));
   }
 
 public:
diff --git a/llvm/include/llvm/CodeGen/ScheduleDAGInstrs.h b/llvm/include/llvm/CodeGen/ScheduleDAGInstrs.h
index 059a3444c609c..8b3907629c00b 100644
--- a/llvm/include/llvm/CodeGen/ScheduleDAGInstrs.h
+++ b/llvm/include/llvm/CodeGen/ScheduleDAGInstrs.h
@@ -82,14 +82,16 @@ namespace llvm {
     PhysRegSUOper(SUnit *su, int op, MCRegUnit R)
         : SU(su), OpIdx(op), RegUnit(R) {}
 
-    unsigned getSparseSetIndex() const { return RegUnit; }
+    unsigned getSparseSetIndex() const {
+      return static_cast<unsigned>(RegUnit);
+    }
   };
 
   /// Use a SparseMultiSet to track physical registers. Storage is only
   /// allocated once for the pass. It can be cleared in constant time and reused
   /// without any frees.
   using RegUnit2SUnitsMap =
-      SparseMultiSet<PhysRegSUOper, unsigned, identity, uint16_t>;
+      SparseMultiSet<PhysRegSUOper, MCRegUnit, MCRegUnitToIndex, uint16_t>;
 
   /// Track local uses of virtual registers. These uses are gathered by the DAG
   /// builder and may be consulted by the scheduler to avoid iterating an entire
diff --git a/llvm/include/llvm/MC/MCRegister.h b/llvm/include/llvm/MC/MCRegister.h
index 388cb5958f32e..c6cde36478c1d 100644
--- a/llvm/include/llvm/MC/MCRegister.h
+++ b/llvm/include/llvm/MC/MCRegister.h
@@ -27,7 +27,15 @@ using MCPhysReg = uint16_t;
 /// A target with a complicated sub-register structure will typically have many
 /// fewer register units than actual registers. MCRI::getNumRegUnits() returns
 /// the number of register units in the target.
-using MCRegUnit = unsigned;
+enum class MCRegUnit : unsigned;
+
+struct MCRegUnitToIndex {
+  using argument_type = MCRegUnit;
+
+  unsigned operator()(MCRegUnit Unit) const {
+    return static_cast<unsigned>(Unit);
+  }
+};
 
 /// Wrapper class representing physical registers. Should be passed by value.
 class MCRegister {
diff --git a/llvm/include/llvm/MC/MCRegisterInfo.h b/llvm/include/llvm/MC/MCRegisterInfo.h
index f1caa077a6d7b..6e36e580358e7 100644
--- a/llvm/include/llvm/MC/MCRegisterInfo.h
+++ b/llvm/include/llvm/MC/MCRegisterInfo.h
@@ -724,9 +724,10 @@ class MCRegUnitRootIterator {
   MCRegUnitRootIterator() = default;
 
   MCRegUnitRootIterator(MCRegUnit RegUnit, const MCRegisterInfo *MCRI) {
-    assert(RegUnit < MCRI->getNumRegUnits() && "Invalid register unit");
-    Reg0 = MCRI->RegUnitRoots[RegUnit][0];
-    Reg1 = MCRI->RegUnitRoots[RegUnit][1];
+    assert(static_cast<unsigned>(RegUnit) < MCRI->getNumRegUnits() &&
+           "Invalid register unit");
+    Reg0 = MCRI->RegUnitRoots[static_cast<unsigned>(RegUnit)][0];
+    Reg1 = MCRI->RegUnitRoots[static_cast<unsigned>(RegUnit)][1];
   }
 
   /// Dereference to get the current root register.
@@ -803,7 +804,9 @@ MCRegisterInfo::sub_and_superregs_inclusive(MCRegister Reg) const {
 }
 
 inline iota_range<MCRegUnit> MCRegisterInfo::regunits() const {
-  return seq(getNumRegUnits());
+  return enum_seq(static_cast<MCRegUnit>(0),
+                  static_cast<MCRegUnit>(getNumRegUnits()),
+                  force_iteration_on_noniterable_enum);
 }
 
 inline iterator_range<MCRegUnitIterator>
diff --git a/llvm/lib/CodeGen/EarlyIfConversion.cpp b/llvm/lib/CodeGen/EarlyIfConversion.cpp
index 55caa6e8a8f95..28993c47c094d 100644
--- a/llvm/lib/CodeGen/EarlyIfConversion.cpp
+++ b/llvm/lib/CodeGen/EarlyIfConversion.cpp
@@ -134,7 +134,7 @@ class SSAIfConv {
   BitVector ClobberedRegUnits;
 
   // Scratch pad for findInsertionPoint.
-  SparseSet<MCRegUnit> LiveRegUnits;
+  SparseSet<MCRegUnit, MCRegUnit, MCRegUnitToIndex> LiveRegUnits;
 
   /// Insertion point in Head for speculatively executed instructions form TBB
   /// and FBB.
@@ -271,7 +271,7 @@ bool SSAIfConv::InstrDependenciesAllowIfConv(MachineInstr *I) {
     // Remember clobbered regunits.
     if (MO.isDef() && Reg.isPhysical())
       for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
-        ClobberedRegUnits.set(Unit);
+        ClobberedRegUnits.set(static_cast<unsigned>(Unit));
 
     if (!MO.readsReg() || !Reg.isVirtual())
       continue;
@@ -409,7 +409,7 @@ bool SSAIfConv::findInsertionPoint() {
     // Anything read by I is live before I.
     while (!Reads.empty())
       for (MCRegUnit Unit : TRI->regunits(Reads.pop_back_val()))
-        if (ClobberedRegUnits.test(Unit))
+        if (ClobberedRegUnits.test(static_cast<unsigned>(Unit)))
           LiveRegUnits.insert(Unit);
 
     // We can't insert before a terminator.
diff --git a/llvm/lib/CodeGen/InterferenceCache.cpp b/llvm/lib/CodeGen/InterferenceCache.cpp
index ebdf0506bb22f..466070b312b2d 100644
--- a/llvm/lib/CodeGen/InterferenceCache.cpp
+++ b/llvm/lib/CodeGen/InterferenceCache.cpp
@@ -93,7 +93,7 @@ void InterferenceCache::Entry::revalidate(LiveIntervalUnion *LIUArray,
   PrevPos = SlotIndex();
   unsigned i = 0;
   for (MCRegUnit Unit : TRI->regunits(PhysReg))
-    RegUnits[i++].VirtTag = LIUArray[Unit].getTag();
+    RegUnits[i++].VirtTag = LIUArray[static_cast<unsigned>(Unit)].getTag();
 }
 
 void InterferenceCache::Entry::reset(MCRegister physReg,
@@ -110,7 +110,7 @@ void InterferenceCache::Entry::reset(MCRegister physReg,
   PrevPos = SlotIndex();
   RegUnits.clear();
   for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
-    RegUnits.push_back(LIUArray[Unit]);
+    RegUnits.push_back(LIUArray[static_cast<unsigned>(Unit)]);
     RegUnits.back().Fixed = &LIS->getRegUnit(Unit);
   }
 }
@@ -121,7 +121,7 @@ bool InterferenceCache::Entry::valid(LiveIntervalUnion *LIUArray,
   for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
     if (i == e)
       return false;
-    if (LIUArray[Unit].changedSince(RegUnits[i].VirtTag))
+    if (LIUArray[static_cast<unsigned>(Unit)].changedSince(RegUnits[i].VirtTag))
       return false;
     ++i;
   }
diff --git a/llvm/lib/CodeGen/LiveIntervals.cpp b/llvm/lib/CodeGen/LiveIntervals.cpp
index b600e0411bc48..2e8756565c8f7 100644
--- a/llvm/lib/CodeGen/LiveIntervals.cpp
+++ b/llvm/lib/CodeGen/LiveIntervals.cpp
@@ -184,7 +184,8 @@ void LiveIntervals::print(raw_ostream &OS) const {
   // Dump the regunits.
   for (unsigned Unit = 0, UnitE = RegUnitRanges.size(); Unit != UnitE; ++Unit)
     if (LiveRange *LR = RegUnitRanges[Unit])
-      OS << printRegUnit(Unit, TRI) << ' ' << *LR << '\n';
+      OS << printRegUnit(static_cast<MCRegUnit>(Unit), TRI) << ' ' << *LR
+         << '\n';
 
   // Dump the virtregs.
   for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
@@ -367,10 +368,11 @@ void LiveIntervals::computeLiveInRegUnits() {
     LLVM_DEBUG(dbgs() << Begin << "\t" << printMBBReference(MBB));
     for (const auto &LI : MBB.liveins()) {
       for (MCRegUnit Unit : TRI->regunits(LI.PhysReg)) {
-        LiveRange *LR = RegUnitRanges[Unit];
+        LiveRange *LR = RegUnitRanges[static_cast<unsigned>(Unit)];
         if (!LR) {
           // Use segment set to speed-up initial computation of the live range.
-          LR = RegUnitRanges[Unit] = new LiveRange(UseSegmentSetForPhysRegs);
+          LR = RegUnitRanges[static_cast<unsigned>(Unit)] =
+              new LiveRange(UseSegmentSetForPhysRegs);
           NewRanges.push_back(Unit);
         }
         VNInfo *VNI = LR->createDeadDef(Begin, getVNInfoAllocator());
@@ -384,7 +386,7 @@ void LiveIntervals::computeLiveInRegUnits() {
 
   // Compute the 'normal' part of the ranges.
   for (MCRegUnit Unit : NewRanges)
-    computeRegUnitRange(*RegUnitRanges[Unit], Unit);
+    computeRegUnitRange(*RegUnitRanges[static_cast<unsigned>(Unit)], Unit);
 }
 
 static void createSegmentsForValues(LiveRange &LR,
diff --git a/llvm/lib/CodeGen/LiveRegMatrix.cpp b/llvm/lib/CodeGen/LiveRegMatrix.cpp
index e3ee8dc325933..e7238008d2c69 100644
--- a/llvm/lib/CodeGen/LiveRegMatrix.cpp
+++ b/llvm/lib/CodeGen/LiveRegMatrix.cpp
@@ -76,7 +76,7 @@ void LiveRegMatrixWrapperLegacy::releaseMemory() { LRM.releaseMemory(); }
 
 void LiveRegMatrix::releaseMemory() {
   for (unsigned i = 0, e = Matrix.size(); i != e; ++i) {
-    Matrix[i].clear();
+    Matrix[static_cast<MCRegUnit>(i)].clear();
     // No need to clear Queries here, since LiveIntervalUnion::Query doesn't
     // have anything important to clear and LiveRegMatrix's runOnFunction()
     // does a std::unique_ptr::reset anyways.
@@ -185,7 +185,7 @@ bool LiveRegMatrix::checkRegUnitInterference(const LiveInterval &VirtReg,
 
 LiveIntervalUnion::Query &LiveRegMatrix::query(const LiveRange &LR,
                                                MCRegUnit RegUnit) {
-  LiveIntervalUnion::Query &Q = Queries[RegUnit];
+  LiveIntervalUnion::Query &Q = Queries[static_cast<unsigned>(RegUnit)];
   Q.init(UserTag, LR, Matrix[RegUnit]);
   return Q;
 }
diff --git a/llvm/lib/CodeGen/LiveRegUnits.cpp b/llvm/lib/CodeGen/LiveRegUnits.cpp
index 3e7052a9b6245..348ccd85f4c45 100644
--- a/llvm/lib/CodeGen/LiveRegUnits.cpp
+++ b/llvm/lib/CodeGen/LiveRegUnits.cpp
@@ -23,7 +23,7 @@ void LiveRegUnits::removeRegsNotPreserved(const uint32_t *RegMask) {
   for (MCRegUnit U : TRI->regunits()) {
     for (MCRegUnitRootIterator RootReg(U, TRI); RootReg.isValid(); ++RootReg) {
       if (MachineOperand::clobbersPhysReg(RegMask, *RootReg)) {
-        Units.reset(U);
+        Units.reset(static_cast<unsigned>(U));
         break;
       }
     }
@@ -34,7 +34,7 @@ void LiveRegUnits::addRegsInMask(const uint32_t *RegMask) {
   for (MCRegUnit U : TRI->regunits()) {
     for (MCRegUnitRoot...
[truncated]

@jayfoad
Copy link
Contributor

jayfoad commented Nov 14, 2025

This changes MCRegUnit type from unsigned to enum class : unsigned

Just to be clear, it's a enum with no enumerators, right? We don't have names for the individual regunits.

@s-barannikov
Copy link
Contributor Author

s-barannikov commented Nov 14, 2025

This changes MCRegUnit type from unsigned to enum class : unsigned

Just to be clear, it's a enum with no enumerators, right? We don't have names for the individual regunits.

Right, that's what I meant by "opaque to users".

# Conflicts:
#	llvm/include/llvm/CodeGen/RDFRegisters.h
#	llvm/lib/CodeGen/RDFRegisters.cpp
#	llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
@s-barannikov s-barannikov merged commit 97a60aa into llvm:main Nov 16, 2025
10 checks passed
@s-barannikov s-barannikov deleted the regunit-enum-class branch November 16, 2025 17:46
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Nov 16, 2025
This changes `MCRegUnit` type from `unsigned` to `enum class : unsigned`
and inserts necessary casts.
The added `MCRegUnitToIndex` functor is used with `SparseSet`,
`SparseMultiSet` and `IndexedMap` in a few places.

`MCRegUnit` is opaque to users, so it didn't seem worth making it a
full-fledged class like `Register`.

Static type checking has detected one issue in
`PrologueEpilogueInserter.cpp`, where `BitVector` created for
`MCRegister` is indexed by both `MCRegister` and `MCRegUnit`.

The number of casts could be reduced by using `IndexedMap` in more
places and/or adding a `BitVector` adaptor, but the number of casts *per
file* is still small and `IndexedMap` has limitations, so it didn't seem
worth the effort.

Pull Request: llvm/llvm-project#167943
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants