Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DSE] Fold malloc/store pair into calloc #87048

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

XChy
Copy link
Member

@XChy XChy commented Mar 29, 2024

Resolve TODO in DSE.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 29, 2024

@llvm/pr-subscribers-llvm-transforms

Author: XChy (XChy)

Changes

Resolve TODO in DSE.


Full diff: https://github.com/llvm/llvm-project/pull/87048.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp (+38-25)
  • (added) llvm/test/Transforms/DeadStoreElimination/malloc-store.ll (+98)
  • (modified) llvm/test/Transforms/DeadStoreElimination/simple.ll (+2-4)
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index bfc8bd5970bf27..04ff6faa0db6bd 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -1842,20 +1842,39 @@ struct DSEState {
     return MadeChange;
   }
 
-  /// If we have a zero initializing memset following a call to malloc,
+  /// If we have a zero initializing memset/store following a call to malloc,
   /// try folding it into a call to calloc.
   bool tryFoldIntoCalloc(MemoryDef *Def, const Value *DefUO) {
     Instruction *DefI = Def->getMemoryInst();
-    MemSetInst *MemSet = dyn_cast<MemSetInst>(DefI);
-    if (!MemSet)
-      // TODO: Could handle zero store to small allocation as well.
+    Constant *StoredConstant;
+    Value *PtrToStore;
+
+    auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUO));
+
+    if (!Malloc)
       return false;
-    Constant *StoredConstant = dyn_cast<Constant>(MemSet->getValue());
+
+    if (MemSetInst *MemSet = dyn_cast<MemSetInst>(DefI)) {
+      StoredConstant = dyn_cast<Constant>(MemSet->getValue());
+      PtrToStore = MemSet->getArgOperand(0);
+      if (Malloc->getOperand(0) != MemSet->getLength())
+        return false;
+    } else if (StoreInst *SI = dyn_cast<StoreInst>(DefI)) {
+      StoredConstant = dyn_cast<Constant>(SI->getValueOperand());
+      PtrToStore = SI->getPointerOperand();
+      if (SI->getAccessType()->isScalableTy())
+        return false;
+      uint64_t StoreSize = DL.getTypeStoreSize(SI->getAccessType());
+      if (!match(Malloc->getOperand(0), m_SpecificInt(StoreSize)))
+        return false;
+    } else
+      return false;
+
     if (!StoredConstant || !StoredConstant->isNullValue())
       return false;
 
     if (!isRemovable(DefI))
-      // The memset might be volatile..
+      // The memset/store might be volatile..
       return false;
 
     if (F.hasFnAttribute(Attribute::SanitizeMemory) ||
@@ -1863,9 +1882,7 @@ struct DSEState {
         F.hasFnAttribute(Attribute::SanitizeHWAddress) ||
         F.getName() == "calloc")
       return false;
-    auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUO));
-    if (!Malloc)
-      return false;
+
     auto *InnerCallee = Malloc->getCalledFunction();
     if (!InnerCallee)
       return false;
@@ -1878,30 +1895,27 @@ struct DSEState {
     if (!MallocDef)
       return false;
 
-    auto shouldCreateCalloc = [](CallInst *Malloc, CallInst *Memset) {
+    auto shouldCreateCalloc = [](CallInst *Malloc, Instruction *DefI,
+                                 Value *Ptr) {
       // Check for br(icmp ptr, null), truebb, falsebb) pattern at the end
       // of malloc block
-      auto *MallocBB = Malloc->getParent(),
-        *MemsetBB = Memset->getParent();
-      if (MallocBB == MemsetBB)
+      auto *MallocBB = Malloc->getParent(), *DefBB = DefI->getParent();
+      if (MallocBB == DefBB)
         return true;
-      auto *Ptr = Memset->getArgOperand(0);
       auto *TI = MallocBB->getTerminator();
       ICmpInst::Predicate Pred;
       BasicBlock *TrueBB, *FalseBB;
       if (!match(TI, m_Br(m_ICmp(Pred, m_Specific(Ptr), m_Zero()), TrueBB,
                           FalseBB)))
         return false;
-      if (Pred != ICmpInst::ICMP_EQ || MemsetBB != FalseBB)
+      if (Pred != ICmpInst::ICMP_EQ || DefBB != FalseBB)
         return false;
       return true;
     };
 
-    if (Malloc->getOperand(0) != MemSet->getLength())
-      return false;
-    if (!shouldCreateCalloc(Malloc, MemSet) ||
-        !DT.dominates(Malloc, MemSet) ||
-        !memoryIsNotModifiedBetween(Malloc, MemSet, BatchAA, DL, &DT))
+    if (!shouldCreateCalloc(Malloc, DefI, PtrToStore) ||
+        !DT.dominates(Malloc, DefI) ||
+        !memoryIsNotModifiedBetween(Malloc, DefI, BatchAA, DL, &DT))
       return false;
     IRBuilder<> IRB(Malloc);
     Type *SizeTTy = Malloc->getArgOperand(0)->getType();
@@ -1911,9 +1925,8 @@ struct DSEState {
       return false;
 
     MemorySSAUpdater Updater(&MSSA);
-    auto *NewAccess =
-      Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), nullptr,
-                                      MallocDef);
+    auto *NewAccess = Updater.createMemoryAccessAfter(cast<Instruction>(Calloc),
+                                                      nullptr, MallocDef);
     auto *NewAccessMD = cast<MemoryDef>(NewAccess);
     Updater.insertDef(NewAccessMD, /*RenameUses=*/true);
     Malloc->replaceAllUsesWith(Calloc);
@@ -2288,9 +2301,9 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
       continue;
     }
 
-    // Can we form a calloc from a memset/malloc pair?
+    // Can we form a calloc from a memset/malloc or store/malloc pair?
     if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
+      LLVM_DEBUG(dbgs() << "DSE: Remove memset/store after forming calloc:\n"
                         << "  DEAD: " << *KillingI << '\n');
       State.deleteDeadInstruction(KillingI);
       MadeChange = true;
diff --git a/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
new file mode 100644
index 00000000000000..70938b60df36e0
--- /dev/null
+++ b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=dse -S | FileCheck %s
+
+declare noalias ptr @malloc(i64) willreturn allockind("alloc,uninitialized")
+declare void @use(ptr)
+
+define ptr @basic() {
+; CHECK-LABEL: define ptr @basic() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @vec_type() {
+; CHECK-LABEL: define ptr @vec_type() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store <2 x i16> zeroinitializer, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @clobber() {
+; CHECK-LABEL: define ptr @clobber() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
+; CHECK-NEXT:    [[L:%.*]] = load i8, ptr [[PTR]], align 1
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  %l = load i8, ptr %ptr
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @wrong_size() {
+; CHECK-LABEL: define ptr @wrong_size() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i8 0, ptr [[PTR]], align 1
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i8 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @bigstore() {
+; CHECK-LABEL: define ptr @bigstore() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 8096)
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 8096)
+  store <8096 x i8> zeroinitializer, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @nonconstant1(i64 %l) {
+; CHECK-LABEL: define ptr @nonconstant1(
+; CHECK-SAME: i64 [[L:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 [[L]])
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 %l)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @nonconstant2(i32 %v) {
+; CHECK-LABEL: define ptr @nonconstant2(
+; CHECK-SAME: i32 [[V:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 [[V]], ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 %v, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @clobber_fail(i32 %a) {
+; CHECK-LABEL: define ptr @clobber_fail(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 [[A]], ptr [[PTR]], align 4
+; CHECK-NEXT:    call void @use(ptr [[PTR]])
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 %a, ptr %ptr
+  call void @use(ptr %ptr)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
diff --git a/llvm/test/Transforms/DeadStoreElimination/simple.ll b/llvm/test/Transforms/DeadStoreElimination/simple.ll
index e5d3dd09fa148d..4a80fc5b141c18 100644
--- a/llvm/test/Transforms/DeadStoreElimination/simple.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/simple.ll
@@ -204,9 +204,8 @@ define void @test_matrix_store(i64 %stride) {
 declare void @may_unwind()
 define ptr @test_malloc_no_escape_before_return() {
 ; CHECK-LABEL: @test_malloc_no_escape_before_return(
-; CHECK-NEXT:    [[PTR:%.*]] = tail call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    call void @may_unwind()
-; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = tail call ptr @malloc(i64 4)
@@ -236,10 +235,9 @@ define ptr @test_custom_malloc_no_escape_before_return() {
 
 define ptr addrspace(1) @test13_addrspacecast() {
 ; CHECK-LABEL: @test13_addrspacecast(
-; CHECK-NEXT:    [[P:%.*]] = tail call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[P:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    [[P_AC:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
 ; CHECK-NEXT:    call void @may_unwind()
-; CHECK-NEXT:    store i32 0, ptr addrspace(1) [[P_AC]], align 4
 ; CHECK-NEXT:    ret ptr addrspace(1) [[P_AC]]
 ;
   %p = tail call ptr @malloc(i64 4)

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 29, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 29, 2024

Should we revert the transform when the memory region allocated by calloc will be overridden by other stores?

Example:

define void @test(i1 %cond, i64 %x, i64 %y) {
  %ptr = call ptr @malloc(i64 8)
  store i64 0, ptr %ptr, align 8 
  br i1 %cond, label %if.then, label %if.else

if.then:
  store i64 %x, ptr %ptr, align 8 ; It will be created in later passes
  ret void
if.else:
  store i64 %y, ptr %ptr, align 8
  ret void
}

See also dtcxzyw/llvm-opt-benchmark#461 (comment)

@XChy
Copy link
Member Author

XChy commented Mar 29, 2024

Should we revert the transform when the memory region allocated by calloc will be overridden by other stores?

Sounds reasonable. calloc can be modelled as a malloc and a following memset/store, so I guess we can eliminate the virtual memset/store as DSE has done for real memorydef. But I'm not that familiar with the code of DSE. @fhahn, could you provide any insight?

@XChy XChy requested a review from yurai007 March 29, 2024 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants