Skip to content

Commit

Permalink
[CodeGen][SVE] Lowering of shift operations with scalable types
Browse files Browse the repository at this point in the history
Summary:
Adds AArch64ISD nodes for:
 - SHL_PRED (logical shift left)
 - SHR_PRED (logical shift right)
 - SRA_PRED (arithmetic shift right)

Existing patterns for unpredicated left shift by immediate
have also been moved into the appropriate multiclasses
in SVEInstrFormats.td.

Reviewers: sdesmalen, efriedma, ctetreau, huihuiz, rengolin

Reviewed By: efriedma

Subscribers: huihuiz, tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, cfe-commits, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79478
  • Loading branch information
kmclaughlin-arm committed May 7, 2020
1 parent 54c927b commit 3bcd3dd
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 43 deletions.
24 changes: 24 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -887,6 +887,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::UMIN, VT, Custom);
setOperationAction(ISD::SMAX, VT, Custom);
setOperationAction(ISD::UMAX, VT, Custom);
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
}
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
Expand Down Expand Up @@ -1291,6 +1294,9 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
case AArch64ISD::UMIN_PRED: return "AArch64ISD::UMIN_PRED";
case AArch64ISD::SMAX_PRED: return "AArch64ISD::SMAX_PRED";
case AArch64ISD::UMAX_PRED: return "AArch64ISD::UMAX_PRED";
case AArch64ISD::SHL_PRED: return "AArch64ISD::SHL_PRED";
case AArch64ISD::SRL_PRED: return "AArch64ISD::SRL_PRED";
case AArch64ISD::SRA_PRED: return "AArch64ISD::SRA_PRED";
case AArch64ISD::ADC: return "AArch64ISD::ADC";
case AArch64ISD::SBC: return "AArch64ISD::SBC";
case AArch64ISD::ADDS: return "AArch64ISD::ADDS";
Expand Down Expand Up @@ -8599,6 +8605,9 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
llvm_unreachable("unexpected shift opcode");

case ISD::SHL:
if (VT.isScalableVector())
return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_PRED);

if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize)
return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0),
DAG.getConstant(Cnt, DL, MVT::i32));
Expand All @@ -8608,6 +8617,12 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
Op.getOperand(0), Op.getOperand(1));
case ISD::SRA:
case ISD::SRL:
if (VT.isScalableVector()) {
unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
: AArch64ISD::SRL_PRED;
return LowerToPredicatedOp(Op, DAG, Opc);
}

