diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 9572f9d702e1b..6794e9c18e79c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1955,6 +1955,79 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI, return nullptr; } +// Transform +// +// select(icmp(eq, X, Y), Z, select(icmp(ult, X, Y), -1, 1)) +// -> +// select(icmp(eq, X, Y), Z, llvm.ucmp(freeze(X), freeze(Y))) +// +// or +// +// select(icmp(eq, X, Y), Z, select(icmp(slt, X, Y), -1, 1)) +// -> +// select(icmp(eq, X, Y), Z, llvm.scmp(freeze(X), freeze(Y))) +static Value *foldSelectToInstrincCmp(SelectInst &SI, const ICmpInst *ICI, + Value *TrueVal, Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + + if (Pred != ICmpInst::ICMP_EQ) + return nullptr; + + CmpPredicate IPred; + Value *X = ICI->getOperand(0); + Value *Y = ICI->getOperand(1); + if (match(FalseVal, m_Select(m_ICmp(IPred, m_Specific(X), m_Specific(Y)), + m_AllOnes(), m_One())) && + (IPred == ICmpInst::ICMP_ULT || IPred == ICmpInst::ICMP_SLT)) { + + // icmp(ult, ptr %X, ptr %Y) -> cannot be folded because + // there is no intrinsic for a pointer comparison. + if (!X->getType()->isIntegerTy() || !Y->getType()->isIntegerTy()) + return nullptr; + + Builder.SetInsertPoint(&SI); + auto IID = IPred == ICmpInst::ICMP_ULT ? Intrinsic::ucmp : Intrinsic::scmp; + + // Edge Case: if Z is the constant 0 then the select can be folded + // to just the instrinsic comparison. + if (match(TrueVal, m_Zero())) + return Builder.CreateIntrinsic(SI.getType(), IID, {X, Y}); + + Value *FrozenX = Builder.CreateFreeze(X, X->getName() + ".frz"); + Value *FrozenY = Builder.CreateFreeze(Y, Y->getName() + ".frz"); + Value *Cmp = + Builder.CreateIntrinsic(FalseVal->getType(), IID, {FrozenX, FrozenY}); + return Builder.CreateSelect(SI.getCondition(), TrueVal, Cmp, "select.ucmp"); + } + + return nullptr; +} + +// Transform +// select(icmp(eq, X, Y), 0, llvm.cmp(X, Y)) +// -> +// llvm.cmp(X, Y) +static Value *foldInstrincCmp(SelectInst &SI, const ICmpInst *ICI, + Value *TrueVal, Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + + if (Pred != ICmpInst::ICMP_EQ) + return nullptr; + + Value *X = ICI->getOperand(0); + Value *Y = ICI->getOperand(1); + + auto ucmp = m_Intrinsic(m_Specific(X), m_Specific(Y)); + auto scmp = m_Intrinsic(m_Specific(X), m_Specific(Y)); + if (match(SI.getTrueValue(), m_Zero()) && + (match(SI.getFalseValue(), ucmp) || match(SI.getFalseValue(), scmp))) + return SI.getFalseValue(); + + return nullptr; +} + /// Fold `X Pred C1 ? X BOp C2 : C1 BOp C2` to `min/max(X, C1) BOp C2`. /// This allows for better canonicalization. Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp, @@ -2186,6 +2259,12 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectWithConstOpToBinOp(ICI, TrueVal, FalseVal)) return replaceInstUsesWith(SI, V); + if (Value *V = foldSelectToInstrincCmp(SI, ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + + if (Value *V = foldInstrincCmp(SI, ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } diff --git a/llvm/test/Transforms/InstCombine/select-cmp.ll b/llvm/test/Transforms/InstCombine/select-cmp.ll index b1bd7a0ecc8ac..37dc84fe33111 100644 --- a/llvm/test/Transforms/InstCombine/select-cmp.ll +++ b/llvm/test/Transforms/InstCombine/select-cmp.ll @@ -808,5 +808,176 @@ define i1 @icmp_lt_slt(i1 %c, i32 %arg) { ret i1 %select } +define i16 @icmp_fold_to_llvm_ucmp_when_eq(i16 %x, i16 %y) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_when_eq( +; CHECK-NEXT: [[Y_FRZ:%.*]] = freeze i16 [[Y:%.*]] +; CHECK-NEXT: [[X_FRZ:%.*]] = freeze i16 [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i16 [[X_FRZ]], [[Y_FRZ]] +; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.ucmp.i16.i16(i16 [[X_FRZ]], i16 [[Y_FRZ]]) +; CHECK-NEXT: [[SELECT_UCMP:%.*]] = select i1 [[TMP1]], i16 42, i16 [[TMP2]] +; CHECK-NEXT: ret i16 [[SELECT_UCMP]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp ult i16 %x, %y + %5 = select i1 %4, i16 -1, i16 1 + %6 = select i1 %3, i16 42, i16 %5 + ret i16 %6 +} + +define i16 @icmp_fold_to_llvm_ucmp_when_ult_and_Z_zero(i16 %x, i16 %y) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_when_ult_and_Z_zero( +; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.ucmp.i16.i16(i16 [[X:%.*]], i16 [[Y:%.*]]) +; CHECK-NEXT: ret i16 [[TMP1]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp ult i16 %x, %y + %5 = select i1 %4, i16 -1, i16 1 + %6 = select i1 %3, i16 0, i16 %5 + ret i16 %6 +} + +define i16 @icmp_fold_to_llvm_ucmp_when_slt_and_Z_zero(i16 %x, i16 %y) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_when_slt_and_Z_zero( +; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.scmp.i16.i16(i16 [[X:%.*]], i16 [[Y:%.*]]) +; CHECK-NEXT: ret i16 [[TMP1]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp slt i16 %x, %y + %5 = select i1 %4, i16 -1, i16 1 + %6 = select i1 %3, i16 0, i16 %5 + ret i16 %6 +} + +define i16 @icmp_fold_to_llvm_ucmp_when_cmp_slt(i16 %x, i16 %y) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_when_cmp_slt( +; CHECK-NEXT: [[Y_FRZ:%.*]] = freeze i16 [[Y:%.*]] +; CHECK-NEXT: [[X_FRZ:%.*]] = freeze i16 [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i16 [[X_FRZ]], [[Y_FRZ]] +; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.scmp.i16.i16(i16 [[X_FRZ]], i16 [[Y_FRZ]]) +; CHECK-NEXT: [[SELECT_UCMP:%.*]] = select i1 [[TMP1]], i16 42, i16 [[TMP2]] +; CHECK-NEXT: ret i16 [[SELECT_UCMP]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp slt i16 %x, %y ; here "ult" changed to "slt" + %5 = select i1 %4, i16 -1, i16 1 + %6 = select i1 %3, i16 42, i16 %5 + ret i16 %6 +} + +define i16 @icmp_fold_to_llvm_ucmp_when_value(i16 %x, i16 %y, i16 %Z) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_when_value( +; CHECK-NEXT: [[Y_FRZ:%.*]] = freeze i16 [[Y:%.*]] +; CHECK-NEXT: [[X_FRZ:%.*]] = freeze i16 [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i16 [[X_FRZ]], [[Y_FRZ]] +; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.ucmp.i16.i16(i16 [[X_FRZ]], i16 [[Y_FRZ]]) +; CHECK-NEXT: [[SELECT_UCMP:%.*]] = select i1 [[TMP1]], i16 [[Z:%.*]], i16 [[TMP2]] +; CHECK-NEXT: ret i16 [[SELECT_UCMP]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp ult i16 %x, %y + %5 = select i1 %4, i16 -1, i16 1 + %6 = select i1 %3, i16 %Z, i16 %5 + ret i16 %6 +} + +define i16 @icmp_fold_to_llvm_ucmp_when_ne(i16 %x, i16 %y) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_when_ne( +; CHECK-NEXT: [[Y_FRZ:%.*]] = freeze i16 [[Y:%.*]] +; CHECK-NEXT: [[X_FRZ:%.*]] = freeze i16 [[X:%.*]] +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i16 [[X_FRZ]], [[Y_FRZ]] +; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.ucmp.i16.i16(i16 [[X_FRZ]], i16 [[Y_FRZ]]) +; CHECK-NEXT: [[SELECT_UCMP:%.*]] = select i1 [[DOTNOT]], i16 42, i16 [[TMP1]] +; CHECK-NEXT: ret i16 [[SELECT_UCMP]] +; + %3 = icmp ne i16 %x, %y + %4 = icmp ult i16 %x, %y + %5 = select i1 %4, i16 -1, i16 1 + %6 = select i1 %3, i16 %5, i16 42 + ret i16 %6 +} + +define i32 @icmp_fold_to_llvm_ucmp_mixed_types(i16 %0, i16 %1) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_mixed_types( +; CHECK-NEXT: [[DOTFRZ1:%.*]] = freeze i16 [[TMP1:%.*]] +; CHECK-NEXT: [[DOTFRZ:%.*]] = freeze i16 [[TMP0:%.*]] +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i16 [[DOTFRZ]], [[DOTFRZ1]] +; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.ucmp.i32.i16(i16 [[DOTFRZ]], i16 [[DOTFRZ1]]) +; CHECK-NEXT: [[SELECT_UCMP:%.*]] = select i1 [[DOTNOT]], i32 1, i32 [[TMP3]] +; CHECK-NEXT: ret i32 [[SELECT_UCMP]] +; + %.not = icmp eq i16 %0, %1 + %3 = icmp ult i16 %0, %1 + %4 = select i1 %3, i32 -1, i32 1 + %.1 = select i1 %.not, i32 1, i32 %4 + ret i32 %.1 +} + +define i16 @icmp_fold_to_llvm_ucmp_negative_test_invalid_constant_1(i16 %x, i16 %y, i16 %Z) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_negative_test_invalid_constant_1( +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i16 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i16 [[Z:%.*]], i16 1 +; CHECK-NEXT: ret i16 [[TMP2]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp ult i16 %x, %y + %5 = select i1 %4, i16 1, i16 1 ; invalid constant + %6 = select i1 %3, i16 %Z, i16 %5 + ret i16 %6 +} + +define i16 @icmp_fold_to_llvm_ucmp_negative_test_invalid_constant_2(i16 %x, i16 %y, i16 %Z) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_negative_test_invalid_constant_2( +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i16 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i16 [[Z:%.*]], i16 -1 +; CHECK-NEXT: ret i16 [[TMP2]] +; + %3 = icmp eq i16 %x, %y + %4 = icmp ult i16 %x, %y + %5 = select i1 %4, i16 -1, i16 -1 ; invalid constant + %6 = select i1 %3, i16 %Z, i16 %5 + ret i16 %6 +} + +define i8 @icmp_fold_to_llvm_ucmp_negative_test_ptr(ptr %0, ptr %1) { +; CHECK-LABEL: @icmp_fold_to_llvm_ucmp_negative_test_ptr( +; CHECK-NEXT: [[TMP3:%.*]] = load ptr, ptr [[TMP0:%.*]], align 8 +; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP1:%.*]], align 8 +; CHECK-NEXT: [[TMP5:%.*]] = icmp ult ptr [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[TMP5]], i8 -1, i8 1 +; CHECK-NEXT: [[TMP7:%.*]] = icmp eq ptr [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP8:%.*]] = select i1 [[TMP7]], i8 0, i8 [[TMP6]] +; CHECK-NEXT: ret i8 [[TMP8]] +; + %3 = load ptr, ptr %0, align 8 + %4 = load ptr, ptr %1, align 8 + %5 = icmp ult ptr %3, %4 + %6 = select i1 %5, i8 -1, i8 1 + %7 = icmp eq ptr %3, %4 + %8 = select i1 %7, i8 0, i8 %6 + ret i8 %8 +} + +define i32 @fold_ucmp(i32 %0, i32 %1) { +; CHECK-LABEL: @fold_ucmp( +; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.ucmp.i32.i32(i32 [[TMP0:%.*]], i32 [[TMP1:%.*]]) +; CHECK-NEXT: ret i32 [[TMP3]] +; + %3 = icmp eq i32 %0, %1 + %4 = tail call i32 @llvm.ucmp.i32.i32(i32 %0, i32 %1) + %5 = select i1 %3, i32 0, i32 %4 + ret i32 %5 +} + +define i32 @fold_scmp(i32 %0, i32 %1) { +; CHECK-LABEL: @fold_scmp( +; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.scmp.i32.i32(i32 [[TMP0:%.*]], i32 [[TMP1:%.*]]) +; CHECK-NEXT: ret i32 [[TMP3]] +; + %3 = icmp eq i32 %0, %1 + %4 = tail call i32 @llvm.scmp.i32.i32(i32 %0, i32 %1) + %5 = select i1 %3, i32 0, i32 %4 + ret i32 %5 +} + declare void @use(i1) declare void @use.i8(i8)