Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. #76695

Closed
wants to merge 7 commits into from

Conversation

sun-jacobi
Copy link
Member

@sun-jacobi sun-jacobi commented Jan 2, 2024

This patch is an extension of #72340 and D133739, supporting floating-point extension.

Specifically, this patch works for the below optimization case:
Compiler Explorer

Source code

define void @vfwadd(ptr %dst1, ptr %dst2,  ptr %dst3, ptr %x, ptr %y, ptr %z) {
  %a = load <2 x float>, ptr %x, align 4
  %b = load <2 x float>, ptr %y, align 4
  %c = load <2 x float>, ptr %z, align 4
  %a2 = fpext <2 x float> %a to <2 x double>
  %b2 = fpext <2 x float> %b to <2 x double>
  %c2 = fpext <2 x float> %c to <2 x double>
  %add1 = fadd <2 x double> %a2, %b2
  %add2 = fadd <2 x double> %c2, %b2
  %add3 = fadd <2 x double> %c2, %a2
  store <2 x double> %add1, ptr %dst1
  store <2 x double> %add2, ptr %dst2   
  store <2 x double> %add3, ptr %dst3
  ret void
}

Before this patch

vfwadd:
        vsetivli        zero, 2, e32, mf2, ta, ma
        vle32.v v8, (a3)
        vle32.v v9, (a4)
        vle32.v v10, (a5)
        vfwcvt.f.f.v    v11, v8
        vfwcvt.f.f.v    v8, v9
        vfwcvt.f.f.v    v9, v10
        vsetvli zero, zero, e64, m1, ta, ma
        vfadd.vv        v10, v11, v8
        vfadd.vv        v8, v9, v8
        vfadd.vv        v9, v9, v11
        vse64.v v10, (a0)
        vse64.v v8, (a1)
        vse64.v v9, (a2)
        ret

After this patch

vfwadd:
	vsetivli	zero, 2, e32, mf2, ta, ma
	vle32.v	v8, (a3)
	vle32.v	v9, (a4)
	vle32.v	v10, (a5)
	vfwadd.vv	v11, v8, v9
	vfwadd.vv	v12, v10, v9
	vfwadd.vv	v9, v10, v8
	vse64.v	v11, (a0)
	vse64.v	v12, (a1)
	vse64.v	v9, (a2)
	ret

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 2, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

Changes

This patch is an extension of #72340 and D133739, supporting floating-point extension.


Patch is 31.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76695.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+288-157)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll (+4-10)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 51580d15451ca2..cf48cd7c378a50 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,13 +1374,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
-                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
-                       ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
+                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
   if (Subtarget.hasStdExtFOrZfinx())
