Skip to content
Permalink
Browse files

[ARM] MVE integer compares and selects

This adds the very basics for MVE vector predication, adding integer VCMP and
VSEL instruction support. This is done through predicate registers (MVT::v16i1,
MVT::v8i1, MVT::v4i1), but otherwise using same mechanics as NEON to custom
lower setcc's through ARMISD::VCXX nodes (VCEQ, VCGT, VCEQZ, etc).

An extra VCNE was added, as this can be handled sensibly by MVE's expanded
number of VCMP condition codes. (There are also VCLE and VCLT which are added
later).

VPSEL is also added here, simply selecting on the vselect.

Original code by David Sherwood.

Differential Revision: https://reviews.llvm.org/D65051

llvm-svn: 366885
  • Loading branch information...
DMG862 committed Jul 24, 2019
1 parent d22f877 commit b9d96ceca0c54496d93f0d58bef2968cde5b1edd
@@ -258,6 +258,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::UMIN, VT, Legal);
setOperationAction(ISD::UMAX, VT, Legal);
setOperationAction(ISD::ABS, VT, Legal);
setOperationAction(ISD::SETCC, VT, Custom);

// No native support for these.
setOperationAction(ISD::UDIV, VT, Expand);
@@ -334,6 +335,12 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal);
setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal);
setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal);

// Predicate types
const MVT pTypes[] = {MVT::v16i1, MVT::v8i1, MVT::v4i1};
for (auto VT : pTypes) {
addRegisterClass(VT, &ARM::VCCRRegClass);
}
}

ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
@@ -1500,6 +1507,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {

case ARMISD::VCEQ: return "ARMISD::VCEQ";
case ARMISD::VCEQZ: return "ARMISD::VCEQZ";
case ARMISD::VCNE: return "ARMISD::VCNE";
case ARMISD::VCNEZ: return "ARMISD::VCNEZ";
case ARMISD::VCGE: return "ARMISD::VCGE";
case ARMISD::VCGEZ: return "ARMISD::VCGEZ";
case ARMISD::VCLEZ: return "ARMISD::VCLEZ";
@@ -1601,6 +1610,11 @@ EVT ARMTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &,
EVT VT) const {
if (!VT.isVector())
return getPointerTy(DL);

// MVE has a predicate register.
if (Subtarget->hasMVEIntegerOps() &&
(VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8))
return MVT::getVectorVT(MVT::i1, VT.getVectorElementCount());
return VT.changeVectorElementTypeToInteger();
}

@@ -5849,7 +5863,8 @@ static SDValue Expand64BitShift(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lo, Hi);
}

