Skip to content

Commit

Permalink
[RISCV] Teach doPeepholeMaskedRVV to handle FMA instructions.
Browse files Browse the repository at this point in the history
This lets us remove some isel patterns.

Reviewed By: fakepaper56

Differential Revision: https://reviews.llvm.org/D150463
  • Loading branch information
topperc committed May 13, 2023
1 parent 245cb1f commit 98f59b2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 46 deletions.
37 changes: 21 additions & 16 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3157,37 +3157,42 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
const RISCVInstrInfo &TII = *Subtarget->getInstrInfo();
const MCInstrDesc &MaskedMCID = TII.get(N->getMachineOpcode());

bool IsTA = true;
bool UseTUPseudo = false;
if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) {
TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID);
if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
RISCVII::TAIL_AGNOSTIC)) {
// Keep the true-masked instruction when there is no unmasked TU
// instruction
if (I->UnmaskedTUPseudo == I->MaskedPseudo && !N->getOperand(0).isUndef())
return false;
// We can't use TA if the tie-operand is not IMPLICIT_DEF
if (!N->getOperand(0).isUndef())
IsTA = false;
// Some operations are their own TU.
if (I->UnmaskedTUPseudo == I->UnmaskedPseudo) {
UseTUPseudo = true;
} else {
TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID);
if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
RISCVII::TAIL_AGNOSTIC)) {
// We can't use TA if the tie-operand is not IMPLICIT_DEF
if (!N->getOperand(0).isUndef()) {
// Keep the true-masked instruction when there is no unmasked TU
// instruction
if (I->UnmaskedTUPseudo == I->MaskedPseudo)
return false;
UseTUPseudo = true;
}
}
}
}

unsigned Opc = IsTA ? I->UnmaskedPseudo : I->UnmaskedTUPseudo;
unsigned Opc = UseTUPseudo ? I->UnmaskedTUPseudo : I->UnmaskedPseudo;

// Check that we're dropping the mask operand and any policy operand
// when we transform to this unmasked pseudo. Additionally, if this
// instruction is tail agnostic, the unmasked instruction should not have a
// merge op.
uint64_t TSFlags = TII.get(Opc).TSFlags;
assert((IsTA != RISCVII::hasMergeOp(TSFlags)) &&
assert((UseTUPseudo == RISCVII::hasMergeOp(TSFlags)) &&
RISCVII::hasDummyMaskOp(TSFlags) &&
!RISCVII::hasVecPolicyOp(TSFlags) &&
"Unexpected pseudo to transform to");
(void)TSFlags;

SmallVector<SDValue, 8> Ops;
// Skip the merge operand at index 0 if IsTA
for (unsigned I = IsTA, E = N->getNumOperands(); I != E; I++) {
// Skip the merge operand at index 0 if !UseTUPseudo.
for (unsigned I = !UseTUPseudo, E = N->getNumOperands(); I != E; I++) {
// Skip the mask, the policy, and the Glue.
SDValue Op = N->getOperand(I);
if (I == MaskOpIdx || I == TailPolicyOpIdx ||
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,12 @@ def RISCVVIntrinsicsTable : GenericTable {
let PrimaryKeyName = "getRISCVVIntrinsicInfo";
}

class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true> {
class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true, bit IsTernary = false> {
Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
Pseudo UnmaskedTUPseudo = !if(HasTU, !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")), MaskedPseudo);
Pseudo UnmaskedTUPseudo = !cond(HasTU : !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")),
IsTernary : UnmaskedPseudo,
true : MaskedPseudo);
bits<4> MaskOpIdx = MaskIdx;
}

Expand Down Expand Up @@ -3192,7 +3194,8 @@ multiclass VPseudoTernaryWithPolicy<VReg RetClass,
let VLMul = MInfo.value in {
let isCommutable = Commutable in
def "_" # MInfo.MX : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
RISCVMaskedPseudo</*MaskOpIdx*/ 3, /*HasTU*/ false, /*IsTernary*/true>;
}
}

Expand Down
27 changes: 0 additions & 27 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -1459,26 +1459,13 @@ multiclass VPatNarrowShiftSplat_WX_WI<SDNode op, string instruction_name> {
multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name> {
foreach vti = AllFloatVectors in {
defvar suffix = vti.LMul.MX;
def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
vti.RegClass:$rs2, (vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>(instruction_name#"_VV_"# suffix)
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
vti.RegClass:$rs2, (vti.Mask V0),
VLOpFrag)),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1),
vti.RegClass:$rd, vti.RegClass:$rs2,
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix)
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1),
vti.RegClass:$rd, vti.RegClass:$rs2,
(vti.Mask V0),
Expand All @@ -1492,27 +1479,13 @@ multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name>
multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
foreach vti = AllFloatVectors in {
defvar suffix = vti.LMul.MX;
def : Pat<(riscv_vp_merge_vl (vti.Mask true_mask),
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix)
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
def : Pat<(riscv_vp_merge_vl (vti.Mask true_mask),
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix)
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
Expand Down

0 comments on commit 98f59b2

Please sign in to comment.