Skip to content

Commit

Permalink
[RISCV][GISel] RegBank select and instruction select for vector G_ADD…
Browse files Browse the repository at this point in the history
…, G_SUB (#74114)

RegisterBank Selection for scalable vector G_ADD and G_SUB by creating
new mappings for different types of vector register banks.
Then implement Instruction Selection for the same operations by choosing
the correct RISC-V vector register class.
  • Loading branch information
jiahanxie353 committed Feb 1, 2024
1 parent 41be541 commit 10c2d5f
Show file tree
Hide file tree
Showing 7 changed files with 3,043 additions and 3 deletions.
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
}

const LLT Ty = MRI.getType(VReg);
if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
if (Ty.isValid() &&
TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
reportGISelFailure(
MF, TPC, MORE, "gisel-select",
"VReg's low-level type and register class have different sizes", *MI);
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,20 @@ const TargetRegisterClass *RISCVInstructionSelector::getRegClassForTypeOnBank(
return &RISCV::FPR64RegClass;
}

// TODO: Non-GPR register classes.
if (RB.getID() == RISCV::VRBRegBankID) {
if (Ty.getSizeInBits().getKnownMinValue() <= 64)
return &RISCV::VRRegClass;

if (Ty.getSizeInBits().getKnownMinValue() == 128)
return &RISCV::VRM2RegClass;

if (Ty.getSizeInBits().getKnownMinValue() == 256)
return &RISCV::VRM4RegClass;

if (Ty.getSizeInBits().getKnownMinValue() == 512)
return &RISCV::VRM8RegClass;
}

return nullptr;
}

Expand Down
58 changes: 57 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,27 @@ namespace llvm {
namespace RISCV {

const RegisterBankInfo::PartialMapping PartMappings[] = {
// clang-format off
{0, 32, GPRBRegBank},
{0, 64, GPRBRegBank},
{0, 32, FPRBRegBank},
{0, 64, FPRBRegBank},
{0, 64, VRBRegBank},
{0, 128, VRBRegBank},
{0, 256, VRBRegBank},
{0, 512, VRBRegBank},
// clang-format on
};

enum PartialMappingIdx {
PMI_GPRB32 = 0,
PMI_GPRB64 = 1,
PMI_FPRB32 = 2,
PMI_FPRB64 = 3,
PMI_VRB64 = 4,
PMI_VRB128 = 5,
PMI_VRB256 = 6,
PMI_VRB512 = 7,
};

const RegisterBankInfo::ValueMapping ValueMappings[] = {
Expand All @@ -57,6 +67,22 @@ const RegisterBankInfo::ValueMapping ValueMappings[] = {
{&PartMappings[PMI_FPRB64], 1},
{&PartMappings[PMI_FPRB64], 1},
{&PartMappings[PMI_FPRB64], 1},
// Maximum 3 VR LMUL={1, MF2, MF4, MF8} operands.
{&PartMappings[PMI_VRB64], 1},
{&PartMappings[PMI_VRB64], 1},
{&PartMappings[PMI_VRB64], 1},
// Maximum 3 VR LMUL=2 operands.
{&PartMappings[PMI_VRB128], 1},
{&PartMappings[PMI_VRB128], 1},
{&PartMappings[PMI_VRB128], 1},
// Maximum 3 VR LMUL=4 operands.
{&PartMappings[PMI_VRB256], 1},
{&PartMappings[PMI_VRB256], 1},
{&PartMappings[PMI_VRB256], 1},
// Maximum 3 VR LMUL=8 operands.
{&PartMappings[PMI_VRB512], 1},
{&PartMappings[PMI_VRB512], 1},
{&PartMappings[PMI_VRB512], 1},
};

enum ValueMappingIdx {
Expand All @@ -65,6 +91,10 @@ enum ValueMappingIdx {
GPRB64Idx = 4,
FPRB32Idx = 7,
FPRB64Idx = 10,
VRB64Idx = 13,
VRB128Idx = 16,
VRB256Idx = 19,
VRB512Idx = 22,
};
} // namespace RISCV
} // namespace llvm
Expand Down Expand Up @@ -215,6 +245,23 @@ bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
[&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
}

static const RegisterBankInfo::ValueMapping *getVRBValueMapping(unsigned Size) {
unsigned Idx;

if (Size <= 64)
Idx = RISCV::VRB64Idx;
else if (Size == 128)
Idx = RISCV::VRB128Idx;
else if (Size == 256)
Idx = RISCV::VRB256Idx;
else if (Size == 512)
Idx = RISCV::VRB512Idx;
else
llvm::report_fatal_error("Invalid Size");

return &RISCV::ValueMappings[Idx];
}

const RegisterBankInfo::InstructionMapping &
RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
const unsigned Opc = MI.getOpcode();
Expand Down Expand Up @@ -242,7 +289,16 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {

switch (Opc) {
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_SUB: {
if (MRI.getType(MI.getOperand(0).getReg()).isVector()) {
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
return getInstructionMapping(
DefaultMappingID, /*Cost=*/1,
getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue()),
NumOperands);
}
}
LLVM_FALLTHROUGH;
case TargetOpcode::G_SHL:
case TargetOpcode::G_ASHR:
case TargetOpcode::G_LSHR:
Expand Down
Loading

0 comments on commit 10c2d5f

Please sign in to comment.