-    setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
+    setTargetDAGCombine(
+        {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMAXNUM, ISD::FMINNUM});
 
   if (Subtarget.hasStdExtZbb())
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
@@ -12848,6 +12849,9 @@ namespace {
 // apply a combine.
 struct CombineResult;
 
+// Supported extension kind to be folded.
+enum class SupportExt { ZExt, SExt, FPExt };
+
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
 /// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -12878,6 +12882,8 @@ struct NodeExtensionHelper {
   /// instance, a splat constant (e.g., 3), would support being both sign and
   /// zero extended.
   bool SupportsSExt;
+  /// Records if this operand is like being floating-point extended.
+  bool SupportsFPExt;
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
@@ -12899,8 +12905,10 @@ struct NodeExtensionHelper {
     switch (OrigOperand.getOpcode()) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND:
+    case ISD::FP_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
     default:
       return OrigOperand;
@@ -12909,7 +12917,20 @@ struct NodeExtensionHelper {
 
   /// Check if this instance represents a splat.
   bool isSplat() const {
-    return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+    return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+           OrigOperand.getOpcode() == RISCVISD::VFMV_V_F_VL;
+  }
+
+  /// Get the extended opcode.
+  unsigned getExtOpc(SupportExt Ext) const {
+    switch (Ext) {
+    case SupportExt::ZExt:
+      return RISCVISD::VZEXT_VL;
+    case SupportExt::SExt:
+      return RISCVISD::VSEXT_VL;
+    case SupportExt::FPExt:
+      return RISCVISD::FP_EXTEND_VL;
+    }
   }
 
   /// Get or create a value that can feed \p Root with the given extension \p
@@ -12917,8 +12938,8 @@ struct NodeExtensionHelper {
   /// \see ::getSource().
   SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
                                 const RISCVSubtarget &Subtarget,
-                                std::optional<bool> SExt) const {
-    if (!SExt.has_value())
+                                std::optional<SupportExt> Ext) const {
+    if (!Ext.has_value())
       return OrigOperand;
 
     MVT NarrowVT = getNarrowType(Root);
@@ -12927,20 +12948,24 @@ struct NodeExtensionHelper {
     if (Source.getValueType() == NarrowVT)
       return Source;
 
-    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-
+    unsigned OrigOpc = OrigOperand.getOpcode();
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
     auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
-    switch (OrigOperand.getOpcode()) {
+    switch (OrigOpc) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND:
+    case ISD::FP_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL: {
+      unsigned ExtOpc = getExtOpc(*Ext);
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+    }
+    case RISCVISD::VFMV_V_F_VL:
     case RISCVISD::VMV_V_X_VL:
-      return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
-                         DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+      return DAG.getNode(OrigOpc, DL, NarrowVT, DAG.getUNDEF(NarrowVT),
+                         Source.getOperand(1), VL);
     default:
       // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
       // and that operand should already have the right NarrowVT so no
@@ -12959,62 +12984,157 @@ struct NodeExtensionHelper {
 
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
-    assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
-    MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
-                                    VT.getVectorElementCount());
-    return NarrowVT;
+    // Determine the minimum narrow size.
+    unsigned MinSize = VT.isInteger() ? 8 : 32;
+
+    assert(NarrowSize >= MinSize &&
+           "Trying to extend something we can't represent");
+
+    MVT NarrowScalarVT = VT.isInteger() ? MVT::getIntegerVT(NarrowSize)
+                                        : MVT::getFloatingPointVT(NarrowSize);
+    MVT NarrowVectorVT =
+        MVT::getVectorVT(NarrowScalarVT, VT.getVectorElementCount());
+    return NarrowVectorVT;
   }
 
-  /// Return the opcode required to materialize the folding of the sign
-  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
-  /// both operands for \p Opcode.
-  /// Put differently, get the opcode to materialize:
-  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
-  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
-  /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
-  static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+  /// Get full widening (2*SEW = SEW +/-/* SEW) signed integer add/sub/mul
+  /// opcode.
+  static unsigned getSignedFullWidenOpcode(unsigned Opcode) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
-    case RISCVISD::VWADDU_W_VL:
-      return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+      return RISCVISD::VWADD_VL;
     case ISD::MUL:
     case RISCVISD::MUL_VL:
-      return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+      return RISCVISD::VWMUL_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
+      return RISCVISD::VWSUB_VL;
+    default:
+      llvm_unreachable("Unexpected Opcode");
+    }
+  }
+
+  /// Get full widening (2*SEW = SEW +/-/* SEW) unsigned integer add/sub/mul
+  /// opcode.
+  static unsigned getUnsignedFullWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::ADD:
+    case RISCVISD::ADD_VL:
+    case RISCVISD::VWADDU_W_VL:
+      return RISCVISD::VWADDU_VL;
+    case ISD::MUL:
+    case RISCVISD::MUL_VL:
+      return RISCVISD::VWMULU_VL;
+    case ISD::SUB:
+    case RISCVISD::SUB_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+      return RISCVISD::VWSUBU_VL;
     default:
-      llvm_unreachable("Unexpected opcode");
+      llvm_unreachable("Unexpected Opcode");
+    }
+  }
+
+  /// Get full widening (2*SEW = SEW +/-/* SEW) FP add/sub/mul opcode.
+  static unsigned getFloatFullWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::FADD:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::VFWADD_W_VL:
+      return RISCVISD::VFWADD_VL;
+    case ISD::FSUB:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
+      return RISCVISD::VFWSUB_VL;
+    case ISD::FMUL:
+    case RISCVISD::FMUL_VL:
+      return RISCVISD::VFWMUL_VL;
+    default:
+      llvm_unreachable("Unexpected Opcode");
+    }
+  }
+
+  /// Return the opcode required to materialize the folding of the sign
+  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+  /// both operands for \p Opcode.
+  /// Put differently, get the opcode to materialize:
+  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+  /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
+  static unsigned getFullWidenOpcode(unsigned OrigOpcode, SupportExt Ext) {
+    switch (Ext) {
+    case SupportExt::SExt:
+      return getSignedFullWidenOpcode(OrigOpcode);
+    case SupportExt::ZExt:
+      return getUnsignedFullWidenOpcode(OrigOpcode);
+    case SupportExt::FPExt:
+      return getFloatFullWidenOpcode(OrigOpcode);
     }
   }
 
   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
   /// newOpcode(a, b).
