-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ArgPromotion] Remove redundant logic from recursive argpromotion code #98657
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Vedant Paranjape (vedantparanjape-amd) ChangesThis patch further cleans up the implementation by removing some redundant checks and replacing cast<> with get() calls. It adds a check to see if function call type matches the function type. This contribution is based on the discussion in #78735 Full diff: https://github.com/llvm/llvm-project/pull/98657.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 77dbf349df0df..78805f7fc9554 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -640,8 +640,10 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
}
auto *CB = dyn_cast<CallBase>(V);
- Value *PtrArg = cast<Value>(U);
- if (CB && PtrArg && CB->getCalledFunction() == CB->getFunction()) {
+ Value *PtrArg = U->get();
+ if (CB && CB->getCalledFunction() == CB->getFunction() &&
+ CB->getCalledFunction()->getReturnType() ==
+ CB->getFunction()->getReturnType()) {
if (PtrArg != Arg) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "pointer offset is not equal to zero\n");
@@ -649,7 +651,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
}
unsigned int ArgNo = Arg->getArgNo();
- if (CB->getArgOperand(ArgNo) != Arg || U->getOperandNo() != ArgNo) {
+ if (U->getOperandNo() != ArgNo) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "arg position is different in callee\n");
return false;
diff --git a/llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll b/llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll
new file mode 100644
index 0000000000000..a4ee73727108a
--- /dev/null
+++ b/llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
+define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define internal i32 @foo(
+; CHECK-SAME: ptr [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
+; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
+; CHECK: [[COND_TRUE]]:
+; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT: br label %[[RETURN:.*]]
+; CHECK: [[COND_FALSE]]:
+; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], 1
+; CHECK-NEXT: [[CALLRET0:%.*]] = call float @foo(ptr [[X]], i32 [[SUBVAL]], i32 [[VAL2]])
+; CHECK-NEXT: [[CALLRET1:%.*]] = call i32 @foo(ptr [[X]], i32 [[SUBVAL]], i32 [[VAL2]])
+; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
+; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(ptr [[X]], i32 [[SUBVAL2]], i32 [[M]])
+; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET1]], [[CALLRET2]]
+; CHECK-NEXT: br label %[[RETURN]]
+; CHECK: [[COND_NEXT:.*]]:
+; CHECK-NEXT: br label %[[RETURN]]
+; CHECK: [[RETURN]]:
+; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+entry:
+ %cmp = icmp ne i32 %n, 0
+ br i1 %cmp, label %cond_true, label %cond_false
+
+cond_true: ; preds = %entry
+ %val = load i32, ptr %x, align 4
+ br label %return
+
+cond_false: ; preds = %entry
+ %val2 = load i32, ptr %x, align 4
+ %subval = sub i32 %n, 1
+ %callret0 = call float @foo(ptr %x, i32 %subval, i32 %val2)
+ %callret1 = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
+ %subval2 = sub i32 %n, 2
+ %callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
+ %cmp2 = add i32 %callret1, %callret2
+ br label %return
+
+cond_next: ; No predecessors!
+ br label %return
+
+return: ; preds = %cond_next, %cond_false, %cond_true
+ %retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
+ ret i32 %retval.0
+}
+
+define i32 @bar(ptr align(4) dereferenceable(4) %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define i32 @bar(
+; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], i32 [[N]], i32 [[M]])
+; CHECK-NEXT: br label %[[RETURN:.*]]
+; CHECK: [[RETURN]]:
+; CHECK-NEXT: ret i32 [[CALLRET3]]
+;
+entry:
+ %callret3 = call i32 @foo(ptr %x, i32 %n, i32 %m)
+ br label %return
+
+return: ; preds = %entry
+ ret i32 %callret3
+}
|
if (PtrArg != Arg) { | ||
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " | ||
<< "pointer offset is not equal to zero\n"); | ||
return false; | ||
} | ||
|
||
unsigned int ArgNo = Arg->getArgNo(); | ||
if (CB->getArgOperand(ArgNo) != Arg || U->getOperandNo() != ArgNo) { | ||
if (U->getOperandNo() != ArgNo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use isCallee?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, as I need to check the position of the argument in the call matches the position in the caller function.
CB->getCalledFunction()->getReturnType() == | ||
CB->getFunction()->getReturnType()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the function return type important? getCalledFunction is supposed to return null if the function types mismatch already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused, wasn't I meant to add the check to see if the function call type doesn't match the function def?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, I forgot that getCalledFunction already checks the function type. Ignore my comment about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, cool! so I can remove that check ? and then can I get a LGTM.
db01c69
to
fb27f6f
Compare
This patch further cleans up the implementation by removing some redundant checks and replacing cast<> with get() calls.
36a7d65
to
5110a86
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/39/builds/516 Here is the relevant piece of the build log for the reference:
|
llvm#98657) This patch further cleans up the implementation by removing some redundant checks and replacing cast<> with get() calls. This contribution is based on the discussion in llvm#78735
This patch further cleans up the implementation by removing some redundant checks and replacing cast<> with get() calls.
This contribution is based on the discussion in #78735