Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GVN] Restrict equality propagation for pointers #82458

Merged
merged 8 commits into from
Apr 24, 2024

Conversation

UsmanNadeem
Copy link
Contributor

Reviving https://reviews.llvm.org/D143129

This patch does the following:

Adds the following functions:

  • replaceDominatedUsesWithIf() that takes a callback.

  • canReplacePointersIfEqual(...) returns true if the underlying object
    is the same, and for null and const dereferencable pointer replacements.

  • canReplacePointersIfEqualInUse(...) returns true for the above as well
    as if the use is in icmp/ptrtoint or phi/selects feeding into them.

Updates GVN using the functions above so that the pointer replacements
are only made using the above API.

Change-Id: I4927ea452734458be028854ef0e5cbcd81955910

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Usman Nadeem (UsmanNadeem)

Changes

Reviving https://reviews.llvm.org/D143129

This patch does the following:

Adds the following functions:

  • replaceDominatedUsesWithIf() that takes a callback.

  • canReplacePointersIfEqual(...) returns true if the underlying object
    is the same, and for null and const dereferencable pointer replacements.

  • canReplacePointersIfEqualInUse(...) returns true for the above as well
    as if the use is in icmp/ptrtoint or phi/selects feeding into them.

Updates GVN using the functions above so that the pointer replacements
are only made using the above API.

Change-Id: I4927ea452734458be028854ef0e5cbcd81955910


Patch is 21.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82458.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/Loads.h (+8-7)
  • (modified) llvm/include/llvm/Transforms/Utils/Local.h (+12)
  • (modified) llvm/lib/Analysis/Loads.cpp (+61-17)
  • (modified) llvm/lib/Transforms/Scalar/GVN.cpp (+15-9)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+25-5)
  • (modified) llvm/test/Transforms/GVN/condprop.ll (+202-5)
  • (modified) llvm/unittests/Analysis/LoadsTest.cpp (+31-18)
diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h
index 0926093bba99de..e39a1d6cafad23 100644
--- a/llvm/include/llvm/Analysis/Loads.h
+++ b/llvm/include/llvm/Analysis/Loads.h
@@ -173,14 +173,15 @@ Value *findAvailablePtrLoadStore(const MemoryLocation &Loc, Type *AccessTy,
                                  unsigned MaxInstsToScan, BatchAAResults *AA,
                                  bool *IsLoadCSE, unsigned *NumScanedInst);
 
-/// Returns true if a pointer value \p A can be replace with another pointer
-/// value \B if they are deemed equal through some means (e.g. information from
+/// Returns true if a pointer value \p From can be replaced with another pointer
+/// value \To if they are deemed equal through some means (e.g. information from
 /// conditions).
-/// NOTE: the current implementations is incomplete and unsound. It does not
-/// reject all invalid cases yet, but will be made stricter in the future. In
-/// particular this means returning true means unknown if replacement is safe.
-bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
-                               Instruction *CtxI);
+/// NOTE: The current implementation allows replacement in Icmp and PtrToInt
+/// instructions, as well as when we are replacing with a null pointer.
+/// Additionally it also allows replacement of pointers when both pointers have
+/// the same underlying object.
+bool canReplacePointersIfEqual(const Value *From, const Value *To);
+bool canReplacePointersInUseIfEqual(const Use &U, const Value *To);
 }
 
 #endif
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index 2df3c9049c7d62..d0dbf1951fe8b6 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -432,6 +432,18 @@ unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
 /// the end of the given BasicBlock. Returns the number of replacements made.
 unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
                                   const BasicBlock *BB);
+/// Replace each use of 'From' with 'To' if that use is dominated by
+/// the given edge and the callback ShouldReplace returns true. Returns the
+/// number of replacements made.
+unsigned replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Edge,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
+/// Replace each use of 'From' with 'To' if that use is dominated by
+/// the end of the given BasicBlock and the callback ShouldReplace returns true.
+/// Returns the number of replacements made.
+unsigned replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
 
 /// Return true if this call calls a gc leaf function.
 ///
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 6bf0d2f56eb4eb..e7d72fbc7107d6 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -708,22 +708,66 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
   return Available;
 }
 
-bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
-                                     Instruction *CtxI) {
-  Type *Ty = A->getType();
-  assert(Ty == B->getType() && Ty->isPointerTy() &&
-         "values must have matching pointer types");
-
-  // NOTE: The checks in the function are incomplete and currently miss illegal
-  // cases! The current implementation is a starting point and the
-  // implementation should be made stricter over time.
-  if (auto *C = dyn_cast<Constant>(B)) {
-    // Do not allow replacing a pointer with a constant pointer, unless it is
-    // either null or at least one byte is dereferenceable.
-    APInt OneByte(DL.getPointerTypeSizeInBits(Ty), 1);
-    return C->isNullValue() ||
-           isDereferenceableAndAlignedPointer(B, Align(1), OneByte, DL, CtxI);
-  }
+// Returns true if a use is either in an ICmp/PtrToInt or a Phi/Select that only
+// feeds into them.
+static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
+  if (MaxLookup == 0)
+    return false;
+
+  const User *User = U.getUser();
+  if (isa<ICmpInst>(User))
+    return true;
+  if (isa<PtrToIntInst>(User))
+    return true;
+  if (isa<PHINode, SelectInst>(User) &&
+      all_of(User->uses(), [&](const Use &Use) {
+        return isPointerUseReplacable(Use, MaxLookup - 1);
+      }))
+    return true;
+
+  return false;
+}
+
+static const DataLayout &getDLFromVal(const Value *V) {
+  if (const Argument *A = dyn_cast<Argument>(V))
+    return A->getParent()->getParent()->getDataLayout();
+  if (const Instruction *I = dyn_cast<Instruction>(V))
+    return I->getModule()->getDataLayout();
+  if (const GlobalValue *GV = dyn_cast<GlobalValue>(V))
+    return GV->getParent()->getDataLayout();
+  llvm_unreachable("Unknown Value type");
+}
+
+// Returns true if `To` is a null pointer, constant dereferenceable pointer or
+// both pointers have the same underlying objects.
+static bool isPointerAlwaysReplacable(const Value *From, const Value *To) {
+  if (isa<ConstantPointerNull>(To))
+    return true;
+  if (isa<Constant>(To) &&
+      isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()),
+                               getDLFromVal(From)))
+    return true;
+  if (getUnderlyingObject(From) == getUnderlyingObject(To))
+    return true;
+  return false;
+}
+
+bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To) {
+  assert(U->getType() == To->getType() && "values must have matching types");
+  // Not a pointer, just return true.
+  if (!To->getType()->isPointerTy())
+    return true;
+
+  if (isPointerAlwaysReplacable(&*U, To))
+    return true;
+  return isPointerUseReplacable(U);
+}
+
+bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To) {
+  assert(From->getType() == To->getType() && "values must have matching types");
+  // Not a pointer, just return true.
+  if (!From->getType()->isPointerTy())
+    return true;
 
-  return true;
+  return isPointerAlwaysReplacable(From, To);
 }
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index dcb1ed334b6103..d4b16836d728fa 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/InstructionPrecedenceTracking.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
@@ -2443,7 +2444,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     // using the leader table is about compiling faster, not optimizing better).
     // The leader table only tracks basic blocks, not edges. Only add to if we
     // have the simple case where the edge dominates the end.
-    if (RootDominatesEnd && !isa<Instruction>(RHS))
+    if (RootDominatesEnd && !isa<Instruction>(RHS) &&
+        canReplacePointersIfEqual(LHS, RHS))
       addToLeaderTable(LVN, RHS, Root.getEnd());
 
     // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope.  As
@@ -2452,14 +2454,18 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     if (!LHS->hasOneUse()) {
       unsigned NumReplacements =
           DominatesByEdge
-              ? replaceDominatedUsesWith(LHS, RHS, *DT, Root)
-              : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart());
-
-      Changed |= NumReplacements > 0;
-      NumGVNEqProp += NumReplacements;
-      // Cached information for anything that uses LHS will be invalid.
-      if (MD)
-        MD->invalidateCachedPointerInfo(LHS);
+              ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root,
+                                           canReplacePointersInUseIfEqual)
+              : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(),
+                                           canReplacePointersInUseIfEqual);
+
+      if (NumReplacements > 0) {
+        Changed = true;
+        NumGVNEqProp += NumReplacements;
+        // Cached information for anything that uses LHS will be invalid.
+        if (MD)
+          MD->invalidateCachedPointerInfo(LHS);
+      }
     }
 
     // Now try to deduce additional equalities from this one. For example, if
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index e4aa25f7ac6ad3..dcea7bb42c54e9 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3389,15 +3389,18 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
 }
 
 template <typename RootType, typename DominatesFn>
