Skip to content

Commit

Permalink
[AArch64][SVE] Improve code generation for VLS i1 masks
Browse files Browse the repository at this point in the history
This patch partially resolves an issue for VLS code generation
where a mask is generated from a smaller width integer comparison
than the instruction using the mask requires.

Instead of sign extending a p register by converting it to a z
register, extending that, and converting back, we instead just
do an unpack of the p register.

A separate issue causes the code generation to still be poor when
the mask generation would fit in a neon register, as we then use
a neon comparison operation and have to convert that to a p register.
This will be resolved in a separate patch.

Reviewed By: peterwaller-arm

Differential Revision: https://reviews.llvm.org/D111221
  • Loading branch information
DavidTruby committed Dec 17, 2021
1 parent 9d29943 commit 7e44eb0
Show file tree
Hide file tree
Showing 5 changed files with 676 additions and 226 deletions.
61 changes: 61 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
Expand Down Expand Up @@ -15841,6 +15842,23 @@ static SDValue performVectorShiftCombine(SDNode *N,
return SDValue();
}

static SDValue performSunpkloCombine(SDNode *N, SelectionDAG &DAG) {
// sunpklo(sext(pred)) -> sext(extract_low_half(pred))
// This transform works in partnership with performSetCCPunpkCombine to
// remove unnecessary transfer of predicates into standard registers and back
if (N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND &&
N->getOperand(0)->getOperand(0)->getValueType(0).getScalarType() ==
MVT::i1) {
SDValue CC = N->getOperand(0)->getOperand(0);
auto VT = CC->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
SDValue Unpk = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, CC,
DAG.getVectorIdxConstant(0, SDLoc(N)));
return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), N->getValueType(0), Unpk);
}

return SDValue();
}

/// Target-specific DAG combine function for post-increment LD1 (lane) and
/// post-increment LD1R.
static SDValue performPostLD1Combine(SDNode *N,
Expand Down Expand Up @@ -16518,6 +16536,44 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}

static SDValue performSetCCPunpkCombine(SDNode *N, SelectionDAG &DAG) {
// setcc_merge_zero pred
// (sign_extend (extract_subvector (setcc_merge_zero ... pred ...))), 0, ne
// => extract_subvector (inner setcc_merge_zero)
SDValue Pred = N->getOperand(0);
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();

if (Cond != ISD::SETNE || !isZerosVector(RHS.getNode()) ||
LHS->getOpcode() != ISD::SIGN_EXTEND)
return SDValue();

SDValue Extract = LHS->getOperand(0);
if (Extract->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
Extract->getValueType(0) != N->getValueType(0) ||
Extract->getConstantOperandVal(1) != 0)
return SDValue();

SDValue InnerSetCC = Extract->getOperand(0);
if (InnerSetCC->getOpcode() != AArch64ISD::SETCC_MERGE_ZERO)
return SDValue();

// By this point we've effectively got
// zero_inactive_lanes_and_trunc_i1(sext_i1(A)). If we can prove A's inactive
// lanes are already zero then the trunc(sext()) sequence is redundant and we
// can operate on A directly.
SDValue InnerPred = InnerSetCC.getOperand(0);
if (Pred.getOpcode() == AArch64ISD::PTRUE &&
InnerPred.getOpcode() == AArch64ISD::PTRUE &&
Pred.getConstantOperandVal(0) == InnerPred.getConstantOperandVal(0) &&
Pred->getConstantOperandVal(0) >= AArch64SVEPredPattern::vl1 &&
Pred->getConstantOperandVal(0) <= AArch64SVEPredPattern::vl256)
return Extract;

return SDValue();
}

static SDValue performSetccMergeZeroCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
"Unexpected opcode!");
Expand All @@ -16536,6 +16592,9 @@ static SDValue performSetccMergeZeroCombine(SDNode *N, SelectionDAG &DAG) {
LHS->getOperand(0)->getOperand(0) == Pred)
return LHS->getOperand(0);

if (SDValue V = performSetCCPunpkCombine(N, DAG))
return V;

return SDValue();
}

Expand Down Expand Up @@ -17479,6 +17538,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::VASHR:
case AArch64ISD::VLSHR:
return performVectorShiftCombine(N, *this, DCI);
case AArch64ISD::SUNPKLO:
return performSunpkloCombine(N, DAG);
case ISD::INSERT_VECTOR_ELT:
return performInsertVectorEltCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
Expand Down

0 comments on commit 7e44eb0

Please sign in to comment.