Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5416,6 +5416,45 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, A, C);

// icmp (A-B), (C-B) -> icmp A, C for comparisons of pointer subtraction
// that is, if A, B and C are all ptrs converted to integers.
//
// Tricky case because pointers are effectively unsigned integers, but the
// result of their subtraction is signed. Also, these subtractions ought to
// have NSW semantics, except we cannot give that to them because the results
// are considered signed, but if we optimize away the subtraction, the
// underlying pointers need to be treated as unsigned, thus special handling
// is required. In this scenario, we must ensure that the comparison is
// unsigned after removing the subtraction operations.
if (B && D && B == D && isa<PtrToIntOperator>(A) &&
isa<PtrToIntOperator>(B) && isa<PtrToIntOperator>(C)) {
CmpInst::Predicate UnsignedPred;
switch (Pred) {
default:
// If already unsigned, explicit cast from ptr to unsigned,
// so cannot optimize
UnsignedPred = CmpInst::BAD_ICMP_PREDICATE;
break;
case ICmpInst::ICMP_SGT:
UnsignedPred = ICmpInst::ICMP_UGT;
break;
case ICmpInst::ICMP_SLT:
UnsignedPred = ICmpInst::ICMP_ULT;
break;
case ICmpInst::ICMP_SGE:
UnsignedPred = ICmpInst::ICMP_UGE;
break;
case ICmpInst::ICMP_SLE:
UnsignedPred = ICmpInst::ICMP_ULE;
break;
}
if (UnsignedPred != CmpInst::BAD_ICMP_PREDICATE) {
PtrToIntOperator *AOp = dyn_cast<PtrToIntOperator>(A);
PtrToIntOperator *COp = dyn_cast<PtrToIntOperator>(C);
return new ICmpInst(UnsignedPred, AOp->getOperand(0), COp->getOperand(0));
}
}

// icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow.
if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, D, B);
Expand Down
109 changes: 109 additions & 0 deletions llvm/test/Transforms/InstCombine/icmp-ptrdiff.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s

define ptr @icmp_ptrdiff_gt(ptr %0, ptr %1, ptr %2) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it doesn't work in LLVM: https://alive2.llvm.org/ce/z/AYPSer

Can we set nsw for pointer difference here (we use sdiv to compute the number of elements, so it is meaningless if the sub signed wraps)?
See also

// Otherwise, this is a pointer subtraction.
// Do the raw subtraction part.
llvm::Value *LHS
= Builder.CreatePtrToInt(op.LHS, CGF.PtrDiffTy, "sub.ptr.lhs.cast");
llvm::Value *RHS
= Builder.CreatePtrToInt(op.RHS, CGF.PtrDiffTy, "sub.ptr.rhs.cast");
Value *diffInChars = Builder.CreateSub(LHS, RHS, "sub.ptr.sub");
// Okay, figure out the element size.
const BinaryOperator *expr = cast<BinaryOperator>(op.E);
QualType elementType = expr->getLHS()->getType()->getPointeeType();
llvm::Value *divisor = nullptr;
// For a variable-length array, this is going to be non-constant.
if (const VariableArrayType *vla
= CGF.getContext().getAsVariableArrayType(elementType)) {
auto VlaSize = CGF.getVLASize(vla);
elementType = VlaSize.Type;
divisor = VlaSize.NumElts;
// Scale the number of non-VLA elements by the non-VLA element size.
CharUnits eltSize = CGF.getContext().getTypeSizeInChars(elementType);
if (!eltSize.isOne())
divisor = CGF.Builder.CreateNUWMul(CGF.CGM.getSize(eltSize), divisor);
// For everything elese, we can just compute it, safe in the
// assumption that Sema won't let anything through that we can't
// safely compute the size of.
} else {
CharUnits elementSize;
// Handle GCC extension for pointer arithmetic on void* and
// function pointer types.
if (elementType->isVoidType() || elementType->isFunctionType())
elementSize = CharUnits::One();
else
elementSize = CGF.getContext().getTypeSizeInChars(elementType);
// Don't even emit the divide for element size of 1.
if (elementSize.isOne())
return diffInChars;
divisor = CGF.CGM.getSize(elementSize);
}
// Otherwise, do a full sdiv. This uses the "exact" form of sdiv, since
// pointer difference in C is only defined in the case where both operands
// are pointing to elements of an array.
return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably okay, based strictly on the standard. But every time we touch this sort of thing, it breaks someone, somehow...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using NSW runs into the issue that it the optimization would occur without adjusting the signed comparison into an unsigned comparison, and the pointer comparison won't behave correctly as a signed comparison. (I tried that approach initially).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it doesn't work in LLVM: https://alive2.llvm.org/ce/z/AYPSer

I believe the case listed here is UB since pointer subtraction is only defined for two pointers to the same object.

@keinflue dug up this in draft N3220 adjacent to C23 says in §6.5.7:

When two pointers are subtracted, both shall point to elements of the same array object, or one past
the last element of the array object; the result is the difference of the subscripts of the two array
elements.

@efriedma-quic is this something that you would rather not have merged in? My understanding is that this doesn't hit often (only hits once in Spec2017 (502.GCC)), so if you think this will be a headache down the line it might not be worth merging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the pointers have to point to the same object, and the size of an object is limited. (From the C standard: "If the result is not representable in an object of that type, the behavior is undefined." From LangRef: "the size of all allocated objects must be non-negative and not exceed the largest signed integer that fits into the index type.")

