Skip to content

Commit

Permalink
[AArch64] Make nxv1i1 types a legal type for SVE.
Browse files Browse the repository at this point in the history
One motivation to add support for these types are the LD1Q/ST1Q
instructions in SME, for which we have defined a number of load/store
intrinsics which at the moment still take a `<vscale x 16 x i1>` predicate
regardless of their element type.

This patch adds basic support for the nxv1i1 type such that it can be passed/returned
from functions, as well as some basic support to support some existing tests that
result in a nxv1i1 type. It also adds support for splats.

Other operations (e.g. insert/extract subvector, logical ops, etc) will be
supported in follow-up patches.

Reviewed By: paulwalker-arm, efriedma

Differential Revision: https://reviews.llvm.org/D128665
  • Loading branch information
sdesmalen-arm committed Jul 1, 2022
1 parent 560e694 commit 690db16
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 14 deletions.
18 changes: 12 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Expand Up @@ -6653,18 +6653,18 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT,
EVT InVT = InOp.getValueType();
assert(InVT.getVectorElementType() == NVT.getVectorElementType() &&
"input and widen element type must match");
assert(!InVT.isScalableVector() && !NVT.isScalableVector() &&
assert(InVT.isScalableVector() == NVT.isScalableVector() &&
"cannot modify scalable vectors in this way");
SDLoc dl(InOp);

// Check if InOp already has the right width.
if (InVT == NVT)
return InOp;

unsigned InNumElts = InVT.getVectorNumElements();
unsigned WidenNumElts = NVT.getVectorNumElements();
if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) {
unsigned NumConcat = WidenNumElts / InNumElts;
ElementCount InEC = InVT.getVectorElementCount();
ElementCount WidenEC = NVT.getVectorElementCount();
if (WidenEC.hasKnownScalarFactor(InEC)) {
unsigned NumConcat = WidenEC.getKnownScalarFactor(InEC);
SmallVector<SDValue, 16> Ops(NumConcat);
SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, InVT) :
DAG.getUNDEF(InVT);
Expand All @@ -6675,10 +6675,16 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT,
return DAG.getNode(ISD::CONCAT_VECTORS, dl, NVT, Ops);
}

if (WidenNumElts < InNumElts && InNumElts % WidenNumElts)
if (InEC.hasKnownScalarFactor(WidenEC))
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp,
DAG.getVectorIdxConstant(0, dl));

assert(!InVT.isScalableVector() && !NVT.isScalableVector() &&
"Scalable vectors should have been handled already.");

unsigned InNumElts = InEC.getFixedValue();
unsigned WidenNumElts = WidenEC.getFixedValue();

// Fall back to extract and build.
SmallVector<SDValue, 16> Ops(WidenNumElts);
EVT EltVT = NVT.getVectorElementType();
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64CallingConvention.td
Expand Up @@ -82,9 +82,9 @@ def CC_AArch64_AAPCS : CallingConv<[
nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCPassIndirect<i64>>,

CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCAssignToReg<[P0, P1, P2, P3]>>,
CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCPassIndirect<i64>>,

// Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers,
Expand Down Expand Up @@ -149,7 +149,7 @@ def RetCC_AArch64_AAPCS : CallingConv<[
nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,

CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCAssignToReg<[P0, P1, P2, P3]>>
]>;

Expand Down
14 changes: 10 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -292,6 +292,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

if (Subtarget->hasSVE() || Subtarget->hasSME()) {
// Add legal sve predicate types
addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass);
Expand Down Expand Up @@ -1156,7 +1157,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 })
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal);

for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
for (auto VT :
{MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::SETCC, VT, Custom);
Expand Down Expand Up @@ -4676,7 +4678,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Op.getOperand(2), Op.getOperand(3),
DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
Op.getOperand(1));

case Intrinsic::localaddress: {
const auto &MF = DAG.getMachineFunction();
const auto *RegInfo = Subtarget->getRegisterInfo();
Expand Down Expand Up @@ -10551,8 +10552,13 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
DAG.getValueType(MVT::i1));
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID,
DAG.getConstant(0, DL, MVT::i64), SplatVal);
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
if (VT == MVT::nxv1i1)
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv1i1,
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID,
Zero, SplatVal),
Zero);
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal);
}

SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64RegisterInfo.td
Expand Up @@ -871,7 +871,7 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size,
// SVE predicate register classes.
class PPRClass<int lastreg> : RegisterClass<
"AArch64",
[ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16,
[ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16,
(sequence "P%u", 0, lastreg)> {
let Size = 16;
}
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -748,6 +748,11 @@ let Predicates = [HasSVEorSME] in {
defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>;
defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>;

// Define pattern for `nxv1i1 splat_vector(1)`.
// We do this here instead of in ISelLowering such that PatFrag's can still
// recognize a splat.
def : Pat<(nxv1i1 immAllOnesV), (PUNPKLO_PP (PTRUE_D 31))>;

defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">;
defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">;
def MOVPRFX_ZZ : sve_int_bin_cons_misc_0_c<0b00000001, "movprfx", ZPRAny>;
Expand Down Expand Up @@ -1509,6 +1514,10 @@ let Predicates = [HasSVEorSME] in {
defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2", AArch64trn2>;

// Extract lo/hi halves of legal predicate types.
def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 0))),
(PUNPKLO_PP PPR:$Ps)>;
def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 1))),
(PUNPKHI_PP PPR:$Ps)>;
def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 0))),
(PUNPKLO_PP PPR:$Ps)>;
def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 2))),
Expand Down Expand Up @@ -1599,6 +1608,8 @@ let Predicates = [HasSVEorSME] in {
(UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>;

// Concatenate two predicates.
def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)),
(UZP1_PPP_D $p1, $p2)>;
def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)),
(UZP1_PPP_S $p1, $p2)>;
def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)),
Expand Down Expand Up @@ -2298,15 +2309,23 @@ let Predicates = [HasSVEorSME] in {
def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv1i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv1i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv1i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv1i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;

// These allow casting from/to unpacked floating-point types.
def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Expand Up @@ -647,6 +647,7 @@ multiclass sve_int_pfalse<bits<6> opc, string asm> {
def : Pat<(nxv8i1 immAllZerosV), (!cast<Instruction>(NAME))>;
def : Pat<(nxv4i1 immAllZerosV), (!cast<Instruction>(NAME))>;
def : Pat<(nxv2i1 immAllZerosV), (!cast<Instruction>(NAME))>;
def : Pat<(nxv1i1 immAllZerosV), (!cast<Instruction>(NAME))>;
}

class sve_int_ptest<bits<6> opc, string asm>
Expand Down Expand Up @@ -1681,6 +1682,7 @@ multiclass sve_int_pred_log<bits<4> opc, string asm, SDPatternOperator op,
def : SVE_3_Op_Pat<nxv8i1, op, nxv8i1, nxv8i1, nxv8i1, !cast<Instruction>(NAME)>;
def : SVE_3_Op_Pat<nxv4i1, op, nxv4i1, nxv4i1, nxv4i1, !cast<Instruction>(NAME)>;
def : SVE_3_Op_Pat<nxv2i1, op, nxv2i1, nxv2i1, nxv2i1, !cast<Instruction>(NAME)>;
def : SVE_3_Op_Pat<nxv1i1, op, nxv1i1, nxv1i1, nxv1i1, !cast<Instruction>(NAME)>;
def : SVE_2_Op_AllActive_Pat<nxv16i1, op_nopred, nxv16i1, nxv16i1,
!cast<Instruction>(NAME), PTRUE_B>;
def : SVE_2_Op_AllActive_Pat<nxv8i1, op_nopred, nxv8i1, nxv8i1,
Expand Down
24 changes: 24 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll
Expand Up @@ -1078,3 +1078,27 @@ define <vscale x 2 x i1> @extract_nxv2i1_nxv16i1_all_zero() {

declare <vscale x 2 x float> @llvm.vector.extract.nxv2f32.nxv4f32(<vscale x 4 x float>, i64)
declare <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32>, i64)

;
; Extract nxv1i1 type from: nxv2i1
;

define <vscale x 1 x i1> @extract_nxv1i1_nxv2i1_0(<vscale x 2 x i1> %in) {
; CHECK-LABEL: extract_nxv1i1_nxv2i1_0:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: ret
%res = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1(<vscale x 2 x i1> %in, i64 0)
ret <vscale x 1 x i1> %res
}

define <vscale x 1 x i1> @extract_nxv1i1_nxv2i1_1(<vscale x 2 x i1> %in) {
; CHECK-LABEL: extract_nxv1i1_nxv2i1_1:
; CHECK: // %bb.0:
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: ret
%res = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1(<vscale x 2 x i1> %in, i64 1)
ret <vscale x 1 x i1> %res
}

declare <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1(<vscale x 2 x i1>, i64)
3 changes: 3 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-select.ll
Expand Up @@ -187,6 +187,7 @@ define <vscale x 1 x i1> @select_nxv1i1(i1 %cond, <vscale x 1 x i1> %a, <vscal
; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0
; CHECK-NEXT: sbfx x8, x0, #0, #1
; CHECK-NEXT: whilelo p2.d, xzr, x8
; CHECK-NEXT: punpklo p2.h, p2.b
; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b
; CHECK-NEXT: ret
%res = select i1 %cond, <vscale x 1 x i1> %a, <vscale x 1 x i1> %b
Expand Down Expand Up @@ -225,6 +226,7 @@ define <vscale x 4 x i32> @sel_nxv4i32(<vscale x 4 x i1> %p, <vscale x 4 x i32>
define <vscale x 1 x i64> @sel_nxv1i64(<vscale x 1 x i1> %p, <vscale x 1 x i64> %dst, <vscale x 1 x i64> %a) {
; CHECK-LABEL: sel_nxv1i64:
; CHECK: // %bb.0:
; CHECK-NEXT: uzp1 p0.d, p0.d, p0.d
; CHECK-NEXT: mov z0.d, p0/m, z1.d
; CHECK-NEXT: ret
%sel = select <vscale x 1 x i1> %p, <vscale x 1 x i64> %a, <vscale x 1 x i64> %dst
Expand Down Expand Up @@ -483,6 +485,7 @@ define <vscale x 1 x i1> @icmp_select_nxv1i1(<vscale x 1 x i1> %a, <vscale x 1 x
; CHECK-NEXT: cset w8, eq
; CHECK-NEXT: sbfx x8, x8, #0, #1
; CHECK-NEXT: whilelo p2.d, xzr, x8
; CHECK-NEXT: punpklo p2.h, p2.b
; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b
; CHECK-NEXT: ret
%mask = icmp eq i64 %x0, 0
Expand Down
7 changes: 7 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-zeroinit.ll
Expand Up @@ -52,6 +52,13 @@ define <vscale x 8 x half> @test_zeroinit_8xf16() {
ret <vscale x 8 x half> zeroinitializer
}

define <vscale x 1 x i1> @test_zeroinit_1xi1() {
; CHECK-LABEL: test_zeroinit_1xi1
; CHECK: pfalse p0.b
; CHECK-NEXT: ret
ret <vscale x 1 x i1> zeroinitializer
}

define <vscale x 2 x i1> @test_zeroinit_2xi1() {
; CHECK-LABEL: test_zeroinit_2xi1
; CHECK: pfalse p0.b
Expand Down

0 comments on commit 690db16

Please sign in to comment.