Skip to content

Commit

Permalink
[TableGen][GlobalISel] Account for HwMode in RegisterBank register sizes
Browse files Browse the repository at this point in the history
This patch adds logic for determining RegisterBank size to RegisterBankInfo, which allows accounting for the HwMode of the target. Individual RegisterBanks cannot be constructed with HwMode information as construction is generated by TableGen, but a RegisterBankInfo subclass can provide the HwMode as a constructor argument. The HwMode is used to select the appropriate RegisterBank size from an array relating sizes to RegisterBanks.

Targets simply need to provide the HwMode argument to the <target>GenRegisterBankInfo constructor. The RISC-V RegisterBankInfo constructor has been updated accordingly (plus an unused argument removed).

Reviewed By: simoncook, craig.topper

Differential Revision: https://reviews.llvm.org/D76007
  • Loading branch information
nitinjohnraj committed Jun 3, 2023
1 parent e501ed8 commit aa7eace
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 64 deletions.
10 changes: 3 additions & 7 deletions llvm/include/llvm/CodeGen/RegisterBank.h
Expand Up @@ -29,7 +29,6 @@ class RegisterBank {
private:
unsigned ID;
const char *Name;
unsigned Size;
BitVector ContainedRegClasses;

/// Sentinel value used to recognize register bank not properly
Expand All @@ -40,8 +39,8 @@ class RegisterBank {
friend RegisterBankInfo;

public:
RegisterBank(unsigned ID, const char *Name, unsigned Size,
const uint32_t *CoveredClasses, unsigned NumRegClasses);
RegisterBank(unsigned ID, const char *Name, const uint32_t *CoveredClasses,
unsigned NumRegClasses);

/// Get the identifier of this register bank.
unsigned getID() const { return ID; }
Expand All @@ -50,9 +49,6 @@ class RegisterBank {
/// Should be used only for debugging purposes.
const char *getName() const { return Name; }

/// Get the maximal size in bits that fits in this register bank.
unsigned getSize() const { return Size; }

/// Check whether this instance is ready to be used.
bool isValid() const;

Expand All @@ -62,7 +58,7 @@ class RegisterBank {
/// \note This method does not check anything when assertions are disabled.
///
/// \return True is the check was successful.
bool verify(const TargetRegisterInfo &TRI) const;
bool verify(const RegisterBankInfo &RBI, const TargetRegisterInfo &TRI) const;

/// Check whether this register bank covers \p RC.
/// In other words, check if this register bank fully covers
Expand Down
24 changes: 18 additions & 6 deletions llvm/include/llvm/CodeGen/RegisterBankInfo.h
Expand Up @@ -20,6 +20,7 @@
#include "llvm/ADT/iterator_range.h"
#include "llvm/CodeGen/LowLevelType.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/RegisterBank.h"
#include "llvm/Support/ErrorHandling.h"
#include <cassert>
#include <initializer_list>
Expand All @@ -30,7 +31,6 @@ namespace llvm {
class MachineInstr;
class MachineRegisterInfo;
class raw_ostream;
class RegisterBank;
class TargetInstrInfo;
class TargetRegisterClass;
class TargetRegisterInfo;
Expand Down Expand Up @@ -83,7 +83,7 @@ class RegisterBankInfo {
/// \note This method does not check anything when assertions are disabled.
///
/// \return True is the check was successful.
bool verify() const;
bool verify(const RegisterBankInfo &RBI) const;
};

/// Helper struct that represents how a value is mapped through
Expand Down Expand Up @@ -175,7 +175,7 @@ class RegisterBankInfo {
/// \note This method does not check anything when assertions are disabled.
///
/// \return True is the check was successful.
bool verify(unsigned MeaningfulBitWidth) const;
bool verify(const RegisterBankInfo &RBI, unsigned MeaningfulBitWidth) const;

/// Print this on dbgs() stream.
void dump() const;
Expand Down Expand Up @@ -384,11 +384,17 @@ class RegisterBankInfo {

protected:
/// Hold the set of supported register banks.
RegisterBank **RegBanks;
const RegisterBank **RegBanks;

/// Total number of register banks.
unsigned NumRegBanks;

/// Hold the sizes of the register banks for all HwModes.
const unsigned *Sizes;

/// Current HwMode for the target.
unsigned HwMode;

/// Keep dynamically allocated PartialMapping in a separate map.
/// This shouldn't be needed when everything gets TableGen'ed.
mutable DenseMap<unsigned, std::unique_ptr<const PartialMapping>>
Expand All @@ -415,7 +421,8 @@ class RegisterBankInfo {

/// Create a RegisterBankInfo that can accommodate up to \p NumRegBanks
/// RegisterBank instances.
RegisterBankInfo(RegisterBank **RegBanks, unsigned NumRegBanks);
RegisterBankInfo(const RegisterBank **RegBanks, unsigned NumRegBanks,
const unsigned *Sizes, unsigned HwMode);

/// This constructor is meaningless.
/// It just provides a default constructor that can be used at link time
Expand All @@ -428,7 +435,7 @@ class RegisterBankInfo {
}

/// Get the register bank identified by \p ID.
RegisterBank &getRegBank(unsigned ID) {
const RegisterBank &getRegBank(unsigned ID) {
assert(ID < getNumRegBanks() && "Accessing an unknown register bank");
return *RegBanks[ID];
}
Expand Down Expand Up @@ -576,6 +583,11 @@ class RegisterBankInfo {
return const_cast<RegisterBankInfo *>(this)->getRegBank(ID);
}

/// Get the maximum size in bits that fits in the given register bank.
unsigned getMaximumSize(unsigned RegBankID) const {
return Sizes[RegBankID + HwMode * NumRegBanks];
}

/// Get the register bank of \p Reg.
/// If Reg has not been assigned a register, a register class,
/// or a register bank, then this returns nullptr.
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/CodeGen/MachineVerifier.cpp
Expand Up @@ -2174,6 +2174,7 @@ MachineVerifier::visitMachineOperand(const MachineOperand *MO, unsigned MONum) {
}

const RegisterBank *RegBank = MRI->getRegBankOrNull(Reg);
const RegisterBankInfo *RBI = MF->getSubtarget().getRegBankInfo();

// If we're post-RegBankSelect, the gvreg must have a bank.
if (!RegBank && isFunctionRegBankSelected) {
Expand All @@ -2185,12 +2186,12 @@ MachineVerifier::visitMachineOperand(const MachineOperand *MO, unsigned MONum) {

// Make sure the register fits into its register bank if any.
if (RegBank && Ty.isValid() &&
RegBank->getSize() < Ty.getSizeInBits()) {
RBI->getMaximumSize(RegBank->getID()) < Ty.getSizeInBits()) {
report("Register bank is too small for virtual register", MO,
MONum);
errs() << "Register bank " << RegBank->getName() << " too small("
<< RegBank->getSize() << ") to fit " << Ty.getSizeInBits()
<< "-bits\n";
<< RBI->getMaximumSize(RegBank->getID()) << ") to fit "
<< Ty.getSizeInBits() << "-bits\n";
return;
}
}
Expand Down
18 changes: 10 additions & 8 deletions llvm/lib/CodeGen/RegisterBank.cpp
Expand Up @@ -11,6 +11,7 @@

#include "llvm/CodeGen/RegisterBank.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/CodeGen/RegisterBankInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Debug.h"
Expand All @@ -21,15 +22,16 @@ using namespace llvm;

const unsigned RegisterBank::InvalidID = UINT_MAX;

RegisterBank::RegisterBank(
unsigned ID, const char *Name, unsigned Size,
const uint32_t *CoveredClasses, unsigned NumRegClasses)
: ID(ID), Name(Name), Size(Size) {
RegisterBank::RegisterBank(unsigned ID, const char *Name,
const uint32_t *CoveredClasses,
unsigned NumRegClasses)
: ID(ID), Name(Name) {
ContainedRegClasses.resize(NumRegClasses);
ContainedRegClasses.setBitsInMask(CoveredClasses);
}

bool RegisterBank::verify(const TargetRegisterInfo &TRI) const {
bool RegisterBank::verify(const RegisterBankInfo &RBI,
const TargetRegisterInfo &TRI) const {
assert(isValid() && "Invalid register bank");
for (unsigned RCId = 0, End = TRI.getNumRegClasses(); RCId != End; ++RCId) {
const TargetRegisterClass &RC = *TRI.getRegClass(RCId);
Expand All @@ -50,7 +52,7 @@ bool RegisterBank::verify(const TargetRegisterInfo &TRI) const {

// Verify that the Size of the register bank is big enough to cover
// all the register classes it covers.
assert(getSize() >= TRI.getRegSizeInBits(SubRC) &&
assert(RBI.getMaximumSize(getID()) >= TRI.getRegSizeInBits(SubRC) &&
"Size is not big enough for all the subclasses!");
assert(covers(SubRC) && "Not all subclasses are covered");
}
Expand All @@ -64,7 +66,7 @@ bool RegisterBank::covers(const TargetRegisterClass &RC) const {
}

bool RegisterBank::isValid() const {
return ID != InvalidID && Name != nullptr && Size != 0 &&
return ID != InvalidID && Name != nullptr &&
// A register bank that does not cover anything is useless.
!ContainedRegClasses.empty();
}
Expand All @@ -89,7 +91,7 @@ void RegisterBank::print(raw_ostream &OS, bool IsForDebug,
OS << getName();
if (!IsForDebug)
return;
OS << "(ID:" << getID() << ", Size:" << getSize() << ")\n"
OS << "(ID:" << getID() << ")\n"
<< "isValid:" << isValid() << '\n'
<< "Number of Covered register classes: " << ContainedRegClasses.count()
<< '\n';
Expand Down
26 changes: 16 additions & 10 deletions llvm/lib/CodeGen/RegisterBankInfo.cpp
Expand Up @@ -52,9 +52,11 @@ const unsigned RegisterBankInfo::InvalidMappingID = UINT_MAX - 1;
//------------------------------------------------------------------------------
// RegisterBankInfo implementation.
//------------------------------------------------------------------------------
RegisterBankInfo::RegisterBankInfo(RegisterBank **RegBanks,
unsigned NumRegBanks)
: RegBanks(RegBanks), NumRegBanks(NumRegBanks) {
RegisterBankInfo::RegisterBankInfo(const RegisterBank **RegBanks,
unsigned NumRegBanks, const unsigned *Sizes,
unsigned HwMode)
: RegBanks(RegBanks), NumRegBanks(NumRegBanks), Sizes(Sizes),
HwMode(HwMode) {
#ifndef NDEBUG
for (unsigned Idx = 0, End = getNumRegBanks(); Idx != End; ++Idx) {
assert(RegBanks[Idx] != nullptr && "Invalid RegisterBank");
Expand All @@ -70,7 +72,7 @@ bool RegisterBankInfo::verify(const TargetRegisterInfo &TRI) const {
assert(Idx == RegBank.getID() &&
"ID does not match the index in the array");
LLVM_DEBUG(dbgs() << "Verify " << RegBank << '\n');
assert(RegBank.verify(TRI) && "RegBank is invalid");
assert(RegBank.verify(*this, TRI) && "RegBank is invalid");
}
#endif // NDEBUG
return true;
Expand Down Expand Up @@ -516,12 +518,14 @@ LLVM_DUMP_METHOD void RegisterBankInfo::PartialMapping::dump() const {
}
#endif

bool RegisterBankInfo::PartialMapping::verify() const {
bool RegisterBankInfo::PartialMapping::verify(
const RegisterBankInfo &RBI) const {
assert(RegBank && "Register bank not set");
assert(Length && "Empty mapping");
assert((StartIdx <= getHighBitIdx()) && "Overflow, switch to APInt?");
// Check if the minimum width fits into RegBank.
assert(RegBank->getSize() >= Length && "Register bank too small for Mask");
assert(RBI.getMaximumSize(RegBank->getID()) >= Length &&
"Register bank too small for Mask");
return true;
}

Expand All @@ -546,13 +550,14 @@ bool RegisterBankInfo::ValueMapping::partsAllUniform() const {
return true;
}

bool RegisterBankInfo::ValueMapping::verify(unsigned MeaningfulBitWidth) const {
bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI,
unsigned MeaningfulBitWidth) const {
assert(NumBreakDowns && "Value mapped nowhere?!");
unsigned OrigValueBitWidth = 0;
for (const RegisterBankInfo::PartialMapping &PartMap : *this) {
// Check that each register bank is big enough to hold the partial value:
// this check is done by PartialMapping::verify
assert(PartMap.verify() && "Partial mapping is invalid");
assert(PartMap.verify(RBI) && "Partial mapping is invalid");
// The original value should completely be mapped.
// Thus the maximum accessed index + 1 is the size of the original value.
OrigValueBitWidth =
Expand Down Expand Up @@ -626,8 +631,9 @@ bool RegisterBankInfo::InstructionMapping::verify(
(void)MOMapping;
// Register size in bits.
// This size must match what the mapping expects.
assert(MOMapping.verify(RBI->getSizeInBits(
Reg, MF.getRegInfo(), *MF.getSubtarget().getRegisterInfo())) &&
assert(MOMapping.verify(*RBI, RBI->getSizeInBits(
Reg, MF.getRegInfo(),
*MF.getSubtarget().getRegisterInfo())) &&
"Value mapping is invalid");
}
return true;
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
Expand Up @@ -71,20 +71,22 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
// GR64all + its subclasses.
assert(RBGPR.covers(*TRI.getRegClass(AArch64::GPR32RegClassID)) &&
"Subclass not added?");
assert(RBGPR.getSize() == 128 && "GPRs should hold up to 128-bit");
assert(getMaximumSize(RBGPR.getID()) == 128 &&
"GPRs should hold up to 128-bit");

// The FPR register bank is fully defined by all the registers in
// GR64all + its subclasses.
assert(RBFPR.covers(*TRI.getRegClass(AArch64::QQRegClassID)) &&
"Subclass not added?");
assert(RBFPR.covers(*TRI.getRegClass(AArch64::FPR64RegClassID)) &&
"Subclass not added?");
assert(RBFPR.getSize() == 512 &&
assert(getMaximumSize(RBFPR.getID()) == 512 &&
"FPRs should hold up to 512-bit via QQQQ sequence");

assert(RBCCR.covers(*TRI.getRegClass(AArch64::CCRRegClassID)) &&
"Class not added?");
assert(RBCCR.getSize() == 32 && "CCR should hold up to 32-bit");
assert(getMaximumSize(RBCCR.getID()) == 32 &&
"CCR should hold up to 32-bit");

// Check that the TableGen'ed like file is in sync we our expectations.
// First, the Idx.
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
Expand Up @@ -162,7 +162,8 @@ ARMRegisterBankInfo::ARMRegisterBankInfo(const TargetRegisterInfo &TRI) {
"Subclass not added?");
assert(RBGPR.covers(*TRI.getRegClass(ARM::tGPROdd_and_tcGPRRegClassID)) &&
"Subclass not added?");
assert(RBGPR.getSize() == 32 && "GPRs should hold up to 32-bit");
assert(getMaximumSize(RBGPR.getID()) == 32 &&
"GPRs should hold up to 32-bit");

#ifndef NDEBUG
ARM::checkPartialMappings();
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
Expand Up @@ -22,4 +22,5 @@

using namespace llvm;

RISCVRegisterBankInfo::RISCVRegisterBankInfo(const TargetRegisterInfo &TRI) {}
RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
: RISCVGenRegisterBankInfo(HwMode) {}
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h
Expand Up @@ -31,7 +31,7 @@ class RISCVGenRegisterBankInfo : public RegisterBankInfo {
/// This class provides the information for the target register banks.
class RISCVRegisterBankInfo final : public RISCVGenRegisterBankInfo {
public:
RISCVRegisterBankInfo(const TargetRegisterInfo &TRI);
RISCVRegisterBankInfo(unsigned HwMode);
};
} // end namespace llvm
#endif
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVSubtarget.cpp
Expand Up @@ -86,7 +86,7 @@ RISCVSubtarget::RISCVSubtarget(const Triple &TT, StringRef CPU,
CallLoweringInfo.reset(new RISCVCallLowering(*getTargetLowering()));
Legalizer.reset(new RISCVLegalizerInfo(*this));

auto *RBI = new RISCVRegisterBankInfo(*getRegisterInfo());
auto *RBI = new RISCVRegisterBankInfo(getHwMode());
RegBankInfo.reset(RBI);
InstSelector.reset(createRISCVInstructionSelector(
*static_cast<const RISCVTargetMachine *>(&TM), *this, *RBI));
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86RegisterBankInfo.cpp
Expand Up @@ -36,7 +36,8 @@ X86RegisterBankInfo::X86RegisterBankInfo(const TargetRegisterInfo &TRI) {
// GR64 + its subclasses.
assert(RBGPR.covers(*TRI.getRegClass(X86::GR64RegClassID)) &&
"Subclass not added?");
assert(RBGPR.getSize() == 64 && "GPRs should hold up to 64-bit");
assert(getMaximumSize(RBGPR.getID()) == 64 &&
"GPRs should hold up to 64-bit");
}

const RegisterBank &
Expand Down

0 comments on commit aa7eace

Please sign in to comment.