-  static unsigned getSUOpcode(unsigned Opcode) {
+  static unsigned getSignedUnsignedWidenOpcode(unsigned Opcode) {
     assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
            "SU is only supported for MUL");
     return RISCVISD::VWMULSU_VL;
   }
 
-  /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
-  /// newOpcode(a, b).
-  static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+  /// Get half widening (2*SEW = 2*SEW +/- SEW) signed integer add/sub opcode.
+  static unsigned getSignedHalfWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::ADD:
+    case RISCVISD::ADD_VL:
+      return RISCVISD::VWADD_W_VL;
+    case ISD::SUB:
+    case RISCVISD::SUB_VL:
+      return RISCVISD::VWSUB_W_VL;
+    default:
+      llvm_unreachable("Unexpected opcode");
+    }
+  }
+
+  /// Get half widening (2*SEW = 2*SEW +/- SEW) unsigned integer add/sub opcode.
+  static unsigned getUnsignedHalfWidenOpcode(unsigned Opcode) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
-      return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+      return RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
-      return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+      return RISCVISD::VWSUBU_W_VL;
+    default:
+      llvm_unreachable("Unexpected opcode");
+    }
+  }
+
+  /// Get half widening (2*SEW = 2*SEW +/- SEW) FP add/sub opcode.
+  static unsigned getFloatHalfWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case RISCVISD::FADD_VL:
+      return RISCVISD::VFWADD_W_VL;
+    case RISCVISD::FSUB_VL:
+      return RISCVISD::VFWSUB_W_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
   }
 
