diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 5c1f154efa991..3efd09aeae879 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1633,8 +1633,230 @@ static bool isFMUL(unsigned Opc) { } } +bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst, + bool Invert) const { +#define OPCODE_LMUL_CASE(OPC) \ + case RISCV::OPC##_M1: \ + case RISCV::OPC##_M2: \ + case RISCV::OPC##_M4: \ + case RISCV::OPC##_M8: \ + case RISCV::OPC##_MF2: \ + case RISCV::OPC##_MF4: \ + case RISCV::OPC##_MF8 + +#define OPCODE_LMUL_MASK_CASE(OPC) \ + case RISCV::OPC##_M1_MASK: \ + case RISCV::OPC##_M2_MASK: \ + case RISCV::OPC##_M4_MASK: \ + case RISCV::OPC##_M8_MASK: \ + case RISCV::OPC##_MF2_MASK: \ + case RISCV::OPC##_MF4_MASK: \ + case RISCV::OPC##_MF8_MASK + + unsigned Opcode = Inst.getOpcode(); + if (Invert) { + if (auto InvOpcode = getInverseOpcode(Opcode)) + Opcode = *InvOpcode; + else + return false; + } + + // clang-format off + switch (Opcode) { + default: + return false; + OPCODE_LMUL_CASE(PseudoVADD_VV): + OPCODE_LMUL_MASK_CASE(PseudoVADD_VV): + OPCODE_LMUL_CASE(PseudoVMUL_VV): + OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV): + return true; + } + // clang-format on + +#undef OPCODE_LMUL_MASK_CASE +#undef OPCODE_LMUL_CASE +} + +bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &Root, + const MachineInstr &Prev) const { + if (!areOpcodesEqualOrInverse(Root.getOpcode(), Prev.getOpcode())) + return false; + + assert(Root.getMF() == Prev.getMF()); + const MachineRegisterInfo *MRI = &Root.getMF()->getRegInfo(); + const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo(); + + // Make sure vtype operands are also the same. + const MCInstrDesc &Desc = get(Root.getOpcode()); + const uint64_t TSFlags = Desc.TSFlags; + + auto checkImmOperand = [&](unsigned OpIdx) { + return Root.getOperand(OpIdx).getImm() == Prev.getOperand(OpIdx).getImm(); + }; + + auto checkRegOperand = [&](unsigned OpIdx) { + return Root.getOperand(OpIdx).getReg() == Prev.getOperand(OpIdx).getReg(); + }; + + // PassThru + // TODO: Potentially we can loosen the condition to consider Root to be + // associable with Prev if Root has NoReg as passthru. In which case we + // also need to loosen the condition on vector policies between these. + if (!checkRegOperand(1)) + return false; + + // SEW + if (RISCVII::hasSEWOp(TSFlags) && + !checkImmOperand(RISCVII::getSEWOpNum(Desc))) + return false; + + // Mask + if (RISCVII::usesMaskPolicy(TSFlags)) { + const MachineBasicBlock *MBB = Root.getParent(); + const MachineBasicBlock::const_reverse_iterator It1(&Root); + const MachineBasicBlock::const_reverse_iterator It2(&Prev); + Register MI1VReg; + + bool SeenMI2 = false; + for (auto End = MBB->rend(), It = It1; It != End; ++It) { + if (It == It2) { + SeenMI2 = true; + if (!MI1VReg.isValid()) + // There is no V0 def between Root and Prev; they're sharing the + // same V0. + break; + } + + if (It->modifiesRegister(RISCV::V0, TRI)) { + Register SrcReg = It->getOperand(1).getReg(); + // If it's not VReg it'll be more difficult to track its defs, so + // bailing out here just to be safe. + if (!SrcReg.isVirtual()) + return false; + + if (!MI1VReg.isValid()) { + // This is the V0 def for Root. + MI1VReg = SrcReg; + continue; + } + + // Some random mask updates. + if (!SeenMI2) + continue; + + // This is the V0 def for Prev; check if it's the same as that of + // Root. + if (MI1VReg != SrcReg) + return false; + else + break; + } + } + + // If we haven't encountered Prev, it's likely that this function was + // called in a wrong way (e.g. Root is before Prev). + assert(SeenMI2 && "Prev is expected to appear before Root"); + } + + // Tail / Mask policies + if (RISCVII::hasVecPolicyOp(TSFlags) && + !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc))) + return false; + + // VL + if (RISCVII::hasVLOp(TSFlags)) { + unsigned OpIdx = RISCVII::getVLOpNum(Desc); + const MachineOperand &Op1 = Root.getOperand(OpIdx); + const MachineOperand &Op2 = Prev.getOperand(OpIdx); + if (Op1.getType() != Op2.getType()) + return false; + switch (Op1.getType()) { + case MachineOperand::MO_Register: + if (Op1.getReg() != Op2.getReg()) + return false; + break; + case MachineOperand::MO_Immediate: + if (Op1.getImm() != Op2.getImm()) + return false; + break; + default: + llvm_unreachable("Unrecognized VL operand type"); + } + } + + // Rounding modes + if (RISCVII::hasRoundModeOp(TSFlags) && + !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1)) + return false; + + return true; +} + +// Most of our RVV pseudos have passthru operand, so the real operands +// start from index = 2. +bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst, + bool &Commuted) const { + const MachineBasicBlock *MBB = Inst.getParent(); + const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo(); + assert(RISCVII::isFirstDefTiedToFirstUse(get(Inst.getOpcode())) && + "Expect the present of passthrough operand."); + MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg()); + MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg()); + + // If only one operand has the same or inverse opcode and it's the second + // source operand, the operands must be commuted. + Commuted = !areRVVInstsReassociable(Inst, *MI1) && + areRVVInstsReassociable(Inst, *MI2); + if (Commuted) + std::swap(MI1, MI2); + + return areRVVInstsReassociable(Inst, *MI1) && + (isVectorAssociativeAndCommutative(*MI1) || + isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) && + hasReassociableOperands(*MI1, MBB) && + MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg()); +} + +bool RISCVInstrInfo::hasReassociableOperands( + const MachineInstr &Inst, const MachineBasicBlock *MBB) const { + if (!isVectorAssociativeAndCommutative(Inst) && + !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true)) + return TargetInstrInfo::hasReassociableOperands(Inst, MBB); + + const MachineOperand &Op1 = Inst.getOperand(2); + const MachineOperand &Op2 = Inst.getOperand(3); + const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo(); + + // We need virtual register definitions for the operands that we will + // reassociate. + MachineInstr *MI1 = nullptr; + MachineInstr *MI2 = nullptr; + if (Op1.isReg() && Op1.getReg().isVirtual()) + MI1 = MRI.getUniqueVRegDef(Op1.getReg()); + if (Op2.isReg() && Op2.getReg().isVirtual()) + MI2 = MRI.getUniqueVRegDef(Op2.getReg()); + + // And at least one operand must be defined in MBB. + return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB); +} + +void RISCVInstrInfo::getReassociateOperandIndices( + const MachineInstr &Root, unsigned Pattern, + std::array &OperandIndices) const { + TargetInstrInfo::getReassociateOperandIndices(Root, Pattern, OperandIndices); + if (RISCV::getRVVMCOpcode(Root.getOpcode())) { + // Skip the passthrough operand, so increment all indices by one. + for (unsigned I = 0; I < 5; ++I) + ++OperandIndices[I]; + } +} + bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst, bool &Commuted) const { + if (isVectorAssociativeAndCommutative(Inst) || + isVectorAssociativeAndCommutative(Inst, /*Invert=*/true)) + return hasReassociableVectorSibling(Inst, Commuted); + if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted)) return false; @@ -1654,6 +1876,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst, bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst, bool Invert) const { + if (isVectorAssociativeAndCommutative(Inst, Invert)) + return true; + unsigned Opc = Inst.getOpcode(); if (Invert) { auto InverseOpcode = getInverseOpcode(Opc); @@ -1706,6 +1931,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst, std::optional RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const { +#define RVV_OPC_LMUL_CASE(OPC, INV) \ + case RISCV::OPC##_M1: \ + return RISCV::INV##_M1; \ + case RISCV::OPC##_M2: \ + return RISCV::INV##_M2; \ + case RISCV::OPC##_M4: \ + return RISCV::INV##_M4; \ + case RISCV::OPC##_M8: \ + return RISCV::INV##_M8; \ + case RISCV::OPC##_MF2: \ + return RISCV::INV##_MF2; \ + case RISCV::OPC##_MF4: \ + return RISCV::INV##_MF4; \ + case RISCV::OPC##_MF8: \ + return RISCV::INV##_MF8 + +#define RVV_OPC_LMUL_MASK_CASE(OPC, INV) \ + case RISCV::OPC##_M1_MASK: \ + return RISCV::INV##_M1_MASK; \ + case RISCV::OPC##_M2_MASK: \ + return RISCV::INV##_M2_MASK; \ + case RISCV::OPC##_M4_MASK: \ + return RISCV::INV##_M4_MASK; \ + case RISCV::OPC##_M8_MASK: \ + return RISCV::INV##_M8_MASK; \ + case RISCV::OPC##_MF2_MASK: \ + return RISCV::INV##_MF2_MASK; \ + case RISCV::OPC##_MF4_MASK: \ + return RISCV::INV##_MF4_MASK; \ + case RISCV::OPC##_MF8_MASK: \ + return RISCV::INV##_MF8_MASK + switch (Opcode) { default: return std::nullopt; @@ -1729,7 +1986,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const { return RISCV::SUBW; case RISCV::SUBW: return RISCV::ADDW; + // clang-format off + RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV); + RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV); + RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV); + RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV); + // clang-format on } + +#undef RVV_OPC_LMUL_MASK_CASE +#undef RVV_OPC_LMUL_CASE } static bool canCombineFPFusedMultiply(const MachineInstr &Root, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 3b03d5efde6ef..170f813eb10d7 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -266,6 +266,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo { SmallVectorImpl &DelInstrs, DenseMap &InstrIdxForVirtReg) const override; + bool hasReassociableOperands(const MachineInstr &Inst, + const MachineBasicBlock *MBB) const override; + bool hasReassociableSibling(const MachineInstr &Inst, bool &Commuted) const override; @@ -274,6 +277,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo { std::optional getInverseOpcode(unsigned Opcode) const override; + void getReassociateOperandIndices( + const MachineInstr &Root, unsigned Pattern, + std::array &OperandIndices) const override; + ArrayRef> getSerializableMachineMemOperandTargetFlags() const override; @@ -297,6 +304,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo { private: unsigned getInstBundleLength(const MachineInstr &MI) const; + + bool isVectorAssociativeAndCommutative(const MachineInstr &MI, + bool Invert = false) const; + bool areRVVInstsReassociable(const MachineInstr &MI1, + const MachineInstr &MI2) const; + bool hasReassociableVectorSibling(const MachineInstr &Inst, + bool &Commuted) const; }; namespace RISCV { diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll index 3cb6f3c35286c..6435c1c14e061 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll @@ -31,7 +31,7 @@ define @simple_vadd_vv( %0, ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma ; CHECK-NEXT: vadd.vv v9, v8, v9 -; CHECK-NEXT: vadd.vv v9, v8, v9 +; CHECK-NEXT: vadd.vv v8, v8, v8 ; CHECK-NEXT: vadd.vv v8, v8, v9 ; CHECK-NEXT: ret entry: @@ -61,7 +61,7 @@ define @simple_vadd_vsub_vv( %0, @simple_vmul_vv( %0, ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma ; CHECK-NEXT: vmul.vv v9, v8, v9 -; CHECK-NEXT: vmul.vv v9, v8, v9 +; CHECK-NEXT: vmul.vv v8, v8, v8 ; CHECK-NEXT: vmul.vv v8, v8, v9 ; CHECK-NEXT: ret entry: @@ -124,8 +124,8 @@ define @vadd_vv_passthru( %0, @llvm.riscv.vadd.nxv1i8.nxv1i8( @@ -187,8 +187,8 @@ define @vadd_vv_mask( %0, % ; CHECK-NEXT: vmv1r.v v10, v8 ; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t ; CHECK-NEXT: vmv1r.v v9, v8 -; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t -; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t +; CHECK-NEXT: vadd.vv v9, v8, v8, v0.t +; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t ; CHECK-NEXT: ret entry: %a = call @llvm.riscv.vadd.mask.nxv1i8.nxv1i8( @@ -215,15 +215,16 @@ entry: ret %c } -define @vadd_vv_mask_negative( %0, %1, i32 %2, %m) nounwind { +define @vadd_vv_mask_negative( %0, %1, i32 %2, %m, %m2) nounwind { ; CHECK-LABEL: vadd_vv_mask_negative: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu -; CHECK-NEXT: vmv1r.v v10, v8 -; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v11, v8 +; CHECK-NEXT: vadd.vv v11, v8, v9, v0.t ; CHECK-NEXT: vmv1r.v v9, v8 -; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t -; CHECK-NEXT: vadd.vv v8, v8, v9 +; CHECK-NEXT: vadd.vv v9, v8, v11, v0.t +; CHECK-NEXT: vmv1r.v v0, v10 +; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t ; CHECK-NEXT: ret entry: %a = call @llvm.riscv.vadd.mask.nxv1i8.nxv1i8( @@ -240,8 +241,6 @@ entry: %m, i32 %2, i32 1) - %splat = insertelement poison, i1 1, i32 0 - %m2 = shufflevector %splat, poison, zeroinitializer %c = call @llvm.riscv.vadd.mask.nxv1i8.nxv1i8( %0, %0,