diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index fa37306a49990..71759fdde9af0 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13657,9 +13657,8 @@ struct NodeExtensionHelper { unsigned ScalarBits = VT.getScalarSizeInBits(); unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits(); - // Ensure the narrowing element type is legal - if (!Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType())) - break; + assert( + Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType())); // Ensure the extension's semantic is equivalent to rvv vzext or vsext. if (ScalarBits != NarrowScalarBits * 2) @@ -13732,14 +13731,11 @@ struct NodeExtensionHelper { } /// Check if \p Root supports any extension folding combines. - static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + static bool isSupportedRoot(const SDNode *Root) { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: case ISD::MUL: { - if (!TLI.isTypeLegal(Root->getValueType(0))) - return false; return Root->getValueType(0).isScalableVector(); } // Vector Widening Integer Add/Sub/Mul Instructions @@ -13756,7 +13752,7 @@ struct NodeExtensionHelper { case RISCVISD::FMUL_VL: case RISCVISD::VFWADD_W_VL: case RISCVISD::VFWSUB_W_VL: - return TLI.isTypeLegal(Root->getValueType(0)); + return true; default: return false; } @@ -13765,9 +13761,10 @@ struct NodeExtensionHelper { /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx). NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an " - "unsupported root"); + assert(isSupportedRoot(Root) && "Trying to build an helper with an " + "unsupported root"); assert(OperandIdx < 2 && "Requesting something else than LHS or RHS"); + assert(DAG.getTargetLoweringInfo().isTypeLegal(Root->getValueType(0))); OrigOperand = Root->getOperand(OperandIdx); unsigned Opc = Root->getOpcode(); @@ -13817,7 +13814,7 @@ struct NodeExtensionHelper { static std::pair getMaskAndVL(const SDNode *Root, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(isSupportedRoot(Root, DAG) && "Unexpected root"); + assert(isSupportedRoot(Root) && "Unexpected root"); switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: @@ -14117,8 +14114,10 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; + if (DCI.isBeforeLegalize()) + return SDValue(); - if (!NodeExtensionHelper::isSupportedRoot(N, DAG)) + if (!NodeExtensionHelper::isSupportedRoot(N)) return SDValue(); SmallVector Worklist; @@ -14129,7 +14128,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, while (!Worklist.empty()) { SDNode *Root = Worklist.pop_back_val(); - if (!NodeExtensionHelper::isSupportedRoot(Root, DAG)) + if (!NodeExtensionHelper::isSupportedRoot(Root)) return SDValue(); NodeExtensionHelper LHS(N, 0, DAG, Subtarget);