Skip to content

Commit

Permalink
[RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC (#…
Browse files Browse the repository at this point in the history
…93574)

I plan to add other combines on TRUNCATE_VECTOR_VL.
  • Loading branch information
topperc authored May 28, 2024
1 parent e3f74d4 commit 060b302
Showing 1 changed file with 53 additions and 50 deletions.
103 changes: 53 additions & 50 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}

static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
// This would be benefit for the cases where X and Y are both the same value
// type of low precision vectors. Since the truncate would be lowered into
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
// restriction, such pattern would be expanded into a series of "vsetvli"
// and "vnsrl" instructions later to reach this point.
auto IsTruncNode = [](SDValue V) {
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
return false;
SDValue VL = V.getOperand(2);
auto *C = dyn_cast<ConstantSDNode>(VL);
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
(isa<RegisterSDNode>(VL) &&
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
};

SDValue Op = N->getOperand(0);

// We need to first find the inner level of TRUNCATE_VECTOR_VL node
// to distinguish such pattern.
while (IsTruncNode(Op)) {
if (!Op.hasOneUse())
return SDValue();
Op = Op.getOperand(0);
}

if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
return SDValue();

SDValue N0 = Op.getOperand(0);
SDValue N1 = Op.getOperand(1);
if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
return SDValue();

SDValue N00 = N0.getOperand(0);
SDValue N10 = N1.getOperand(0);
if (!N00.getValueType().isVector() ||
N00.getValueType() != N10.getValueType() ||
N->getValueType(0) != N10.getValueType())
return SDValue();

unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
SDValue SMin =
DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
}

SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
Expand Down Expand Up @@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
return SDValue();
case RISCVISD::TRUNCATE_VECTOR_VL: {
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
// This would be benefit for the cases where X and Y are both the same value
// type of low precision vectors. Since the truncate would be lowered into
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
// restriction, such pattern would be expanded into a series of "vsetvli"
// and "vnsrl" instructions later to reach this point.
auto IsTruncNode = [](SDValue V) {
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
return false;
SDValue VL = V.getOperand(2);
auto *C = dyn_cast<ConstantSDNode>(VL);
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
(isa<RegisterSDNode>(VL) &&
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
IsVLMAXForVMSET;
};

SDValue Op = N->getOperand(0);

// We need to first find the inner level of TRUNCATE_VECTOR_VL node
// to distinguish such pattern.
while (IsTruncNode(Op)) {
if (!Op.hasOneUse())
return SDValue();
Op = Op.getOperand(0);
}

if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
SDValue N0 = Op.getOperand(0);
SDValue N1 = Op.getOperand(1);
if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
SDValue N00 = N0.getOperand(0);
SDValue N10 = N1.getOperand(0);
if (N00.getValueType().isVector() &&
N00.getValueType() == N10.getValueType() &&
N->getValueType(0) == N10.getValueType()) {
unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
SDValue SMin = DAG.getNode(
ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
}
}
}
break;
}
case RISCVISD::TRUNCATE_VECTOR_VL:
return combineTruncOfSraSext(N, DAG);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
Expand Down

0 comments on commit 060b302

Please sign in to comment.