diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ac1f2d5b37745d..d1bac8a089388c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -950,6 +950,47 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Try to match patterns with select and subtract as absolute difference. +static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + auto *TI = dyn_cast(TVal); + auto *FI = dyn_cast(FVal); + if (!TI || !FI) + return nullptr; + + // Normalize predicate to gt/lt rather than ge/le. + ICmpInst::Predicate Pred = Cmp->getStrictPredicate(); + Value *A = Cmp->getOperand(0); + Value *B = Cmp->getOperand(1); + + // Normalize "A - B" as the true value of the select. + if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) { + std::swap(FI, TI); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + // With any pair of no-wrap subtracts: + // (A > B) ? (A - B) : (B - A) --> abs(A - B) + if (Pred == CmpInst::ICMP_SGT && + match(TI, m_Sub(m_Specific(A), m_Specific(B))) && + match(FI, m_Sub(m_Specific(B), m_Specific(A))) && + (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) && + (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) { + // The remaining subtract is not "nuw" any more. + // If there's one use of the subtract (no other use than the use we are + // about to replace), then we know that the sub is "nsw" in this context + // even if it was only "nuw" before. If there's another use, then we can't + // add "nsw" to the existing instruction because it may not be safe in the + // other user's context. + TI->setHasNoUnsignedWrap(false); + if (!TI->hasNoSignedWrap()) + TI->setHasNoSignedWrap(TI->hasOneUse()); + return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue()); + } + + return nullptr; +} + /// Fold the following code sequence: /// \code /// int a = ctlz(x & -x); @@ -1790,6 +1831,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } diff --git a/llvm/test/Transforms/InstCombine/abs-1.ll b/llvm/test/Transforms/InstCombine/abs-1.ll index b0a1044902bc47..7355c560c820b2 100644 --- a/llvm/test/Transforms/InstCombine/abs-1.ll +++ b/llvm/test/Transforms/InstCombine/abs-1.ll @@ -679,6 +679,8 @@ define i8 @nabs_extra_use_icmp_sub(i8 %x) { ret i8 %s } +; TODO: negate-of-abs-diff + define i32 @nabs_diff_signed_slt(i32 %a, i32 %b) { ; CHECK-LABEL: @nabs_diff_signed_slt( ; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]] @@ -694,6 +696,8 @@ define i32 @nabs_diff_signed_slt(i32 %a, i32 %b) { ret i32 %cond } +; TODO: negate-of-abs-diff + define <2 x i8> @nabs_diff_signed_sle(<2 x i8> %a, <2 x i8> %b) { ; CHECK-LABEL: @nabs_diff_signed_sle( ; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt <2 x i8> [[A:%.*]], [[B:%.*]] @@ -711,11 +715,9 @@ define <2 x i8> @nabs_diff_signed_sle(<2 x i8> %a, <2 x i8> %b) { define i8 @abs_diff_signed_sgt(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sgt( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]]) -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sgt i8 %a, %b @@ -728,12 +730,11 @@ define i8 @abs_diff_signed_sgt(i8 %a, i8 %b) { define i8 @abs_diff_signed_sge(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sge( -; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B]], [[A]] +; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]]) ; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]]) -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP_NOT]], i8 [[SUB_BA]], i8 [[SUB_AB]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sge i8 %a, %b @@ -745,6 +746,8 @@ define i8 @abs_diff_signed_sge(i8 %a, i8 %b) { ret i8 %cond } +; negative test - need nsw + define i32 @abs_diff_signed_slt_no_nsw(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_diff_signed_slt_no_nsw( ; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]] @@ -760,12 +763,12 @@ define i32 @abs_diff_signed_slt_no_nsw(i32 %a, i32 %b) { ret i32 %cond } +; bonus nuw - it's fine to match the pattern, but nuw can't propagate + define i8 @abs_diff_signed_sgt_nsw_nuw(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sgt_nsw_nuw( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw nsw i8 [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw nsw i8 [[A]], [[B]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sgt i8 %a, %b @@ -775,12 +778,12 @@ define i8 @abs_diff_signed_sgt_nsw_nuw(i8 %a, i8 %b) { ret i8 %cond } +; this is absolute diff, but nuw can't propagate and nsw can be set. + define i8 @abs_diff_signed_sgt_nuw(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sgt_nuw( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sgt i8 %a, %b @@ -790,13 +793,14 @@ define i8 @abs_diff_signed_sgt_nuw(i8 %a, i8 %b) { ret i8 %cond } +; same as above + define i8 @abs_diff_signed_sgt_nuw_extra_use1(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sgt_nuw_extra_use1( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]] +; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]]) -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sgt i8 %a, %b @@ -807,13 +811,13 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use1(i8 %a, i8 %b) { ret i8 %cond } +; nuw can't propagate, and the extra use prevents applying nsw + define i8 @abs_diff_signed_sgt_nuw_extra_use2(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sgt_nuw_extra_use2( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]]) -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sgt i8 %a, %b @@ -824,14 +828,15 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use2(i8 %a, i8 %b) { ret i8 %cond } +; same as above + define i8 @abs_diff_signed_sgt_nuw_extra_use3(i8 %a, i8 %b) { ; CHECK-LABEL: @abs_diff_signed_sgt_nuw_extra_use3( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]] +; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]]) -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub i8 [[A]], [[B]] ; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]]) -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]] +; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i8 [[COND]] ; %cmp = icmp sgt i8 %a, %b @@ -843,6 +848,8 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use3(i8 %a, i8 %b) { ret i8 %cond } +; negative test - wrong predicate + define i32 @abs_diff_signed_slt_swap_wrong_pred1(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_pred1( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]] @@ -858,6 +865,8 @@ define i32 @abs_diff_signed_slt_swap_wrong_pred1(i32 %a, i32 %b) { ret i32 %cond } +; negative test - wrong predicate + define i32 @abs_diff_signed_slt_swap_wrong_pred2(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_pred2( ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]] @@ -873,6 +882,8 @@ define i32 @abs_diff_signed_slt_swap_wrong_pred2(i32 %a, i32 %b) { ret i32 %cond } +; negative test - need common operands + define i32 @abs_diff_signed_slt_swap_wrong_op(i32 %a, i32 %b, i32 %z) { ; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_op( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]] @@ -890,10 +901,8 @@ define i32 @abs_diff_signed_slt_swap_wrong_op(i32 %a, i32 %b, i32 %z) { define i32 @abs_diff_signed_slt_swap(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_diff_signed_slt_swap( -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i32 [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A]], [[B]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[SUB_BA]], i32 [[SUB_AB]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB_AB]], i1 true) ; CHECK-NEXT: ret i32 [[COND]] ; %cmp = icmp slt i32 %a, %b @@ -905,10 +914,8 @@ define i32 @abs_diff_signed_slt_swap(i32 %a, i32 %b) { define <2 x i8> @abs_diff_signed_sle_swap(<2 x i8> %a, <2 x i8> %b) { ; CHECK-LABEL: @abs_diff_signed_sle_swap( -; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt <2 x i8> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw <2 x i8> [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw <2 x i8> [[A]], [[B]] -; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[CMP_NOT]], <2 x i8> [[SUB_AB]], <2 x i8> [[SUB_BA]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw <2 x i8> [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[COND:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB_AB]], i1 true) ; CHECK-NEXT: ret <2 x i8> [[COND]] ; %cmp = icmp sle <2 x i8> %a, %b @@ -918,6 +925,8 @@ define <2 x i8> @abs_diff_signed_sle_swap(<2 x i8> %a, <2 x i8> %b) { ret <2 x i8> %cond } +; TODO: negate-of-abs-diff + define i8 @nabs_diff_signed_sgt_swap(i8 %a, i8 %b) { ; CHECK-LABEL: @nabs_diff_signed_sgt_swap( ; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] @@ -935,6 +944,8 @@ define i8 @nabs_diff_signed_sgt_swap(i8 %a, i8 %b) { ret i8 %cond } +; TODO: negate-of-abs-diff, but too many uses? + define i8 @nabs_diff_signed_sge_swap(i8 %a, i8 %b) { ; CHECK-LABEL: @nabs_diff_signed_sge_swap( ; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]] @@ -954,6 +965,8 @@ define i8 @nabs_diff_signed_sge_swap(i8 %a, i8 %b) { ret i8 %cond } +; negative test - need nsw + define i32 @abs_diff_signed_slt_no_nsw_swap(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_diff_signed_slt_no_nsw_swap( ; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]