Skip to content

Commit

Permalink
[RISCV] Refactor selectVSplat. NFCI
Browse files Browse the repository at this point in the history
This patch shares the logic between the various splat ComplexPatterns to help
the diff in some upcoming patches.

It's worth noting that the uimm splat pattern now takes into account the
implicit truncation + sign extend semantics of vmv_v_x_vl, but that doesn't
seem to affect the result since it always took the sext value anyway.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D158741
  • Loading branch information
lukel97 committed Aug 29, 2023
1 parent 998c323 commit 18c7bf0
Showing 1 changed file with 33 additions and 39 deletions.
72 changes: 33 additions & 39 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2955,27 +2955,35 @@ bool RISCVDAGToDAGISel::selectVLOp(SDValue N, SDValue &VL) {
return true;
}

static SDValue findVSplat(SDValue N) {
SDValue Splat = N;
if (Splat.getOpcode() != RISCVISD::VMV_V_X_VL ||
!Splat.getOperand(0).isUndef())
return SDValue();
assert(Splat.getNumOperands() == 3 && "Unexpected number of operands");
return Splat;
}

bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) {
if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !N.getOperand(0).isUndef())
SDValue Splat = findVSplat(N);
if (!Splat)
return false;
assert(N.getNumOperands() == 3 && "Unexpected number of operands");
SplatVal = N.getOperand(1);

SplatVal = Splat.getOperand(1);
return true;
}

using ValidateFn = bool (*)(int64_t);

static bool selectVSplatSimmHelper(SDValue N, SDValue &SplatVal,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
ValidateFn ValidateImm) {
if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !N.getOperand(0).isUndef() ||
!isa<ConstantSDNode>(N.getOperand(1)))
static bool selectVSplatImmHelper(SDValue N, SDValue &SplatVal,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
std::function<bool(int64_t)> ValidateImm) {
SDValue Splat = findVSplat(N);
if (!Splat || !isa<ConstantSDNode>(Splat.getOperand(1)))
return false;
assert(N.getNumOperands() == 3 && "Unexpected number of operands");

int64_t SplatImm =
cast<ConstantSDNode>(N.getOperand(1))->getSExtValue();
const unsigned SplatEltSize = Splat.getScalarValueSizeInBits();
assert(Subtarget.getXLenVT() == Splat.getOperand(1).getSimpleValueType() &&
"Unexpected splat operand type");

// The semantics of RISCVISD::VMV_V_X_VL is that when the operand
// type is wider than the resulting vector element type: an implicit
Expand All @@ -2984,55 +2992,41 @@ static bool selectVSplatSimmHelper(SDValue N, SDValue &SplatVal,
// any zero-extended immediate.
// For example, we wish to match (i8 -1) -> (XLenVT 255) as a simm5 by first
// sign-extending to (XLenVT -1).
MVT XLenVT = Subtarget.getXLenVT();
assert(XLenVT == N.getOperand(1).getSimpleValueType() &&
"Unexpected splat operand type");
MVT EltVT = N.getSimpleValueType().getVectorElementType();
if (EltVT.bitsLT(XLenVT))
SplatImm = SignExtend64(SplatImm, EltVT.getSizeInBits());
APInt SplatConst = Splat.getConstantOperandAPInt(1).sextOrTrunc(SplatEltSize);

int64_t SplatImm = SplatConst.getSExtValue();

if (!ValidateImm(SplatImm))
return false;

SplatVal = DAG.getTargetConstant(SplatImm, SDLoc(N), XLenVT);
SplatVal = DAG.getTargetConstant(SplatImm, SDLoc(N), Subtarget.getXLenVT());
return true;
}

bool RISCVDAGToDAGISel::selectVSplatSimm5(SDValue N, SDValue &SplatVal) {
return selectVSplatSimmHelper(N, SplatVal, *CurDAG, *Subtarget,
[](int64_t Imm) { return isInt<5>(Imm); });
return selectVSplatImmHelper(N, SplatVal, *CurDAG, *Subtarget,
[](int64_t Imm) { return isInt<5>(Imm); });
}

bool RISCVDAGToDAGISel::selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal) {
return selectVSplatSimmHelper(
return selectVSplatImmHelper(
N, SplatVal, *CurDAG, *Subtarget,
[](int64_t Imm) { return (isInt<5>(Imm) && Imm != -16) || Imm == 16; });
}

bool RISCVDAGToDAGISel::selectVSplatSimm5Plus1NonZero(SDValue N,
SDValue &SplatVal) {
return selectVSplatSimmHelper(
return selectVSplatImmHelper(
N, SplatVal, *CurDAG, *Subtarget, [](int64_t Imm) {
return Imm != 0 && ((isInt<5>(Imm) && Imm != -16) || Imm == 16);
});
}

bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
SDValue &SplatVal) {
if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !N.getOperand(0).isUndef() ||
!isa<ConstantSDNode>(N.getOperand(1)))
return false;

int64_t SplatImm =
cast<ConstantSDNode>(N.getOperand(1))->getSExtValue();

if (!isUIntN(Bits, SplatImm))
return false;

SplatVal =
CurDAG->getTargetConstant(SplatImm, SDLoc(N), Subtarget->getXLenVT());

return true;
return selectVSplatImmHelper(
N, SplatVal, *CurDAG, *Subtarget,
[Bits](int64_t Imm) { return isUIntN(Bits, Imm); });
}

bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
Expand Down

0 comments on commit 18c7bf0

Please sign in to comment.