-static unsigned replaceDominatedUsesWith(Value *From, Value *To,
-                                         const RootType &Root,
-                                         const DominatesFn &Dominates) {
+static unsigned replaceDominatedUsesWith(
+    Value *From, Value *To, const RootType &Root, const DominatesFn &Dominates,
+    std::optional<function_ref<bool(const Use &U, const Value *To)>>
+        ShouldReplace) {
   assert(From->getType() == To->getType());
 
   unsigned Count = 0;
   for (Use &U : llvm::make_early_inc_range(From->uses())) {
     if (!Dominates(Root, U))
       continue;
+    if (ShouldReplace.has_value() && !ShouldReplace.value()(U, To))
+      continue;
     LLVM_DEBUG(dbgs() << "Replace dominated use of '";
                From->printAsOperand(dbgs());
                dbgs() << "' with " << *To << " in " << *U.getUser() << "\n");
@@ -3428,7 +3431,7 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
   auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
     return DT.dominates(Root, U);
   };
-  return ::replaceDominatedUsesWith(From, To, Root, Dominates);
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates, std::nullopt);
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
@@ -3437,9 +3440,26 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
   auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
     return DT.dominates(BB, U);
   };
-  return ::replaceDominatedUsesWith(From, To, BB, Dominates);
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates, std::nullopt);
+}
+
+unsigned llvm::replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
+    return DT.dominates(Root, U);
+  };
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates, ShouldReplace);
 }
 
+unsigned llvm::replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
+    return DT.dominates(BB, U);
+  };
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates, ShouldReplace);
+}
 bool llvm::callsGCLeafFunction(const CallBase *Call,
                                const TargetLibraryInfo &TLI) {
   // Check if the function is specifically marked as a gc leaf function.
diff --git a/llvm/test/Transforms/GVN/condprop.ll b/llvm/test/Transforms/GVN/condprop.ll
index 6b1e4d10601099..2b202b1701150e 100644
--- a/llvm/test/Transforms/GVN/condprop.ll
+++ b/llvm/test/Transforms/GVN/condprop.ll
@@ -521,15 +521,16 @@ define i32 @test13(ptr %ptr1, ptr %ptr2) {
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i32, ptr [[PTR2:%.*]], i32 1
 ; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr i32, ptr [[PTR2]], i32 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[PTR1:%.*]], [[PTR2]]
-; CHECK-NEXT:    [[VAL2_PRE:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
 ; CHECK:       if:
+; CHECK-NEXT:    [[VAL1:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br label [[END]]
 ; CHECK:       end:
-; CHECK-NEXT:    [[PHI1:%.*]] = phi ptr [ [[PTR2]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ [[VAL2_PRE]], [[IF]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[PHI1:%.*]] = phi ptr [ [[PTR1]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ [[VAL1]], [[IF]] ], [ 0, [[ENTRY]] ]
 ; CHECK-NEXT:    store i32 0, ptr [[PHI1]], align 4
-; CHECK-NEXT:    [[RET:%.*]] = add i32 [[PHI2]], [[VAL2_PRE]]
+; CHECK-NEXT:    [[VAL2:%.*]] = load i32, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[RET:%.*]] = add i32 [[PHI2]], [[VAL2]]
 ; CHECK-NEXT:    ret i32 [[RET]]
 ;
 entry:
@@ -574,7 +575,7 @@ define void @test14(ptr %ptr1, ptr noalias %ptr2) {
 ; CHECK:       if2:
 ; CHECK-NEXT:    br label [[LOOP_END]]
 ; CHECK:       loop.end:
-; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[PTR2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
+; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[GEP2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
 ; CHECK-NEXT:    [[VAL3]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    store i32 [[VAL3]], ptr [[PHI3]], align 4
 ; CHECK-NEXT:    br i1 undef, label [[LOOP]], label [[IF1]]
@@ -609,3 +610,199 @@ loop.end:
   %gep3 = getelementptr inbounds i32, ptr %ptr1, i32 1
   br i1 undef, label %loop, label %if1
 }
+
+; Make sure that the call to use_ptr does not have %p1
+define void @single_phi1(ptr %p1) {
+; CHECK-LABEL: @single_phi1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  unreachable
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  call void @use_ptr(ptr %phi1)
+  ret void
+}
+
+define void @single_phi2(ptr %p1) {
+; CHECK-LABEL: @single_phi2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  br label %bb4
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  call void @use_ptr(ptr %phi1)
+  ret void
+}
+
+define void @multiple_phi1(ptr %p1) {
+; CHECK-LABEL: @multiple_phi1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    br label [[BB5:%.*]]
+; CHECK:       bb5:
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    br label [[BB5]]
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  unreachable
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  br label %bb5
+
+bb5:
+  %phi2 = phi ptr [ poison, %bb5 ], [ %phi1, %bb4 ]
+  call void @use_ptr(ptr %phi2)
+  br label %bb5
+}
+
+define void @multiple_phi2(ptr %p1) {
+; CHECK-LABEL: @multiple_phi2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    br label [[BB5:%.*]]
+; CHECK:       bb5:
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    br label [[BB5]]
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  br label %bb4
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  br label %bb5
+
+bb5:
+  %phi2 = phi ptr [ poison, %bb5 ], [ %phi1, %bb4 ]
+  call void @use_ptr(ptr %phi2)
+  br label %bb5
+}
+
+declare void @use_bool(i1)
+declare void @use_ptr(ptr)
diff --git a/llvm/unittests/Analysis/LoadsTest.cpp b/llvm/unittests/Analysis/LoadsTest.cpp
index 0111cfeefa41ae..0694889596fa23 100644
--- a/llvm/unittests/Analysis/LoadsTest.cpp
+++ b/llvm/unittests/Analysis/LoadsTest.cpp
@@ -68,35 +68,48 @@ TEST(LoadsTest, CanReplacePointersIfEqual) {
                                       R"IR(
 @y = common global [1 x i32] zeroinitializer, align 4
 @x = common global [1 x i32] zeroinitializer, align 4
-
 declare void @use(i32*)
 
-define void @f(i32* %p) {
+define void @f(i32* %p1, i32* %p2, i64 %i) {
   call void @use(i32* getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0))
-  call void @use(i32* getelementptr inbounds (i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @x, i64 0, i64 0), i64 1))
+
+  %p1_idx = getelementptr inbounds i32, i32* %p1, i64 %i
+  call void @use(i32* %p1_idx)
+
+  %icmp = icmp eq i32* %p1, getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0)
+  %ptrInt = ptrtoint i32* %p1 to i64
   ret void
 }
 )IR");
-  const auto &DL = M->getDataLayout();
   auto *GV = M->getNamedValue("f");
   ASSERT_TRUE(GV);
   auto *F = dyn_cast<Function>(GV);
   ASSERT_TRUE(F);
 
-  // NOTE: the implementation of canReplacePointersIfEqual is incomplete.
-  // Currently the only the cases it returns false for are really sound and
-  // returning true means unknown.
-  Value *P = &*F->arg_begin();
+  Value *P1 = &*F->arg_begin();
+  Value *P2 = F->getArg(1);
+  Value *NullPtr = Constant::getNullValue(P1->getType());
   auto InstIter = F->front().begin();
-  Value *ConstDerefPtr = *cast<CallInst>(&*InstIter)->arg_begin();
-  // ConstDerefPtr is a constant pointer that is provably de-referenceable. We
-  // can replace an arbitrary pointer with it.
-  EXPECT_TRUE(canReplacePointersIfEqual(P, ConstDerefPtr, DL, nullptr));
+  CallInst *UserOfY = cast<CallIns...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Usman Nadeem (UsmanNadeem)

Changes

Reviving https://reviews.llvm.org/D143129

This patch does the following:

Adds the following functions:

  • replaceDominatedUsesWithIf() that takes a callback.

  • canReplacePointersIfEqual(...) returns true if the underlying object
    is the same, and for null and const dereferencable pointer replacements.

  • canReplacePointersIfEqualInUse(...) returns true for the above as well
    as if the use is in icmp/ptrtoint or phi/selects feeding into them.

Updates GVN using the functions above so that the pointer replacements
are only made using the above API.

Change-Id: I4927ea452734458be028854ef0e5cbcd81955910


Patch is 21.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82458.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/Loads.h (+8-7)
  • (modified) llvm/include/llvm/Transforms/Utils/Local.h (+12)
  • (modified) llvm/lib/Analysis/Loads.cpp (+61-17)
  • (modified) llvm/lib/Transforms/Scalar/GVN.cpp (+15-9)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+25-5)
  • (modified) llvm/test/Transforms/GVN/condprop.ll (+202-5)
  • (modified) llvm/unittests/Analysis/LoadsTest.cpp (+31-18)
diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h
index 0926093bba99de..e39a1d6cafad23 100644
--- a/llvm/include/llvm/Analysis/Loads.h
+++ b/llvm/include/llvm/Analysis/Loads.h
@@ -173,14 +173,15 @@ Value *findAvailablePtrLoadStore(const MemoryLocation &Loc, Type *AccessTy,
                                  unsigned MaxInstsToScan, BatchAAResults *AA,
                                  bool *IsLoadCSE, unsigned *NumScanedInst);
 
-/// Returns true if a pointer value \p A can be replace with another pointer
-/// value \B if they are deemed equal through some means (e.g. information from
+/// Returns true if a pointer value \p From can be replaced with another pointer
+/// value \To if they are deemed equal through some means (e.g. information from
 /// conditions).
-/// NOTE: the current implementations is incomplete and unsound. It does not
-/// reject all invalid cases yet, but will be made stricter in the future. In
-/// particular this means returning true means unknown if replacement is safe.
-bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
-                               Instruction *CtxI);
+/// NOTE: The current implementation allows replacement in Icmp and PtrToInt
+/// instructions, as well as when we are replacing with a null pointer.
+/// Additionally it also allows replacement of pointers when both pointers have
+/// the same underlying object.
+bool canReplacePointersIfEqual(const Value *From, const Value *To);
+bool canReplacePointersInUseIfEqual(const Use &U, const Value *To);
 }
 
 #endif
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index 2df3c9049c7d62..d0dbf1951fe8b6 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -432,6 +432,18 @@ unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
 /// the end of the given BasicBlock. Returns the number of replacements made.
 unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
                                   const BasicBlock *BB);
+/// Replace each use of 'From' with 'To' if that use is dominated by
+/// the given edge and the callback ShouldReplace returns true. Returns the
+/// number of replacements made.
+unsigned replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Edge,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
+/// Replace each use of 'From' with 'To' if that use is dominated by
+/// the end of the given BasicBlock and the callback ShouldReplace returns true.
+/// Returns the number of replacements made.
+unsigned replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace);
 
 /// Return true if this call calls a gc leaf function.
 ///
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 6bf0d2f56eb4eb..e7d72fbc7107d6 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -708,22 +708,66 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
   return Available;
 }
 
-bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
-                                     Instruction *CtxI) {
-  Type *Ty = A->getType();
-  assert(Ty == B->getType() && Ty->isPointerTy() &&
-         "values must have matching pointer types");
-
-  // NOTE: The checks in the function are incomplete and currently miss illegal
-  // cases! The current implementation is a starting point and the
-  // implementation should be made stricter over time.
-  if (auto *C = dyn_cast<Constant>(B)) {
-    // Do not allow replacing a pointer with a constant pointer, unless it is
-    // either null or at least one byte is dereferenceable.
-    APInt OneByte(DL.getPointerTypeSizeInBits(Ty), 1);
-    return C->isNullValue() ||
-           isDereferenceableAndAlignedPointer(B, Align(1), OneByte, DL, CtxI);
-  }
+// Returns true if a use is either in an ICmp/PtrToInt or a Phi/Select that only
+// feeds into them.
+static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
+  if (MaxLookup == 0)
+    return false;
+
+  const User *User = U.getUser();
+  if (isa<ICmpInst>(User))
+    return true;
+  if (isa<PtrToIntInst>(User))
+    return true;
+  if (isa<PHINode, SelectInst>(User) &&
+      all_of(User->uses(), [&](const Use &Use) {
+        return isPointerUseReplacable(Use, MaxLookup - 1);
+      }))
+    return true;
+
+  return false;
+}
+
+static const DataLayout &getDLFromVal(const Value *V) {
+  if (const Argument *A = dyn_cast<Argument>(V))
+    return A->getParent()->getParent()->getDataLayout();
+  if (const Instruction *I = dyn_cast<Instruction>(V))
+    return I->getModule()->getDataLayout();
+  if (const GlobalValue *GV = dyn_cast<GlobalValue>(V))
+    return GV->getParent()->getDataLayout();
+  llvm_unreachable("Unknown Value type");
+}
+
+// Returns true if `To` is a null pointer, constant dereferenceable pointer or
+// both pointers have the same underlying objects.
+static bool isPointerAlwaysReplacable(const Value *From, const Value *To) {
+  if (isa<ConstantPointerNull>(To))
+    return true;
+  if (isa<Constant>(To) &&
+      isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()),
+                               getDLFromVal(From)))
+    return true;
+  if (getUnderlyingObject(From) == getUnderlyingObject(To))
+    return true;
+  return false;
+}
+
+bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To) {
+  assert(U->getType() == To->getType() && "values must have matching types");
+  // Not a pointer, just return true.
+  if (!To->getType()->isPointerTy())
+    return true;
+
+  if (isPointerAlwaysReplacable(&*U, To))
+    return true;
+  return isPointerUseReplacable(U);
+}
+
+bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To) {
+  assert(From->getType() == To->getType() && "values must have matching types");
+  // Not a pointer, just return true.
+  if (!From->getType()->isPointerTy())
+    return true;
 
-  return true;
+  return isPointerAlwaysReplacable(From, To);
 }
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index dcb1ed334b6103..d4b16836d728fa 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/InstructionPrecedenceTracking.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
@@ -2443,7 +2444,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     // using the leader table is about compiling faster, not optimizing better).
     // The leader table only tracks basic blocks, not edges. Only add to if we
     // have the simple case where the edge dominates the end.
-    if (RootDominatesEnd && !isa<Instruction>(RHS))
+    if (RootDominatesEnd && !isa<Instruction>(RHS) &&
+        canReplacePointersIfEqual(LHS, RHS))
       addToLeaderTable(LVN, RHS, Root.getEnd());
 
     // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope.  As
@@ -2452,14 +2454,18 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     if (!LHS->hasOneUse()) {
       unsigned NumReplacements =
           DominatesByEdge
-              ? replaceDominatedUsesWith(LHS, RHS, *DT, Root)
-              : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart());
-
-      Changed |= NumReplacements > 0;
-      NumGVNEqProp += NumReplacements;
-      // Cached information for anything that uses LHS will be invalid.
-      if (MD)
-        MD->invalidateCachedPointerInfo(LHS);
+              ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root,
+                                           canReplacePointersInUseIfEqual)
+              : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(),
+                                           canReplacePointersInUseIfEqual);
+
+      if (NumReplacements > 0) {
+        Changed = true;
+        NumGVNEqProp += NumReplacements;
+        // Cached information for anything that uses LHS will be invalid.
+        if (MD)
+          MD->invalidateCachedPointerInfo(LHS);
+      }
     }
 
     // Now try to deduce additional equalities from this one. For example, if
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index e4aa25f7ac6ad3..dcea7bb42c54e9 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -3389,15 +3389,18 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
 }
 
 template <typename RootType, typename DominatesFn>
-static unsigned replaceDominatedUsesWith(Value *From, Value *To,
-                                         const RootType &Root,
-                                         const DominatesFn &Dominates) {
+static unsigned replaceDominatedUsesWith(
+    Value *From, Value *To, const RootType &Root, const DominatesFn &Dominates,
+    std::optional<function_ref<bool(const Use &U, const Value *To)>>
+        ShouldReplace) {
   assert(From->getType() == To->getType());
 
   unsigned Count = 0;
   for (Use &U : llvm::make_early_inc_range(From->uses())) {
     if (!Dominates(Root, U))
       continue;
+    if (ShouldReplace.has_value() && !ShouldReplace.value()(U, To))
+      continue;
     LLVM_DEBUG(dbgs() << "Replace dominated use of '";
                From->printAsOperand(dbgs());
                dbgs() << "' with " << *To << " in " << *U.getUser() << "\n");
@@ -3428,7 +3431,7 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
   auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
     return DT.dominates(Root, U);
   };
-  return ::replaceDominatedUsesWith(From, To, Root, Dominates);
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates, std::nullopt);
 }
 
 unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
@@ -3437,9 +3440,26 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
   auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
     return DT.dominates(BB, U);
   };
-  return ::replaceDominatedUsesWith(From, To, BB, Dominates);
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates, std::nullopt);
+}
+
+unsigned llvm::replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
+    return DT.dominates(Root, U);
+  };
+  return ::replaceDominatedUsesWith(From, To, Root, Dominates, ShouldReplace);
 }
 
+unsigned llvm::replaceDominatedUsesWithIf(
+    Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB,
+    function_ref<bool(const Use &U, const Value *To)> ShouldReplace) {
+  auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
+    return DT.dominates(BB, U);
+  };
+  return ::replaceDominatedUsesWith(From, To, BB, Dominates, ShouldReplace);
+}
 bool llvm::callsGCLeafFunction(const CallBase *Call,
                                const TargetLibraryInfo &TLI) {
   // Check if the function is specifically marked as a gc leaf function.
diff --git a/llvm/test/Transforms/GVN/condprop.ll b/llvm/test/Transforms/GVN/condprop.ll
index 6b1e4d10601099..2b202b1701150e 100644
--- a/llvm/test/Transforms/GVN/condprop.ll
+++ b/llvm/test/Transforms/GVN/condprop.ll
@@ -521,15 +521,16 @@ define i32 @test13(ptr %ptr1, ptr %ptr2) {
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i32, ptr [[PTR2:%.*]], i32 1
 ; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr i32, ptr [[PTR2]], i32 2
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[PTR1:%.*]], [[PTR2]]
-; CHECK-NEXT:    [[VAL2_PRE:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
 ; CHECK:       if:
+; CHECK-NEXT:    [[VAL1:%.*]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    br label [[END]]
 ; CHECK:       end:
-; CHECK-NEXT:    [[PHI1:%.*]] = phi ptr [ [[PTR2]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ [[VAL2_PRE]], [[IF]] ], [ 0, [[ENTRY]] ]
+; CHECK-NEXT:    [[PHI1:%.*]] = phi ptr [ [[PTR1]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ [[VAL1]], [[IF]] ], [ 0, [[ENTRY]] ]
 ; CHECK-NEXT:    store i32 0, ptr [[PHI1]], align 4
-; CHECK-NEXT:    [[RET:%.*]] = add i32 [[PHI2]], [[VAL2_PRE]]
+; CHECK-NEXT:    [[VAL2:%.*]] = load i32, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[RET:%.*]] = add i32 [[PHI2]], [[VAL2]]
 ; CHECK-NEXT:    ret i32 [[RET]]
 ;
 entry:
@@ -574,7 +575,7 @@ define void @test14(ptr %ptr1, ptr noalias %ptr2) {
 ; CHECK:       if2:
 ; CHECK-NEXT:    br label [[LOOP_END]]
 ; CHECK:       loop.end:
-; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[PTR2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
+; CHECK-NEXT:    [[PHI3:%.*]] = phi ptr [ [[GEP2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ]
 ; CHECK-NEXT:    [[VAL3]] = load i32, ptr [[GEP2]], align 4
 ; CHECK-NEXT:    store i32 [[VAL3]], ptr [[PHI3]], align 4
 ; CHECK-NEXT:    br i1 undef, label [[LOOP]], label [[IF1]]
@@ -609,3 +610,199 @@ loop.end:
   %gep3 = getelementptr inbounds i32, ptr %ptr1, i32 1
   br i1 undef, label %loop, label %if1
 }
+
+; Make sure that the call to use_ptr does not have %p1
+define void @single_phi1(ptr %p1) {
+; CHECK-LABEL: @single_phi1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  unreachable
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  call void @use_ptr(ptr %phi1)
+  ret void
+}
+
+define void @single_phi2(ptr %p1) {
+; CHECK-LABEL: @single_phi2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  br label %bb4
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  call void @use_ptr(ptr %phi1)
+  ret void
+}
+
+define void @multiple_phi1(ptr %p1) {
+; CHECK-LABEL: @multiple_phi1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    br label [[BB5:%.*]]
+; CHECK:       bb5:
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    br label [[BB5]]
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  unreachable
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  br label %bb5
+
+bb5:
+  %phi2 = phi ptr [ poison, %bb5 ], [ %phi1, %bb4 ]
+  call void @use_ptr(ptr %phi2)
+  br label %bb5
+}
+
+define void @multiple_phi2(ptr %p1) {
+; CHECK-LABEL: @multiple_phi2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr poison, align 8
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[P2]], [[P1:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB4:%.*]], label [[BB1:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    switch i8 poison, label [[BB2:%.*]] [
+; CHECK-NEXT:    i8 0, label [[BB1]]
+; CHECK-NEXT:    i8 1, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[BB4]]
+; CHECK:       bb4:
+; CHECK-NEXT:    call void @use_bool(i1 [[CMP1]])
+; CHECK-NEXT:    br label [[BB5:%.*]]
+; CHECK:       bb5:
+; CHECK-NEXT:    call void @use_ptr(ptr [[P2]])
+; CHECK-NEXT:    br label [[BB5]]
+;
+entry:
+  %p2 = load ptr, ptr poison, align 8
+  %cmp1 = icmp eq ptr %p2, %p1
+  br i1 %cmp1, label %bb4, label %bb1
+
+bb1:
+  switch i8 poison, label %bb2 [
+  i8 0, label %bb1
+  i8 1, label %bb3
+  ]
+
+bb2:
+  br label %bb4
+
+bb3:
+  br label %bb4
+
+bb4:
+  %phi1 = phi ptr [ %p2, %entry ], [ %p2, %bb2 ], [ poison, %bb3 ]
+  %cmp2 = icmp eq ptr %phi1, %p1
+  call void @use_bool(i1 %cmp2)
+  br label %bb5
+
+bb5:
+  %phi2 = phi ptr [ poison, %bb5 ], [ %phi1, %bb4 ]
+  call void @use_ptr(ptr %phi2)
+  br label %bb5
+}
+
+declare void @use_bool(i1)
+declare void @use_ptr(ptr)
diff --git a/llvm/unittests/Analysis/LoadsTest.cpp b/llvm/unittests/Analysis/LoadsTest.cpp
index 0111cfeefa41ae..0694889596fa23 100644
--- a/llvm/unittests/Analysis/LoadsTest.cpp
+++ b/llvm/unittests/Analysis/LoadsTest.cpp
@@ -68,35 +68,48 @@ TEST(LoadsTest, CanReplacePointersIfEqual) {
                                       R"IR(
 @y = common global [1 x i32] zeroinitializer, align 4
 @x = common global [1 x i32] zeroinitializer, align 4
-
 declare void @use(i32*)
 
-define void @f(i32* %p) {
+define void @f(i32* %p1, i32* %p2, i64 %i) {
   call void @use(i32* getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0))
-  call void @use(i32* getelementptr inbounds (i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @x, i64 0, i64 0), i64 1))
+
+  %p1_idx = getelementptr inbounds i32, i32* %p1, i64 %i
+  call void @use(i32* %p1_idx)
+
+  %icmp = icmp eq i32* %p1, getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0)
+  %ptrInt = ptrtoint i32* %p1 to i64
   ret void
 }
 )IR");
-  const auto &DL = M->getDataLayout();
   auto *GV = M->getNamedValue("f");
   ASSERT_TRUE(GV);
   auto *F = dyn_cast<Function>(GV);
   ASSERT_TRUE(F);
 
-  // NOTE: the implementation of canReplacePointersIfEqual is incomplete.
-  // Currently the only the cases it returns false for are really sound and
-  // returning true means unknown.
-  Value *P = &*F->arg_begin();
+  Value *P1 = &*F->arg_begin();
+  Value *P2 = F->getArg(1);
+  Value *NullPtr = Constant::getNullValue(P1->getType());
   auto InstIter = F->front().begin();
-  Value *ConstDerefPtr = *cast<CallInst>(&*InstIter)->arg_begin();
-  // ConstDerefPtr is a constant pointer that is provably de-referenceable. We
-  // can replace an arbitrary pointer with it.
-  EXPECT_TRUE(canReplacePointersIfEqual(P, ConstDerefPtr, DL, nullptr));
+  CallInst *UserOfY = cast<CallIns...
[truncated]

llvm/lib/Analysis/Loads.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/Loads.cpp Outdated Show resolved Hide resolved
@nunoplopes
Copy link
Member

Essentially this patch allows replacement of pointers in this situation:

if (p == q) {
  use(p);
}

Where use(p) is any of icmp or ptrtoint.
This is correct because these users only care about the pointer address not its provenance.
I'm not super familiar with GVN's code (only NewGVN), but the patch looks good to me.

For a follow up patch, another condition that can be considered is if both p and q are dereferenceable. That guarantees that the remaining uses (load, store, etc) can see either of the pointers.

Reviving https://reviews.llvm.org/D143129

This patch does the following:

Adds the following functions:
- replaceDominatedUsesWithIf() that takes a callback.

- canReplacePointersIfEqual(...) returns true if the underlying object
is the same, and for null and const dereferencable pointer replacements.

- canReplacePointersIfEqualInUse(...) returns true for the above as well
as if the use is in icmp/ptrtoint or phi/selects feeding into them.

Updates GVN using the functions above so that the pointer replacements
are only made using the above API.

Change-Id: I4927ea452734458be028854ef0e5cbcd81955910
Change-Id: I7d0ddd8e6d2bab90de72d2c3f691cfbea9c3e257
@UsmanNadeem
Copy link
Contributor Author

ping

@UsmanNadeem
Copy link
Contributor Author

  • Use DataLayout as function argument

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@aqjune
Copy link
Contributor

aqjune commented Apr 22, 2024

LGTM in my eyes too. thanks.

llvm/lib/Analysis/Loads.cpp Outdated Show resolved Hide resolved
@nikic
Copy link
Contributor

nikic commented Apr 23, 2024

@dtcxzyw Could you please test this patch?

Change-Id: Ifba7e54313c5d02bcba356e0a2e5e93d521b9c4b
Change-Id: Ic203125c66a72af9439547fc592c80b459290e86
llvm/lib/Analysis/Loads.cpp Outdated Show resolved Hide resolved
Co-authored-by: Eli Friedman <efriedma@quicinc.com>
llvm/lib/Analysis/Loads.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/Loads.cpp Outdated Show resolved Hide resolved
Visited.insert(User);
if (isa<ICmpInst, PtrToIntInst>(User))
continue;
if (isa<PHINode, SelectInst>(User)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Early return here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't return true here as we still need to check other users in the list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant here is to write:

if (!isa<PHINode, SelectInst>(User))
  return false;

But now that the if block has become so small it doesn't really matter anymore.

llvm/unittests/Analysis/LoadsTest.cpp Outdated Show resolved Hide resolved
Change-Id: Iba33ade6b0dca3a8c1341f8bf99f1013d17ca3f3
@UsmanNadeem UsmanNadeem merged commit b10e4b8 into llvm:main Apr 24, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants