Skip to content

Commit

Permalink
[RISCV] Don't run combineBinOp_VLToVWBinOp_VL until after legalize ty…
Browse files Browse the repository at this point in the history
…pes. NFCI (#84125)

I noticed this from a discrepancy in fillUpExtensionSupport between how
we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND,
but we don't need to for RISCVISD::V{Z,S}EXT_VL.

Prior to #72340, combineBinOp_VLToVWBinOp_VL only ran after type
legalization because it only operated on _VL nodes. _VL nodes are only
emitted during op legalization, which takes place **after** type
legalization, which is presumably why the existing code didn't need to
check for legal types.

After #72340 we now handle generic ops like ISD::ADD that exist before
op legalization and thus **before** type legalization. This meant that
we needed to add extra checks that the narrow type was legal in #76785.

I think the easiest thing to do here is to just maintain the invariant
that the types are legal and only run the combine after type
legalization.
  • Loading branch information
lukel97 committed Mar 11, 2024
1 parent 718962f commit 58dd59a
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13657,9 +13657,8 @@ struct NodeExtensionHelper {
unsigned ScalarBits = VT.getScalarSizeInBits();
unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits();

// Ensure the narrowing element type is legal
if (!Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType()))
break;
assert(
Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType()));

// Ensure the extension's semantic is equivalent to rvv vzext or vsext.
if (ScalarBits != NarrowScalarBits * 2)
Expand Down Expand Up @@ -13732,14 +13731,11 @@ struct NodeExtensionHelper {
}

/// Check if \p Root supports any extension folding combines.
static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
static bool isSupportedRoot(const SDNode *Root) {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL: {
if (!TLI.isTypeLegal(Root->getValueType(0)))
return false;
return Root->getValueType(0).isScalableVector();
}
// Vector Widening Integer Add/Sub/Mul Instructions
Expand All @@ -13756,7 +13752,7 @@ struct NodeExtensionHelper {
case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFWSUB_W_VL:
return TLI.isTypeLegal(Root->getValueType(0));
return true;
default:
return false;
}
Expand All @@ -13765,9 +13761,10 @@ struct NodeExtensionHelper {
/// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
"unsupported root");
assert(isSupportedRoot(Root) && "Trying to build an helper with an "
"unsupported root");
assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
assert(DAG.getTargetLoweringInfo().isTypeLegal(Root->getValueType(0)));
OrigOperand = Root->getOperand(OperandIdx);

unsigned Opc = Root->getOpcode();
Expand Down Expand Up @@ -13817,7 +13814,7 @@ struct NodeExtensionHelper {
static std::pair<SDValue, SDValue>
getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(isSupportedRoot(Root, DAG) && "Unexpected root");
assert(isSupportedRoot(Root) && "Unexpected root");
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
Expand Down Expand Up @@ -14117,8 +14114,10 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
if (DCI.isBeforeLegalize())
return SDValue();

if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
if (!NodeExtensionHelper::isSupportedRoot(N))
return SDValue();

SmallVector<SDNode *> Worklist;
Expand All @@ -14129,7 +14128,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,

while (!Worklist.empty()) {
SDNode *Root = Worklist.pop_back_val();
if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
if (!NodeExtensionHelper::isSupportedRoot(Root))
return SDValue();

NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
Expand Down

0 comments on commit 58dd59a

Please sign in to comment.