Skip to content

Commit

Permalink
[RISCV][MachineCombiner] Add reassociation optimizations for RVV inst…
Browse files Browse the repository at this point in the history
…ructions (#88307)

This patch covers a really basic reassociation optimizations for VADD_VV and VMUL_VV.
  • Loading branch information
mshockwave committed Apr 25, 2024
1 parent 733b271 commit 5f67ce5
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 14 deletions.
266 changes: 266 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1633,8 +1633,230 @@ static bool isFMUL(unsigned Opc) {
}
}

bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
bool Invert) const {
#define OPCODE_LMUL_CASE(OPC) \
case RISCV::OPC##_M1: \
case RISCV::OPC##_M2: \
case RISCV::OPC##_M4: \
case RISCV::OPC##_M8: \
case RISCV::OPC##_MF2: \
case RISCV::OPC##_MF4: \
case RISCV::OPC##_MF8

#define OPCODE_LMUL_MASK_CASE(OPC) \
case RISCV::OPC##_M1_MASK: \
case RISCV::OPC##_M2_MASK: \
case RISCV::OPC##_M4_MASK: \
case RISCV::OPC##_M8_MASK: \
case RISCV::OPC##_MF2_MASK: \
case RISCV::OPC##_MF4_MASK: \
case RISCV::OPC##_MF8_MASK

unsigned Opcode = Inst.getOpcode();
if (Invert) {
if (auto InvOpcode = getInverseOpcode(Opcode))
Opcode = *InvOpcode;
else
return false;
}

// clang-format off
switch (Opcode) {
default:
return false;
OPCODE_LMUL_CASE(PseudoVADD_VV):
OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
OPCODE_LMUL_CASE(PseudoVMUL_VV):
OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
return true;
}
// clang-format on

#undef OPCODE_LMUL_MASK_CASE
#undef OPCODE_LMUL_CASE
}

bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &Root,
const MachineInstr &Prev) const {
if (!areOpcodesEqualOrInverse(Root.getOpcode(), Prev.getOpcode()))
return false;

assert(Root.getMF() == Prev.getMF());
const MachineRegisterInfo *MRI = &Root.getMF()->getRegInfo();
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();

// Make sure vtype operands are also the same.
const MCInstrDesc &Desc = get(Root.getOpcode());
const uint64_t TSFlags = Desc.TSFlags;

auto checkImmOperand = [&](unsigned OpIdx) {
return Root.getOperand(OpIdx).getImm() == Prev.getOperand(OpIdx).getImm();
};

auto checkRegOperand = [&](unsigned OpIdx) {
return Root.getOperand(OpIdx).getReg() == Prev.getOperand(OpIdx).getReg();
};

// PassThru
// TODO: Potentially we can loosen the condition to consider Root to be
// associable with Prev if Root has NoReg as passthru. In which case we
// also need to loosen the condition on vector policies between these.
if (!checkRegOperand(1))
return false;

// SEW
if (RISCVII::hasSEWOp(TSFlags) &&
!checkImmOperand(RISCVII::getSEWOpNum(Desc)))
return false;

// Mask
if (RISCVII::usesMaskPolicy(TSFlags)) {
const MachineBasicBlock *MBB = Root.getParent();
const MachineBasicBlock::const_reverse_iterator It1(&Root);
const MachineBasicBlock::const_reverse_iterator It2(&Prev);
Register MI1VReg;

bool SeenMI2 = false;
for (auto End = MBB->rend(), It = It1; It != End; ++It) {
if (It == It2) {
SeenMI2 = true;
if (!MI1VReg.isValid())
// There is no V0 def between Root and Prev; they're sharing the
// same V0.
break;
}

if (It->modifiesRegister(RISCV::V0, TRI)) {
Register SrcReg = It->getOperand(1).getReg();
// If it's not VReg it'll be more difficult to track its defs, so
// bailing out here just to be safe.
if (!SrcReg.isVirtual())
return false;

if (!MI1VReg.isValid()) {
// This is the V0 def for Root.
MI1VReg = SrcReg;
continue;
}

// Some random mask updates.
if (!SeenMI2)
continue;

// This is the V0 def for Prev; check if it's the same as that of
// Root.
if (MI1VReg != SrcReg)
return false;
else
break;
}
}

// If we haven't encountered Prev, it's likely that this function was
// called in a wrong way (e.g. Root is before Prev).
assert(SeenMI2 && "Prev is expected to appear before Root");
}

// Tail / Mask policies
if (RISCVII::hasVecPolicyOp(TSFlags) &&
!checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
return false;

// VL
if (RISCVII::hasVLOp(TSFlags)) {
unsigned OpIdx = RISCVII::getVLOpNum(Desc);
const MachineOperand &Op1 = Root.getOperand(OpIdx);
const MachineOperand &Op2 = Prev.getOperand(OpIdx);
if (Op1.getType() != Op2.getType())
return false;
switch (Op1.getType()) {
case MachineOperand::MO_Register:
if (Op1.getReg() != Op2.getReg())
return false;
break;
case MachineOperand::MO_Immediate:
if (Op1.getImm() != Op2.getImm())
return false;
break;
default:
llvm_unreachable("Unrecognized VL operand type");
}
}

// Rounding modes
if (RISCVII::hasRoundModeOp(TSFlags) &&
!checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
return false;

return true;
}

// Most of our RVV pseudos have passthru operand, so the real operands
// start from index = 2.
bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
bool &Commuted) const {
const MachineBasicBlock *MBB = Inst.getParent();
const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
assert(RISCVII::isFirstDefTiedToFirstUse(get(Inst.getOpcode())) &&
"Expect the present of passthrough operand.");
MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());

// If only one operand has the same or inverse opcode and it's the second
// source operand, the operands must be commuted.
Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
areRVVInstsReassociable(Inst, *MI2);
if (Commuted)
std::swap(MI1, MI2);

return areRVVInstsReassociable(Inst, *MI1) &&
(isVectorAssociativeAndCommutative(*MI1) ||
isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
hasReassociableOperands(*MI1, MBB) &&
MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
}

bool RISCVInstrInfo::hasReassociableOperands(
const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
if (!isVectorAssociativeAndCommutative(Inst) &&
!isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
return TargetInstrInfo::hasReassociableOperands(Inst, MBB);

const MachineOperand &Op1 = Inst.getOperand(2);
const MachineOperand &Op2 = Inst.getOperand(3);
const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();

// We need virtual register definitions for the operands that we will
// reassociate.
MachineInstr *MI1 = nullptr;
MachineInstr *MI2 = nullptr;
if (Op1.isReg() && Op1.getReg().isVirtual())
MI1 = MRI.getUniqueVRegDef(Op1.getReg());
if (Op2.isReg() && Op2.getReg().isVirtual())
MI2 = MRI.getUniqueVRegDef(Op2.getReg());

// And at least one operand must be defined in MBB.
return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
}

void RISCVInstrInfo::getReassociateOperandIndices(
const MachineInstr &Root, unsigned Pattern,
std::array<unsigned, 5> &OperandIndices) const {
TargetInstrInfo::getReassociateOperandIndices(Root, Pattern, OperandIndices);
if (RISCV::getRVVMCOpcode(Root.getOpcode())) {
// Skip the passthrough operand, so increment all indices by one.
for (unsigned I = 0; I < 5; ++I)
++OperandIndices[I];
}
}

bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
bool &Commuted) const {
if (isVectorAssociativeAndCommutative(Inst) ||
isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
return hasReassociableVectorSibling(Inst, Commuted);

if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
return false;

Expand All @@ -1654,6 +1876,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,

bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
bool Invert) const {
if (isVectorAssociativeAndCommutative(Inst, Invert))
return true;

unsigned Opc = Inst.getOpcode();
if (Invert) {
auto InverseOpcode = getInverseOpcode(Opc);
Expand Down Expand Up @@ -1706,6 +1931,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,

std::optional<unsigned>
RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
#define RVV_OPC_LMUL_CASE(OPC, INV) \
case RISCV::OPC##_M1: \
return RISCV::INV##_M1; \
case RISCV::OPC##_M2: \
return RISCV::INV##_M2; \
case RISCV::OPC##_M4: \
return RISCV::INV##_M4; \
case RISCV::OPC##_M8: \
return RISCV::INV##_M8; \
case RISCV::OPC##_MF2: \
return RISCV::INV##_MF2; \
case RISCV::OPC##_MF4: \
return RISCV::INV##_MF4; \
case RISCV::OPC##_MF8: \
return RISCV::INV##_MF8

#define RVV_OPC_LMUL_MASK_CASE(OPC, INV) \
case RISCV::OPC##_M1_MASK: \
return RISCV::INV##_M1_MASK; \
case RISCV::OPC##_M2_MASK: \
return RISCV::INV##_M2_MASK; \
case RISCV::OPC##_M4_MASK: \
return RISCV::INV##_M4_MASK; \
case RISCV::OPC##_M8_MASK: \
return RISCV::INV##_M8_MASK; \
case RISCV::OPC##_MF2_MASK: \
return RISCV::INV##_MF2_MASK; \
case RISCV::OPC##_MF4_MASK: \
return RISCV::INV##_MF4_MASK; \
case RISCV::OPC##_MF8_MASK: \
return RISCV::INV##_MF8_MASK

switch (Opcode) {
default:
return std::nullopt;
Expand All @@ -1729,7 +1986,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
return RISCV::SUBW;
case RISCV::SUBW:
return RISCV::ADDW;
// clang-format off
RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
// clang-format on
}

#undef RVV_OPC_LMUL_MASK_CASE
#undef RVV_OPC_LMUL_CASE
}

static bool canCombineFPFusedMultiply(const MachineInstr &Root,
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;

bool hasReassociableOperands(const MachineInstr &Inst,
const MachineBasicBlock *MBB) const override;

bool hasReassociableSibling(const MachineInstr &Inst,
bool &Commuted) const override;

Expand All @@ -274,6 +277,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;

void getReassociateOperandIndices(
const MachineInstr &Root, unsigned Pattern,
std::array<unsigned, 5> &OperandIndices) const override;

ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
getSerializableMachineMemOperandTargetFlags() const override;

Expand All @@ -297,6 +304,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

private:
unsigned getInstBundleLength(const MachineInstr &MI) const;

bool isVectorAssociativeAndCommutative(const MachineInstr &MI,
bool Invert = false) const;
bool areRVVInstsReassociable(const MachineInstr &MI1,
const MachineInstr &MI2) const;
bool hasReassociableVectorSibling(const MachineInstr &Inst,
bool &Commuted) const;
};

namespace RISCV {
Expand Down
Loading

0 comments on commit 5f67ce5

Please sign in to comment.