// Right shift immediate
if (isVShiftRImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) {
unsigned Opc =
Expand Down Expand Up @@ -11463,6 +11478,15 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_sve_umax:
return DAG.getNode(AArch64ISD::UMAX_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sve_lsl:
return DAG.getNode(AArch64ISD::SHL_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sve_lsr:
return DAG.getNode(AArch64ISD::SRL_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sve_asr:
return DAG.getNode(AArch64ISD::SRA_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sve_fadda:
return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
case Intrinsic::aarch64_sve_faddv:
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -59,6 +59,9 @@ enum NodeType : unsigned {
UMIN_PRED,
SMAX_PRED,
UMAX_PRED,
SHL_PRED,
SRL_PRED,
SRA_PRED,

// Arithmetic instructions which write flags.
ADDS,
Expand Down
48 changes: 18 additions & 30 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//

def SVE8BitLslImm : ComplexPattern<i32, 2, "SelectSVE8BitLslImm", [imm]>;
def SVELShiftImm64 : ComplexPattern<i32, 1, "SelectSVEShiftImm64<0, 64>", []>;

// Contiguous loads - node definitions
//
Expand Down Expand Up @@ -154,12 +153,15 @@ def SDT_AArch64Arith : SDTypeProfile<1, 3, [
SDTCVecEltisVT<1,i1>, SDTCisSameAs<2,3>
]>;

def AArch64sdiv_pred : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>;
def AArch64udiv_pred : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
def AArch64smin_pred : SDNode<"AArch64ISD::SMIN_PRED", SDT_AArch64Arith>;
def AArch64umin_pred : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
def AArch64smax_pred : SDNode<"AArch64ISD::SMAX_PRED", SDT_AArch64Arith>;
def AArch64umax_pred : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
def AArch64sdiv_pred : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>;
def AArch64udiv_pred : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
def AArch64smin_pred : SDNode<"AArch64ISD::SMIN_PRED", SDT_AArch64Arith>;
def AArch64umin_pred : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
def AArch64smax_pred : SDNode<"AArch64ISD::SMAX_PRED", SDT_AArch64Arith>;
def AArch64umax_pred : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
def AArch64lsl_pred : SDNode<"AArch64ISD::SHL_PRED", SDT_AArch64Arith>;
def AArch64lsr_pred : SDNode<"AArch64ISD::SRL_PRED", SDT_AArch64Arith>;
def AArch64asr_pred : SDNode<"AArch64ISD::SRA_PRED", SDT_AArch64Arith>;

def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
Expand Down Expand Up @@ -1158,23 +1160,9 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm INDEX_II : sve_int_index_ii<"index", index_vector>;

// Unpredicated shifts
defm ASR_ZZI : sve_int_bin_cons_shift_imm_right<0b00, "asr", sra>;
defm LSR_ZZI : sve_int_bin_cons_shift_imm_right<0b01, "lsr", srl>;
defm LSL_ZZI : sve_int_bin_cons_shift_imm_left< 0b11, "lsl", shl>;

// Patterns for unpredicated left shift by immediate
def : Pat<(nxv16i8 (shl (nxv16i8 ZPR:$Zs1),
(nxv16i8 (AArch64dup (vecshiftL8:$imm))))),
(LSL_ZZI_B ZPR:$Zs1, vecshiftL8:$imm)>;
def : Pat<(nxv8i16 (shl (nxv8i16 ZPR:$Zs1),
(nxv8i16 (AArch64dup (vecshiftL16:$imm))))),
(LSL_ZZI_H ZPR:$Zs1, vecshiftL16:$imm)>;
def : Pat<(nxv4i32 (shl (nxv4i32 ZPR:$Zs1),
(nxv4i32 (AArch64dup (vecshiftL32:$imm))))),
(LSL_ZZI_S ZPR:$Zs1, vecshiftL32:$imm)>;
def : Pat<(nxv2i64 (shl (nxv2i64 ZPR:$Zs1),
(nxv2i64 (AArch64dup (i64 (SVELShiftImm64 i32:$imm)))))),
(LSL_ZZI_D ZPR:$Zs1, vecshiftL64:$imm)>;
defm ASR_ZZI : sve_int_bin_cons_shift_imm_right<0b00, "asr", AArch64asr_pred>;
defm LSR_ZZI : sve_int_bin_cons_shift_imm_right<0b01, "lsr", AArch64lsr_pred>;
defm LSL_ZZI : sve_int_bin_cons_shift_imm_left< 0b11, "lsl", AArch64lsl_pred>;

defm ASR_WIDE_ZZZ : sve_int_bin_cons_shift_wide<0b00, "asr">;
defm LSR_WIDE_ZZZ : sve_int_bin_cons_shift_wide<0b01, "lsr">;
Expand All @@ -1186,14 +1174,14 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm LSL_ZPmI : sve_int_bin_pred_shift_imm_left< 0b0011, "lsl">;
defm ASRD_ZPmI : sve_int_bin_pred_shift_imm_right<0b0100, "asrd", "ASRD_ZPZI", int_aarch64_sve_asrd>;

defm ASR_ZPZZ : sve_int_bin_pred_zx<int_aarch64_sve_asr>;
defm LSR_ZPZZ : sve_int_bin_pred_zx<int_aarch64_sve_lsr>;
defm LSL_ZPZZ : sve_int_bin_pred_zx<int_aarch64_sve_lsl>;
defm ASR_ZPZZ : sve_int_bin_pred_zx<AArch64asr_pred>;
defm LSR_ZPZZ : sve_int_bin_pred_zx<AArch64lsr_pred>;
defm LSL_ZPZZ : sve_int_bin_pred_zx<AArch64lsl_pred>;
defm ASRD_ZPZI : sve_int_bin_pred_shift_0_right_zx<int_aarch64_sve_asrd>;

defm ASR_ZPmZ : sve_int_bin_pred_shift<0b000, "asr", "ASR_ZPZZ", int_aarch64_sve_asr, "ASRR_ZPmZ", 1>;
defm LSR_ZPmZ : sve_int_bin_pred_shift<0b001, "lsr", "LSR_ZPZZ", int_aarch64_sve_lsr, "LSRR_ZPmZ", 1>;
defm LSL_ZPmZ : sve_int_bin_pred_shift<0b011, "lsl", "LSL_ZPZZ", int_aarch64_sve_lsl, "LSLR_ZPmZ", 1>;
defm ASR_ZPmZ : sve_int_bin_pred_shift<0b000, "asr", "ASR_ZPZZ", AArch64asr_pred, "ASRR_ZPmZ", 1>;
defm LSR_ZPmZ : sve_int_bin_pred_shift<0b001, "lsr", "LSR_ZPZZ", AArch64lsr_pred, "LSRR_ZPmZ", 1>;
defm LSL_ZPmZ : sve_int_bin_pred_shift<0b011, "lsl", "LSL_ZPZZ", AArch64lsl_pred, "LSLR_ZPmZ", 1>;
defm ASRR_ZPmZ : sve_int_bin_pred_shift<0b100, "asrr", "ASRR_ZPZZ", null_frag, "ASR_ZPmZ", 0>;
defm LSRR_ZPmZ : sve_int_bin_pred_shift<0b101, "lsrr", "LSRR_ZPZZ", null_frag, "LSR_ZPmZ", 0>;
defm LSLR_ZPmZ : sve_int_bin_pred_shift<0b111, "lslr", "LSLR_ZPZZ", null_frag, "LSL_ZPmZ", 0>;
Expand Down
42 changes: 29 additions & 13 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Expand Up @@ -215,6 +215,8 @@ def SVELogicalImm64Pat : ComplexPattern<i64, 1, "SelectSVELogicalImm<MVT::i64>",
def SVEArithUImmPat : ComplexPattern<i32, 1, "SelectSVEArithImm", []>;
def SVEArithSImmPat : ComplexPattern<i32, 1, "SelectSVESignedArithImm", []>;

def SVEShiftImm64 : ComplexPattern<i32, 1, "SelectSVEShiftImm64<0, 64>", []>;

class SVEExactFPImm<string Suffix, string ValA, string ValB> : AsmOperandClass {
let Name = "SVEExactFPImmOperand" # Suffix;
let DiagnosticType = "Invalid" # Name;
Expand Down Expand Up @@ -324,6 +326,11 @@ class SVE_1_Op_Imm_Arith_Pat<ValueType vt, SDPatternOperator op, ZPRRegOp zprty,
: Pat<(vt (op (vt zprty:$Op1), (vt (AArch64dup (it (cpx i32:$imm)))))),
(inst $Op1, i32:$imm)>;

class SVE_1_Op_Imm_Shift_Pred_Pat<ValueType vt, ValueType pt, SDPatternOperator op,
ZPRRegOp zprty, Operand ImmTy, Instruction inst>
: Pat<(vt (op (pt (AArch64ptrue 31)), (vt zprty:$Op1), (vt (AArch64dup (ImmTy:$imm))))),
(inst $Op1, ImmTy:$imm)>;

class SVE_1_Op_Imm_Arith_Pred_Pat<ValueType vt, ValueType pt, SDPatternOperator op,
ZPRRegOp zprty, ValueType it, ComplexPattern cpx, Instruction inst>
: Pat<(vt (op (pt (AArch64ptrue 31)), (vt zprty:$Op1), (vt (AArch64dup (it (cpx i32:$imm)))))),
Expand Down Expand Up @@ -4952,12 +4959,11 @@ multiclass sve_int_bin_cons_shift_wide<bits<2> opc, string asm> {
}

class sve_int_bin_cons_shift_imm<bits<4> tsz8_64, bits<2> opc, string asm,
ZPRRegOp zprty, Operand immtype, ValueType vt,
SDPatternOperator op>
ZPRRegOp zprty, Operand immtype>
: I<(outs zprty:$Zd), (ins zprty:$Zn, immtype:$imm),
asm, "\t$Zd, $Zn, $imm",
"",
[(set (vt zprty:$Zd), (op (vt zprty:$Zn), immtype:$imm))]>, Sched<[]> {
[]>, Sched<[]> {
bits<5> Zd;
bits<5> Zn;
bits<6> imm;
Expand All @@ -4973,33 +4979,43 @@ class sve_int_bin_cons_shift_imm<bits<4> tsz8_64, bits<2> opc, string asm,
}

multiclass sve_int_bin_cons_shift_imm_left<bits<2> opc, string asm,
SDPatternOperator op> {
def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8, nxv16i8, op>;
def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16, nxv8i16, op> {
SDPatternOperator op> {
def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8>;
def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16> {
let Inst{19} = imm{3};
}
def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32, nxv4i32, op> {
def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32> {
let Inst{20-19} = imm{4-3};
}
def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64, nxv2i64, op> {
def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64> {
let Inst{22} = imm{5};
let Inst{20-19} = imm{4-3};
}

def : SVE_1_Op_Imm_Shift_Pred_Pat<nxv16i8, nxv16i1, op, ZPR8, vecshiftL8, !cast<Instruction>(NAME # _B)>;
def : SVE_1_Op_Imm_Shift_Pred_Pat<nxv8i16, nxv8i1, op, ZPR16, vecshiftL16, !cast<Instruction>(NAME # _H)>;
def : SVE_1_Op_Imm_Shift_Pred_Pat<nxv4i32, nxv4i1, op, ZPR32, vecshiftL32, !cast<Instruction>(NAME # _S)>;
def : SVE_1_Op_Imm_Arith_Pred_Pat<nxv2i64, nxv2i1, op, ZPR64, i64, SVEShiftImm64, !cast<Instruction>(NAME # _D)>;
}

multiclass sve_int_bin_cons_shift_imm_right<bits<2> opc, string asm,
SDPatternOperator op> {
def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8, nxv16i8, op>;
def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16, nxv8i16, op> {
SDPatternOperator op> {
def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8>;
def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16> {
let Inst{19} = imm{3};
}
def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32, nxv4i32, op> {
def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32> {
let Inst{20-19} = imm{4-3};
}
def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64, nxv2i64, op> {
def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64> {
let Inst{22} = imm{5};
let Inst{20-19} = imm{4-3};
}

def : SVE_1_Op_Imm_Shift_Pred_Pat<nxv16i8, nxv16i1, op, ZPR8, vecshiftR8, !cast<Instruction>(NAME # _B)>;
def : SVE_1_Op_Imm_Shift_Pred_Pat<nxv8i16, nxv8i1, op, ZPR16, vecshiftR16, !cast<Instruction>(NAME # _H)>;
def : SVE_1_Op_Imm_Shift_Pred_Pat<nxv4i32, nxv4i1, op, ZPR32, vecshiftR32, !cast<Instruction>(NAME # _S)>;
def : SVE_1_Op_Imm_Arith_Pred_Pat<nxv2i64, nxv2i1, op, ZPR64, i64, SVEShiftImm64, !cast<Instruction>(NAME # _D)>;
}
//===----------------------------------------------------------------------===//
// SVE Memory - Store Group
Expand Down

0 comments on commit 3bcd3dd

Please sign in to comment.