Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Defer forming x0,x0 vsetvlis until after insertion #89089

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Apr 17, 2024

Currently we try to detect when the VL doesn't change between two vsetvlis in emitVSETVLIs, and insert a VL-preserving vsetvli x0,x0 then and there.

Doing it in situ has some drawbacks:

  • We lose information about what the VL is in the pseudo which can prevent doLocalPostpass from coalescing vsetvlis further down the line
  • We have to handle emitting x0,x0 for vmv.x.s and friends by plumbing through a dummy NoRegister AVL
  • Other parts of the code need to be aware of x0,x0 vsetvlis and work around them. doLocalPostpass needs to handle them specifically, and callers of getInfoForVSETVLI need to check that they don't pass one in

This patch changes emitVSETVLIs to just emit regular vsetvlis, and adds a separate pass after doLocalPostpass to convert vsetvlis to x0,x0 when possible.

By removing the edge cases needed to handle x0,x0s, we can unify how we check vsetvli compatibility between doLocalPostpass and emitInsertVSETVLIs, and remove the duplicated logic in areCompatibleVTYPEs and canMutatePriorConfig.

We can also remove the dummy NoRegister AVL for vmv.x.s and stricten the invariant that the AVL must be either a virtual register or x0. (cc @BeMg: this may be useful for the post-ra patch)

There are some other changes that were difficult to split out:

  • In order to handle one specific case in saxpy_vec, we need to be able to detect when the VL doesn't change due to the AVL being a PHI node where incoming values are the output VLs of the last vsetvlis in a block. We generalize needVSETVLIPHI so we can reuse this
  • To prevent regressions we need to teach doLocalPostpass to coalesce vsetvlis where either vsetvli may write VL to a register

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

Currently we try and detect when the VL doesn't change between two vsetvlis in emitVSETVLIs, and insert a VL-preserving vsetvli x0,x0 then and there.

Doing it in situ has some drawbacks:

  • We lose information about what the VL is which can prevent doLocalPostpass from coalescing some vsetvlis further down the line
  • We have to handle emitting x0,x0 for vmv.x.s and friends by plumbing through a dummy NoRegister AVL
  • Other parts of the code need to be aware of x0,x0 vsetvlis and work around them. doLocalPostpass needs to handle them specifically, and callers of getInfoForVSETVLI need to check that they don't pass one in

This patch changes emitVSETVLIs to just emit regular vsetvlis, and adds a separate pass after doLocalPostpass to convert vsetvlis to x0,x0 when possible.

By removing the edge cases needed to handle x0,x0s, we can unify how we check vsetvli compatibility between doLocalPostpass and emitInsertVSETVLIs, and remove the duplicated logic in areCompatibleVTYPEs and canMutatePriorConfig.

We can also remove the dummy NoRegister AVL for vmv.x.s and stricten the invariant that the AVL must be either a virtual register or x0. (cc @BeMg: this may be useful for the post-ra patch)

There are some other changes that were difficult to split out:

  • In order to handle one specific case in saxpy_vec, we need to be able to detect when the VL doesn't change due to the AVL being a PHI node where incoming values are the output VLs of the last vsetvlis in a block. We generalize needVSETVLIPHI so we can reuse for this
  • To prevent regressions we need to teach doLocalPostpass to coalesce vsetvlis where either vsetvli may have an output VL register.

Stacked on #89080


Patch is 128.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89089.diff

10 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp (+135-182)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-expandload-fp.ll (+18-18)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-expandload-int.ll (+12-12)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-llrint.ll (+1-1)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-lrint.ll (+1-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll (+161-271)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-scatter.ll (+6-6)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-unaligned.ll (+3-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.mir (+9-9)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.mir (+3-3)
diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
index fa37d1ccccd737..6bc31416faa226 100644
--- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
@@ -62,15 +62,6 @@ static bool isVectorConfigInstr(const MachineInstr &MI) {
          MI.getOpcode() == RISCV::PseudoVSETIVLI;
 }
 
-/// Return true if this is 'vsetvli x0, x0, vtype' which preserves
-/// VL and only sets VTYPE.
-static bool isVLPreservingConfig(const MachineInstr &MI) {
-  if (MI.getOpcode() != RISCV::PseudoVSETVLIX0)
-    return false;
-  assert(RISCV::X0 == MI.getOperand(1).getReg());
-  return RISCV::X0 == MI.getOperand(0).getReg();
-}
-
 static bool isFloatScalarMoveOrScalarSplatInstr(const MachineInstr &MI) {
   switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
   default:
@@ -299,51 +290,6 @@ inline raw_ostream &operator<<(raw_ostream &OS, const DemandedFields &DF) {
 }
 #endif
 
-/// Return true if moving from CurVType to NewVType is
-/// indistinguishable from the perspective of an instruction (or set
-/// of instructions) which use only the Used subfields and properties.
-static bool areCompatibleVTYPEs(uint64_t CurVType, uint64_t NewVType,
-                                const DemandedFields &Used) {
-  switch (Used.SEW) {
-  case DemandedFields::SEWNone:
-    break;
-  case DemandedFields::SEWEqual:
-    if (RISCVVType::getSEW(CurVType) != RISCVVType::getSEW(NewVType))
-      return false;
-    break;
-  case DemandedFields::SEWGreaterThanOrEqual:
-    if (RISCVVType::getSEW(NewVType) < RISCVVType::getSEW(CurVType))
-      return false;
-    break;
-  case DemandedFields::SEWGreaterThanOrEqualAndLessThan64:
-    if (RISCVVType::getSEW(NewVType) < RISCVVType::getSEW(CurVType) ||
-        RISCVVType::getSEW(NewVType) >= 64)
-      return false;
-    break;
-  }
-
-  if (Used.LMUL &&
-      RISCVVType::getVLMUL(CurVType) != RISCVVType::getVLMUL(NewVType))
-    return false;
-
-  if (Used.SEWLMULRatio) {
-    auto Ratio1 = RISCVVType::getSEWLMULRatio(RISCVVType::getSEW(CurVType),
-                                              RISCVVType::getVLMUL(CurVType));
-    auto Ratio2 = RISCVVType::getSEWLMULRatio(RISCVVType::getSEW(NewVType),
-                                              RISCVVType::getVLMUL(NewVType));
-    if (Ratio1 != Ratio2)
-      return false;
-  }
-
-  if (Used.TailPolicy && RISCVVType::isTailAgnostic(CurVType) !=
-                             RISCVVType::isTailAgnostic(NewVType))
-    return false;
-  if (Used.MaskPolicy && RISCVVType::isMaskAgnostic(CurVType) !=
-                             RISCVVType::isMaskAgnostic(NewVType))
-    return false;
-  return true;
-}
-
 /// Return the fields and properties demanded by the provided instruction.
 DemandedFields getDemanded(const MachineInstr &MI,
                            const MachineRegisterInfo *MRI,
@@ -468,7 +414,7 @@ class VSETVLIInfo {
   bool isUnknown() const { return State == Unknown; }
 
   void setAVLReg(Register Reg) {
-    assert(Reg.isVirtual() || Reg == RISCV::X0 || Reg == RISCV::NoRegister);
+    assert(Reg.isVirtual() || Reg == RISCV::X0);
     AVLReg = Reg;
     State = AVLIsReg;
   }
@@ -594,9 +540,43 @@ class VSETVLIInfo {
     return getSEWLMULRatio() == Other.getSEWLMULRatio();
   }
 
+  /// Return true if moving from Require to this is indistinguishable from the
+  /// perspective of an instruction (or set of instructions) which use only the
+  /// Used subfields and properties.
   bool hasCompatibleVTYPE(const DemandedFields &Used,
                           const VSETVLIInfo &Require) const {
-    return areCompatibleVTYPEs(Require.encodeVTYPE(), encodeVTYPE(), Used);
+    if (SEWLMULRatioOnly && (Used.SEW != DemandedFields::SEWNone ||
+                             Used.MaskPolicy || Used.TailPolicy || Used.LMUL))
+      return false;
+
+    switch (Used.SEW) {
+    case DemandedFields::SEWNone:
+      break;
+    case DemandedFields::SEWEqual:
+      if (Require.SEW != SEW)
+        return false;
+      break;
+    case DemandedFields::SEWGreaterThanOrEqual:
+      if (SEW < Require.SEW)
+        return false;
+      break;
+    case DemandedFields::SEWGreaterThanOrEqualAndLessThan64:
+      if (SEW < Require.SEW || SEW >= 64)
+        return false;
+      break;
+    }
+
+    if (Used.LMUL && Require.VLMul != VLMul)
+      return false;
+
+    if (Used.SEWLMULRatio && Require.getSEWLMULRatio() != getSEWLMULRatio())
+      return false;
+
+    if (Used.TailPolicy && Require.TailAgnostic != TailAgnostic)
+      return false;
+    if (Used.MaskPolicy && Require.MaskAgnostic != MaskAgnostic)
+      return false;
+    return true;
   }
 
   // Determine whether the vector instructions requirements represented by
@@ -612,11 +592,7 @@ class VSETVLIInfo {
     if (isUnknown() || Require.isUnknown())
       return false;
 
-    // If only our VLMAX ratio is valid, then this isn't compatible.
-    if (SEWLMULRatioOnly)
-      return false;
-
-    if (Used.VLAny && !hasSameAVL(Require))
+    if (Used.VLAny && !(hasSameAVL(Require) && hasSameVLMAX(Require)))
       return false;
 
     if (Used.VLZeroness && !hasEquallyZeroAVL(Require, MRI))
@@ -764,8 +740,8 @@ class RISCVInsertVSETVLI : public MachineFunctionPass {
 private:
   bool needVSETVLI(const MachineInstr &MI, const VSETVLIInfo &Require,
                    const VSETVLIInfo &CurInfo) const;
-  bool needVSETVLIPHI(const VSETVLIInfo &Require,
-                      const MachineBasicBlock &MBB) const;
+  bool needVSETVLIPHI(const VSETVLIInfo &Require, const MachineBasicBlock &MBB,
+                      const DemandedFields &Used) const;
   void insertVSETVLI(MachineBasicBlock &MBB, MachineInstr &MI,
                      const VSETVLIInfo &Info, const VSETVLIInfo &PrevInfo);
   void insertVSETVLI(MachineBasicBlock &MBB,
@@ -780,6 +756,7 @@ class RISCVInsertVSETVLI : public MachineFunctionPass {
   void emitVSETVLIs(MachineBasicBlock &MBB);
   void doLocalPostpass(MachineBasicBlock &MBB);
   void doPRE(MachineBasicBlock &MBB);
+  void convertToX0X0(MachineBasicBlock &MBB);
   void insertReadVL(MachineBasicBlock &MBB);
 };
 
@@ -792,7 +769,8 @@ INITIALIZE_PASS(RISCVInsertVSETVLI, DEBUG_TYPE, RISCV_INSERT_VSETVLI_NAME,
 
 // Return a VSETVLIInfo representing the changes made by this VSETVLI or
 // VSETIVLI instruction.
-static VSETVLIInfo getInfoForVSETVLI(const MachineInstr &MI) {
+static VSETVLIInfo getInfoForVSETVLI(const MachineInstr &MI,
+                                     const MachineRegisterInfo &MRI) {
   VSETVLIInfo NewInfo;
   if (MI.getOpcode() == RISCV::PseudoVSETIVLI) {
     NewInfo.setAVLImm(MI.getOperand(1).getImm());
@@ -806,6 +784,17 @@ static VSETVLIInfo getInfoForVSETVLI(const MachineInstr &MI) {
   }
   NewInfo.setVTYPE(MI.getOperand(2).getImm());
 
+  // FIXME: store the def of AVL instead of the register in VSETVLIInfo so we
+  // don't need to peek through here with MRI.
+  if (NewInfo.hasAVLReg() && NewInfo.getAVLReg().isVirtual()) {
+    if (MachineInstr *AVLDef = MRI.getUniqueVRegDef(NewInfo.getAVLReg());
+        AVLDef && isVectorConfigInstr(*AVLDef)) {
+      VSETVLIInfo DefInfo = getInfoForVSETVLI(*AVLDef, MRI);
+      if (DefInfo.hasSameVLMAX(NewInfo))
+        NewInfo.setAVL(DefInfo);
+    }
+  }
+
   return NewInfo;
 }
 
@@ -878,7 +867,7 @@ static VSETVLIInfo computeInfoForInstr(const MachineInstr &MI, uint64_t TSFlags,
     }
   } else {
     assert(isScalarExtractInstr(MI));
-    InstrInfo.setAVLReg(RISCV::NoRegister);
+    InstrInfo.setAVLImm(1);
   }
 #ifndef NDEBUG
   if (std::optional<unsigned> EEW = getEEWForLoadStore(MI)) {
@@ -894,7 +883,7 @@ static VSETVLIInfo computeInfoForInstr(const MachineInstr &MI, uint64_t TSFlags,
   if (InstrInfo.hasAVLReg() && InstrInfo.getAVLReg().isVirtual()) {
     MachineInstr *DefMI = MRI->getVRegDef(InstrInfo.getAVLReg());
     if (DefMI && isVectorConfigInstr(*DefMI)) {
-      VSETVLIInfo DefInstrInfo = getInfoForVSETVLI(*DefMI);
+      VSETVLIInfo DefInstrInfo = getInfoForVSETVLI(*DefMI, *MRI);
       if (DefInstrInfo.hasSameVLMAX(InstrInfo) &&
           (DefInstrInfo.hasAVLImm() || DefInstrInfo.getAVLReg() == RISCV::X0)) {
         InstrInfo.setAVL(DefInstrInfo);
@@ -917,38 +906,6 @@ void RISCVInsertVSETVLI::insertVSETVLI(MachineBasicBlock &MBB,
                      const VSETVLIInfo &Info, const VSETVLIInfo &PrevInfo) {
 
   ++NumInsertedVSETVL;
-  if (PrevInfo.isValid() && !PrevInfo.isUnknown()) {
-    // Use X0, X0 form if the AVL is the same and the SEW+LMUL gives the same
-    // VLMAX.
-    if (Info.hasSameAVL(PrevInfo) && Info.hasSameVLMAX(PrevInfo)) {
-      BuildMI(MBB, InsertPt, DL, TII->get(RISCV::PseudoVSETVLIX0))
-          .addReg(RISCV::X0, RegState::Define | RegState::Dead)
-          .addReg(RISCV::X0, RegState::Kill)
-          .addImm(Info.encodeVTYPE())
-          .addReg(RISCV::VL, RegState::Implicit);
-      return;
-    }
-
-    // If our AVL is a virtual register, it might be defined by a VSET(I)VLI. If
-    // it has the same VLMAX we want and the last VL/VTYPE we observed is the
-    // same, we can use the X0, X0 form.
-    if (Info.hasSameVLMAX(PrevInfo) && Info.hasAVLReg() &&
-        Info.getAVLReg().isVirtual()) {
-      if (MachineInstr *DefMI = MRI->getVRegDef(Info.getAVLReg())) {
-        if (isVectorConfigInstr(*DefMI)) {
-          VSETVLIInfo DefInfo = getInfoForVSETVLI(*DefMI);
-          if (DefInfo.hasSameAVL(PrevInfo) && DefInfo.hasSameVLMAX(PrevInfo)) {
-            BuildMI(MBB, InsertPt, DL, TII->get(RISCV::PseudoVSETVLIX0))
-                .addReg(RISCV::X0, RegState::Define | RegState::Dead)
-                .addReg(RISCV::X0, RegState::Kill)
-                .addImm(Info.encodeVTYPE())
-                .addReg(RISCV::VL, RegState::Implicit);
-            return;
-          }
-        }
-      }
-    }
-  }
 
   if (Info.hasAVLImm()) {
     BuildMI(MBB, InsertPt, DL, TII->get(RISCV::PseudoVSETIVLI))
@@ -959,26 +916,6 @@ void RISCVInsertVSETVLI::insertVSETVLI(MachineBasicBlock &MBB,
   }
 
   Register AVLReg = Info.getAVLReg();
-  if (AVLReg == RISCV::NoRegister) {
-    // We can only use x0, x0 if there's no chance of the vtype change causing
-    // the previous vl to become invalid.
-    if (PrevInfo.isValid() && !PrevInfo.isUnknown() &&
-        Info.hasSameVLMAX(PrevInfo)) {
-      BuildMI(MBB, InsertPt, DL, TII->get(RISCV::PseudoVSETVLIX0))
-          .addReg(RISCV::X0, RegState::Define | RegState::Dead)
-          .addReg(RISCV::X0, RegState::Kill)
-          .addImm(Info.encodeVTYPE())
-          .addReg(RISCV::VL, RegState::Implicit);
-      return;
-    }
-    // Otherwise use an AVL of 1 to avoid depending on previous vl.
-    BuildMI(MBB, InsertPt, DL, TII->get(RISCV::PseudoVSETIVLI))
-        .addReg(RISCV::X0, RegState::Define | RegState::Dead)
-        .addImm(1)
-        .addImm(Info.encodeVTYPE());
-    return;
-  }
-
   if (AVLReg.isVirtual())
     MRI->constrainRegClass(AVLReg, &RISCV::GPRNoX0RegClass);
 
@@ -1058,7 +995,7 @@ bool RISCVInsertVSETVLI::needVSETVLI(const MachineInstr &MI,
       CurInfo.hasCompatibleVTYPE(Used, Require)) {
     if (MachineInstr *DefMI = MRI->getVRegDef(Require.getAVLReg())) {
       if (isVectorConfigInstr(*DefMI)) {
-        VSETVLIInfo DefInfo = getInfoForVSETVLI(*DefMI);
+        VSETVLIInfo DefInfo = getInfoForVSETVLI(*DefMI, *MRI);
         if (DefInfo.hasSameAVL(CurInfo) && DefInfo.hasSameVLMAX(CurInfo))
           return false;
       }
@@ -1145,7 +1082,7 @@ void RISCVInsertVSETVLI::transferBefore(VSETVLIInfo &Info,
 void RISCVInsertVSETVLI::transferAfter(VSETVLIInfo &Info,
                                        const MachineInstr &MI) const {
   if (isVectorConfigInstr(MI)) {
-    Info = getInfoForVSETVLI(MI);
+    Info = getInfoForVSETVLI(MI, *MRI);
     return;
   }
 
@@ -1237,7 +1174,8 @@ void RISCVInsertVSETVLI::computeIncomingVLVTYPE(const MachineBasicBlock &MBB) {
 // be unneeded if the AVL is a phi node where all incoming values are VL
 // outputs from the last VSETVLI in their respective basic blocks.
 bool RISCVInsertVSETVLI::needVSETVLIPHI(const VSETVLIInfo &Require,
-                                        const MachineBasicBlock &MBB) const {
+                                        const MachineBasicBlock &MBB,
+                                        const DemandedFields &Used) const {
   if (DisableInsertVSETVLPHIOpt)
     return true;
 
@@ -1260,7 +1198,8 @@ bool RISCVInsertVSETVLI::needVSETVLIPHI(const VSETVLIInfo &Require,
     const BlockData &PBBInfo = BlockInfo[PBB->getNumber()];
     // If the exit from the predecessor has the VTYPE we are looking for
     // we might be able to avoid a VSETVLI.
-    if (PBBInfo.Exit.isUnknown() || !PBBInfo.Exit.hasSameVTYPE(Require))
+    if (PBBInfo.Exit.isUnknown() ||
+        !PBBInfo.Exit.hasCompatibleVTYPE(Used, Require))
       return true;
 
     // We need the PHI input to the be the output of a VSET(I)VLI.
@@ -1270,9 +1209,8 @@ bool RISCVInsertVSETVLI::needVSETVLIPHI(const VSETVLIInfo &Require,
 
     // We found a VSET(I)VLI make sure it matches the output of the
     // predecessor block.
-    VSETVLIInfo DefInfo = getInfoForVSETVLI(*DefMI);
-    if (!DefInfo.hasSameAVL(PBBInfo.Exit) ||
-        !DefInfo.hasSameVTYPE(PBBInfo.Exit))
+    VSETVLIInfo DefInfo = getInfoForVSETVLI(*DefMI, *MRI);
+    if (!DefInfo.isCompatible(Used, PBBInfo.Exit, *MRI))
       return true;
   }
 
@@ -1311,7 +1249,8 @@ void RISCVInsertVSETVLI::emitVSETVLIs(MachineBasicBlock &MBB) {
         // wouldn't be used and VL/VTYPE registers are correct.  Note that
         // we *do* need to model the state as if it changed as while the
         // register contents are unchanged, the abstract model can change.
-        if (!PrefixTransparent || needVSETVLIPHI(CurInfo, MBB))
+        if (!PrefixTransparent ||
+            needVSETVLIPHI(CurInfo, MBB, getDemanded(MI, MRI, ST)))
           insertVSETVLI(MBB, MI, CurInfo, PrevInfo);
         PrefixTransparent = false;
       }
@@ -1488,44 +1427,6 @@ static void doUnion(DemandedFields &A, DemandedFields B) {
   A.MaskPolicy |= B.MaskPolicy;
 }
 
-// Return true if we can mutate PrevMI to match MI without changing any the
-// fields which would be observed.
-static bool canMutatePriorConfig(const MachineInstr &PrevMI,
-                                 const MachineInstr &MI,
-                                 const DemandedFields &Used,
-                                 const MachineRegisterInfo &MRI) {
-  // If the VL values aren't equal, return false if either a) the former is
-  // demanded, or b) we can't rewrite the former to be the later for
-  // implementation reasons.
-  if (!isVLPreservingConfig(MI)) {
-    if (Used.VLAny)
-      return false;
-
-    if (Used.VLZeroness) {
-      if (isVLPreservingConfig(PrevMI))
-        return false;
-      if (!getInfoForVSETVLI(PrevMI).hasEquallyZeroAVL(getInfoForVSETVLI(MI),
-                                                       MRI))
-        return false;
-    }
-
-    auto &AVL = MI.getOperand(1);
-    auto &PrevAVL = PrevMI.getOperand(1);
-    assert(MRI.isSSA());
-
-    // If the AVL is a register, we need to make sure MI's AVL dominates PrevMI.
-    // For now just check that PrevMI uses the same virtual register.
-    if (AVL.isReg() && AVL.getReg() != RISCV::X0 &&
-        (!PrevAVL.isReg() || PrevAVL.getReg() != AVL.getReg()))
-      return false;
-  }
-
-  assert(PrevMI.getOperand(2).isImm() && MI.getOperand(2).isImm());
-  auto PriorVType = PrevMI.getOperand(2).getImm();
-  auto VType = MI.getOperand(2).getImm();
-  return areCompatibleVTYPEs(PriorVType, VType, Used);
-}
-
 void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
   MachineInstr *NextMI = nullptr;
   // We can have arbitrary code in successors, so VL and VTYPE
@@ -1556,25 +1457,42 @@ void RISCVInsertVSETVLI::doLocalPostpass(MachineBasicBlock &MBB) {
         continue;
       }
 
-      if (canMutatePriorConfig(MI, *NextMI, Used, *MRI)) {
-        if (!isVLPreservingConfig(*NextMI)) {
-          MI.getOperand(0).setReg(NextMI->getOperand(0).getReg());
-          MI.getOperand(0).setIsDead(false);
-          Register OldVLReg;
-          if (MI.getOperand(1).isReg())
-            OldVLReg = MI.getOperand(1).getReg();
-          if (NextMI->getOperand(1).isImm())
-            MI.getOperand(1).ChangeToImmediate(NextMI->getOperand(1).getImm());
-          else
-            MI.getOperand(1).ChangeToRegister(NextMI->getOperand(1).getReg(), false);
-          if (OldVLReg) {
-            MachineInstr *VLOpDef = MRI->getUniqueVRegDef(OldVLReg);
-            if (VLOpDef && TII->isAddImmediate(*VLOpDef, OldVLReg) &&
-                MRI->use_nodbg_empty(OldVLReg))
-              VLOpDef->eraseFromParent();
-          }
-          MI.setDesc(NextMI->getDesc());
+      const VSETVLIInfo MIInfo = getInfoForVSETVLI(MI, *MRI);
+      const VSETVLIInfo NextMIInfo = getInfoForVSETVLI(*NextMI, *MRI);
+
+      // If the new AVL is a register make sure it dominates PrevMI. For now
+      // just check that it's the same AVL used by PrevMI.
+      bool NewAVLDominates = true;
+      if (NextMIInfo.hasAVLReg() && NextMIInfo.getAVLReg().isVirtual())
+        NewAVLDominates = MIInfo.hasSameAVL(NextMIInfo);
+
+      // We are coalescing two vsetvlis into one, so at least one of the defs
+      // will need to be dead.
+      const MachineOperand *DefOp = nullptr;
+      if (MI.getOperand(0).isDead())
+        DefOp = &NextMI->getOperand(0);
+      else if (NextMI->getOperand(0).isDead())
+        DefOp = &MI.getOperand(0);
+
+      if (NextMIInfo.isCompatible(Used, MIInfo, *MRI) && NewAVLDominates &&
+          DefOp) {
+        MI.getOperand(0).setReg(DefOp->getReg());
+        MI.getOperand(0).setIsDead(DefOp->isDead());
+        Register OldVLReg;
+        if (MI.getOperand(1).isReg())
+          OldVLReg = MI.getOperand(1).getReg();
+        if (NextMIInfo.hasAVLImm())
+          MI.getOperand(1).ChangeToImmediate(NextMIInfo.getAVLImm());
+        else
+          MI.getOperand(1).ChangeToRegister(NextMIInfo.getAVLReg(), false);
+        if (OldVLReg) {
+          MachineInstr *VLOpDef = MRI->getUniqueVRegDef(OldVLReg);
+          if (VLOpDef && TII->isAddImmediate(*VLOpDef, OldVLReg) &&
+              MRI->use_nodbg_empty(OldVLReg))
+            VLOpDef->eraseFromParent();
         }
+        MI.setDesc(NextMI->getDesc());
+
         MI.getOperand(2).setImm(NextMI->getOperand(2).getImm());
         ToDelete.push_back(NextMI);
         // fallthrough
@@ -1603,6 +1521,34 @@ void RISCVInsertVSETVLI::insertReadVL(MachineBasicBlock &MBB) {
   }
 }
 
+void RISCVInsertVSETVLI::convertToX0X0(MachineBasicBlock &MBB) {
+  VSETVLIInfo Info = BlockInfo[MBB.getNumber()].Pred;
+  for (MachineInstr &MI : MBB) {
+    if (isVectorConfigInstr(MI)) {
+      VSETVLIInfo MIInfo = getInfoForVSETVLI(MI, *MRI);
+
+      // If VL doesn't change going from Info to MIInfo, then we can use x0,x0
+      DemandedFields Demanded;
+      Demanded.demandVL();
+      bool HasSameVL = Info.isCompatible(Demanded, MIInfo, *MRI);
+      // An AVL from a phi node where the incoming values are the output vls of
+      // the last vsetvlis in a block doesn't change the VL.
+      HasSameVL |= !needVSETVLIPHI(MIInfo, MBB, Demanded);
+
+      if (HasSameVL && MI.getOperand(0).isDead()) {
+        MI.setDesc(TII->get(RISCV::PseudoVSETVLIX0));
+        MI.getOperand(0).ChangeToRegister(RISCV::X0, /*isDef*/ true);
+        MI.getOperand(0).setIsDead(true);
+        MI.getOperand(1).ChangeToRegister(RISCV::X0, /*isDef*/ false);
+        MI.getOperand(1).setIsKill(true);
+        Info = MIInfo; // transferAfter can't handle x0,x0
+        continue;
+      }
+    }
+    transferAfter(Info, MI);
+  }
+}
+
 bool RISCVInsertVSETVLI::runOnMachineFunction(MachineFunction &MF) {
   // Skip if the vector extension is not enabled.
   ST = &MF.getSubtarget<RISCVSubtarget>();
@@ -1670,6 +1616,13 @@ bool RISCVInsertVSETVLI::runOnMachineFunction(MachineFunction &MF) {
   for (MachineBasicBlock &MBB : MF)
     doLocalPostpass(MBB);
 
+  // Find vset[i]vlis that don't change VL and replace them with vsetvli x0,x0.
+  // Defer this to the end rather than during vsetvli insertion so we don't lose
+  // any information about the AVL which may help us coalesce them in
+  // doLocalPostpass.
+  for (MachineBasicBlock &MBB : MF)
+    convertToX0X0(MBB);
+
   // Insert PseudoReadVL after VLEFF/VLSEGFF and replace it with the vl output
   // of VLEFF/VLSEGFF.
   f...
[truncated]

@@ -806,6 +784,17 @@ static VSETVLIInfo getInfoForVSETVLI(const MachineInstr &MI) {
}
NewInfo.setVTYPE(MI.getOperand(2).getImm());

// FIXME: store the def of AVL instead of the register in VSETVLIInfo so we
// don't need to peek through here with MRI.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BeMg This is the part that I'm hoping we can get rid of if we track the definition instead of the register in VSETVLIInfo. There's also something similar in computeInfoForInstr where we check the AVL def.

Currently we try and detect when the VL doesn't change between two vsetvlis in emitVSETVLIs, and insert a VL-preserving vsetvli x0,x0 then and there.

Doing it in situ has some drawbacks:

- We lose information about what the VL is which can prevent doLocalPostpass from coalescing some vsetvlis further down the line
- We have to handle emitting x0,x0 for vmv.x.s and friends by plumbing through a dummy NoRegister AVL
- Other parts of the code need to be aware of x0,x0 vsetvlis and work around them. doLocalPostpass needs to handle them specifically, and callers of getInfoForVSETVLI need to check that they don't pass one in

This patch changes emitVSETVLIs to just emit regular vsetvlis, and adds a separate pass after doLocalPostpass to convert vsetvlis to x0,x0 when possible.

By removing the edge cases needed to handle x0,x0s, we can unify how we check vsetvli compatibility between doLocalPostpass and emitInsertVSETVLIs, and remove the duplicated logic in areCompatibleVTYPEs and canMutatePriorConfig.

We can also remove the dummy NoRegister AVL for vmv.x.s and stricten the invariant that the AVL must be either a virtual register or x0. (cc @BeMg: this may be useful for the post-ra patch)

There are some other changes that were difficult to split out:

- In order to handle one specific case in saxpy_vec, we need to be able to detect when the VL doesn't change due to the AVL being a PHI node where incoming values are the output VLs of the last vsetvlis in a block. We generalize needVSETVLIPHI so we can reuse for this
- To prevent regressions we need to teach doLocalPostpass to coalesce vsetvlis where either vsetvli may have an output VL register.
@lukel97
Copy link
Contributor Author

lukel97 commented Apr 25, 2024

Marking as a draft as I can't think of an easy way to rebase this after #88295, as it requires using dataflow analysis after register coalescing which is now in a separate pass. Will revisit once #70549 lands

lukel97 added a commit to lukel97/llvm-project that referenced this pull request May 21, 2024
We no longer need to separate the passes now that llvm#70549 is landed and this will unblock llvm#89089.

It's not strictly NFC because it will move coalescing before register allocation when -riscv-vsetvl-after-rvv-regalloc is disabled. But this makes it closer to the original behaviour.
lukel97 added a commit to lukel97/llvm-project that referenced this pull request May 29, 2024
We no longer need to separate the passes now that llvm#70549 is landed and this will unblock llvm#89089.

It's not strictly NFC because it will move coalescing before register allocation when -riscv-vsetvl-after-rvv-regalloc is disabled. But this makes it closer to the original behaviour.
lukel97 added a commit that referenced this pull request May 29, 2024
We no longer need to separate the passes now that #70549 is landed and
this will unblock #89089.

It's not strictly NFC because it will move coalescing before register
allocation when -riscv-vsetvl-after-rvv-regalloc is disabled. But this
makes it closer to the original behaviour.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants