diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h index eb0440f500735..0622bfae2c845 100644 --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -810,6 +810,26 @@ class MinMaxIntrinsic : public IntrinsicInst { /// Whether the intrinsic is signed or unsigned. bool isSigned() const { return isSigned(getIntrinsicID()); }; + /// Whether the intrinsic is a smin or umin. + static bool isMin(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::umin: + case Intrinsic::smin: + return true; + case Intrinsic::umax: + case Intrinsic::smax: + return false; + default: + llvm_unreachable("Invalid intrinsic"); + } + } + + /// Whether the intrinsic is a smin or a umin. + bool isMin() const { return isMin(getIntrinsicID()); } + + /// Whether the intrinsic is a smax or a umax. + bool isMax() const { return !isMin(getIntrinsicID()); } + /// Min/max intrinsics are monotonic, they operate on a fixed-bitwidth values, /// so there is a certain threshold value, upon reaching which, /// their value can no longer change. Return said threshold. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index e4cb457499ef5..74f035e401bde 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5780,6 +5780,45 @@ Instruction *InstCombinerImpl::foldICmpWithMinMax(Instruction &I, return nullptr; } +/// Match and fold patterns like: +/// icmp eq/ne X, min(max(X, Lo), Hi) +/// which represents a range check and can be repsented as a ConstantRange. +/// +/// For icmp eq, build ConstantRange [Lo, Hi + 1) and convert to: +/// (X - Lo) u< (Hi + 1 - Lo) +/// For icmp ne, build ConstantRange [Hi + 1, Lo) and convert to: +/// (X - (Hi + 1)) u< (Lo - (Hi + 1)) +Instruction *InstCombinerImpl::foldICmpWithClamp(ICmpInst &I, Value *X, + MinMaxIntrinsic *Min) { + if (!I.isEquality() || !Min->hasOneUse() || !Min->isMin()) + return nullptr; + + const APInt *Lo = nullptr, *Hi = nullptr; + if (Min->isSigned()) { + if (!match(Min->getLHS(), m_OneUse(m_SMax(m_Specific(X), m_APInt(Lo)))) || + !match(Min->getRHS(), m_APInt(Hi)) || !Lo->slt(*Hi)) + return nullptr; + } else { + if (!match(Min->getLHS(), m_OneUse(m_UMax(m_Specific(X), m_APInt(Lo)))) || + !match(Min->getRHS(), m_APInt(Hi)) || !Lo->ult(*Hi)) + return nullptr; + } + + ConstantRange CR = ConstantRange::getNonEmpty(*Lo, *Hi + 1); + ICmpInst::Predicate Pred; + APInt C, Offset; + if (I.getPredicate() == ICmpInst::ICMP_EQ) + CR.getEquivalentICmp(Pred, C, Offset); + else + CR.inverse().getEquivalentICmp(Pred, C, Offset); + + if (!Offset.isZero()) + X = Builder.CreateAdd(X, ConstantInt::get(X->getType(), Offset)); + + return replaceInstUsesWith( + I, Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), C))); +} + // Canonicalize checking for a power-of-2-or-zero value: static Instruction *foldICmpPow2Test(ICmpInst &I, InstCombiner::BuilderTy &Builder) { @@ -7467,10 +7506,14 @@ Instruction *InstCombinerImpl::foldICmpCommutative(CmpPredicate Pred, if (Instruction *NI = foldSelectICmp(Pred, SI, Op1, CxtI)) return NI; - if (auto *MinMax = dyn_cast(Op0)) + if (auto *MinMax = dyn_cast(Op0)) { if (Instruction *Res = foldICmpWithMinMax(CxtI, MinMax, Op1, Pred)) return Res; + if (Instruction *Res = foldICmpWithClamp(CxtI, Op1, MinMax)) + return Res; + } + { Value *X; const APInt *C; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 4f94aa2d38541..e01c145bf5de3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -725,6 +725,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpWithMinMax(Instruction &I, MinMaxIntrinsic *MinMax, Value *Z, CmpPredicate Pred); + Instruction *foldICmpWithClamp(ICmpInst &Cmp, Value *X, MinMaxIntrinsic *Min); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); Instruction *foldSignBitTest(ICmpInst &I); diff --git a/llvm/test/Transforms/InstCombine/icmp-clamp.ll b/llvm/test/Transforms/InstCombine/icmp-clamp.ll new file mode 100644 index 0000000000000..4866dbffb567a --- /dev/null +++ b/llvm/test/Transforms/InstCombine/icmp-clamp.ll @@ -0,0 +1,295 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +declare void @use(i32) + +define i1 @test_i32_eq(i32 %x) { +; CHECK-LABEL: define i1 @test_i32_eq( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], 95 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], 256 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +define i1 @test_i32_ne(i32 %x) { +; CHECK-LABEL: define i1 @test_i32_ne( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], -161 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], -256 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp ne i32 %v2, %x + ret i1 %cmp +} + +define i1 @test_i32_eq_no_add(i32 %x) { +; CHECK-LABEL: define i1 @test_i32_eq_no_add( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X]], 161 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 0) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +define i1 @test_i32_ne_no_add(i32 %x) { +; CHECK-LABEL: define i1 @test_i32_ne_no_add( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X]], 160 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 0) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp ne i32 %v2, %x + ret i1 %cmp +} + +define i1 @test_unsigned_eq(i32 %x) { +; CHECK-LABEL: define i1 @test_unsigned_eq( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], -10 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], 91 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.umax.i32(i32 %x, i32 10) + %v2 = tail call i32 @llvm.umin.i32(i32 %v1, i32 100) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +define i1 @test_unsigned_ne(i32 %x) { +; CHECK-LABEL: define i1 @test_unsigned_ne( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], -101 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], -91 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.umax.i32(i32 %x, i32 10) + %v2 = tail call i32 @llvm.umin.i32(i32 %v1, i32 100) + %cmp = icmp ne i32 %v2, %x + ret i1 %cmp +} + + +; Different bit widths +define i1 @test_i8_eq(i8 %x) { +; CHECK-LABEL: define i1 @test_i8_eq( +; CHECK-SAME: i8 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X]], 50 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP1]], 101 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i8 @llvm.smax.i8(i8 %x, i8 -50) + %v2 = tail call i8 @llvm.smin.i8(i8 %v1, i8 50) + %cmp = icmp eq i8 %v2, %x + ret i1 %cmp +} + +define i1 @test_i16_eq(i16 %x) { +; CHECK-LABEL: define i1 @test_i16_eq( +; CHECK-SAME: i16 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[X]], 1000 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[TMP1]], 2001 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i16 @llvm.smax.i16(i16 %x, i16 -1000) + %v2 = tail call i16 @llvm.smin.i16(i16 %v1, i16 1000) + %cmp = icmp eq i16 %v2, %x + ret i1 %cmp +} + +define i1 @test_i64_eq(i64 %x) { +; CHECK-LABEL: define i1 @test_i64_eq( +; CHECK-SAME: i64 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[X]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i64 [[TMP1]], -1 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i64 @llvm.smax.i64(i64 %x, i64 -1) + %v2 = tail call i64 @llvm.smin.i64(i64 %v1, i64 9223372036854775806) + %cmp = icmp eq i64 %v2, %x + ret i1 %cmp +} + +; Negative tests - wrong predicate +define i1 @test_wrong_pred_slt(i32 %x) { +; CHECK-LABEL: define i1 @test_wrong_pred_slt( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X]], 160 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp slt i32 %v2, %x + ret i1 %cmp +} + + +; Negative tests - not a clamp pattern +define i1 @test_not_clamp_pattern(i32 %x, i32 %y) { +; CHECK-LABEL: define i1 @test_not_clamp_pattern( +; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) { +; CHECK-NEXT: [[V1:%.*]] = tail call i32 @llvm.smax.i32(i32 [[Y]], i32 -95) +; CHECK-NEXT: [[V2:%.*]] = tail call i32 @llvm.smin.i32(i32 [[V1]], i32 160) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[V2]], [[X]] +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %y, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Negative tests - Lo >= Hi +define i1 @test_invalid_range(i32 %x) { +; CHECK-LABEL: define i1 @test_invalid_range( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X]], 50 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 100) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 50) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Negative tests - Lo is minimum signed value +define i1 @test_lo_min_signed(i32 %x) { +; CHECK-LABEL: define i1 @test_lo_min_signed( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[X]], 161 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -2147483648) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Negative tests - Hi is maximum signed value +define i1 @test_hi_max_signed(i32 %x) { +; CHECK-LABEL: define i1 @test_hi_max_signed( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X]], -96 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 2147483647) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Negative tests - Hi is maximum unsigned value +define i1 @test_hi_max_unsigned(i32 %x) { +; CHECK-LABEL: define i1 @test_hi_max_unsigned( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X]], 9 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.umax.i32(i32 %x, i32 10) + %v2 = tail call i32 @llvm.umin.i32(i32 %v1, i32 4294967295) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Multi-use tests - multiple uses of max +define i1 @test_multi_use_max(i32 %x) { +; CHECK-LABEL: define i1 @test_multi_use_max( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[V1:%.*]] = tail call i32 @llvm.smax.i32(i32 [[X]], i32 -95) +; CHECK-NEXT: call void @use(i32 [[V1]]) +; CHECK-NEXT: [[V2:%.*]] = tail call i32 @llvm.smin.i32(i32 [[V1]], i32 160) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[V2]], [[X]] +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + call void @use(i32 %v1) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Multi-use tests - multiple uses of min +define i1 @test_multi_use_min(i32 %x) { +; CHECK-LABEL: define i1 @test_multi_use_min( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[V1:%.*]] = tail call i32 @llvm.smax.i32(i32 [[X]], i32 -95) +; CHECK-NEXT: [[V2:%.*]] = tail call i32 @llvm.smin.i32(i32 [[V1]], i32 160) +; CHECK-NEXT: call void @use(i32 [[V2]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[V2]], [[X]] +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + call void @use(i32 %v2) + %cmp = icmp eq i32 %v2, %x + ret i1 %cmp +} + +; Commuted tests +define i1 @test_commuted_eq(i32 %x) { +; CHECK-LABEL: define i1 @test_commuted_eq( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], 95 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[TMP1]], 256 +; CHECK-NEXT: ret i1 [[CMP]] +; + %v1 = tail call i32 @llvm.smax.i32(i32 %x, i32 -95) + %v2 = tail call i32 @llvm.smin.i32(i32 %v1, i32 160) + %cmp = icmp eq i32 %x, %v2 + ret i1 %cmp +} + + +; Vector tests - splat constants +define <2 x i1> @test_vec_splat_eq(<2 x i32> %x) { +; CHECK-LABEL: define <2 x i1> @test_vec_splat_eq( +; CHECK-SAME: <2 x i32> [[X:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[X]], splat (i32 50) +; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i32> [[TMP1]], splat (i32 101) +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %v1 = tail call <2 x i32> @llvm.smax.v2i32(<2 x i32> %x, <2 x i32> ) + %v2 = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %v1, <2 x i32> ) + %cmp = icmp eq <2 x i32> %v2, %x + ret <2 x i1> %cmp +} + +; Vector tests - poison elements +define <2 x i1> @test_vec_poison_eq(<2 x i32> %x) { +; CHECK-LABEL: define <2 x i1> @test_vec_poison_eq( +; CHECK-SAME: <2 x i32> [[X:%.*]]) { +; CHECK-NEXT: [[V1:%.*]] = tail call <2 x i32> @llvm.smax.v2i32(<2 x i32> [[X]], <2 x i32> ) +; CHECK-NEXT: [[V2:%.*]] = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[V1]], <2 x i32> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[V2]], [[X]] +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %v1 = tail call <2 x i32> @llvm.smax.v2i32(<2 x i32> %x, <2 x i32> ) + %v2 = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %v1, <2 x i32> ) + %cmp = icmp eq <2 x i32> %v2, %x + ret <2 x i1> %cmp +} + +; Vector tests - non-splat +define <2 x i1> @test_vec_non_splat_eq(<2 x i32> %x) { +; CHECK-LABEL: define <2 x i1> @test_vec_non_splat_eq( +; CHECK-SAME: <2 x i32> [[X:%.*]]) { +; CHECK-NEXT: [[V1:%.*]] = tail call <2 x i32> @llvm.smax.v2i32(<2 x i32> [[X]], <2 x i32> ) +; CHECK-NEXT: [[V2:%.*]] = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[V1]], <2 x i32> ) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[V2]], [[X]] +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %v1 = tail call <2 x i32> @llvm.smax.v2i32(<2 x i32> %x, <2 x i32> ) + %v2 = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %v1, <2 x i32> ) + %cmp = icmp eq <2 x i32> %v2, %x + ret <2 x i1> %cmp +}