Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISel] Add ISel support for SHXADD_UW and SLLI.UW #69972

Merged
merged 2 commits into from
Oct 24, 2023

Conversation

mshockwave
Copy link
Member

This patch also includes:

  • Remove legacy non_imm12 PatLeaf from RISCVInstrInfoZb.td
  • Implement a custom GlobalISel operand renderer for TrailingZeros SDNodeXForm

This patch also includes:
  - Remove legacy non_imm12 PatLeaf from RISCVInstrInfoZb.td
  - Implement a custom GlobalISel operand renderer for TrailingZeros
  SDNodeXForm
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-globalisel

Author: Min-Yih Hsu (mshockwave)

Changes

This patch also includes:

  • Remove legacy non_imm12 PatLeaf from RISCVInstrInfoZb.td
  • Implement a custom GlobalISel operand renderer for TrailingZeros SDNodeXForm

Full diff: https://github.com/llvm/llvm-project/pull/69972.diff

5 Files Affected:

  • (modified) llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp (+60)
  • (modified) llvm/lib/Target/RISCV/RISCVGISel.td (+10)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td (+3-3)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZb.td (+10-19)
  • (modified) llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir (+75)
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index a7f18c04a190790..0838d58220adb27 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -76,6 +76,13 @@ class RISCVInstructionSelector : public InstructionSelector {
     return selectSHXADDOp(Root, ShAmt);
   }
 
+  ComplexRendererFns selectSHXADD_UWOp(MachineOperand &Root,
+                                       unsigned ShAmt) const;
+  template <unsigned ShAmt>
+  ComplexRendererFns selectSHXADD_UWOp(MachineOperand &Root) const {
+    return selectSHXADD_UWOp(Root, ShAmt);
+  }
+
   // Custom renderers for tablegen
   void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI,
                     int OpIdx) const;
@@ -91,6 +98,9 @@ class RISCVInstructionSelector : public InstructionSelector {
                                 MachineRegisterInfo &MRI, RISCVCC::CondCode &CC,
                                 Register &LHS, Register &RHS) const;
 
+  void renderTrailingZeros(MachineInstrBuilder &MIB, const MachineInstr &MI,
+                           int OpIdx) const;
+
   const RISCVSubtarget &STI;
   const RISCVInstrInfo &TII;
   const RISCVRegisterInfo &TRI;
@@ -239,6 +249,47 @@ RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
   return std::nullopt;
 }
 
+InstructionSelector::ComplexRendererFns
+RISCVInstructionSelector::selectSHXADD_UWOp(MachineOperand &Root,
+                                            unsigned ShAmt) const {
+  using namespace llvm::MIPatternMatch;
+  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  if (!Root.isReg())
+    return std::nullopt;
+  Register RootReg = Root.getReg();
+
+  // Given (and (shl x, c2), mask) in which mask is a shifted mask with
+  // 32 - ShAmt leading zeros and c2 trailing zeros. We can use SLLI by
+  // c2 - ShAmt followed by SHXADD_UW with ShAmt for x amount.
+  APInt Mask, C2;
+  Register RegX;
+  if (mi_match(
+          RootReg, MRI,
+          m_OneNonDBGUse(m_GAnd(m_OneNonDBGUse(m_GShl(m_Reg(RegX), m_ICst(C2))),
+                                m_ICst(Mask))))) {
+    Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
+
+    if (Mask.isShiftedMask()) {
+      unsigned Leading = Mask.countl_zero();
+      unsigned Trailing = Mask.countr_zero();
+      if (Leading == 32 - ShAmt && C2 == Trailing && Trailing > ShAmt) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SLLI, {DstReg}, {RegX})
+              .addImm(C2.getLimitedValue() - ShAmt);
+          MIB.addReg(DstReg);
+        }}};
+      }
+    }
+  }
+
+  return std::nullopt;
+}
+
 InstructionSelector::ComplexRendererFns
 RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
   // TODO: Need to get the immediate from a G_PTR_ADD. Should this be done in
@@ -383,6 +434,15 @@ void RISCVInstructionSelector::renderImm(MachineInstrBuilder &MIB,
   MIB.addImm(CstVal);
 }
 
