-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[InstCombine] Optimize icmp(sub(a, c), sub(b, c)) to icmp(a, b) if a, b, and c are pointers #161698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… b, and c are pointers Has to be specially handled due to C langauage semantics for pointers. ``` If an array is so large (greater than PTRDIFF_MAX elements, but less than SIZE_MAX bytes), that the difference between two pointers may not be representable as std::ptrdiff_t, the result of subtracting two such pointers is undefined. ``` Due to this definition, pointer substraction can be treated as NSW in this scenario. However, the comparison must be made to be unsigned because pointer comparisons are unsigned, even though the result of pointer substraction is signed. This change allows for the following optimization: ```c int8_t* foo(int8_t* a, int8_t* b, int8_t* c) { a = a + 1; if ((a - b) > (c - b)) a = c; return a; } ``` Origionally lowers to: ``` %4 = getelementptr inbounds nuw i8, ptr %0, i64 1 %5 = ptrtoint ptr %4 to i64 %6 = ptrtoint ptr %1 to i64 %7 = sub i64 %5, %6 %8 = ptrtoint ptr %2 to i64 %9 = sub i64 %8, %6 %10 = icmp sgt i64 %7, %9 %11 = select i1 %10, ptr %2, ptr %4 ret ptr %11 ``` Now lowers to: ``` %add.ptr = getelementptr inbounds nuw i8, ptr %a, i64 1 %cmp = icmp ugt ptr %add.ptr, %c %spec.select = select i1 %cmp, ptr %c, ptr %add.ptr ret ptr %spec.select ```
@llvm/pr-subscribers-llvm-transforms Author: Ryan Buchner (bababuck) Changesicmp (A-B), (C-B) -> icmp A, C for comparisons of pointer subtraction, that is, if A, B and C are all ptrs converted to integers. A couple of things I need a little help understanding:
This case has to be specially handled due to C language semantics for pointers.
Due to this definition, pointer subtraction can be treated as NSW in this scenario. However, the comparison must be made to be unsigned because pointer comparisons are unsigned, even though the result of pointer subtraction is signed. This change allows for the following optimization (Seen in int8_t* foo(int8_t* a, int8_t* b, int8_t* c) {
a = a + 1;
if ((a - b) > (c - b))
a = c;
return a;
} Origionally lowers to:
Now lowers to:
Full diff: https://github.com/llvm/llvm-project/pull/161698.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e4cb457499ef5..9d53acbe6d46a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5416,6 +5416,45 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, A, C);
+ // icmp (A-B), (C-B) -> icmp A, C for comparisons of pointer subtraction
+ // that is, if A, B and C are all ptrs converted to integers.
+ //
+ // Tricky case because pointers are effectively unsigned integers, but the
+ // result of their subtraction is signed. Also, these subtractions ought to
+ // have NSW semantics, except we cannot give that to them because the results
+ // are considered signed, but if we optimize away the subtraction, the
+ // underlying pointers need to be treated as unsigned, thus special handling
+ // is required. In this scenario, we must ensure that the comparison is
+ // unsigned after removing the subtraction operations.
+ if (B && D && B == D && isa<PtrToIntOperator>(A) &&
+ isa<PtrToIntOperator>(B) && isa<PtrToIntOperator>(C)) {
+ CmpInst::Predicate UnsignedPred;
+ switch (Pred) {
+ default:
+ // If already unsigned, explicit cast from ptr to unsigned,
+ // so cannot optimize
+ UnsignedPred = CmpInst::BAD_ICMP_PREDICATE;
+ break;
+ case ICmpInst::ICMP_SGT:
+ UnsignedPred = ICmpInst::ICMP_UGT;
+ break;
+ case ICmpInst::ICMP_SLT:
+ UnsignedPred = ICmpInst::ICMP_ULT;
+ break;
+ case ICmpInst::ICMP_SGE:
+ UnsignedPred = ICmpInst::ICMP_UGE;
+ break;
+ case ICmpInst::ICMP_SLE:
+ UnsignedPred = ICmpInst::ICMP_ULE;
+ break;
+ }
+ if (UnsignedPred != CmpInst::BAD_ICMP_PREDICATE) {
+ PtrToIntOperator *AOp = dyn_cast<PtrToIntOperator>(A);
+ PtrToIntOperator *COp = dyn_cast<PtrToIntOperator>(C);
+ return new ICmpInst(UnsignedPred, AOp->getOperand(0), COp->getOperand(0));
+ }
+ }
+
// icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow.
if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, D, B);
diff --git a/llvm/test/Transforms/InstCombine/icmp-ptrdiff.ll b/llvm/test/Transforms/InstCombine/icmp-ptrdiff.ll
new file mode 100644
index 0000000000000..2512135172196
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-ptrdiff.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define ptr @icmp_ptrdiff_gt(ptr %0, ptr %1, ptr %2) {
+; CHECK-LABEL: @icmp_ptrdiff_gt(
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
+; CHECK-NEXT: [[TMP10:%.*]] = icmp ugt ptr [[TMP4]], [[TMP2:%.*]]
+; CHECK-NEXT: [[TMP11:%.*]] = select i1 [[TMP10]], ptr [[TMP2]], ptr [[TMP4]]
+; CHECK-NEXT: ret ptr [[TMP11]]
+;
+ %4 = getelementptr inbounds nuw i8, ptr %0, i64 1
+ %5 = ptrtoint ptr %4 to i64
+ %6 = ptrtoint ptr %1 to i64
+ %7 = sub i64 %5, %6
+ %8 = ptrtoint ptr %2 to i64
+ %9 = sub i64 %8, %6
+ %10 = icmp sgt i64 %7, %9
+ %11 = select i1 %10, ptr %2, ptr %4
+ ret ptr %11
+}
+
+define ptr @icmp_ptrdiff_lt(ptr %0, ptr %1, ptr %2) {
+; CHECK-LABEL: @icmp_ptrdiff_lt(
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
+; CHECK-NEXT: [[TMP10:%.*]] = icmp ult ptr [[TMP4]], [[TMP2:%.*]]
+; CHECK-NEXT: [[TMP11:%.*]] = select i1 [[TMP10]], ptr [[TMP2]], ptr [[TMP4]]
+; CHECK-NEXT: ret ptr [[TMP11]]
+;
+ %4 = getelementptr inbounds nuw i8, ptr %0, i64 1
+ %5 = ptrtoint ptr %4 to i64
+ %6 = ptrtoint ptr %1 to i64
+ %7 = sub i64 %5, %6
+ %8 = ptrtoint ptr %2 to i64
+ %9 = sub i64 %8, %6
+ %10 = icmp slt i64 %7, %9
+ %11 = select i1 %10, ptr %2, ptr %4
+ ret ptr %11
+}
+
+define ptr @icmp_ptrdiff_ge(ptr %0, ptr %1, ptr %2) {
+; CHECK-LABEL: @icmp_ptrdiff_ge(
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
+; CHECK-NEXT: [[DOTNOT:%.*]] = icmp ult ptr [[TMP4]], [[TMP2:%.*]]
+; CHECK-NEXT: [[TMP10:%.*]] = select i1 [[DOTNOT]], ptr [[TMP4]], ptr [[TMP2]]
+; CHECK-NEXT: ret ptr [[TMP10]]
+;
+ %4 = getelementptr inbounds nuw i8, ptr %0, i64 1
+ %5 = ptrtoint ptr %4 to i64
+ %6 = ptrtoint ptr %1 to i64
+ %7 = sub i64 %5, %6
+ %8 = ptrtoint ptr %2 to i64
+ %9 = sub i64 %8, %6
+ %10 = icmp sge i64 %7, %9
+ %11 = select i1 %10, ptr %2, ptr %4
+ ret ptr %11
+}
+
+define ptr @icmp_ptrdiff_le(ptr %0, ptr %1, ptr %2) {
+; CHECK-LABEL: @icmp_ptrdiff_le(
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
+; CHECK-NEXT: [[DOTNOT:%.*]] = icmp ugt ptr [[TMP4]], [[TMP2:%.*]]
+; CHECK-NEXT: [[TMP10:%.*]] = select i1 [[DOTNOT]], ptr [[TMP4]], ptr [[TMP2]]
+; CHECK-NEXT: ret ptr [[TMP10]]
+;
+ %4 = getelementptr inbounds nuw i8, ptr %0, i64 1
+ %5 = ptrtoint ptr %4 to i64
+ %6 = ptrtoint ptr %1 to i64
+ %7 = sub i64 %5, %6
+ %8 = ptrtoint ptr %2 to i64
+ %9 = sub i64 %8, %6
+ %10 = icmp sle i64 %7, %9
+ %11 = select i1 %10, ptr %2, ptr %4
+ ret ptr %11
+}
+
+define ptr @icmp_ptrdiff_eq(ptr %0, ptr %1, ptr %2) {
+; CHECK-LABEL: @icmp_ptrdiff_eq(
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
+; CHECK-NEXT: ret ptr [[TMP4]]
+;
+ %4 = getelementptr inbounds nuw i8, ptr %0, i64 1
+ %5 = ptrtoint ptr %4 to i64
+ %6 = ptrtoint ptr %1 to i64
+ %7 = sub i64 %5, %6
+ %8 = ptrtoint ptr %2 to i64
+ %9 = sub i64 %8, %6
+ %10 = icmp eq i64 %7, %9
+ %11 = select i1 %10, ptr %2, ptr %4
+ ret ptr %11
+}
+
+define ptr @icmp_ptrdiff_ne(ptr %0, ptr %1, ptr %2) {
+; CHECK-LABEL: @icmp_ptrdiff_ne(
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
+; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq ptr [[TMP4]], [[TMP2:%.*]]
+; CHECK-NEXT: [[TMP5:%.*]] = select i1 [[DOTNOT]], ptr [[TMP4]], ptr [[TMP2]]
+; CHECK-NEXT: ret ptr [[TMP5]]
+;
+ %4 = getelementptr inbounds nuw i8, ptr %0, i64 1
+ %5 = ptrtoint ptr %4 to i64
+ %6 = ptrtoint ptr %1 to i64
+ %7 = sub i64 %5, %6
+ %8 = ptrtoint ptr %2 to i64
+ %9 = sub i64 %8, %6
+ %10 = icmp ne i64 %7, %9
+ %11 = select i1 %10, ptr %2, ptr %4
+ ret ptr %11
+}
+
|
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py | ||
; RUN: opt < %s -passes=instcombine -S | FileCheck %s | ||
|
||
define ptr @icmp_ptrdiff_gt(ptr %0, ptr %1, ptr %2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it doesn't work in LLVM: https://alive2.llvm.org/ce/z/AYPSer
Can we set nsw for pointer difference here (we use sdiv to compute the number of elements, so it is meaningless if the sub signed wraps)?
See also
llvm-project/clang/lib/CodeGen/CGExprScalar.cpp
Lines 4695 to 4744 in d68f0c2
// Otherwise, this is a pointer subtraction. | |
// Do the raw subtraction part. | |
llvm::Value *LHS | |
= Builder.CreatePtrToInt(op.LHS, CGF.PtrDiffTy, "sub.ptr.lhs.cast"); | |
llvm::Value *RHS | |
= Builder.CreatePtrToInt(op.RHS, CGF.PtrDiffTy, "sub.ptr.rhs.cast"); | |
Value *diffInChars = Builder.CreateSub(LHS, RHS, "sub.ptr.sub"); | |
// Okay, figure out the element size. | |
const BinaryOperator *expr = cast<BinaryOperator>(op.E); | |
QualType elementType = expr->getLHS()->getType()->getPointeeType(); | |
llvm::Value *divisor = nullptr; | |
// For a variable-length array, this is going to be non-constant. | |
if (const VariableArrayType *vla | |
= CGF.getContext().getAsVariableArrayType(elementType)) { | |
auto VlaSize = CGF.getVLASize(vla); | |
elementType = VlaSize.Type; | |
divisor = VlaSize.NumElts; | |
// Scale the number of non-VLA elements by the non-VLA element size. | |
CharUnits eltSize = CGF.getContext().getTypeSizeInChars(elementType); | |
if (!eltSize.isOne()) | |
divisor = CGF.Builder.CreateNUWMul(CGF.CGM.getSize(eltSize), divisor); | |
// For everything elese, we can just compute it, safe in the | |
// assumption that Sema won't let anything through that we can't | |
// safely compute the size of. | |
} else { | |
CharUnits elementSize; | |
// Handle GCC extension for pointer arithmetic on void* and | |
// function pointer types. | |
if (elementType->isVoidType() || elementType->isFunctionType()) | |
elementSize = CharUnits::One(); | |
else | |
elementSize = CGF.getContext().getTypeSizeInChars(elementType); | |
// Don't even emit the divide for element size of 1. | |
if (elementSize.isOne()) | |
return diffInChars; | |
divisor = CGF.CGM.getSize(elementSize); | |
} | |
// Otherwise, do a full sdiv. This uses the "exact" form of sdiv, since | |
// pointer difference in C is only defined in the case where both operands | |
// are pointing to elements of an array. | |
return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably okay, based strictly on the standard. But every time we touch this sort of thing, it breaks someone, somehow...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using NSW runs into the issue that it the optimization would occur without adjusting the signed comparison into an unsigned comparison, and the pointer comparison won't behave correctly as a signed comparison. (I tried that approach initially).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it doesn't work in LLVM: https://alive2.llvm.org/ce/z/AYPSer
I believe the case listed here is UB since pointer subtraction is only defined for two pointers to the same object.
@keinflue dug up this in draft N3220 adjacent to C23 says in §6.5.7:
When two pointers are subtracted, both shall point to elements of the same array object, or one past
the last element of the array object; the result is the difference of the subscripts of the two array
elements.
@efriedma-quic is this something that you would rather not have merged in? My understanding is that this doesn't hit often (only hits once in Spec2017 (502.GCC)), so if you think this will be a headache down the line it might not be worth merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the pointers have to point to the same object, and the size of an object is limited. (From the C standard: "If the result is not representable in an object of that type, the behavior is undefined." From LangRef: "the size of all allocated objects must be non-negative and not exceed the largest signed integer that fits into the index type.")
That said, because people with security-sensitive codebases are very sensitive to undefined behavior, we're trying to be conservative with new forms of undefined behavior: there should be a sanitizer to detect the behavior, and there should be a flag to disable it. Which we can do here, I guess, but I'm not sure it's worth the trouble.
Each use of undef can observe a different value. See https://llvm.org/docs/UndefinedBehavior.html#undef-values |
Posted an updated proof. |
icmp (A-B), (C-B) -> icmp A, C for comparisons of pointer subtraction, that is, if A, B and C are all ptrs converted to integers.
A couple of things I need a little help understanding:
ptrtoint
vsptrtoaddr
on this optimization. See [clang] pointer subtraction should use ptrtoaddr rather than ptrtoint #159419.Alive proof has some issues withundef
that I don't quite understand, https://alive2.llvm.org/ce/z/AdZoWJParticularly%p2 = ptrtoint ptr %2 to i2
is solved different, inSrc
,i2 %p2 = #x3 (3, -1)
, while inTgt
:i2 %p2 = #x0 (0)
Proof: https://alive2.llvm.org/ce/z/JxrJHj
This case has to be specially handled due to C language semantics for pointers.
Due to this definition, pointer subtraction can be treated as NSW in this scenario. However, the comparison must be made to be unsigned because pointer comparisons are unsigned, even though the result of pointer subtraction is signed.
This change allows for the following optimization (Seen in
502.GCC
from Spec2017. GCC makes this optimization, but LLVM previously failed to.)Origionally lowers to:
Now lowers to: