diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index b426f1a7b3791..c9727a3e5a8db 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13552,7 +13552,7 @@ enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; /// NodeExtensionHelper for `a` and one for `b`. /// /// This class abstracts away how the extension is materialized and -/// how its Mask, VL, number of users affect the combines. +/// how its number of users affect the combines. /// /// In particular: /// - VWADD_W is conceptually == add(op0, sext(op1)) @@ -13576,15 +13576,6 @@ struct NodeExtensionHelper { /// This boolean captures whether we care if this operand would still be /// around after the folding happens. bool EnforceOneUse; - /// Records if this operand's mask needs to match the mask of the operation - /// that it will fold into. - bool CheckMask; - /// Value of the Mask for this operand. - /// It may be SDValue(). - SDValue Mask; - /// Value of the vector length operand. - /// It may be SDValue(). - SDValue VL; /// Original value that this NodeExtensionHelper represents. SDValue OrigOperand; @@ -13789,8 +13780,10 @@ struct NodeExtensionHelper { SupportsSExt = false; SupportsFPExt = false; EnforceOneUse = true; - CheckMask = true; unsigned Opc = OrigOperand.getOpcode(); + // For the nodes we handle below, we end up using their inputs directly: see + // getSource(). However since they either don't have a passthru or we check + // that their passthru is undef, we can safely ignore their mask and VL. switch (Opc) { case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND: { @@ -13806,32 +13799,21 @@ struct NodeExtensionHelper { SupportsZExt = Opc == ISD::ZERO_EXTEND; SupportsSExt = Opc == ISD::SIGN_EXTEND; - - SDLoc DL(Root); - std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget); break; } case RISCVISD::VZEXT_VL: SupportsZExt = true; - Mask = OrigOperand.getOperand(1); - VL = OrigOperand.getOperand(2); break; case RISCVISD::VSEXT_VL: SupportsSExt = true; - Mask = OrigOperand.getOperand(1); - VL = OrigOperand.getOperand(2); break; case RISCVISD::FP_EXTEND_VL: SupportsFPExt = true; - Mask = OrigOperand.getOperand(1); - VL = OrigOperand.getOperand(2); break; case RISCVISD::VMV_V_X_VL: { // Historically, we didn't care about splat values not disappearing during // combines. EnforceOneUse = false; - CheckMask = false; - VL = OrigOperand.getOperand(2); // The operand is a splat of a scalar. @@ -13930,8 +13912,6 @@ struct NodeExtensionHelper { Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL; SupportsFPExt = Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL; - std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget); - CheckMask = true; // There's no existing extension here, so we don't have to worry about // making sure it gets removed. EnforceOneUse = false; @@ -13944,16 +13924,6 @@ struct NodeExtensionHelper { } } - /// Check if this operand is compatible with the given vector length \p VL. - bool isVLCompatible(SDValue VL) const { - return this->VL != SDValue() && this->VL == VL; - } - - /// Check if this operand is compatible with the given \p Mask. - bool isMaskCompatible(SDValue Mask) const { - return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask); - } - /// Helper function to get the Mask and VL from \p Root. static std::pair getMaskAndVL(const SDNode *Root, SelectionDAG &DAG, @@ -13973,13 +13943,6 @@ struct NodeExtensionHelper { } } - /// Check if the Mask and VL of this operand are compatible with \p Root. - bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) const { - auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget); - return isMaskCompatible(Mask) && isVLCompatible(VL); - } - /// Helper function to check if \p N is commutative with respect to the /// foldings that are supported by this class. static bool isCommutative(const SDNode *N) { @@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, uint8_t AllowExtMask, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || - !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) - return std::nullopt; if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt) return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()), Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS, @@ -14120,9 +14080,6 @@ static std::optional canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) - return std::nullopt; - if (RHS.SupportsFPExt) return CombineResult( NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt), @@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS, if (!LHS.SupportsSExt || !RHS.SupportsZExt) return std::nullopt; - if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || - !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) - return std::nullopt; return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()), Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS, /*RHSExt=*/{ExtKind::ZExt}); diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll index a0b7726d3cb5e..433f5d2717e48 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll @@ -41,3 +41,61 @@ declare @llvm.vp.sext.nxv2i32.nxv2i8(, @llvm.vp.zext.nxv2i32.nxv2i8(, , i32) declare @llvm.vp.add.nxv2i32(, , , i32) declare @llvm.vp.merge.nxv2i32(, , , i32) + +define @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16( %x, %y, %m, i32 signext %evl) { +; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a0, a0, 32 +; CHECK-NEXT: srli a0, a0, 32 +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %x.sext = call @llvm.vp.sext.nxv2i32.nxv2i16( %x, %m, i32 %evl) + %y.sext = call @llvm.vp.sext.nxv2i32.nxv2i16( %y, %m, i32 %evl) + %add = call @llvm.vp.add.nxv2i32( %x.sext, %y.sext, %m, i32 %evl) + ret %add +} + +define @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16( %x, %y, %m, i32 signext %evl) { +; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a0, a0, 32 +; CHECK-NEXT: srli a0, a0, 32 +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %x.sext = call @llvm.vp.sext.nxv2i32.nxv2i16( %x, %m, i32 %evl) + %y.sext = sext %y to + %add = call @llvm.vp.add.nxv2i32( %x.sext, %y.sext, %m, i32 %evl) + ret %add +} + +define @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16( %x, %y, %m, i32 signext %evl) { +; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a0, a0, 32 +; CHECK-NEXT: srli a0, a0, 32 +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma +; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %x.sext = sext %x to + %y.sext = sext %y to + %add = call @llvm.vp.add.nxv2i32( %x.sext, %y.sext, %m, i32 %evl) + ret %add +} + +define @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16( %x, %y, %m, i32 signext %evl) { +; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma +; CHECK-NEXT: vwadd.vv v10, v8, v9 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %x.sext = call @llvm.vp.sext.nxv2i32.nxv2i16( %x, %m, i32 %evl) + %y.sext = call @llvm.vp.sext.nxv2i32.nxv2i16( %y, %m, i32 %evl) + %add = add %x.sext, %y.sext + ret %add +}