Skip to content

Commit

Permalink
[ARM] Match MVE vqdmulh
Browse files Browse the repository at this point in the history
This adds ISel matching for a form of VQDMULH. There are several ir
patterns that we could match to that instruction, this one is for:

min(ashr(mul(sext(a), sext(b)), 7), 127)

Which is what llvm will optimize to once it has removed the max that
usually makes up the min/max saturate pattern, as in this case the
compare will always be false. The additional complication to match i32
patterns (which extend into an i64) is that the min will be a
vselect/setcc, as vmin is not supported for i64 vectors. Tablegen
patterns have also been updated to attempt to reuse the MVE_TwoOpPattern
patterns.

Differential Revision: https://reviews.llvm.org/D90096
  • Loading branch information
davemgreen committed Oct 30, 2020
1 parent 62286c5 commit d14db8c
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 663 deletions.
99 changes: 96 additions & 3 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -1718,6 +1718,7 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
case ARMISD::VCVTL: return "ARMISD::VCVTL";
case ARMISD::VMULLs: return "ARMISD::VMULLs";
case ARMISD::VMULLu: return "ARMISD::VMULLu";
case ARMISD::VQDMULH: return "ARMISD::VQDMULH";
case ARMISD::VADDVs: return "ARMISD::VADDVs";
case ARMISD::VADDVu: return "ARMISD::VADDVu";
case ARMISD::VADDVps: return "ARMISD::VADDVps";
Expand Down Expand Up @@ -12206,9 +12207,93 @@ static SDValue PerformSELECTCombine(SDNode *N,
return Reduction;
}

// A special combine for the vqdmulh family of instructions. This is one of the
// potential set of patterns that could patch this instruction. The base pattern
// you would expect to be min(max(ashr(mul(mul(sext(x), 2), sext(y)), 16))).
// This matches the different min(max(ashr(mul(mul(sext(x), sext(y)), 2), 16))),
// which llvm will have optimized to min(ashr(mul(sext(x), sext(y)), 15))) as
// the max is unnecessary.
static SDValue PerformVQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
SDValue Shft;
ConstantSDNode *Clamp;

if (N->getOpcode() == ISD::SMIN) {
Shft = N->getOperand(0);
Clamp = isConstOrConstSplat(N->getOperand(1));
} else if (N->getOpcode() == ISD::VSELECT) {
// Detect a SMIN, which for an i64 node will be a vselect/setcc, not a smin.
SDValue Cmp = N->getOperand(0);
if (Cmp.getOpcode() != ISD::SETCC ||
cast<CondCodeSDNode>(Cmp.getOperand(2))->get() != ISD::SETLT ||
Cmp.getOperand(0) != N->getOperand(1) ||
Cmp.getOperand(1) != N->getOperand(2))
return SDValue();
Shft = N->getOperand(1);
Clamp = isConstOrConstSplat(N->getOperand(2));
} else
return SDValue();

if (!Clamp)
return SDValue();

MVT ScalarType;
int ShftAmt = 0;
switch (Clamp->getSExtValue()) {
case (1 << 7) - 1:
ScalarType = MVT::i8;
ShftAmt = 7;
break;
case (1 << 15) - 1:
ScalarType = MVT::i16;
ShftAmt = 15;
break;
case (1ULL << 31) - 1:
ScalarType = MVT::i32;
ShftAmt = 31;
break;
default:
return SDValue();
}

if (Shft.getOpcode() != ISD::SRA)
return SDValue();
ConstantSDNode *N1 = isConstOrConstSplat(Shft.getOperand(1));
if (!N1 || N1->getSExtValue() != ShftAmt)
return SDValue();

SDValue Mul = Shft.getOperand(0);
if (Mul.getOpcode() != ISD::MUL)
return SDValue();

SDValue Ext0 = Mul.getOperand(0);
SDValue Ext1 = Mul.getOperand(1);
if (Ext0.getOpcode() != ISD::SIGN_EXTEND ||
Ext1.getOpcode() != ISD::SIGN_EXTEND)
return SDValue();
EVT VecVT = Ext0.getOperand(0).getValueType();
if (VecVT != MVT::v4i32 && VecVT != MVT::v8i16 && VecVT != MVT::v16i8)
return SDValue();
if (Ext1.getOperand(0).getValueType() != VecVT ||
VecVT.getScalarType() != ScalarType ||
VT.getScalarSizeInBits() < ScalarType.getScalarSizeInBits() * 2)
return SDValue();

SDLoc DL(Mul);
SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, VecVT, Ext0.getOperand(0),
Ext1.getOperand(0));
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, VQDMULH);
}

