diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 8d66c9f317e1e..a8e15b2b1760f 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2156,54 +2156,49 @@ bool TargetLowering::SimplifyDemandedBits( } break; } - case ISD::SMIN: { - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - // If we're only wanting the signbit, then we can simplify to OR node. - // TODO: Extend this based on ComputeNumSignBits. - if (DemandedBits.isSignMask()) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1)); - break; - } - case ISD::SMAX: { - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - // If we're only wanting the signbit, then we can simplify to AND node. - // TODO: Extend this based on ComputeNumSignBits. - if (DemandedBits.isSignMask()) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1)); - break; - } - case ISD::UMIN: { - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - // If we're only wanting the msb, then we can simplify to AND node. - if (DemandedBits.isSignMask()) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1)); - // Check if one arg is always less than (or equal) to the other arg. - KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1); - KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1); - Known = KnownBits::umin(Known0, Known1); - if (std::optional IsULE = KnownBits::ule(Known0, Known1)) - return TLO.CombineTo(Op, *IsULE ? Op0 : Op1); - if (std::optional IsULT = KnownBits::ult(Known0, Known1)) - return TLO.CombineTo(Op, *IsULT ? Op0 : Op1); - break; - } + case ISD::SMIN: + case ISD::SMAX: + case ISD::UMIN: case ISD::UMAX: { + unsigned Opc = Op.getOpcode(); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - // If we're only wanting the msb, then we can simplify to OR node. - if (DemandedBits.isSignMask()) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1)); - // Check if one arg is always greater than (or equal) to the other arg. + + // If we're only demanding signbits, then we can simplify to OR/AND node. + unsigned BitOp = + (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND; + unsigned NumSignBits = + std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1), + TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1)); + unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero(); + if (NumSignBits >= NumDemandedUpperBits) + return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1)); + + // Check if one arg is always less/greater than (or equal) to the other arg. KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1); KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1); - Known = KnownBits::umax(Known0, Known1); - if (std::optional IsUGE = KnownBits::uge(Known0, Known1)) - return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1); - if (std::optional IsUGT = KnownBits::ugt(Known0, Known1)) - return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1); + switch (Opc) { + case ISD::SMIN: + // TODO: Add KnownBits::sle/slt handling. + break; + case ISD::SMAX: + // TODO: Add KnownBits::sge/sgt handling. + break; + case ISD::UMIN: + if (std::optional IsULE = KnownBits::ule(Known0, Known1)) + return TLO.CombineTo(Op, *IsULE ? Op0 : Op1); + if (std::optional IsULT = KnownBits::ult(Known0, Known1)) + return TLO.CombineTo(Op, *IsULT ? Op0 : Op1); + Known = KnownBits::umin(Known0, Known1); + break; + case ISD::UMAX: + if (std::optional IsUGE = KnownBits::uge(Known0, Known1)) + return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1); + if (std::optional IsUGT = KnownBits::ugt(Known0, Known1)) + return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1); + Known = KnownBits::umax(Known0, Known1); + break; + } break; } case ISD::BITREVERSE: { diff --git a/llvm/test/CodeGen/X86/known-signbits-vector.ll b/llvm/test/CodeGen/X86/known-signbits-vector.ll index de7186584e67a..e500801b69c4d 100644 --- a/llvm/test/CodeGen/X86/known-signbits-vector.ll +++ b/llvm/test/CodeGen/X86/known-signbits-vector.ll @@ -483,28 +483,24 @@ define <4 x float> @signbits_ashr_sext_select_shuffle_sitofp(<4 x i64> %a0, <4 x define <4 x i32> @signbits_mask_ashr_smax(<4 x i32> %a0, <4 x i32> %a1) { ; X86-LABEL: signbits_mask_ashr_smax: ; X86: # %bb.0: -; X86-NEXT: vpsrad $25, %xmm0, %xmm0 -; X86-NEXT: vpsrad $25, %xmm1, %xmm1 -; X86-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0 +; X86-NEXT: vpand %xmm1, %xmm0, %xmm0 ; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X86-NEXT: vpsrad $25, %xmm0, %xmm0 ; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0 ; X86-NEXT: retl ; ; X64-AVX1-LABEL: signbits_mask_ashr_smax: ; X64-AVX1: # %bb.0: -; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 -; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1 -; X64-AVX1-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0 +; X64-AVX1-NEXT: vpand %xmm1, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX1-NEXT: retq ; ; X64-AVX2-LABEL: signbits_mask_ashr_smax: ; X64-AVX2: # %bb.0: -; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0] -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0 -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1 -; X64-AVX2-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpand %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0 ; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX2-NEXT: retq @@ -521,28 +517,24 @@ declare <4 x i32> @llvm.smax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone define <4 x i32> @signbits_mask_ashr_smin(<4 x i32> %a0, <4 x i32> %a1) { ; X86-LABEL: signbits_mask_ashr_smin: ; X86: # %bb.0: -; X86-NEXT: vpsrad $25, %xmm0, %xmm0 -; X86-NEXT: vpsrad $25, %xmm1, %xmm1 -; X86-NEXT: vpminsd %xmm1, %xmm0, %xmm0 +; X86-NEXT: vpor %xmm1, %xmm0, %xmm0 ; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X86-NEXT: vpsrad $25, %xmm0, %xmm0 ; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0 ; X86-NEXT: retl ; ; X64-AVX1-LABEL: signbits_mask_ashr_smin: ; X64-AVX1: # %bb.0: -; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 -; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1 -; X64-AVX1-NEXT: vpminsd %xmm1, %xmm0, %xmm0 +; X64-AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX1-NEXT: retq ; ; X64-AVX2-LABEL: signbits_mask_ashr_smin: ; X64-AVX2: # %bb.0: -; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0] -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0 -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1 -; X64-AVX2-NEXT: vpminsd %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0 ; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX2-NEXT: retq @@ -559,28 +551,24 @@ declare <4 x i32> @llvm.smin.v4i32(<4 x i32>, <4 x i32>) nounwind readnone define <4 x i32> @signbits_mask_ashr_umax(<4 x i32> %a0, <4 x i32> %a1) { ; X86-LABEL: signbits_mask_ashr_umax: ; X86: # %bb.0: -; X86-NEXT: vpsrad $25, %xmm0, %xmm0 -; X86-NEXT: vpsrad $25, %xmm1, %xmm1 -; X86-NEXT: vpmaxud %xmm1, %xmm0, %xmm0 +; X86-NEXT: vpor %xmm1, %xmm0, %xmm0 ; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X86-NEXT: vpsrad $25, %xmm0, %xmm0 ; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0 ; X86-NEXT: retl ; ; X64-AVX1-LABEL: signbits_mask_ashr_umax: ; X64-AVX1: # %bb.0: -; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 -; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1 -; X64-AVX1-NEXT: vpmaxud %xmm1, %xmm0, %xmm0 +; X64-AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX1-NEXT: retq ; ; X64-AVX2-LABEL: signbits_mask_ashr_umax: ; X64-AVX2: # %bb.0: -; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0] -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0 -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1 -; X64-AVX2-NEXT: vpmaxud %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0 ; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX2-NEXT: retq @@ -597,28 +585,24 @@ declare <4 x i32> @llvm.umax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone define <4 x i32> @signbits_mask_ashr_umin(<4 x i32> %a0, <4 x i32> %a1) { ; X86-LABEL: signbits_mask_ashr_umin: ; X86: # %bb.0: -; X86-NEXT: vpsrad $25, %xmm0, %xmm0 -; X86-NEXT: vpsrad $25, %xmm1, %xmm1 -; X86-NEXT: vpminud %xmm1, %xmm0, %xmm0 +; X86-NEXT: vpand %xmm1, %xmm0, %xmm0 ; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X86-NEXT: vpsrad $25, %xmm0, %xmm0 ; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0 ; X86-NEXT: retl ; ; X64-AVX1-LABEL: signbits_mask_ashr_umin: ; X64-AVX1: # %bb.0: -; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 -; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1 -; X64-AVX1-NEXT: vpminud %xmm1, %xmm0, %xmm0 +; X64-AVX1-NEXT: vpand %xmm1, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] +; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX1-NEXT: retq ; ; X64-AVX2-LABEL: signbits_mask_ashr_umin: ; X64-AVX2: # %bb.0: -; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0] -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0 -; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1 -; X64-AVX2-NEXT: vpminud %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpand %xmm1, %xmm0, %xmm0 +; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0 ; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0 ; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 ; X64-AVX2-NEXT: retq