Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGen][TII] Allow reassociation on custom operand indices #88306

Merged
merged 2 commits into from
Apr 23, 2024

Conversation

mshockwave
Copy link
Member

This opens up a door for reusing reassociation optimizations on target-specific binary operations with non-standard operand list.

This is effectively a NFC.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Min-Yih Hsu (mshockwave)

Changes

This opens up a door for reusing reassociation optimizations on target-specific binary operations with non-standard operand list.

This is effectively a NFC.


Full diff: https://github.com/llvm/llvm-project/pull/88306.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetInstrInfo.h (+11)
  • (modified) llvm/lib/CodeGen/TargetInstrInfo.cpp (+100-45)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (+4-4)
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index 9fd0ebe6956fbe..82c952b227557d 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -30,6 +30,7 @@
 #include "llvm/MC/MCInstrInfo.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <array>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -1268,12 +1269,22 @@ class TargetInstrInfo : public MCInstrInfo {
     return true;
   }
 
+  /// The returned array encodes the operand index for each parameter because
+  /// the operands may be commuted; the operand indices for associative
+  /// operations might also be target-specific. Each element specifies the index
+  /// of {Prev, A, B, X, Y}.
+  virtual void
+  getReassociateOperandIdx(const MachineInstr &Root,
+                           MachineCombinerPattern Pattern,
+                           std::array<unsigned, 5> &OperandIndices) const;
+
   /// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
   /// reduce critical path length.
   void reassociateOps(MachineInstr &Root, MachineInstr &Prev,
                       MachineCombinerPattern Pattern,
                       SmallVectorImpl<MachineInstr *> &InsInstrs,
                       SmallVectorImpl<MachineInstr *> &DelInstrs,
+                      ArrayRef<unsigned> OperandIndices,
                       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
 
   /// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 9fbd516acea8e1..488922e3c1b720 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1051,13 +1051,34 @@ static std::pair<bool, bool> mustSwapOperands(MachineCombinerPattern Pattern) {
   }
 }
 
+void TargetInstrInfo::getReassociateOperandIdx(
+    const MachineInstr &Root, MachineCombinerPattern Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    OperandIndices = {1, 1, 1, 2, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    OperandIndices = {2, 1, 2, 2, 1};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    OperandIndices = {1, 2, 1, 1, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    OperandIndices = {2, 2, 2, 1, 1};
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+}
+
 /// Attempt the reassociation transformation to reduce critical path length.
 /// See the above comments before getMachineCombinerPatterns().
 void TargetInstrInfo::reassociateOps(
-    MachineInstr &Root, MachineInstr &Prev,
-    MachineCombinerPattern Pattern,
+    MachineInstr &Root, MachineInstr &Prev, MachineCombinerPattern Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
     SmallVectorImpl<MachineInstr *> &DelInstrs,
+    ArrayRef<unsigned> OperandIndices,
     DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
   MachineFunction *MF = Root.getMF();
   MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1065,29 +1086,10 @@ void TargetInstrInfo::reassociateOps(
   const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
   const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
 
-  // This array encodes the operand index for each parameter because the
-  // operands may be commuted. Each row corresponds to a pattern value,
-  // and each column specifies the index of A, B, X, Y.
-  unsigned OpIdx[4][4] = {
-    { 1, 1, 2, 2 },
-    { 1, 2, 2, 1 },
-    { 2, 1, 1, 2 },
-    { 2, 2, 1, 1 }
-  };
-
-  int Row;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
-  case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
-  case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
-  case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
-  default: llvm_unreachable("unexpected MachineCombinerPattern");
-  }
-
-  MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
-  MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
-  MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
-  MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+  MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+  MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+  MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+  MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
   MachineOperand &OpC = Root.getOperand(0);
 
   Register RegA = OpA.getReg();
@@ -1126,11 +1128,62 @@ void TargetInstrInfo::reassociateOps(
     std::swap(KillX, KillY);
   }
 
+  unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+  unsigned RootFirstOpIdx, RootSecondOpIdx;
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+
+  // Basically BuildMI but doesn't add implicit operands by default.
+  auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+                              const MCInstrDesc &MCID, Register DestReg) {
+    return MachineInstrBuilder(
+               MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+        .setPCSections(MIMD.getPCSections())
+        .addReg(DestReg, RegState::Define);
+  };
+
   // Create new instructions for insertion.
   MachineInstrBuilder MIB1 =
-      BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
-          .addReg(RegX, getKillRegState(KillX))
-          .addReg(RegY, getKillRegState(KillY));
+      buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+  for (const auto &MO : Prev.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand we'd already added.
+    if (Idx == 0)
+      continue;
+    if (Idx == PrevFirstOpIdx)
+      MIB1.addReg(RegX, getKillRegState(KillX));
+    else if (Idx == PrevSecondOpIdx)
+      MIB1.addReg(RegY, getKillRegState(KillY));
+    else
+      MIB1.add(MO);
+  }
+  MIB1.copyImplicitOps(Prev);
 
   if (SwapRootOperands) {
     std::swap(RegA, NewVR);
@@ -1138,9 +1191,20 @@ void TargetInstrInfo::reassociateOps(
   }
 
   MachineInstrBuilder MIB2 =
-      BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
-          .addReg(RegA, getKillRegState(KillA))
-          .addReg(NewVR, getKillRegState(KillNewVR));
+      buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+  for (const auto &MO : Root.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand.
+    if (Idx == 0)
+      continue;
+    if (Idx == RootFirstOpIdx)
+      MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+    else if (Idx == RootSecondOpIdx)
+      MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+    else
+      MIB2 = MIB2.add(MO);
+  }
+  MIB2.copyImplicitOps(Root);
 
   // Propagate FP flags from the original instructions.
   // But clear poison-generating flags because those may not be valid now.
@@ -1184,25 +1248,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
   MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
 
   // Select the previous instruction in the sequence based on the input pattern.
-  MachineInstr *Prev = nullptr;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY:
-  case MachineCombinerPattern::REASSOC_XA_BY:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
-    break;
-  case MachineCombinerPattern::REASSOC_AX_YB:
-  case MachineCombinerPattern::REASSOC_XA_YB:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
-    break;
-  default:
-    llvm_unreachable("Unknown pattern for machine combiner");
-  }
+  std::array<unsigned, 5> OpIdx;
+  getReassociateOperandIdx(Root, Pattern, OpIdx);
+  MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
 
   // Don't reassociate if Prev and Root are in different blocks.
   if (Prev->getParent() != Root.getParent())
     return;
 
-  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+                 InstIdxForVirtReg);
 }
 
 MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 6b75efe684d913..5eeb0d7c27cb98 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1575,10 +1575,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
   MachineFunction &MF = *Root.getMF();
 
   for (auto *NewMI : InsInstrs) {
-    assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
-               NewMI->getOpcode(), RISCV::OpName::frm)) ==
-               NewMI->getNumOperands() &&
-           "Instruction has unexpected number of operands");
+    // We'd already added the FRM operand.
+    if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+            NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+      continue;
     MachineInstrBuilder MIB(MF, NewMI);
     MIB.add(FRM);
     if (FRM.getImm() == RISCVFPRndMode::DYN)

MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does indice 0 represent, I don't see it used but maybe I missed it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually used in TargetInstrInfo::genAlternativeCodeSequence, which is reassociateOps's caller, to get the Prev instruction from Root.

default:
llvm_unreachable("Unknown pattern for machine combiner");
}
std::array<unsigned, 5> OpIdx;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to OperandIndices to match what reassociateOps calls it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

llvm_unreachable("Unknown pattern for machine combiner");
}
std::array<unsigned, 5> OpIdx;
getReassociateOperandIdx(Root, Pattern, OpIdx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getReassociateOperandIdx -> getReassociateOperandIndices?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@mshockwave
Copy link
Member Author

ping.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This opens up a door for reusing reassociation optimizations on target-specific
binary operations with non-standard operand list.

This is effectively a NFC.
@mshockwave mshockwave force-pushed the patch/reassoc-custom-operand-index branch from dc25a72 to e71afbe Compare April 23, 2024 17:45
@mshockwave
Copy link
Member Author

Forced push to rebase.

@mshockwave mshockwave merged commit 5fe93b0 into llvm:main Apr 23, 2024
3 of 4 checks passed
@mshockwave mshockwave deleted the patch/reassoc-custom-operand-index branch April 23, 2024 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants