Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ArgPromotion] Handle pointer arguments of recursive calls #78735

Merged
merged 19 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,

/// Return true if we can prove that all callees pass in a valid pointer for the
/// specified function argument.
static bool allCallersPassValidPointerForArgument(Argument *Arg,
Align NeededAlign,
uint64_t NeededDerefBytes) {
static bool allCallersPassValidPointerForArgument(
Argument *Arg, SmallPtrSetImpl<CallBase *> &RecursiveCalls,
Align NeededAlign, uint64_t NeededDerefBytes) {
Function *Callee = Arg->getParent();
const DataLayout &DL = Callee->getDataLayout();
APInt Bytes(64, NeededDerefBytes);
Expand All @@ -438,6 +438,33 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
// direct callees.
return all_of(Callee->users(), [&](User *U) {
CallBase &CB = cast<CallBase>(*U);
// In case of functions with recursive calls, this check
// (isDereferenceableAndAlignedPointer) will fail when it tries to look at
// the first caller of this function. The caller may or may not have a load,
// incase it doesn't load the pointer being passed, this check will fail.
// So, it's safe to skip the check incase we know that we are dealing with a
// recursive call. For example we have a IR given below.
//
// def fun(ptr %a) {
// ...
// %loadres = load i32, ptr %a, align 4
// %res = call i32 @fun(ptr %a)
// ...
// }
//
// def bar(ptr %x) {
// ...
// %resbar = call i32 @fun(ptr %x)
// ...
// }
//
// Since we record processed recursive calls, we check if the current
// CallBase has been processed before. If yes it means that it is a
// recursive call and we can skip the check just for this call. So, just
// return true.
if (RecursiveCalls.contains(&CB))
return true;

return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
NeededAlign, Bytes, DL);
});
Expand Down Expand Up @@ -571,6 +598,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
SmallVector<const Use *, 16> Worklist;
SmallPtrSet<const Use *, 16> Visited;
SmallVector<LoadInst *, 16> Loads;
SmallPtrSet<CallBase *, 4> RecursiveCalls;
auto AppendUses = [&](const Value *V) {
for (const Use &U : V->uses())
if (Visited.insert(&U).second)
Expand Down Expand Up @@ -611,6 +639,33 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
// unknown users
}

auto *CB = dyn_cast<CallBase>(V);
Value *PtrArg = cast<Value>(U);
vedantparanjape-amd marked this conversation as resolved.
Show resolved Hide resolved
if (CB && PtrArg && CB->getCalledFunction() == CB->getFunction()) {
vedantparanjape-amd marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably also check that the function type of the call and function match? I wouldn't want to reason about how the arguments map if this is not the case.

Copy link
Member Author

@vedantparanjape-amd vedantparanjape-amd Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by function type, is it the thing returned by getFunctionType() ? I am not able to understand why it would be different ? Can you give an example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call and the function have separate types that do not match:

declare float @ret_float()

define i32 @call_as_i32() {
  %val = call i32 @ret_float()
  ret i32 %val
}

Copy link
Member Author

@vedantparanjape-amd vedantparanjape-amd Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call and the function have separate types that do not match:

declare float @ret_float()

define i32 @call_as_i32() {
  %val = call i32 @ret_float()
  ret i32 %val
}

I didn't know this was possible, I mean it should not be valid? Okay, I will check for the return types as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created a new PR to address these changes, please take a look (#98657)

if (PtrArg != Arg) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "pointer offset is not equal to zero\n");
return false;
}

unsigned int ArgNo = Arg->getArgNo();
if (CB->getArgOperand(ArgNo) != Arg || U->getOperandNo() != ArgNo) {
vedantparanjape-amd marked this conversation as resolved.
Show resolved Hide resolved
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "arg position is different in callee\n");
return false;
}

// We limit promotion to only promoting up to a fixed number of elements
// of the aggregate.
if (MaxElements > 0 && ArgParts.size() > MaxElements) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "more than " << MaxElements << " parts\n");
return false;
}

RecursiveCalls.insert(CB);
continue;
}
// Unknown user.
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "unknown user " << *V << "\n");
Expand All @@ -619,7 +674,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,

if (NeededDerefBytes || NeededAlign > 1) {
// Try to prove a required deref / aligned requirement.
if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
NeededDerefBytes)) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "not dereferenceable or aligned\n");
Expand Down Expand Up @@ -700,6 +755,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
/// calls the DoPromotion method.
static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
unsigned MaxElements, bool IsRecursive) {
// Due to complexity of handling cases where the SCC has more than one
// component. We want to limit argument promotion of recursive calls to
// just functions that directly call themselves.
bool IsSelfRecursive = false;
vedantparanjape-amd marked this conversation as resolved.
Show resolved Hide resolved
// Don't perform argument promotion for naked functions; otherwise we can end
// up removing parameters that are seemingly 'not used' as they are referred
// to in the assembly.
Expand Down Expand Up @@ -745,8 +804,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
if (CB->isMustTailCall())
return nullptr;

if (CB->getFunction() == F)
if (CB->getFunction() == F) {
IsRecursive = true;
IsSelfRecursive = true;
}
vedantparanjape-amd marked this conversation as resolved.
Show resolved Hide resolved
}

// Can't change signature of musttail caller
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes
; RUN: opt < %s -passes=argpromotion -S | FileCheck %s

%T = type { i32, i32, i32, i32 }
@G = constant %T { i32 0, i32 0, i32 17, i32 25 }

define internal i32 @test(ptr %p) {
; CHECK-LABEL: define {{[^@]+}}@test
; CHECK-SAME: (i32 [[P_8_VAL:%.*]], i32 [[P_12_VAL:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[V:%.*]] = add i32 [[P_12_VAL]], [[P_8_VAL]]
; CHECK-NEXT: [[RET:%.*]] = call i32 @test(i32 [[P_8_VAL]], i32 [[P_12_VAL]])
; CHECK-NEXT: [[ARET:%.*]] = add i32 [[V]], [[RET]]
; CHECK-NEXT: ret i32 [[ARET]]
;
entry:
%a.gep = getelementptr %T, ptr %p, i64 0, i32 3
%b.gep = getelementptr %T, ptr %p, i64 0, i32 2
%a = load i32, ptr %a.gep
%b = load i32, ptr %b.gep
%v = add i32 %a, %b
%ret = call i32 @test(ptr %p)
%aret = add i32 %v, %ret
ret i32 %aret
}

define i32 @caller() {
; CHECK-LABEL: define {{[^@]+}}@caller() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr @G, i64 8
; CHECK-NEXT: [[G_VAL:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr @G, i64 12
; CHECK-NEXT: [[G_VAL1:%.*]] = load i32, ptr [[TMP1]], align 4
; CHECK-NEXT: [[V:%.*]] = call i32 @test(i32 [[G_VAL]], i32 [[G_VAL1]])
; CHECK-NEXT: ret i32 [[V]]
;
entry:
%v = call i32 @test(ptr @G)
ret i32 %v
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], 1
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL]], i32 [[X_0_VAL]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%subval = sub i32 %n, 1
%callret = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[X_VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(i32 [[X_VAL]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: ptr [[X:%.*]], ptr [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[VAL3:%.*]] = load i32, ptr [[Y]], align 4
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], [[VAL3]]
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(ptr [[X]], ptr [[Y]], i32 [[SUBVAL]], i32 [[VAL2]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(ptr [[Y]], ptr [[X]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%val3 = load i32, ptr %y, align 4
%subval = sub i32 %n, %val3
%callret = call i32 @foo(ptr %x, ptr %y, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %y, ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, ptr align(4) dereferenceable(4) %y, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], ptr align 4 dereferenceable(4) [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], ptr [[Y]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @zoo(ptr %x, i32 %m) {
; CHECK-LABEL: define internal i32 @zoo(
; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[RESZOO:%.*]] = add i32 [[X_0_VAL]], [[M]]
; CHECK-NEXT: ret i32 [[X_0_VAL]]
;
%valzoo = load i32, ptr %x, align 4
%reszoo = add i32 %valzoo, %m
ret i32 %valzoo
}

define internal i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: ptr [[X:%.*]], i32 [[Y_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], [[Y_0_VAL]]
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(ptr [[X]], i32 [[Y_0_VAL]], i32 [[SUBVAL]], i32 [[VAL2]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(ptr [[X]], i32 [[Y_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP1:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: [[X_VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[CALLRETFINAL:%.*]] = call i32 @zoo(i32 [[X_VAL]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CMP1]], [[CALLRETFINAL]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%val3 = load i32, ptr %y, align 4
%subval = sub i32 %n, %val3
%callret = call i32 @foo(ptr %x, ptr %y, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %x, ptr %y, i32 %subval2, i32 %m)
%cmp1 = add i32 %callret, %callret2
%callretfinal = call i32 @zoo(ptr %x, i32 %m)
%cmp2 = add i32 %cmp1, %callretfinal
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, ptr align(4) dereferenceable(4) %y, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], ptr align 4 dereferenceable(4) [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[Y_VAL:%.*]] = load i32, ptr [[Y]], align 4
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], i32 [[Y_VAL]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Loading