-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. #76695
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-backend-risc-v Author: Chia (sun-jacobi) ChangesThis patch is an extension of #72340 and D133739, supporting floating-point extension. Patch is 31.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76695.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 51580d15451ca2..cf48cd7c378a50 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,13 +1374,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
- ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
- ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+ ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
+ ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
if (Subtarget.is64Bit())
setTargetDAGCombine(ISD::SRA);
if (Subtarget.hasStdExtFOrZfinx())
- setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
+ setTargetDAGCombine(
+ {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMAXNUM, ISD::FMINNUM});
if (Subtarget.hasStdExtZbb())
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
@@ -12848,6 +12849,9 @@ namespace {
// apply a combine.
struct CombineResult;
+// Supported extension kind to be folded.
+enum class SupportExt { ZExt, SExt, FPExt };
+
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -12878,6 +12882,8 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
+ /// Records if this operand is like being floating-point extended.
+ bool SupportsFPExt;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
@@ -12899,8 +12905,10 @@ struct NodeExtensionHelper {
switch (OrigOperand.getOpcode()) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
+ case ISD::FP_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
default:
return OrigOperand;
@@ -12909,7 +12917,20 @@ struct NodeExtensionHelper {
/// Check if this instance represents a splat.
bool isSplat() const {
- return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+ return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+ OrigOperand.getOpcode() == RISCVISD::VFMV_V_F_VL;
+ }
+
+ /// Get the extended opcode.
+ unsigned getExtOpc(SupportExt Ext) const {
+ switch (Ext) {
+ case SupportExt::ZExt:
+ return RISCVISD::VZEXT_VL;
+ case SupportExt::SExt:
+ return RISCVISD::VSEXT_VL;
+ case SupportExt::FPExt:
+ return RISCVISD::FP_EXTEND_VL;
+ }
}
/// Get or create a value that can feed \p Root with the given extension \p
@@ -12917,8 +12938,8 @@ struct NodeExtensionHelper {
/// \see ::getSource().
SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
- std::optional<bool> SExt) const {
- if (!SExt.has_value())
+ std::optional<SupportExt> Ext) const {
+ if (!Ext.has_value())
return OrigOperand;
MVT NarrowVT = getNarrowType(Root);
@@ -12927,20 +12948,24 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;
- unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-
+ unsigned OrigOpc = OrigOperand.getOpcode();
// If we need an extension, we should be changing the type.
SDLoc DL(Root);
auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
- switch (OrigOperand.getOpcode()) {
+ switch (OrigOpc) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
+ case ISD::FP_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL: {
+ unsigned ExtOpc = getExtOpc(*Ext);
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+ }
+ case RISCVISD::VFMV_V_F_VL:
case RISCVISD::VMV_V_X_VL:
- return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
- DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+ return DAG.getNode(OrigOpc, DL, NarrowVT, DAG.getUNDEF(NarrowVT),
+ Source.getOperand(1), VL);
default:
// Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
// and that operand should already have the right NarrowVT so no
@@ -12959,62 +12984,157 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
- MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
- VT.getVectorElementCount());
- return NarrowVT;
+ // Determine the minimum narrow size.
+ unsigned MinSize = VT.isInteger() ? 8 : 32;
+
+ assert(NarrowSize >= MinSize &&
+ "Trying to extend something we can't represent");
+
+ MVT NarrowScalarVT = VT.isInteger() ? MVT::getIntegerVT(NarrowSize)
+ : MVT::getFloatingPointVT(NarrowSize);
+ MVT NarrowVectorVT =
+ MVT::getVectorVT(NarrowScalarVT, VT.getVectorElementCount());
+ return NarrowVectorVT;
}
- /// Return the opcode required to materialize the folding of the sign
- /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
- /// both operands for \p Opcode.
- /// Put differently, get the opcode to materialize:
- /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
- /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
- /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
- static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+ /// Get full widening (2*SEW = SEW +/-/* SEW) signed integer add/sub/mul
+ /// opcode.
+ static unsigned getSignedFullWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
- case RISCVISD::VWADDU_W_VL:
- return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+ return RISCVISD::VWADD_VL;
case ISD::MUL:
case RISCVISD::MUL_VL:
- return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ return RISCVISD::VWMUL_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
+ return RISCVISD::VWSUB_VL;
+ default:
+ llvm_unreachable("Unexpected Opcode");
+ }
+ }
+
+ /// Get full widening (2*SEW = SEW +/-/* SEW) unsigned integer add/sub/mul
+ /// opcode.
+ static unsigned getUnsignedFullWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case ISD::ADD:
+ case RISCVISD::ADD_VL:
+ case RISCVISD::VWADDU_W_VL:
+ return RISCVISD::VWADDU_VL;
+ case ISD::MUL:
+ case RISCVISD::MUL_VL:
+ return RISCVISD::VWMULU_VL;
+ case ISD::SUB:
+ case RISCVISD::SUB_VL:
case RISCVISD::VWSUBU_W_VL:
- return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+ return RISCVISD::VWSUBU_VL;
default:
- llvm_unreachable("Unexpected opcode");
+ llvm_unreachable("Unexpected Opcode");
+ }
+ }
+
+ /// Get full widening (2*SEW = SEW +/-/* SEW) FP add/sub/mul opcode.
+ static unsigned getFloatFullWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case ISD::FADD:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::VFWADD_W_VL:
+ return RISCVISD::VFWADD_VL;
+ case ISD::FSUB:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ return RISCVISD::VFWSUB_VL;
+ case ISD::FMUL:
+ case RISCVISD::FMUL_VL:
+ return RISCVISD::VFWMUL_VL;
+ default:
+ llvm_unreachable("Unexpected Opcode");
+ }
+ }
+
+ /// Return the opcode required to materialize the folding of the sign
+ /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+ /// both operands for \p Opcode.
+ /// Put differently, get the opcode to materialize:
+ /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+ /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+ /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
+ static unsigned getFullWidenOpcode(unsigned OrigOpcode, SupportExt Ext) {
+ switch (Ext) {
+ case SupportExt::SExt:
+ return getSignedFullWidenOpcode(OrigOpcode);
+ case SupportExt::ZExt:
+ return getUnsignedFullWidenOpcode(OrigOpcode);
+ case SupportExt::FPExt:
+ return getFloatFullWidenOpcode(OrigOpcode);
}
}
/// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
/// newOpcode(a, b).
- static unsigned getSUOpcode(unsigned Opcode) {
+ static unsigned getSignedUnsignedWidenOpcode(unsigned Opcode) {
assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
"SU is only supported for MUL");
return RISCVISD::VWMULSU_VL;
}
- /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
- /// newOpcode(a, b).
- static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+ /// Get half widening (2*SEW = 2*SEW +/- SEW) signed integer add/sub opcode.
+ static unsigned getSignedHalfWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case ISD::ADD:
+ case RISCVISD::ADD_VL:
+ return RISCVISD::VWADD_W_VL;
+ case ISD::SUB:
+ case RISCVISD::SUB_VL:
+ return RISCVISD::VWSUB_W_VL;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+ }
+
+ /// Get half widening (2*SEW = 2*SEW +/- SEW) unsigned integer add/sub opcode.
+ static unsigned getUnsignedHalfWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
- return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+ return RISCVISD::VWADDU_W_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
- return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+ return RISCVISD::VWSUBU_W_VL;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+ }
+
+ /// Get half widening (2*SEW = 2*SEW +/- SEW) FP add/sub opcode.
+ static unsigned getFloatHalfWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case RISCVISD::FADD_VL:
+ return RISCVISD::VFWADD_W_VL;
+ case RISCVISD::FSUB_VL:
+ return RISCVISD::VFWSUB_W_VL;
default:
llvm_unreachable("Unexpected opcode");
}
}
+ /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
+ /// newOpcode(a, b).
+ static unsigned getHalfWidenOpcode(unsigned Opcode, SupportExt Ext) {
+ switch (Ext) {
+ case SupportExt::SExt:
+ return getSignedHalfWidenOpcode(Opcode);
+ case SupportExt::ZExt:
+ return getUnsignedHalfWidenOpcode(Opcode);
+ case SupportExt::FPExt:
+ return getFloatHalfWidenOpcode(Opcode);
+ }
+ }
+
using CombineToTry = std::function<std::optional<CombineResult>(
SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
@@ -13029,15 +13149,18 @@ struct NodeExtensionHelper {
const RISCVSubtarget &Subtarget) {
SupportsZExt = false;
SupportsSExt = false;
+ SupportsFPExt = false;
EnforceOneUse = true;
CheckMask = true;
unsigned Opc = OrigOperand.getOpcode();
switch (Opc) {
case ISD::ZERO_EXTEND:
- case ISD::SIGN_EXTEND: {
+ case ISD::SIGN_EXTEND:
+ case ISD::FP_EXTEND: {
if (OrigOperand.getValueType().isVector()) {
SupportsZExt = Opc == ISD::ZERO_EXTEND;
SupportsSExt = Opc == ISD::SIGN_EXTEND;
+ SupportsFPExt = Opc == ISD::FP_EXTEND;
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13054,6 +13177,12 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case RISCVISD::FP_EXTEND_VL:
+ SupportsFPExt = true;
+ Mask = OrigOperand.getOperand(1);
+ VL = OrigOperand.getOperand(2);
+ break;
+ case RISCVISD::VFMV_V_F_VL:
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
@@ -13080,6 +13209,11 @@ struct NodeExtensionHelper {
if (ScalarBits < EltBits)
break;
+ if (VT.isFloatingPoint()) {
+ SupportsFPExt = true;
+ break;
+ }
+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
// If the narrow type cannot be expressed with a legal VMV,
// this is not a valid candidate.
@@ -13103,7 +13237,10 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL: {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isTypeLegal(Root->getValueType(0)))
return false;
@@ -13116,6 +13253,11 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
return true;
default:
return false;
@@ -13138,10 +13280,15 @@ struct NodeExtensionHelper {
case RISCVISD::VWADDU_W_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
if (OperandIdx == 1) {
SupportsZExt =
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
- SupportsSExt = !SupportsZExt;
+ SupportsSExt =
+ Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+ SupportsFPExt =
+ Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
CheckMask = true;
// There's no existing extension here, so we don't have to worry about
@@ -13174,7 +13321,10 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13197,15 +13347,23 @@ struct NodeExtensionHelper {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FMUL:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
return true;
case ISD::SUB:
+ case ISD::FSUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
return false;
default:
llvm_unreachable("Unexpected opcode");
@@ -13227,22 +13385,23 @@ struct NodeExtensionHelper {
struct CombineResult {
/// Opcode to be generated when materializing the combine.
unsigned TargetOpcode;
- // No value means no extension is needed. If extension is needed, the value
- // indicates if it needs to be sign extended.
- std::optional<bool> SExtLHS;
- std::optional<bool> SExtRHS;
- /// Root of the combine.
- SDNode *Root;
/// LHS of the TargetOpcode.
NodeExtensionHelper LHS;
+ /// Extension of the LHS
+ std::optional<SupportExt> ExtLHS;
/// RHS of the TargetOpcode.
NodeExtensionHelper RHS;
+ /// Extension of the RHS
+ std::optional<SupportExt> ExtRHS;
+ /// Root of the combine.
+ SDNode *Root;
- CombineResult(unsigned TargetOpcode, SDNode *Root,
- const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
- const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
- : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
- Root(Root), LHS(LHS), RHS(RHS) {}
+ CombineResult(
+ unsigned TargetOpcode, SDNode *Root,
+ std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> LHS,
+ std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> RHS)
+ : TargetOpcode(TargetOpcode), LHS(LHS.first), ExtLHS(LHS.second),
+ RHS(RHS.first), ExtRHS(RHS.second), Root(Root) {}
/// Return a value that uses TargetOpcode and that can be used to replace
/// Root.
@@ -13259,12 +13418,15 @@ struct CombineResult {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
- LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
- RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+ LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtLHS),
+ RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtRHS),
Merge, Mask, VL);
}
};
@@ -13279,24 +13441,30 @@ struct CombineResult {
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, bool AllowSExt,
- bool AllowZExt, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+ SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+ bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+ assert((AllowSExt || AllowZExt || AllowFPExt) &&
+ "Forgot to set what you want?");
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+ return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+ Root->getOpcode(), SupportExt::ZExt),
+ Root, {LHS, SupportExt::ZExt},
+ {RHS, SupportExt::ZExt});
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/true, RHS,
- /*SExtRHS=*/true);
+ return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+ Root->getOpcode(), SupportExt::SExt),
+ Root, {LHS, SupportExt::SExt},
+ {RHS, SupportExt::SExt});
+ if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+ return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+ Root->getOpcode(), SupportExt::FPExt),
+ Root, {LHS, SupportExt::FPExt},
+ {RHS, SupportExt::FPExt});
return std::nullopt;
}
@@ -13311,7 +13479,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*Allow...
[truncated]
|
…per::isSupportedRoot.
Turning to draft due to #72340 (comment) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This patch is an extension of #72340 and D133739, supporting floating-point extension.
Specifically, this patch works for the below optimization case:
Compiler Explorer
Source code
Before this patch
After this patch