static SDValue PerformVSELECTCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const ARMSubtarget *Subtarget) {
if (!Subtarget->hasMVEIntegerOps())
return SDValue();

if (SDValue V = PerformVQDMULHCombine(N, DCI.DAG))
return V;

// Transforms vselect(not(cond), lhs, rhs) into vselect(cond, rhs, lhs).
//
// We need to re-implement this optimization here as the implementation in the
Expand All @@ -12218,9 +12303,6 @@ static SDValue PerformVSELECTCombine(SDNode *N,
//
// Currently, this is only done for MVE, as it's the only target that benefits
// from this transformation (e.g. VPNOT+VPSEL becomes a single VPSEL).
if (!Subtarget->hasMVEIntegerOps())
return SDValue();

if (N->getOperand(0).getOpcode() != ISD::XOR)
return SDValue();
SDValue XOR = N->getOperand(0);
Expand Down Expand Up @@ -14582,6 +14664,14 @@ static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St,
return true;
};

// It may be preferable to keep the store unsplit as the trunc may end up
// being removed. Check that here.
if (Trunc.getOperand(0).getOpcode() == ISD::SMIN) {
if (SDValue U = PerformVQDMULHCombine(Trunc.getOperand(0).getNode(), DAG)) {
DAG.ReplaceAllUsesWith(Trunc.getOperand(0), U);
return SDValue();
}
}
if (auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Trunc->getOperand(0)))
if (isVMOVNOriginalMask(Shuffle->getMask(), false) ||
isVMOVNOriginalMask(Shuffle->getMask(), true))
Expand Down Expand Up @@ -15555,6 +15645,9 @@ static SDValue PerformMinMaxCombine(SDNode *N, SelectionDAG &DAG,
if (!ST->hasMVEIntegerOps())
return SDValue();

if (SDValue V = PerformVQDMULHCombine(N, DAG))
return V;

if (VT != MVT::v4i32 && VT != MVT::v8i16)
return SDValue();

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Expand Up @@ -216,6 +216,8 @@ class VectorType;
VMULLs, // ...signed
VMULLu, // ...unsigned

VQDMULH, // MVE vqdmulh instruction

// MVE reductions
VADDVs, // sign- or zero-extend the elements of a vector to i32,
VADDVu, // add them all together, and return an i32 of their sum
Expand Down
30 changes: 14 additions & 16 deletions llvm/lib/Target/ARM/ARMInstrMVE.td
Expand Up @@ -1955,28 +1955,26 @@ class MVE_VQxDMULH_Base<string iname, string suffix, bits<2> size, bit rounding,
let validForTailPredication = 1;
}

def MVEvqdmulh : SDNode<"ARMISD::VQDMULH", SDTIntBinOp>;

multiclass MVE_VQxDMULH_m<string iname, MVEVectorVTInfo VTI,
SDNode unpred_op, Intrinsic pred_int,
SDNode Op, Intrinsic unpred_int, Intrinsic pred_int,
bit rounding> {
def "" : MVE_VQxDMULH_Base<iname, VTI.Suffix, VTI.Size, rounding>;
defvar Inst = !cast<Instruction>(NAME);
defm : MVE_TwoOpPattern<VTI, Op, pred_int, (? ), Inst>;

let Predicates = [HasMVEInt] in {
// Unpredicated multiply
def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn))),
// Extra unpredicated multiply intrinsic patterns
def : Pat<(VTI.Vec (unpred_int (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn))),
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;

// Predicated multiply
def : Pat<(VTI.Vec (pred_int (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn),
(VTI.Pred VCCR:$mask), (VTI.Vec MQPR:$inactive))),
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn),
ARMVCCThen, (VTI.Pred VCCR:$mask),
(VTI.Vec MQPR:$inactive)))>;
}
}

multiclass MVE_VQxDMULH<string iname, MVEVectorVTInfo VTI, bit rounding>
: MVE_VQxDMULH_m<iname, VTI, !if(rounding, int_arm_mve_vqrdmulh,
: MVE_VQxDMULH_m<iname, VTI, !if(rounding, null_frag,
MVEvqdmulh),
!if(rounding, int_arm_mve_vqrdmulh,
int_arm_mve_vqdmulh),
!if(rounding, int_arm_mve_qrdmulh_predicated,
int_arm_mve_qdmulh_predicated),
Expand Down Expand Up @@ -5492,18 +5490,18 @@ class MVE_VxxMUL_qr<string iname, string suffix,
}

multiclass MVE_VxxMUL_qr_m<string iname, MVEVectorVTInfo VTI, bit bit_28,
Intrinsic int_unpred, Intrinsic int_pred> {
PatFrag Op, Intrinsic int_unpred, Intrinsic int_pred> {
def "" : MVE_VxxMUL_qr<iname, VTI.Suffix, bit_28, VTI.Size>;
defm : MVE_vec_scalar_int_pat_m<!cast<Instruction>(NAME), VTI,
int_unpred, int_pred>;
defm : MVE_TwoOpPatternDup<VTI, Op, int_pred, (? ), !cast<Instruction>(NAME)>;
defm : MVE_vec_scalar_int_pat_m<!cast<Instruction>(NAME), VTI, int_unpred, int_pred>;
}

multiclass MVE_VQDMULH_qr_m<MVEVectorVTInfo VTI> :
MVE_VxxMUL_qr_m<"vqdmulh", VTI, 0b0,
MVE_VxxMUL_qr_m<"vqdmulh", VTI, 0b0, MVEvqdmulh,
int_arm_mve_vqdmulh, int_arm_mve_qdmulh_predicated>;

multiclass MVE_VQRDMULH_qr_m<MVEVectorVTInfo VTI> :
MVE_VxxMUL_qr_m<"vqrdmulh", VTI, 0b1,
MVE_VxxMUL_qr_m<"vqrdmulh", VTI, 0b1, null_frag,
int_arm_mve_vqrdmulh, int_arm_mve_qrdmulh_predicated>;

defm MVE_VQDMULH_qr_s8 : MVE_VQDMULH_qr_m<MVE_v16s8>;
Expand Down

0 comments on commit d14db8c

Please sign in to comment.