diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 975498f6d233c..aa4969003e788 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -4679,5 +4679,31 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { cast(TrueVal)->getParamAlign(0).valueOrOne(), CondVal, FalseVal)); + // Canonicalize sign function ashr pattern: select (icmp slt X, 1), ashr X, + // bitwidth-1, 1 -> scmp(X, 0) + // Also handles: select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0) + unsigned BitWidth = SI.getType()->getScalarSizeInBits(); + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; + + // Canonicalize sign function ashr patterns: + // select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0) + // select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0) + if (match(&SI, m_Select(m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)), + m_Value(TrueVal), m_Value(FalseVal))) && + ((Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_One()) && + match(TrueVal, + m_AShr(m_Specific(CmpLHS), m_SpecificInt(BitWidth - 1))) && + match(FalseVal, m_One())) || + (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_Zero()) && + match(TrueVal, m_One()) && + match(FalseVal, + m_AShr(m_Specific(CmpLHS), m_SpecificInt(BitWidth - 1)))))) { + + Function *Scmp = Intrinsic::getOrInsertDeclaration( + SI.getModule(), Intrinsic::scmp, {SI.getType(), SI.getType()}); + return CallInst::Create(Scmp, {CmpLHS, ConstantInt::get(SI.getType(), 0)}); + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll index c0be5b986b7fd..2ae062cdc7033 100644 --- a/llvm/test/Transforms/InstCombine/scmp.ll +++ b/llvm/test/Transforms/InstCombine/scmp.ll @@ -519,9 +519,7 @@ define <3 x i2> @scmp_unary_shuffle_ops(<3 x i8> %x, <3 x i8> %y) { define i32 @scmp_sgt_slt(i32 %a) { ; CHECK-LABEL: define i32 @scmp_sgt_slt( ; CHECK-SAME: i32 [[A:%.*]]) { -; CHECK-NEXT: [[A_LOBIT:%.*]] = ashr i32 [[A]], 31 -; CHECK-NEXT: [[CMP_INV:%.*]] = icmp slt i32 [[A]], 1 -; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP_INV]], i32 [[A_LOBIT]], i32 1 +; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0) ; CHECK-NEXT: ret i32 [[RETVAL_0]] ; %cmp = icmp sgt i32 %a, 0 @@ -747,3 +745,55 @@ define i8 @scmp_from_select_eq_and_gt_neg3(i32 %x, i32 %y) { %r = select i1 %eq, i8 0, i8 %sel1 ret i8 %r } + +define i32 @scmp_ashr(i32 %a) { +; CHECK-LABEL: define i32 @scmp_ashr( +; CHECK-SAME: i32 [[A:%.*]]) { +; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0) +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; + %a.lobit = ashr i32 %a, 31 + %cmp.inv = icmp slt i32 %a, 1 + %retval.0 = select i1 %cmp.inv, i32 %a.lobit, i32 1 + ret i32 %retval.0 +} + +; select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0) +define i8 @scmp_ashr_sgt_pattern(i8 %a) { +; CHECK-LABEL: define i8 @scmp_ashr_sgt_pattern( +; CHECK-SAME: i8 [[A:%.*]]) { +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i8(i8 [[A]], i8 0) +; CHECK-NEXT: ret i8 [[R]] +; + %a.lobit = ashr i8 %a, 7 + %cmp = icmp sgt i8 %a, 0 + %retval = select i1 %cmp, i8 1, i8 %a.lobit + ret i8 %retval +} + +; select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0) +define i8 @scmp_ashr_slt_pattern(i8 %a) { +; CHECK-LABEL: define i8 @scmp_ashr_slt_pattern( +; CHECK-SAME: i8 [[A:%.*]]) { +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i8(i8 [[A]], i8 0) +; CHECK-NEXT: ret i8 [[R]] +; + %a.lobit = ashr i8 %a, 7 + %cmp = icmp slt i8 %a, 1 + %retval = select i1 %cmp, i8 %a.lobit, i8 1 + ret i8 %retval +} + +define i8 @scmp_ashr_slt_pattern_neg(i8 %a) { +; CHECK-LABEL: define i8 @scmp_ashr_slt_pattern_neg( +; CHECK-SAME: i8 [[A:%.*]]) { +; CHECK-NEXT: [[A_LOBIT:%.*]] = ashr i8 [[A]], 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[A]], 1 +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[A_LOBIT]], i8 1 +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %a.lobit = ashr i8 %a, 4 + %cmp = icmp slt i8 %a, 1 + %retval = select i1 %cmp, i8 %a.lobit, i8 1 + ret i8 %retval +}