+void RISCVInstructionSelector::renderTrailingZeros(MachineInstrBuilder &MIB,
+                                                   const MachineInstr &MI,
+                                                   int OpIdx) const {
+  assert(MI.getOpcode() == TargetOpcode::G_CONSTANT && OpIdx == -1 &&
+         "Expected G_CONSTANT");
+  uint64_t C = MI.getOperand(1).getCImm()->getZExtValue();
+  MIB.addImm(llvm::countr_zero(C));
+}
+
 const TargetRegisterClass *RISCVInstructionSelector::getRegClassForTypeOnBank(
     LLT Ty, const RegisterBank &RB) const {
   if (RB.getID() == RISCV::GPRRegBankID) {
diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td
index 8d0d088c1116238..12484384c18dd04 100644
--- a/llvm/lib/Target/RISCV/RISCVGISel.td
+++ b/llvm/lib/Target/RISCV/RISCVGISel.td
@@ -57,6 +57,9 @@ def as_i64imm : SDNodeXForm<imm, [{
 def gi_as_i64imm : GICustomOperandRenderer<"renderImm">,
   GISDNodeXFormEquiv<as_i64imm>;
 
+def gi_trailing_zero : GICustomOperandRenderer<"renderTrailingZeros">,
+  GISDNodeXFormEquiv<TrailingZeros>;
+
 // FIXME: This is labelled as handling 's32', however the ComplexPattern it
 // refers to handles both i32 and i64 based on the HwMode. Currently this LLT
 // parameter appears to be ignored so this pattern works for both, however we
@@ -73,6 +76,13 @@ def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
 def gi_sh3add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<3>">,
                    GIComplexPatternEquiv<sh3add_op>;
 
+def gi_sh1add_uw_op : GIComplexOperandMatcher<s32, "selectSHXADD_UWOp<1>">,
+                      GIComplexPatternEquiv<sh1add_uw_op>;
+def gi_sh2add_uw_op : GIComplexOperandMatcher<s32, "selectSHXADD_UWOp<2>">,
+                      GIComplexPatternEquiv<sh2add_uw_op>;
+def gi_sh3add_uw_op : GIComplexOperandMatcher<s32, "selectSHXADD_UWOp<3>">,
+                      GIComplexPatternEquiv<sh3add_uw_op>;
+
 // FIXME: Canonicalize (sub X, C) -> (add X, -C) earlier.
 def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
           (ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
index d20ed70e1a5290e..41e139e3c7a9ebe 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
@@ -540,11 +540,11 @@ def : Pat<(add (XLenVT GPR:$rs1), (shl GPR:$rs2, uimm2:$uimm2)),
           (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>;
 
 // Reuse complex patterns from StdExtZba
-def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add_non_imm12 sh1add_op:$rs1, (XLenVT GPR:$rs2)),
           (TH_ADDSL GPR:$rs2, sh1add_op:$rs1, 1)>;
-def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add_non_imm12 sh2add_op:$rs1, (XLenVT GPR:$rs2)),
           (TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>;
-def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add_non_imm12 sh3add_op:$rs1, (XLenVT GPR:$rs2)),
           (TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>;
 
 def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index a7572e908b56b89..4a62a61dadcf3bb 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -230,16 +230,9 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
                                    N->getValueType(0));
 }]>;
 
-// Pattern to exclude simm12 immediates from matching.
-// Note: this will be removed once the GISel complex patterns for
-// SHXADD_UW is landed.
-def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
-  auto *C = dyn_cast<ConstantSDNode>(N);
-  return !C || !isInt<12>(C->getSExtValue());
-}]>;
-
+// Pattern to exclude simm12 immediates from matching, namely `non_imm12`.
 // GISel currently doesn't support PatFrag for leaf nodes, so `non_imm12`
-// cannot be directly supported in GISel. To reuse patterns between the two
+// cannot be implemented in that way. To reuse patterns between the two
 // ISels, we instead create PatFrag on operators that use `non_imm12`.
 class binop_with_non_imm12<SDPatternOperator binop>
   : PatFrag<(ops node:$x, node:$y), (binop node:$x, node:$y), [{
@@ -264,12 +257,11 @@ class binop_with_non_imm12<SDPatternOperator binop>
 def add_non_imm12       : binop_with_non_imm12<add>;
 def or_is_add_non_imm12 : binop_with_non_imm12<or_is_add>;
 
-def Shifted32OnesMask : PatLeaf<(imm), [{
-  uint64_t Imm = N->getZExtValue();
-  if (!isShiftedMask_64(Imm))
+def Shifted32OnesMask : IntImmLeaf<XLenVT, [{
+  if (!Imm.isShiftedMask())
     return false;
 
-  unsigned TrailingZeros = llvm::countr_zero(Imm);
+  unsigned TrailingZeros = Imm.countr_zero();
   return TrailingZeros > 0 && TrailingZeros < 32 &&
          Imm == UINT64_C(0xFFFFFFFF) << TrailingZeros;
 }], TrailingZeros>;
@@ -776,12 +768,11 @@ def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (XLenV
           (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
 
 // More complex cases use a ComplexPattern.
-def : Pat<(i64 (add sh1add_uw_op:$rs1, non_imm12:$rs2)),
-          (SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
-          (SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
-          (SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;
+foreach i = {1,2,3} in {
+  defvar pat = !cast<ComplexPattern>("sh"#i#"add_uw_op");
+  def : Pat<(i64 (add_non_imm12 pat:$rs1, (XLenVT GPR:$rs2))),
+            (!cast<Instruction>("SH"#i#"ADD_UW") pat:$rs1, GPR:$rs2)>;
+}
 
 def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFE), (XLenVT GPR:$rs2))),
           (SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
diff --git a/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
index 092a3305b3453d2..6dc3f8998e7d3b3 100644
--- a/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
+++ b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
@@ -150,3 +150,78 @@ body:             |
     %6:gprb(s64) = G_ADD %5, %1
     $x10 = COPY %6(s64)
 ...
+---
+name:            shXadd_uw_complex_shl_and
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_uw_complex_shl_and
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SLLI:%[0-9]+]]:gpr = SLLI [[COPY]], 1
+    ; CHECK-NEXT: [[SH2ADD_UW:%[0-9]+]]:gpr = SH2ADD_UW [[SLLI]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD_UW]]
+    %0:gprb(s64) = COPY $x10
+    %1:gprb(s64) = COPY $x11
+
+    %2:gprb(s64) = G_CONSTANT i64 3
+    %3:gprb(s64) = G_SHL %0, %2
+    %4:gprb(s64) = G_CONSTANT i64 17179869183
+    %5:gprb(s64) = G_AND %3, %4
+
+    %6:gprb(s64) = G_ADD %5, %1
+    $x10 = COPY %6(s64)
+...
+---
+name:            slli_uw
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10
+
+    ; CHECK-LABEL: name: slli_uw
+    ; CHECK: liveins: $x10
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[SLLI_UW:%[0-9]+]]:gpr = SLLI_UW [[COPY]], 7
+    ; CHECK-NEXT: $x10 = COPY [[SLLI_UW]]
+    %0:gprb(s64) = COPY $x10
+
+    %1:gprb(s64) = G_CONSTANT i64 4294967295
+    %2:gprb(s64) = G_AND %0, %1
+    %3:gprb(s64) = G_CONSTANT i64 7
+    %4:gprb(s64) = G_SHL %2, %3
+
+    $x10 = COPY %4(s64)
+...
+---
+name:            slli_uw_complex
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10
+
+    ; CHECK-LABEL: name: slli_uw_complex
+    ; CHECK: liveins: $x10
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[SRLI:%[0-9]+]]:gpr = SRLI [[COPY]], 2
+    ; CHECK-NEXT: [[SLLI_UW:%[0-9]+]]:gpr = SLLI_UW [[SRLI]], 2
+    ; CHECK-NEXT: $x10 = COPY [[SLLI_UW]]
+    %0:gprb(s64) = COPY $x10
+
+    %1:gprb(s64) = G_CONSTANT i64 17179869180
+    %2:gprb(s64) = G_AND %0, %1
+
+    $x10 = COPY %2(s64)
+...

$x10 = COPY %6(s64)
...
---
name: slli_uw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need patterns for (shl (i64 (zext (i32 X))), C). SelectionDAG doesn't have that case since i32 isn't a legal type on RV64 there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mshockwave mshockwave merged commit cdcaef8 into llvm:main Oct 24, 2023
2 of 3 checks passed
@mshockwave mshockwave deleted the gisel-riscv-zba branch October 24, 2023 23:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants