Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISel] Sink getOperandsMapping call out of the switch in getInstrMapping. #72326

Merged
merged 1 commit into from
Nov 15, 2023

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Nov 14, 2023

Use a SmallVector of ValueMapping * that we populate in the switch for each register operand. The entry in the SmallVector defaults to nullptr for each operand so we don't need to write explicit nullptr in the cases.

After this we can generically fill in GPR for registers as a default case and remove some opcodes from the switch.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 14, 2023

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

Use a SmallVector of ValueMapping * that we populate in the switch for each register operand. The entry in the SmallVector defaults to nullptr for each operand so we don't need to write explicit nullptr in the cases.

After this we can generically fill in GPR for registers as a default case and remove some opcodes from the switch.


Full diff: https://github.com/llvm/llvm-project/pull/72326.diff

1 Files Affected:

  • (modified) llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp (+55-54)
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index cb1da8ff11c08cb..be9c735c875466b 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -181,12 +181,8 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   const ValueMapping *GPRValueMapping =
       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
                                           : RISCV::GPRB32Idx];
-  const ValueMapping *OperandsMapping = GPRValueMapping;
 
   switch (Opc) {
-  case TargetOpcode::G_INVOKE_REGION_START:
-    OperandsMapping = getOperandsMapping({});
-    break;
   case TargetOpcode::G_ADD:
   case TargetOpcode::G_SUB:
   case TargetOpcode::G_SHL:
@@ -215,14 +211,37 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_ZEXT:
   case TargetOpcode::G_SEXTLOAD:
   case TargetOpcode::G_ZEXTLOAD:
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
+                                 NumOperands);
+  case TargetOpcode::G_FADD:
+  case TargetOpcode::G_FSUB:
+  case TargetOpcode::G_FMUL:
+  case TargetOpcode::G_FDIV:
+  case TargetOpcode::G_FABS:
+  case TargetOpcode::G_FNEG:
+  case TargetOpcode::G_FSQRT:
+  case TargetOpcode::G_FMAXNUM:
+  case TargetOpcode::G_FMINNUM: {
+    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                                 getFPValueMapping(Ty.getSizeInBits()),
+                                 NumOperands);
+  }
+  }
+
+  SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
+
+  switch (Opc) {
+  case TargetOpcode::G_INVOKE_REGION_START:
     break;
   case TargetOpcode::G_LOAD: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 loads on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
@@ -237,26 +256,25 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
                  // instruction.
                  return onlyUsesFP(UseMI);
                })) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     }
 
     break;
   }
   case TargetOpcode::G_STORE: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 stores on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
     if (onlyDefinesFP(*DefMI)) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     }
     break;
   }
@@ -265,76 +283,60 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_GLOBAL_VALUE:
   case TargetOpcode::G_JUMP_TABLE:
   case TargetOpcode::G_BRCOND:
-    OperandsMapping = getOperandsMapping({GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
     break;
   case TargetOpcode::G_BR:
-    OperandsMapping = getOperandsMapping({nullptr});
     break;
   case TargetOpcode::G_BRJT:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, nullptr, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
     break;
   case TargetOpcode::G_ICMP:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
   case TargetOpcode::G_SEXT_INREG:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     break;
   case TargetOpcode::G_SELECT:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, GPRValueMapping, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
-  case TargetOpcode::G_FADD:
-  case TargetOpcode::G_FSUB:
-  case TargetOpcode::G_FMUL:
-  case TargetOpcode::G_FDIV:
-  case TargetOpcode::G_FNEG:
-  case TargetOpcode::G_FABS:
-  case TargetOpcode::G_FSQRT:
-  case TargetOpcode::G_FMAXNUM:
-  case TargetOpcode::G_FMINNUM: {
-    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getFPValueMapping(Ty.getSizeInBits());
-    break;
-  }
   case TargetOpcode::G_FMA: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    const RegisterBankInfo::ValueMapping *FPValueMapping =
-        getFPValueMapping(Ty.getSizeInBits());
-    OperandsMapping = getOperandsMapping(
-        {FPValueMapping, FPValueMapping, FPValueMapping, FPValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = OpdsMapping[2] = OpdsMapping[3] = OpdsMapping[0];
     break;
   }
   case TargetOpcode::G_FPEXT:
   case TargetOpcode::G_FPTRUNC: {
     LLT ToTy = MRI.getType(MI.getOperand(0).getReg());
     LLT FromTy = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(ToTy.getSizeInBits()),
-                            getFPValueMapping(FromTy.getSizeInBits())});
+    OpdsMapping[0] = getFPValueMapping(ToTy.getSizeInBits());
+    OpdsMapping[1] = getFPValueMapping(FromTy.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FPTOSI:
   case TargetOpcode::G_FPTOUI: {
     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping,
-                            getFPValueMapping(Ty.getSizeInBits())});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_SITOFP:
   case TargetOpcode::G_UITOFP: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getOperandsMapping(
-        {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = GPRValueMapping;
     break;
   }
   case TargetOpcode::G_FCONSTANT: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(Ty.getSizeInBits()), nullptr});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FCMP: {
@@ -343,15 +345,14 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     unsigned Size = Ty.getSizeInBits();
     assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");
 
-    auto *FPRValueMapping = getFPValueMapping(Size);
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, FPRValueMapping, FPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
     break;
   }
   default:
     return getInvalidInstructionMapping();
   }
 
-  return getInstructionMapping(DefaultMappingID, /*Cost=*/1, OperandsMapping,
-                               NumOperands);
+  return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                               getOperandsMapping(OpdsMapping), NumOperands);
 }

Copy link
Member

@mshockwave mshockwave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ minor suggestions.

llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp Outdated Show resolved Hide resolved
@topperc topperc force-pushed the pr/regbank branch 2 times, most recently from 742aa46 to 55a05d8 Compare November 15, 2023 19:21
…nstrMapping.

Use a SmallVector of ValueMapping * that we populate in the switch for
each register operand. The entry in the SmallVector defaults to nullptr
foreach operand so we don't need to do explicit nullptr in the switch.

Afer this we can generically fill in GPR for registers as a default case and
remove some opcodes from the switch.
@topperc topperc merged commit 0b8379b into llvm:main Nov 15, 2023
2 of 3 checks passed
@topperc topperc deleted the pr/regbank branch November 18, 2023 19:02
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…nstrMapping. (llvm#72326)

Use a SmallVector of `ValueMapping *` that we populate in the switch for
each register operand. The entry in the SmallVector defaults to nullptr
for each operand so we don't need to write explicit `nullptr` in the
cases.

After this we can generically fill in GPR for registers as a default
case and remove some opcodes from the switch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants