diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp index e1ff243bb1a47..7982b2956a642 100644 --- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp +++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp @@ -73,7 +73,7 @@ class RISCVVectorPeephole : public MachineFunctionPass { bool isAllOnesMask(const MachineInstr *MaskDef) const; std::optional getConstant(const MachineOperand &VL) const; bool ensureDominates(const MachineOperand &Use, MachineInstr &Src) const; - bool isKnownSameDefs(Register A, Register B) const; + Register lookThruCopies(Register Reg) const; }; } // namespace @@ -387,23 +387,18 @@ bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const { return true; } -bool RISCVVectorPeephole::isKnownSameDefs(Register A, Register B) const { - if (A.isPhysical() || B.isPhysical()) - return false; - - auto LookThruVirtRegCopies = [this](Register Reg) { - while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) { - if (!Def->isFullCopy()) - break; - Register Src = Def->getOperand(1).getReg(); - if (!Src.isVirtual()) - break; - Reg = Src; - } - return Reg; - }; - - return LookThruVirtRegCopies(A) == LookThruVirtRegCopies(B); +/// If \p Reg is defined by one or more COPYs of virtual registers, traverses +/// the chain and returns the root non-COPY source. +Register RISCVVectorPeephole::lookThruCopies(Register Reg) const { + while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) { + if (!Def->isFullCopy()) + break; + Register Src = Def->getOperand(1).getReg(); + if (!Src.isVirtual()) + break; + Reg = Src; + } + return Reg; } /// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the @@ -428,10 +423,11 @@ bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) { if (!TrueMaskedInfo || !hasSameEEW(MI, *True)) return false; - const MachineOperand &TrueMask = - True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs()); - const MachineOperand &MIMask = MI.getOperand(4); - if (!isKnownSameDefs(TrueMask.getReg(), MIMask.getReg())) + Register TrueMaskReg = lookThruCopies( + True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs()) + .getReg()); + Register MIMaskReg = lookThruCopies(MI.getOperand(4).getReg()); + if (!TrueMaskReg.isVirtual() || TrueMaskReg != MIMaskReg) return false; // Masked off lanes past TrueVL will come from False, and converting to vmv @@ -717,9 +713,9 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const { if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMERGE_VVM) return false; - Register PassthruReg = MI.getOperand(1).getReg(); - Register FalseReg = MI.getOperand(2).getReg(); - Register TrueReg = MI.getOperand(3).getReg(); + Register PassthruReg = lookThruCopies(MI.getOperand(1).getReg()); + Register FalseReg = lookThruCopies(MI.getOperand(2).getReg()); + Register TrueReg = lookThruCopies(MI.getOperand(3).getReg()); if (!TrueReg.isVirtual() || !MRI->hasOneUse(TrueReg)) return false; MachineInstr &True = *MRI->getUniqueVRegDef(TrueReg); @@ -740,16 +736,17 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const { // We require that either passthru and false are the same, or that passthru // is undefined. - if (PassthruReg && !isKnownSameDefs(PassthruReg, FalseReg)) + if (PassthruReg && !(PassthruReg.isVirtual() && PassthruReg == FalseReg)) return false; std::optional> NeedsCommute; // If True has a passthru operand then it needs to be the same as vmerge's // False, since False will be used for the result's passthru operand. - Register TruePassthru = True.getOperand(True.getNumExplicitDefs()).getReg(); + Register TruePassthru = + lookThruCopies(True.getOperand(True.getNumExplicitDefs()).getReg()); if (RISCVII::isFirstDefTiedToFirstUse(True.getDesc()) && TruePassthru && - !isKnownSameDefs(TruePassthru, FalseReg)) { + !(TruePassthru.isVirtual() && TruePassthru == FalseReg)) { // If True's passthru != False, check if it uses False in another operand // and try to commute it. int OtherIdx = True.findRegisterUseOperandIdx(FalseReg, TRI); diff --git a/llvm/test/CodeGen/RISCV/rvv/vmerge-peephole.mir b/llvm/test/CodeGen/RISCV/rvv/vmerge-peephole.mir index 374afa3aafdea..670e99c799af5 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vmerge-peephole.mir +++ b/llvm/test/CodeGen/RISCV/rvv/vmerge-peephole.mir @@ -116,3 +116,23 @@ body: | %vfmadd:vrnov0 = nofpexcept PseudoVFMADD_VV_M1_E32 %x, %y, %passthru, 7, -1, 5 /* e32 */, 3 /* ta, ma */, implicit $frm %vmerge:vrnov0 = PseudoVMERGE_VVM_M1 %passthru, %passthru, %vfmadd, %mask, %avl, 5 ... +--- +name: true_copy +body: | + bb.0: + liveins: $x8, $v0, $v8 + ; CHECK-LABEL: name: true_copy + ; CHECK: liveins: $x8, $v0, $v8 + ; CHECK-NEXT: {{ $}} + ; CHECK-NEXT: %avl:gprnox0 = COPY $x8 + ; CHECK-NEXT: %passthru:vrnov0 = COPY $v8 + ; CHECK-NEXT: %mask:vmv0 = COPY $v0 + ; CHECK-NEXT: %z:vrnov0 = PseudoVLE32_V_M1_MASK %passthru, $noreg, %mask, %avl, 5 /* e32 */, 0 /* tu, mu */ :: (load unknown-size, align 1) + ; CHECK-NEXT: %y:vrnov0 = COPY %z + %avl:gprnox0 = COPY $x8 + %passthru:vrnov0 = COPY $v8 + %x:vr = PseudoVLE32_V_M1 $noreg, $noreg, %avl, 5 /* e32 */, 2 /* tu, ma */ :: (load unknown-size) + %mask:vmv0 = COPY $v0 + %y:vrnov0 = COPY %x + %z:vrnov0 = PseudoVMERGE_VVM_M1 %passthru, %passthru, %y, %mask, %avl, 5 /* e32 */ +...