That said, because people with security-sensitive codebases are very sensitive to undefined behavior, we're trying to be conservative with new forms of undefined behavior: there should be a sanitizer to detect the behavior, and there should be a flag to disable it. Which we can do here, I guess, but I'm not sure it's worth the trouble.

; CHECK-LABEL: @icmp_ptrdiff_gt(
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
; CHECK-NEXT: [[TMP10:%.*]] = icmp ugt ptr [[TMP4]], [[TMP2:%.*]]
; CHECK-NEXT: [[TMP11:%.*]] = select i1 [[TMP10]], ptr [[TMP2]], ptr [[TMP4]]
; CHECK-NEXT: ret ptr [[TMP11]]
;
%4 = getelementptr inbounds nuw i8, ptr %0, i64 1
%5 = ptrtoint ptr %4 to i64
%6 = ptrtoint ptr %1 to i64
%7 = sub i64 %5, %6
%8 = ptrtoint ptr %2 to i64
%9 = sub i64 %8, %6
%10 = icmp sgt i64 %7, %9
%11 = select i1 %10, ptr %2, ptr %4
ret ptr %11
}

define ptr @icmp_ptrdiff_lt(ptr %0, ptr %1, ptr %2) {
; CHECK-LABEL: @icmp_ptrdiff_lt(
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
; CHECK-NEXT: [[TMP10:%.*]] = icmp ult ptr [[TMP4]], [[TMP2:%.*]]
; CHECK-NEXT: [[TMP11:%.*]] = select i1 [[TMP10]], ptr [[TMP2]], ptr [[TMP4]]
; CHECK-NEXT: ret ptr [[TMP11]]
;
%4 = getelementptr inbounds nuw i8, ptr %0, i64 1
%5 = ptrtoint ptr %4 to i64
%6 = ptrtoint ptr %1 to i64
%7 = sub i64 %5, %6
%8 = ptrtoint ptr %2 to i64
%9 = sub i64 %8, %6
%10 = icmp slt i64 %7, %9
%11 = select i1 %10, ptr %2, ptr %4
ret ptr %11
}

define ptr @icmp_ptrdiff_ge(ptr %0, ptr %1, ptr %2) {
; CHECK-LABEL: @icmp_ptrdiff_ge(
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp ult ptr [[TMP4]], [[TMP2:%.*]]
; CHECK-NEXT: [[TMP10:%.*]] = select i1 [[DOTNOT]], ptr [[TMP4]], ptr [[TMP2]]
; CHECK-NEXT: ret ptr [[TMP10]]
;
%4 = getelementptr inbounds nuw i8, ptr %0, i64 1
%5 = ptrtoint ptr %4 to i64
%6 = ptrtoint ptr %1 to i64
%7 = sub i64 %5, %6
%8 = ptrtoint ptr %2 to i64
%9 = sub i64 %8, %6
%10 = icmp sge i64 %7, %9
%11 = select i1 %10, ptr %2, ptr %4
ret ptr %11
}

define ptr @icmp_ptrdiff_le(ptr %0, ptr %1, ptr %2) {
; CHECK-LABEL: @icmp_ptrdiff_le(
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp ugt ptr [[TMP4]], [[TMP2:%.*]]
; CHECK-NEXT: [[TMP10:%.*]] = select i1 [[DOTNOT]], ptr [[TMP4]], ptr [[TMP2]]
; CHECK-NEXT: ret ptr [[TMP10]]
;
%4 = getelementptr inbounds nuw i8, ptr %0, i64 1
%5 = ptrtoint ptr %4 to i64
%6 = ptrtoint ptr %1 to i64
%7 = sub i64 %5, %6
%8 = ptrtoint ptr %2 to i64
%9 = sub i64 %8, %6
%10 = icmp sle i64 %7, %9
%11 = select i1 %10, ptr %2, ptr %4
ret ptr %11
}

define ptr @icmp_ptrdiff_eq(ptr %0, ptr %1, ptr %2) {
; CHECK-LABEL: @icmp_ptrdiff_eq(
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
; CHECK-NEXT: ret ptr [[TMP4]]
;
%4 = getelementptr inbounds nuw i8, ptr %0, i64 1
%5 = ptrtoint ptr %4 to i64
%6 = ptrtoint ptr %1 to i64
%7 = sub i64 %5, %6
%8 = ptrtoint ptr %2 to i64
%9 = sub i64 %8, %6
%10 = icmp eq i64 %7, %9
%11 = select i1 %10, ptr %2, ptr %4
ret ptr %11
}

define ptr @icmp_ptrdiff_ne(ptr %0, ptr %1, ptr %2) {
; CHECK-LABEL: @icmp_ptrdiff_ne(
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 1
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq ptr [[TMP4]], [[TMP2:%.*]]
; CHECK-NEXT: [[TMP5:%.*]] = select i1 [[DOTNOT]], ptr [[TMP4]], ptr [[TMP2]]
; CHECK-NEXT: ret ptr [[TMP5]]
;
%4 = getelementptr inbounds nuw i8, ptr %0, i64 1
%5 = ptrtoint ptr %4 to i64
%6 = ptrtoint ptr %1 to i64
%7 = sub i64 %5, %6
%8 = ptrtoint ptr %2 to i64
%9 = sub i64 %8, %6
%10 = icmp ne i64 %7, %9
%11 = select i1 %10, ptr %2, ptr %4
ret ptr %11
}