+  /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
+  /// newOpcode(a, b).
+  static unsigned getHalfWidenOpcode(unsigned Opcode, SupportExt Ext) {
+    switch (Ext) {
+    case SupportExt::SExt:
+      return getSignedHalfWidenOpcode(Opcode);
+    case SupportExt::ZExt:
+      return getUnsignedHalfWidenOpcode(Opcode);
+    case SupportExt::FPExt:
+      return getFloatHalfWidenOpcode(Opcode);
+    }
+  }
+
   using CombineToTry = std::function<std::optional<CombineResult>(
       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
       const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
@@ -13029,15 +13149,18 @@ struct NodeExtensionHelper {
                               const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
+    SupportsFPExt = false;
     EnforceOneUse = true;
     CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
     switch (Opc) {
     case ISD::ZERO_EXTEND:
-    case ISD::SIGN_EXTEND: {
+    case ISD::SIGN_EXTEND:
+    case ISD::FP_EXTEND: {
       if (OrigOperand.getValueType().isVector()) {
         SupportsZExt = Opc == ISD::ZERO_EXTEND;
         SupportsSExt = Opc == ISD::SIGN_EXTEND;
+        SupportsFPExt = Opc == ISD::FP_EXTEND;
         SDLoc DL(Root);
         MVT VT = Root->getSimpleValueType(0);
         std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13054,6 +13177,12 @@ struct NodeExtensionHelper {
       Mask = OrigOperand.getOperand(1);
       VL = OrigOperand.getOperand(2);
       break;
+    case RISCVISD::FP_EXTEND_VL:
+      SupportsFPExt = true;
+      Mask = OrigOperand.getOperand(1);
+      VL = OrigOperand.getOperand(2);
+      break;
+    case RISCVISD::VFMV_V_F_VL:
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
@@ -13080,6 +13209,11 @@ struct NodeExtensionHelper {
       if (ScalarBits < EltBits)
         break;
 
+      if (VT.isFloatingPoint()) {
+        SupportsFPExt = true;
+        break;
+      }
+
       unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
       // If the narrow type cannot be expressed with a legal VMV,
       // this is not a valid candidate.
@@ -13103,7 +13237,10 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL: {
       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       if (!TLI.isTypeLegal(Root->getValueType(0)))
         return false;
@@ -13116,6 +13253,11 @@ struct NodeExtensionHelper {
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
       return true;
     default:
       return false;
@@ -13138,10 +13280,15 @@ struct NodeExtensionHelper {
     case RISCVISD::VWADDU_W_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
       if (OperandIdx == 1) {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
-        SupportsSExt = !SupportsZExt;
+        SupportsSExt =
+            Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+        SupportsFPExt =
+            Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
         std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
@@ -13174,7 +13321,10 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13197,15 +13347,23 @@ struct NodeExtensionHelper {
     switch (N->getOpcode()) {
     case ISD::ADD:
     case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FMUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
       return true;
     case ISD::SUB:
+    case ISD::FSUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
       return false;
     default:
       llvm_unreachable("Unexpected opcode");
@@ -13227,22 +13385,23 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  // No value means no extension is needed. If extension is needed, the value
-  // indicates if it needs to be sign extended.
-  std::optional<bool> SExtLHS;
-  std::optional<bool> SExtRHS;
-  /// Root of the combine.
-  SDNode *Root;
   /// LHS of the TargetOpcode.
   NodeExtensionHelper LHS;
+  /// Extension of the LHS
+  std::optional<SupportExt> ExtLHS;
   /// RHS of the TargetOpcode.
   NodeExtensionHelper RHS;
+  /// Extension of the RHS
+  std::optional<SupportExt> ExtRHS;
+  /// Root of the combine.
+  SDNode *Root;
 
-  CombineResult(unsigned TargetOpcode, SDNode *Root,
-                const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
-                const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
-        Root(Root), LHS(LHS), RHS(RHS) {}
+  CombineResult(
+      unsigned TargetOpcode, SDNode *Root,
+      std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> LHS,
+      std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> RHS)
+      : TargetOpcode(TargetOpcode), LHS(LHS.first), ExtLHS(LHS.second),
+        RHS(RHS.first), ExtRHS(RHS.second), Root(Root) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -13259,12 +13418,15 @@ struct CombineResult {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtRHS),
                        Merge, Mask, VL);
   }
 };
@@ -13279,24 +13441,30 @@ struct CombineResult {
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
-                                 const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt, SelectionDAG &DAG,
-                                 const RISCVSubtarget &Subtarget) {
-  assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+    SDNode *Root, const NodeExtensionHelper &LHS,
+    const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+    bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+  assert((AllowSExt || AllowZExt || AllowFPExt) &&
+         "Forgot to set what you want?");
   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+    return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+                             Root->getOpcode(), SupportExt::ZExt),
+                         Root, {LHS, SupportExt::ZExt},
+                         {RHS, SupportExt::ZExt});
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/true),
-                         Root, LHS, /*SExtLHS=*/true, RHS,
-                         /*SExtRHS=*/true);
+    return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+                             Root->getOpcode(), SupportExt::SExt),
+                         Root, {LHS, SupportExt::SExt},
+                         {RHS, SupportExt::SExt});
+  if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+    return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+                             Root->getOpcode(), SupportExt::FPExt),
+                         Root, {LHS, SupportExt::FPExt},
+                         {RHS, SupportExt::FPExt});
   return std::nullopt;
 }
 
@@ -13311,7 +13479,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                              const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*Allow...
[truncated]

@sun-jacobi
Copy link
Member Author

sun-jacobi commented Jan 3, 2024

Turning to draft due to #72340 (comment)

@sun-jacobi sun-jacobi marked this pull request as draft January 3, 2024 05:38
@sun-jacobi sun-jacobi closed this Feb 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants