Skip to content

Commit

Permalink
[AArch64][SVE] Ensure PTEST operands have type nxv16i1
Browse files Browse the repository at this point in the history
Currently any legal predicate types will be pattern-matched when
creating a PTEST instruction. This could be a problem in future since
PTEST always uses the .B specifier for the operand, but it is not
always guaranteed that the extra lanes of unpacked types (e.g. nxv4i1)
are zero. This patch ensures the operands of PTEST are type nxv16i1,
where the undef lanes are set to zero.

Differential Revision: https://reviews.llvm.org/D129282/
  • Loading branch information
RosieSumpter committed Jul 12, 2022
1 parent 767b26a commit e5edc1b
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 49 deletions.
82 changes: 52 additions & 30 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -237,6 +237,39 @@ static bool isMergePassthruOpcode(unsigned Opc) {
}
}

// Returns true if inactive lanes are known to be zeroed by construction.
static bool isZeroingInactiveLanes(SDValue Op) {
switch (Op.getOpcode()) {
default:
// We guarantee i1 splat_vectors to zero the other lanes by
// implementing it with ptrue and possibly a punpklo for nxv1i1.
if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
return true;
return false;
case AArch64ISD::PTRUE:
case AArch64ISD::SETCC_MERGE_ZERO:
return true;
case ISD::INTRINSIC_WO_CHAIN:
switch (Op.getConstantOperandVal(0)) {
default:
return false;
case Intrinsic::aarch64_sve_ptrue:
case Intrinsic::aarch64_sve_pnext:
case Intrinsic::aarch64_sve_cmpeq_wide:
case Intrinsic::aarch64_sve_cmpne_wide:
case Intrinsic::aarch64_sve_cmpge_wide:
case Intrinsic::aarch64_sve_cmpgt_wide:
case Intrinsic::aarch64_sve_cmplt_wide:
case Intrinsic::aarch64_sve_cmple_wide:
case Intrinsic::aarch64_sve_cmphs_wide:
case Intrinsic::aarch64_sve_cmphi_wide:
case Intrinsic::aarch64_sve_cmplo_wide:
case Intrinsic::aarch64_sve_cmpls_wide:
return true;
}
}
}

AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
const AArch64Subtarget &STI)
: TargetLowering(TM), Subtarget(&STI) {
Expand Down Expand Up @@ -4368,16 +4401,18 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
DAG.getTargetConstant(Pattern, DL, MVT::i32));
}

SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
SelectionDAG &DAG) const {
// Returns a safe bitcast between two scalable vector predicates, where
// any newly created lanes from a widening bitcast are defined as zero.
static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
SDLoc DL(Op);
EVT InVT = Op.getValueType();

assert(InVT.getVectorElementType() == MVT::i1 &&
VT.getVectorElementType() == MVT::i1 &&
"Expected a predicate-to-predicate bitcast");
assert(VT.isScalableVector() && isTypeLegal(VT) &&
InVT.isScalableVector() && isTypeLegal(InVT) &&
assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
InVT.isScalableVector() &&
DAG.getTargetLoweringInfo().isTypeLegal(InVT) &&
"Only expect to cast between legal scalable predicate types!");

// Return the operand if the cast isn't changing type,
Expand All @@ -4396,33 +4431,8 @@ SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,

// Check if the other lanes are already known to be zeroed by
// construction.
switch (Op.getOpcode()) {
default:
// We guarantee i1 splat_vectors to zero the other lanes by
// implementing it with ptrue and possibly a punpklo for nxv1i1.
if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
return Reinterpret;
break;
case AArch64ISD::SETCC_MERGE_ZERO:
if (isZeroingInactiveLanes(Op))
return Reinterpret;
case ISD::INTRINSIC_WO_CHAIN:
switch (Op.getConstantOperandVal(0)) {
default:
break;
case Intrinsic::aarch64_sve_ptrue:
case Intrinsic::aarch64_sve_cmpeq_wide:
case Intrinsic::aarch64_sve_cmpne_wide:
case Intrinsic::aarch64_sve_cmpge_wide:
case Intrinsic::aarch64_sve_cmpgt_wide:
case Intrinsic::aarch64_sve_cmplt_wide:
case Intrinsic::aarch64_sve_cmple_wide:
case Intrinsic::aarch64_sve_cmphs_wide:
case Intrinsic::aarch64_sve_cmphi_wide:
case Intrinsic::aarch64_sve_cmplo_wide:
case Intrinsic::aarch64_sve_cmpls_wide:
return Reinterpret;
}
}

// Zero the newly introduced lanes.
SDValue Mask = DAG.getConstant(1, DL, InVT);
Expand Down Expand Up @@ -16164,12 +16174,24 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
assert(Op.getValueType().isScalableVector() &&
TLI.isTypeLegal(Op.getValueType()) &&
"Expected legal scalable vector type!");
assert(Op.getValueType() == Pg.getValueType() &&
"Expected same type for PTEST operands");

// Ensure target specific opcodes are using legal type.
EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDValue TVal = DAG.getConstant(1, DL, OutVT);
SDValue FVal = DAG.getConstant(0, DL, OutVT);

// Ensure operands have type nxv16i1.
if (Op.getValueType() != MVT::nxv16i1) {
if ((Cond == AArch64CC::ANY_ACTIVE || Cond == AArch64CC::NONE_ACTIVE) &&
isZeroingInactiveLanes(Op))
Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pg);
else
Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Op);
}

// Set condition code (CC) flags.
SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op);

Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -1154,10 +1154,6 @@ class AArch64TargetLowering : public TargetLowering {
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;

// Returns a safe bitcast between two scalable vector predicates, where
// any newly created lanes from a widening bitcast are defined as zero.
SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;

bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
LLT Ty2) const override;
};
Expand Down
13 changes: 1 addition & 12 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -778,7 +778,7 @@ let Predicates = [HasSVEorSME] in {
defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>;
defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;

def PTEST_PP : sve_int_ptest<0b010000, "ptest">;
def PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest>;
defm PFALSE : sve_int_pfalse<0b000000, "pfalse">;
defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
Expand Down Expand Up @@ -2131,17 +2131,6 @@ let Predicates = [HasSVEorSME] in {
def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
}

def : Pat<(AArch64ptest (nxv16i1 PPR:$pg), (nxv16i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;
def : Pat<(AArch64ptest (nxv8i1 PPR:$pg), (nxv8i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;
def : Pat<(AArch64ptest (nxv4i1 PPR:$pg), (nxv4i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;
def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;
def : Pat<(AArch64ptest (nxv1i1 PPR:$pg), (nxv1i1 PPR:$src)),
(PTEST_PP PPR:$pg, PPR:$src)>;

let AddedComplexity = 1 in {
class LD1RPat<ValueType vt, SDPatternOperator operator,
Instruction load, Instruction ptrue, ValueType index_vt, ComplexPattern CP, Operand immtype> :
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Expand Up @@ -650,11 +650,11 @@ multiclass sve_int_pfalse<bits<6> opc, string asm> {
def : Pat<(nxv1i1 immAllZerosV), (!cast<Instruction>(NAME))>;
}

class sve_int_ptest<bits<6> opc, string asm>
class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
: I<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
asm, "\t$Pg, $Pn",
"",
[]>, Sched<[]> {
[(op (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, Sched<[]> {
bits<4> Pg;
bits<4> Pn;
let Inst{31-24} = 0b00100101;
Expand Down
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/AArch64/sve-setcc.ll
Expand Up @@ -51,7 +51,10 @@ if.end:
define void @sve_cmplt_setcc_hslo(<vscale x 8 x i16>* %out, <vscale x 8 x i16> %in, <vscale x 8 x i1> %pg) {
; CHECK-LABEL: sve_cmplt_setcc_hslo:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cmplt p1.h, p0/z, z0.h, #0
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: cmplt p2.h, p0/z, z0.h, #0
; CHECK-NEXT: and p1.b, p0/z, p0.b, p1.b
; CHECK-NEXT: ptest p1, p2.b
; CHECK-NEXT: b.hs .LBB2_2
; CHECK-NEXT: // %bb.1: // %if.then
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
Expand Down

0 comments on commit e5edc1b

Please sign in to comment.