static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) {
static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG,
const ARMSubtarget *ST) {
SDValue TmpOp0, TmpOp1;
bool Invert = false;
bool Swap = false;
@@ -5858,11 +5873,23 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) {
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
SDValue CC = Op.getOperand(2);
EVT CmpVT = Op0.getValueType().changeVectorElementTypeToInteger();
EVT VT = Op.getValueType();
ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get();
SDLoc dl(Op);

EVT CmpVT;
if (ST->hasNEON())
CmpVT = Op0.getValueType().changeVectorElementTypeToInteger();
else {
assert(ST->hasMVEIntegerOps() &&
"No hardware support for integer vector comparison!");

if (Op.getValueType().getVectorElementType() != MVT::i1)
return SDValue();

CmpVT = VT;
}

if (Op0.getValueType().getVectorElementType() == MVT::i64 &&
(SetCCOpcode == ISD::SETEQ || SetCCOpcode == ISD::SETNE)) {
// Special-case integer 64-bit equality comparisons. They aren't legal,
@@ -5930,7 +5957,12 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) {
// Integer comparisons.
switch (SetCCOpcode) {
default: llvm_unreachable("Illegal integer comparison");
case ISD::SETNE: Invert = true; LLVM_FALLTHROUGH;
case ISD::SETNE:
if (ST->hasMVEIntegerOps()) {
Opc = ARMISD::VCNE; break;
} else {
Invert = true; LLVM_FALLTHROUGH;
}
case ISD::SETEQ: Opc = ARMISD::VCEQ; break;
case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH;
case ISD::SETGT: Opc = ARMISD::VCGT; break;
@@ -5943,7 +5975,7 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) {
}

// Detect VTST (Vector Test Bits) = icmp ne (and (op0, op1), zero).
if (Opc == ARMISD::VCEQ) {
if (ST->hasNEON() && Opc == ARMISD::VCEQ) {
SDValue AndOp;
if (ISD::isBuildVectorAllZeros(Op1.getNode()))
AndOp = Op0;
@@ -5982,6 +6014,9 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) {
SDValue Result;
if (SingleOp.getNode()) {
switch (Opc) {
case ARMISD::VCNE:
assert(ST->hasMVEIntegerOps() && "Unexpected DAG node");
Result = DAG.getNode(ARMISD::VCNEZ, dl, CmpVT, SingleOp); break;
case ARMISD::VCEQ:
Result = DAG.getNode(ARMISD::VCEQZ, dl, CmpVT, SingleOp); break;
case ARMISD::VCGE:
@@ -8436,7 +8471,7 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::CTTZ:
case ISD::CTTZ_ZERO_UNDEF: return LowerCTTZ(Op.getNode(), DAG, Subtarget);
case ISD::CTPOP: return LowerCTPOP(Op.getNode(), DAG, Subtarget);
case ISD::SETCC: return LowerVSETCC(Op, DAG);
case ISD::SETCC: return LowerVSETCC(Op, DAG, Subtarget);
case ISD::SETCCCARRY: return LowerSETCCCARRY(Op, DAG);
case ISD::ConstantFP: return LowerConstantFP(Op, DAG, Subtarget);
case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG, Subtarget);

if (!Subtarget->hasMVEIntegerOps())
return false;

// These are for predicates
if ((Ty == MVT::v16i1 || Ty == MVT::v8i1 || Ty == MVT::v4i1)) {
if (Fast)
*Fast = true;
return true;
}

if (Ty != MVT::v16i8 && Ty != MVT::v8i16 && Ty != MVT::v8f16 &&
Ty != MVT::v4i32 && Ty != MVT::v4f32 && Ty != MVT::v2i64 &&
Ty != MVT::v2f64 &&
@@ -131,6 +131,8 @@ class VectorType;

VCEQ, // Vector compare equal.
VCEQZ, // Vector compare equal to zero.
VCNE, // Vector compare not equal (MVE)
VCNEZ, // Vector compare not equal to zero (MVE)
VCGE, // Vector compare greater than or equal.
VCGEZ, // Vector compare greater than or equal to zero.
VCLEZ, // Vector compare less than or equal to zero.
@@ -265,9 +265,26 @@ def ARMvshruImm : SDNode<"ARMISD::VSHRuIMM", SDTARMVSHIMM>;
def ARMvshls : SDNode<"ARMISD::VSHLs", SDTARMVSH>;
def ARMvshlu : SDNode<"ARMISD::VSHLu", SDTARMVSH>;

def SDTARMVCMP : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisSameAs<1, 2>]>;
def SDTARMVCMPZ : SDTypeProfile<1, 1, []>;

def ARMvceq : SDNode<"ARMISD::VCEQ", SDTARMVCMP>;
def ARMvceqz : SDNode<"ARMISD::VCEQZ", SDTARMVCMPZ>;
def ARMvcne : SDNode<"ARMISD::VCNE", SDTARMVCMP>;
def ARMvcnez : SDNode<"ARMISD::VCNEZ", SDTARMVCMPZ>;
def ARMvcge : SDNode<"ARMISD::VCGE", SDTARMVCMP>;
def ARMvcgez : SDNode<"ARMISD::VCGEZ", SDTARMVCMPZ>;
def ARMvclez : SDNode<"ARMISD::VCLEZ", SDTARMVCMPZ>;
def ARMvcgeu : SDNode<"ARMISD::VCGEU", SDTARMVCMP>;
def ARMvcgt : SDNode<"ARMISD::VCGT", SDTARMVCMP>;
def ARMvcgtz : SDNode<"ARMISD::VCGTZ", SDTARMVCMPZ>;
def ARMvcltz : SDNode<"ARMISD::VCLTZ", SDTARMVCMPZ>;
def ARMvcgtu : SDNode<"ARMISD::VCGTU", SDTARMVCMP>;

def ARMWLS : SDNode<"ARMISD::WLS", SDT_ARMLoLoop, [SDNPHasChain]>;
def ARMLE : SDNode<"ARMISD::LE", SDT_ARMLoLoop, [SDNPHasChain]>;
def ARMLoopDec : SDNode<"ARMISD::LOOP_DEC", SDTIntBinOp, [SDNPHasChain]>;

//===----------------------------------------------------------------------===//
// ARM Flag Definitions.

@@ -2982,6 +2982,40 @@ def MVE_VCMPs8r : MVE_VCMPqrs<"s8", 0b00>;
def MVE_VCMPs16r : MVE_VCMPqrs<"s16", 0b01>;
def MVE_VCMPs32r : MVE_VCMPqrs<"s32", 0b10>;

multiclass unpred_vcmp_z<SDPatternOperator opnode, string suffix, int fc> {
def i8 : Pat<(v16i1 (opnode (v16i8 MQPR:$v1))),
(v16i1 (!cast<Instruction>("MVE_VCMP"#suffix#"8r") (v16i8 MQPR:$v1), ZR, fc))>;
def i16 : Pat<(v8i1 (opnode (v8i16 MQPR:$v1))),
(v8i1 (!cast<Instruction>("MVE_VCMP"#suffix#"16r") (v8i16 MQPR:$v1), ZR, fc))>;
def i32 : Pat<(v4i1 (opnode (v4i32 MQPR:$v1))),
(v4i1 (!cast<Instruction>("MVE_VCMP"#suffix#"32r") (v4i32 MQPR:$v1), ZR, fc))>;
}

multiclass unpred_vcmp_r<SDPatternOperator opnode, string suffix, int fc> {
def i8 : Pat<(v16i1 (opnode (v16i8 MQPR:$v1), (v16i8 MQPR:$v2))),
(v16i1 (!cast<Instruction>("MVE_VCMP"#suffix#"8") (v16i8 MQPR:$v1), (v16i8 MQPR:$v2), fc))>;
def i16 : Pat<(v8i1 (opnode (v8i16 MQPR:$v1), (v8i16 MQPR:$v2))),
(v8i1 (!cast<Instruction>("MVE_VCMP"#suffix#"16") (v8i16 MQPR:$v1), (v8i16 MQPR:$v2), fc))>;
def i32 : Pat<(v4i1 (opnode (v4i32 MQPR:$v1), (v4i32 MQPR:$v2))),
(v4i1 (!cast<Instruction>("MVE_VCMP"#suffix#"32") (v4i32 MQPR:$v1), (v4i32 MQPR:$v2), fc))>;
}

let Predicates = [HasMVEInt] in {
defm MVE_VCEQZ : unpred_vcmp_z<ARMvceqz, "i", 0>;
defm MVE_VCNEZ : unpred_vcmp_z<ARMvcnez, "i", 1>;
defm MVE_VCLEZ : unpred_vcmp_z<ARMvclez, "s", 13>;
defm MVE_VCGTZ : unpred_vcmp_z<ARMvcgtz, "s", 12>;
defm MVE_VCLTZ : unpred_vcmp_z<ARMvcltz, "s", 11>;
defm MVE_VCGEZ : unpred_vcmp_z<ARMvcgez, "s", 10>;

defm MVE_VCEQ : unpred_vcmp_r<ARMvceq, "i", 0>;
defm MVE_VCNE : unpred_vcmp_r<ARMvcne, "i", 1>;
defm MVE_VCGT : unpred_vcmp_r<ARMvcgt, "s", 12>;
defm MVE_VCGE : unpred_vcmp_r<ARMvcge, "s", 10>;
defm MVE_VCGTU : unpred_vcmp_r<ARMvcgtu, "u", 8>;
defm MVE_VCGEU : unpred_vcmp_r<ARMvcgeu, "u", 2>;
}

// end of MVE compares

// start of MVE_qDest_qSrc
@@ -4369,6 +4403,15 @@ foreach suffix = ["s8", "s16", "s32", "u8", "u16", "u32",
def : MVEInstAlias<"vpsel${vp}." # suffix # "\t$Qd, $Qn, $Qm",
(MVE_VPSEL MQPR:$Qd, MQPR:$Qn, MQPR:$Qm, vpred_n:$vp)>;

let Predicates = [HasMVEInt] in {
def : Pat<(v16i8 (vselect (v16i1 VCCR:$pred), (v16i8 MQPR:$v1), (v16i8 MQPR:$v2))),
(v16i8 (MVE_VPSEL MQPR:$v1, MQPR:$v2, 0, VCCR:$pred))>;
def : Pat<(v8i16 (vselect (v8i1 VCCR:$pred), (v8i16 MQPR:$v1), (v8i16 MQPR:$v2))),
(v8i16 (MVE_VPSEL MQPR:$v1, MQPR:$v2, 0, VCCR:$pred))>;
def : Pat<(v4i32 (vselect (v4i1 VCCR:$pred), (v4i32 MQPR:$v1), (v4i32 MQPR:$v2))),
(v4i32 (MVE_VPSEL MQPR:$v1, MQPR:$v2, 0, VCCR:$pred))>;
}

def MVE_VPNOT : MVE_p<(outs), (ins), NoItinerary,
"vpnot", "", "", vpred_n, "", []> {
let Inst{31-0} = 0b11111110001100010000111101001101;
@@ -478,19 +478,6 @@ def non_word_alignedstore : PatFrag<(ops node:$val, node:$ptr),
// NEON-specific DAG Nodes.
//===----------------------------------------------------------------------===//

def SDTARMVCMP : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisSameAs<1, 2>]>;
def SDTARMVCMPZ : SDTypeProfile<1, 1, []>;

def NEONvceq : SDNode<"ARMISD::VCEQ", SDTARMVCMP>;
def NEONvceqz : SDNode<"ARMISD::VCEQZ", SDTARMVCMPZ>;
def NEONvcge : SDNode<"ARMISD::VCGE", SDTARMVCMP>;
def NEONvcgez : SDNode<"ARMISD::VCGEZ", SDTARMVCMPZ>;
def NEONvclez : SDNode<"ARMISD::VCLEZ", SDTARMVCMPZ>;
def NEONvcgeu : SDNode<"ARMISD::VCGEU", SDTARMVCMP>;
def NEONvcgt : SDNode<"ARMISD::VCGT", SDTARMVCMP>;
def NEONvcgtz : SDNode<"ARMISD::VCGTZ", SDTARMVCMPZ>;
def NEONvcltz : SDNode<"ARMISD::VCLTZ", SDTARMVCMPZ>;
def NEONvcgtu : SDNode<"ARMISD::VCGTU", SDTARMVCMP>;
def NEONvtst : SDNode<"ARMISD::VTST", SDTARMVCMP>;

// Types for vector shift by immediates. The "SHX" version is for long and
@@ -5027,66 +5014,66 @@ def : Pat<(v2i32 (trunc (ARMvshruImm (sub (v2i64 QPR:$Vn), QPR:$Vm), 32))),

// VCEQ : Vector Compare Equal
defm VCEQ : N3V_QHS<1, 0, 0b1000, 1, IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q,
IIC_VSUBi4Q, "vceq", "i", NEONvceq, 1>;
IIC_VSUBi4Q, "vceq", "i", ARMvceq, 1>;
def VCEQfd : N3VD<0,0,0b00,0b1110,0, IIC_VBIND, "vceq", "f32", v2i32, v2f32,
NEONvceq, 1>;
ARMvceq, 1>;
def VCEQfq : N3VQ<0,0,0b00,0b1110,0, IIC_VBINQ, "vceq", "f32", v4i32, v4f32,
NEONvceq, 1>;
ARMvceq, 1>;
def VCEQhd : N3VD<0,0,0b01,0b1110,0, IIC_VBIND, "vceq", "f16", v4i16, v4f16,
NEONvceq, 1>,
ARMvceq, 1>,
Requires<[HasNEON, HasFullFP16]>;
def VCEQhq : N3VQ<0,0,0b01,0b1110,0, IIC_VBINQ, "vceq", "f16", v8i16, v8f16,
NEONvceq, 1>,
ARMvceq, 1>,
Requires<[HasNEON, HasFullFP16]>;

let TwoOperandAliasConstraint = "$Vm = $Vd" in
defm VCEQz : N2V_QHS_cmp<0b11, 0b11, 0b01, 0b00010, 0, "vceq", "i",
"$Vd, $Vm, #0", NEONvceqz>;
"$Vd, $Vm, #0", ARMvceqz>;

// VCGE : Vector Compare Greater Than or Equal
defm VCGEs : N3V_QHS<0, 0, 0b0011, 1, IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q,
IIC_VSUBi4Q, "vcge", "s", NEONvcge, 0>;
IIC_VSUBi4Q, "vcge", "s", ARMvcge, 0>;
defm VCGEu : N3V_QHS<1, 0, 0b0011, 1, IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q,
IIC_VSUBi4Q, "vcge", "u", NEONvcgeu, 0>;
IIC_VSUBi4Q, "vcge", "u", ARMvcgeu, 0>;
def VCGEfd : N3VD<1,0,0b00,0b1110,0, IIC_VBIND, "vcge", "f32", v2i32, v2f32,
NEONvcge, 0>;
ARMvcge, 0>;
def VCGEfq : N3VQ<1,0,0b00,0b1110,0, IIC_VBINQ, "vcge", "f32", v4i32, v4f32,
NEONvcge, 0>;
ARMvcge, 0>;
def VCGEhd : N3VD<1,0,0b01,0b1110,0, IIC_VBIND, "vcge", "f16", v4i16, v4f16,
NEONvcge, 0>,
ARMvcge, 0>,
Requires<[HasNEON, HasFullFP16]>;
def VCGEhq : N3VQ<1,0,0b01,0b1110,0, IIC_VBINQ, "vcge", "f16", v8i16, v8f16,
NEONvcge, 0>,
ARMvcge, 0>,
Requires<[HasNEON, HasFullFP16]>;

let TwoOperandAliasConstraint = "$Vm = $Vd" in {
defm VCGEz : N2V_QHS_cmp<0b11, 0b11, 0b01, 0b00001, 0, "vcge", "s",
"$Vd, $Vm, #0", NEONvcgez>;
"$Vd, $Vm, #0", ARMvcgez>;
defm VCLEz : N2V_QHS_cmp<0b11, 0b11, 0b01, 0b00011, 0, "vcle", "s",
"$Vd, $Vm, #0", NEONvclez>;
"$Vd, $Vm, #0", ARMvclez>;
}

// VCGT : Vector Compare Greater Than
defm VCGTs : N3V_QHS<0, 0, 0b0011, 0, IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q,
IIC_VSUBi4Q, "vcgt", "s", NEONvcgt, 0>;
IIC_VSUBi4Q, "vcgt", "s", ARMvcgt, 0>;
defm VCGTu : N3V_QHS<1, 0, 0b0011, 0, IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q,
IIC_VSUBi4Q, "vcgt", "u", NEONvcgtu, 0>;
IIC_VSUBi4Q, "vcgt", "u", ARMvcgtu, 0>;
def VCGTfd : N3VD<1,0,0b10,0b1110,0, IIC_VBIND, "vcgt", "f32", v2i32, v2f32,
NEONvcgt, 0>;
ARMvcgt, 0>;
def VCGTfq : N3VQ<1,0,0b10,0b1110,0, IIC_VBINQ, "vcgt", "f32", v4i32, v4f32,
NEONvcgt, 0>;
ARMvcgt, 0>;
def VCGThd : N3VD<1,0,0b11,0b1110,0, IIC_VBIND, "vcgt", "f16", v4i16, v4f16,
NEONvcgt, 0>,
ARMvcgt, 0>,
Requires<[HasNEON, HasFullFP16]>;
def VCGThq : N3VQ<1,0,0b11,0b1110,0, IIC_VBINQ, "vcgt", "f16", v8i16, v8f16,
NEONvcgt, 0>,
ARMvcgt, 0>,
Requires<[HasNEON, HasFullFP16]>;

let TwoOperandAliasConstraint = "$Vm = $Vd" in {
defm VCGTz : N2V_QHS_cmp<0b11, 0b11, 0b01, 0b00000, 0, "vcgt", "s",
"$Vd, $Vm, #0", NEONvcgtz>;
"$Vd, $Vm, #0", ARMvcgtz>;
defm VCLTz : N2V_QHS_cmp<0b11, 0b11, 0b01, 0b00100, 0, "vclt", "s",
"$Vd, $Vm, #0", NEONvcltz>;
"$Vd, $Vm, #0", ARMvcltz>;
}

// VACGE : Vector Absolute Compare Greater Than or Equal (aka VCAGE)

0 comments on commit b9d96ce

Please sign in to comment.
You can’t perform that action at this time.