40 changes: 38 additions & 2 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {

unsigned Value::getNumUses() const { return range_size(Val->users()); }

void Value::replaceUsesWithIf(
Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
assert(getType() == OtherV->getType() && "Can't replace with different type");
llvm::Value *OtherVal = OtherV->Val;
Val->replaceUsesWithIf(
OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
if (DstU == nullptr)
return false;
return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
});
}

void Value::replaceAllUsesWith(Value *Other) {
assert(getType() == Other->getType() &&
"Replacing with Value of different type!");
Val->replaceAllUsesWith(Other->Val);
}

#ifndef NDEBUG
std::string Value::getName() const {
std::stringstream SS;
Expand Down Expand Up @@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
return Use(LLVMUse, const_cast<User *>(this), Ctx);
}

#ifndef NDEBUG
void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
assert(Ctx.getValue(Use.getUser()) == this &&
"Use not found in this SBUser's operands!");
}
#endif

bool User::classof(const Value *From) {
switch (From->getSubclassID()) {
#define DEF_VALUE(ID, CLASS)
Expand All @@ -180,6 +206,15 @@ bool User::classof(const Value *From) {
}
}

void User::setOperand(unsigned OperandIdx, Value *Operand) {
assert(isa<llvm::User>(Val) && "No operands!");
cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
}

bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
}

#ifndef NDEBUG
void User::dumpCommonHeader(raw_ostream &OS) const {
Value::dumpCommonHeader(OS);
Expand Down Expand Up @@ -325,10 +360,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();

if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
return It->second.get();
return NewC;
}
if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
unsigned PassthroughArgSize =
(F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
assert(ArgTranslations.size() == F->isVarArg() ? 5 : PassthroughArgSize);
assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));

// Translate arguments to call.
SmallVector<Value *> Args;
Expand Down
17 changes: 12 additions & 5 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1890,11 +1890,18 @@ ParseStatus RISCVAsmParser::parseCSRSystemRegister(OperandVector &Operands) {
if (CE) {
int64_t Imm = CE->getValue();
if (isUInt<12>(Imm)) {
auto SysReg = RISCVSysReg::lookupSysRegByEncoding(Imm);
// Accept an immediate representing a named or un-named Sys Reg
// if the range is valid, regardless of the required features.
Operands.push_back(
RISCVOperand::createSysReg(SysReg ? SysReg->Name : "", S, Imm));
auto Range = RISCVSysReg::lookupSysRegByEncoding(Imm);
// Accept an immediate representing a named Sys Reg if it satisfies the
// the required features.
for (auto &Reg : Range) {
if (Reg.haveRequiredFeatures(STI->getFeatureBits())) {
Operands.push_back(RISCVOperand::createSysReg(Reg.Name, S, Imm));
return ParseStatus::Success;
}
}
// Accept an immediate representing an un-named Sys Reg if the range is
// valid, regardless of the required features.
Operands.push_back(RISCVOperand::createSysReg("", S, Imm));
return ParseStatus::Success;
}
}
Expand Down
13 changes: 8 additions & 5 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,14 @@ void RISCVInstPrinter::printCSRSystemRegister(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
unsigned Imm = MI->getOperand(OpNo).getImm();
auto SysReg = RISCVSysReg::lookupSysRegByEncoding(Imm);
if (SysReg && SysReg->haveRequiredFeatures(STI.getFeatureBits()))
markup(O, Markup::Register) << SysReg->Name;
else
markup(O, Markup::Register) << formatImm(Imm);
auto Range = RISCVSysReg::lookupSysRegByEncoding(Imm);
for (auto &Reg : Range) {
if (Reg.haveRequiredFeatures(STI.getFeatureBits())) {
markup(O, Markup::Register) << Reg.Name;
return;
}
}
markup(O, Markup::Register) << formatImm(Imm);
}

void RISCVInstPrinter::printFenceArg(const MCInst *MI, unsigned OpNo,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVSystemOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def SysRegsList : GenericTable {

let PrimaryKey = [ "Encoding" ];
let PrimaryKeyName = "lookupSysRegByEncoding";
let PrimaryKeyReturnRange = true;
}

def lookupSysRegByName : SearchIndex {
Expand Down
28 changes: 13 additions & 15 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ static void handleNoSuspendCoroutine(coro::Shape &Shape) {
}

CoroBegin->eraseFromParent();
Shape.CoroBegin = nullptr;
}

// SimplifySuspendPoint needs to check that there is no calls between
Expand Down Expand Up @@ -1970,9 +1971,17 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
}

/// Remove calls to llvm.coro.end in the original function.
static void removeCoroEnds(const coro::Shape &Shape) {
for (auto *End : Shape.CoroEnds) {
replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr);
static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) {
if (Shape.ABI != coro::ABI::Switch) {
for (auto *End : Shape.CoroEnds) {
replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr);
}
} else {
for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) {
auto &Context = End->getContext();
End->replaceAllUsesWith(ConstantInt::getFalse(Context));
End->eraseFromParent();
}
}
}

Expand All @@ -1981,18 +1990,6 @@ static void updateCallGraphAfterCoroutineSplit(
const SmallVectorImpl<Function *> &Clones, LazyCallGraph::SCC &C,
LazyCallGraph &CG, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR,
FunctionAnalysisManager &FAM) {
if (!Shape.CoroBegin)
return;

if (Shape.ABI != coro::ABI::Switch)
removeCoroEnds(Shape);
else {
for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) {
auto &Context = End->getContext();
End->replaceAllUsesWith(ConstantInt::getFalse(Context));
End->eraseFromParent();
}
}

if (!Clones.empty()) {
switch (Shape.ABI) {
Expand Down Expand Up @@ -2120,6 +2117,7 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
const coro::Shape Shape =
splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
OptimizeFrame, MaterializableCallback);
removeCoroEndsFromRampFunction(Shape);
updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);

ORE.emit([&]() {
Expand Down
71 changes: 66 additions & 5 deletions llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM,

/// Return true if we can prove that all callees pass in a valid pointer for the
/// specified function argument.
static bool allCallersPassValidPointerForArgument(Argument *Arg,
Align NeededAlign,
uint64_t NeededDerefBytes) {
static bool allCallersPassValidPointerForArgument(
Argument *Arg, SmallPtrSetImpl<CallBase *> &RecursiveCalls,
Align NeededAlign, uint64_t NeededDerefBytes) {
Function *Callee = Arg->getParent();
const DataLayout &DL = Callee->getDataLayout();
APInt Bytes(64, NeededDerefBytes);
Expand All @@ -438,6 +438,33 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
// direct callees.
return all_of(Callee->users(), [&](User *U) {
CallBase &CB = cast<CallBase>(*U);
// In case of functions with recursive calls, this check
// (isDereferenceableAndAlignedPointer) will fail when it tries to look at
// the first caller of this function. The caller may or may not have a load,
// incase it doesn't load the pointer being passed, this check will fail.
// So, it's safe to skip the check incase we know that we are dealing with a
// recursive call. For example we have a IR given below.
//
// def fun(ptr %a) {
// ...
// %loadres = load i32, ptr %a, align 4
// %res = call i32 @fun(ptr %a)
// ...
// }
//
// def bar(ptr %x) {
// ...
// %resbar = call i32 @fun(ptr %x)
// ...
// }
//
// Since we record processed recursive calls, we check if the current
// CallBase has been processed before. If yes it means that it is a
// recursive call and we can skip the check just for this call. So, just
// return true.
if (RecursiveCalls.contains(&CB))
return true;

return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
NeededAlign, Bytes, DL);
});
Expand Down Expand Up @@ -571,6 +598,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
SmallVector<const Use *, 16> Worklist;
SmallPtrSet<const Use *, 16> Visited;
SmallVector<LoadInst *, 16> Loads;
SmallPtrSet<CallBase *, 4> RecursiveCalls;
auto AppendUses = [&](const Value *V) {
for (const Use &U : V->uses())
if (Visited.insert(&U).second)
Expand Down Expand Up @@ -611,6 +639,33 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
// unknown users
}

auto *CB = dyn_cast<CallBase>(V);
Value *PtrArg = cast<Value>(U);
if (CB && PtrArg && CB->getCalledFunction() == CB->getFunction()) {
if (PtrArg != Arg) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "pointer offset is not equal to zero\n");
return false;
}

unsigned int ArgNo = Arg->getArgNo();
if (CB->getArgOperand(ArgNo) != Arg || U->getOperandNo() != ArgNo) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "arg position is different in callee\n");
return false;
}

// We limit promotion to only promoting up to a fixed number of elements
// of the aggregate.
if (MaxElements > 0 && ArgParts.size() > MaxElements) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "more than " << MaxElements << " parts\n");
return false;
}

RecursiveCalls.insert(CB);
continue;
}
// Unknown user.
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "unknown user " << *V << "\n");
Expand All @@ -619,7 +674,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,

if (NeededDerefBytes || NeededAlign > 1) {
// Try to prove a required deref / aligned requirement.
if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
NeededDerefBytes)) {
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "not dereferenceable or aligned\n");
Expand Down Expand Up @@ -700,6 +755,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
/// calls the DoPromotion method.
static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
unsigned MaxElements, bool IsRecursive) {
// Due to complexity of handling cases where the SCC has more than one
// component. We want to limit argument promotion of recursive calls to
// just functions that directly call themselves.
bool IsSelfRecursive = false;
// Don't perform argument promotion for naked functions; otherwise we can end
// up removing parameters that are seemingly 'not used' as they are referred
// to in the assembly.
Expand Down Expand Up @@ -745,8 +804,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
if (CB->isMustTailCall())
return nullptr;

if (CB->getFunction() == F)
if (CB->getFunction() == F) {
IsRecursive = true;
IsSelfRecursive = true;
}
}

// Can't change signature of musttail caller
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Transforms/Scalar/LoopDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,10 @@ class LoopDistributeForLoop {

} // end anonymous namespace

/// Shared implementation between new and old PMs.
static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT,
ScalarEvolution *SE, OptimizationRemarkEmitter *ORE,
LoopAccessInfoManager &LAIs) {
// Build up a worklist of inner-loops to vectorize. This is necessary as the
// Build up a worklist of inner-loops to distribute. This is necessary as the
// act of distributing a loop creates new loops and can invalidate iterators
// across the loops.
SmallVector<Loop *, 8> Worklist;
Expand Down
22 changes: 22 additions & 0 deletions llvm/test/CodeGen/AMDGPU/copyprop_regsequence_with_undef.mir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
# RUN: llc -mtriple=amdgcn -run-pass=machine-cse -verify-machineinstrs -o - %s | FileCheck %s

# Test to ensure that this does not crash on undefs
---
name: copyprop_regsequence_with_undef
tracksRegLiveness: true
body: |
bb.0:
; CHECK-LABEL: name: copyprop_regsequence_with_undef
; CHECK: [[DEF:%[0-9]+]]:sreg_32 = IMPLICIT_DEF
; CHECK-NEXT: [[DEF1:%[0-9]+]]:sreg_32 = IMPLICIT_DEF
; CHECK-NEXT: [[REG_SEQUENCE:%[0-9]+]]:sreg_64 = REG_SEQUENCE undef %3:sreg_32, %subreg.sub0, [[DEF]], %subreg.sub1
; CHECK-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:sreg_64 = REG_SEQUENCE undef %5:sreg_32, %subreg.sub0, [[DEF1]], %subreg.sub1
; CHECK-NEXT: [[S_ADD_I32_:%[0-9]+]]:sreg_32 = S_ADD_I32 [[REG_SEQUENCE]].sub1, [[REG_SEQUENCE1]].sub1, implicit-def $scc
%0:sreg_32 = IMPLICIT_DEF
%1:sreg_32 = IMPLICIT_DEF
%4:sreg_64 = REG_SEQUENCE undef %10:sreg_32, %subreg.sub0, %0:sreg_32, %subreg.sub1
%5:sreg_64 = REG_SEQUENCE undef %11:sreg_32, %subreg.sub0, %1:sreg_32, %subreg.sub1
%6:sreg_32 = S_ADD_I32 %4.sub1:sreg_64, %5.sub1:sreg_64, implicit-def $scc

...
33 changes: 27 additions & 6 deletions llvm/test/MC/AsmParser/altmacro-arg.s
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
## Arguments can be expanded even if they are not preceded by \
# RUN: llvm-mc -triple=x86_64 %s | FileCheck %s
# RUN: rm -rf %t && split-file %s %t && cd %t
# RUN: llvm-mc -triple=x86_64 a.s | FileCheck %s
# RUN: llvm-mc -triple=x86_64 b.s | FileCheck %s --check-prefix=CHECK1

# CHECK: 1 1 1a
# CHECK-NEXT: 1 2 1a 2b
# CHECK-NEXT: \$b \$b
#--- a.s
.altmacro
# CHECK: ja .Ltmp0
# CHECK-NEXT: xorq %rbx, %rbx
# CHECK: .data
# CHECK-NEXT: .ascii "b cc rbx"
# CHECK-NEXT: .ascii "bcc ccx rbx raxx"
.macro gen a, ra, rax
ja 1f
xorq %rax, %rax
1:
.data
.ascii "\a \ra \rax"
.ascii "a\()ra ra\()x rax raxx"
.endm
gen b, cc, rbx

#--- b.s
.altmacro
# CHECK1: 1 1 1a
# CHECK1-NEXT: 1 2 1a 2b
# CHECK1-NEXT: \$b \$b
.irp ._a,1
.print "\._a \._a& ._a&a"
.irp $b,2
Expand All @@ -13,10 +33,11 @@
.print "\$b \$b&"
.endr

# CHECK: 1 1& ._a&a
# CHECK-NEXT: \$b \$b&
# CHECK1: 1 1& ._a&a
# CHECK1-NEXT: \$b \$b&
.noaltmacro
.irp ._a,1
.print "\._a \._a& ._a&a"
.print "\$b \$b&"
.endr
.altmacro
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes
; RUN: opt < %s -passes=argpromotion -S | FileCheck %s

%T = type { i32, i32, i32, i32 }
@G = constant %T { i32 0, i32 0, i32 17, i32 25 }

define internal i32 @test(ptr %p) {
; CHECK-LABEL: define {{[^@]+}}@test
; CHECK-SAME: (i32 [[P_8_VAL:%.*]], i32 [[P_12_VAL:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[V:%.*]] = add i32 [[P_12_VAL]], [[P_8_VAL]]
; CHECK-NEXT: [[RET:%.*]] = call i32 @test(i32 [[P_8_VAL]], i32 [[P_12_VAL]])
; CHECK-NEXT: [[ARET:%.*]] = add i32 [[V]], [[RET]]
; CHECK-NEXT: ret i32 [[ARET]]
;
entry:
%a.gep = getelementptr %T, ptr %p, i64 0, i32 3
%b.gep = getelementptr %T, ptr %p, i64 0, i32 2
%a = load i32, ptr %a.gep
%b = load i32, ptr %b.gep
%v = add i32 %a, %b
%ret = call i32 @test(ptr %p)
%aret = add i32 %v, %ret
ret i32 %aret
}

define i32 @caller() {
; CHECK-LABEL: define {{[^@]+}}@caller() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr @G, i64 8
; CHECK-NEXT: [[G_VAL:%.*]] = load i32, ptr [[TMP0]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr @G, i64 12
; CHECK-NEXT: [[G_VAL1:%.*]] = load i32, ptr [[TMP1]], align 4
; CHECK-NEXT: [[V:%.*]] = call i32 @test(i32 [[G_VAL]], i32 [[G_VAL1]])
; CHECK-NEXT: ret i32 [[V]]
;
entry:
%v = call i32 @test(ptr @G)
ret i32 %v
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], 1
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL]], i32 [[X_0_VAL]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%subval = sub i32 %n, 1
%callret = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[X_VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(i32 [[X_VAL]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: ptr [[X:%.*]], ptr [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[VAL3:%.*]] = load i32, ptr [[Y]], align 4
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], [[VAL3]]
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(ptr [[X]], ptr [[Y]], i32 [[SUBVAL]], i32 [[VAL2]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(ptr [[Y]], ptr [[X]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%val3 = load i32, ptr %y, align 4
%subval = sub i32 %n, %val3
%callret = call i32 @foo(ptr %x, ptr %y, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %y, ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, ptr align(4) dereferenceable(4) %y, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], ptr align 4 dereferenceable(4) [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], ptr [[Y]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @zoo(ptr %x, i32 %m) {
; CHECK-LABEL: define internal i32 @zoo(
; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[RESZOO:%.*]] = add i32 [[X_0_VAL]], [[M]]
; CHECK-NEXT: ret i32 [[X_0_VAL]]
;
%valzoo = load i32, ptr %x, align 4
%reszoo = add i32 %valzoo, %m
ret i32 %valzoo
}

define internal i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: ptr [[X:%.*]], i32 [[Y_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], [[Y_0_VAL]]
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(ptr [[X]], i32 [[Y_0_VAL]], i32 [[SUBVAL]], i32 [[VAL2]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(ptr [[X]], i32 [[Y_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP1:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: [[X_VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[CALLRETFINAL:%.*]] = call i32 @zoo(i32 [[X_VAL]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CMP1]], [[CALLRETFINAL]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%val3 = load i32, ptr %y, align 4
%subval = sub i32 %n, %val3
%callret = call i32 @foo(ptr %x, ptr %y, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %x, ptr %y, i32 %subval2, i32 %m)
%cmp1 = add i32 %callret, %callret2
%callretfinal = call i32 @zoo(ptr %x, i32 %m)
%cmp2 = add i32 %cmp1, %callretfinal
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, ptr align(4) dereferenceable(4) %y, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], ptr align 4 dereferenceable(4) [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[Y_VAL:%.*]] = load i32, ptr [[Y]], align 4
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], i32 [[Y_VAL]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], 1
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL]], i32 [[X_0_VAL]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%subval = sub i32 %n, 1
%callret = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[GEPVAL:%.*]] = getelementptr ptr, ptr [[X]], i32 0
; CHECK-NEXT: [[GEPVAL_VAL:%.*]] = load i32, ptr [[GEPVAL]], align 4
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(i32 [[GEPVAL_VAL]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%gepval = getelementptr ptr, ptr %x, i32 0
%callret3 = call i32 @foo(ptr %gepval, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}

define internal i32 @foo2(ptr %x, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo2(
; CHECK-SAME: ptr [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], 1
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo2(ptr [[X]], i32 [[SUBVAL]], i32 [[VAL2]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo2(ptr [[X]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%subval = sub i32 %n, 1
%callret = call i32 @foo2(ptr %x, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo2(ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar2(ptr align(4) dereferenceable(4) %x, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar2(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[GEPVAL:%.*]] = getelementptr ptr, ptr [[X]], i32 4
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo2(ptr [[GEPVAL]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%gepval = getelementptr ptr, ptr %x, i32 4
%callret3 = call i32 @foo2(ptr %gepval, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
define internal i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m) {
; CHECK-LABEL: define internal i32 @foo(
; CHECK-SAME: ptr [[X:%.*]], ptr [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
; CHECK: [[COND_TRUE]]:
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[COND_FALSE]]:
; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[VAL3:%.*]] = load i32, ptr [[Y]], align 4
; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], [[VAL3]]
; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(ptr [[X]], ptr [[Y]], i32 [[SUBVAL]], i32 [[VAL2]])
; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(ptr [[X]], ptr [[X]], i32 [[SUBVAL2]], i32 [[M]])
; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[COND_NEXT:.*]]:
; CHECK-NEXT: br label %[[RETURN]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp ne i32 %n, 0
br i1 %cmp, label %cond_true, label %cond_false

cond_true: ; preds = %entry
%val = load i32, ptr %x, align 4
br label %return

cond_false: ; preds = %entry
%val2 = load i32, ptr %x, align 4
%val3 = load i32, ptr %y, align 4
%subval = sub i32 %n, %val3
%callret = call i32 @foo(ptr %x, ptr %y, i32 %subval, i32 %val2)
%subval2 = sub i32 %n, 2
%callret2 = call i32 @foo(ptr %x, ptr %x, i32 %subval2, i32 %m)
%cmp2 = add i32 %callret, %callret2
br label %return

cond_next: ; No predecessors!
br label %return

return: ; preds = %cond_next, %cond_false, %cond_true
%retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
ret i32 %retval.0
}

define i32 @bar(ptr align(4) dereferenceable(4) %x, ptr align(4) dereferenceable(4) %y, i32 %n, i32 %m) {
; CHECK-LABEL: define i32 @bar(
; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], ptr align 4 dereferenceable(4) [[Y:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], ptr [[Y]], i32 [[N]], i32 [[M]])
; CHECK-NEXT: br label %[[RETURN:.*]]
; CHECK: [[RETURN]]:
; CHECK-NEXT: ret i32 [[CALLRET3]]
;
entry:
%callret3 = call i32 @foo(ptr %x, ptr %y, i32 %n, i32 %m)
br label %return

return: ; preds = %entry
ret i32 %callret3
}
123 changes: 123 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMBBIt = LLVMBB->begin();
Instruction *LLVMI0 = &*LLVMBBIt++;
Instruction *LLVMRet = &*LLVMBBIt++;
Argument *LLVMArg0 = LLVMF.getArg(0);
Argument *LLVMArg1 = LLVMF.getArg(1);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
Expand Down Expand Up @@ -203,6 +206,126 @@ OperandNo: 0
EXPECT_FALSE(I0->hasNUses(0u));
EXPECT_TRUE(I0->hasNUses(1u));
EXPECT_FALSE(I0->hasNUses(2u));

// Check User.setOperand().
Ret->setOperand(0, Arg0);
EXPECT_EQ(Ret->getOperand(0), Arg0);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);

Ret->setOperand(0, Arg1);
EXPECT_EQ(Ret->getOperand(0), Arg1);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
}

TEST_F(SandboxIRTest, RUOW) {
parseIR(C, R"IR(
declare void @bar0()
declare void @bar1()

@glob0 = global ptr @bar0
@glob1 = global ptr @bar1

define i32 @foo(i32 %arg0, i32 %arg1) {
%add0 = add i32 %arg0, %arg1
%gep1 = getelementptr i8, ptr @glob0, i32 1
%gep2 = getelementptr i8, ptr @glob1, i32 1
ret i32 %add0
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto It = BB.begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *I2 = &*It++;
auto *Ret = &*It++;

bool Replaced;
// Try to replace an operand that doesn't match.
Replaced = I0->replaceUsesOfWith(Ret, Arg1);
EXPECT_FALSE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands differ.
Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg1);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands are the same.
Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg0);

// Replace Ret operand.
Replaced = Ret->replaceUsesOfWith(I0, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(Ret->getOperand(0), Arg0);

// Check RAUW on constant.
auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
auto *Glob0Op = Glob0->getOperand(0);
Glob0->replaceUsesOfWith(Glob0Op, Glob1);
EXPECT_EQ(Glob0->getOperand(0), Glob1);
}

TEST_F(SandboxIRTest, RAUW_RUWIf) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
%ld0 = load float, ptr %ptr
%ld1 = load float, ptr %ptr
store float %ld0, ptr %ptr
store float %ld0, ptr %ptr
ret void
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
llvm::BasicBlock *LLVMBB = &*LLVMF.begin();

Ctx.createFunction(&LLVMF);
auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
auto It = BB->begin();
sandboxir::Instruction *Ld0 = &*It++;
sandboxir::Instruction *Ld1 = &*It++;
sandboxir::Instruction *St0 = &*It++;
sandboxir::Instruction *St1 = &*It++;
// Check RUWIf when the lambda returns false.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
EXPECT_EQ(St0->getOperand(0), Ld0);
EXPECT_EQ(St1->getOperand(0), Ld0);
// Check RUWIf when the lambda returns true.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
EXPECT_EQ(St0->getOperand(0), Ld1);
EXPECT_EQ(St1->getOperand(0), Ld1);
St0->setOperand(0, Ld0);
St1->setOperand(0, Ld0);
// Check RUWIf user == St0.
Ld0->replaceUsesWithIf(
Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; });
EXPECT_EQ(St0->getOperand(0), Ld1);
EXPECT_EQ(St1->getOperand(0), Ld0);
St0->setOperand(0, Ld0);
// Check RUWIf user == St1.
Ld0->replaceUsesWithIf(
Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; });
EXPECT_EQ(St0->getOperand(0), Ld0);
EXPECT_EQ(St1->getOperand(0), Ld1);
St1->setOperand(0, Ld0);
// Check RAUW.
Ld1->replaceAllUsesWith(Ld0);
EXPECT_EQ(St0->getOperand(0), Ld0);
EXPECT_EQ(St1->getOperand(0), Ld0);
}

// Check that the operands/users are counted correctly.
Expand Down
280 changes: 280 additions & 0 deletions mlir/include/mlir/Support/CyclicReplacerCache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains helper classes for caching replacer-like functions that
// map values between two domains. They are able to handle replacer logic that
// contains self-recursion.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H

#include "mlir/IR/Visitors.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include <set>

namespace mlir {

//===----------------------------------------------------------------------===//
// CyclicReplacerCache
//===----------------------------------------------------------------------===//

/// A cache for replacer-like functions that map values between two domains. The
/// difference compared to just using a map to cache in-out pairs is that this
/// class is able to handle replacer logic that is self-recursive (and thus may
/// cause infinite recursion in the naive case).
///
/// This class provides a hook for the user to perform cycle pruning when a
/// cycle is identified, and is able to perform context-sensitive caching so
/// that the replacement result for an input that is part of a pruned cycle can
/// be distinct from the replacement result for the same input when it is not
/// part of a cycle.
///
/// In addition, this class allows deferring cycle pruning until specific inputs
/// are repeated. This is useful for cases where not all elements in a cycle can
/// perform pruning. The user still must guarantee that at least one element in
/// any given cycle can perform pruning. Even if not, an assertion will
/// eventually be tripped instead of infinite recursion (the run-time is
/// linearly bounded by the maximum cycle length of its input).
///
/// WARNING: This class works best with InT & OutT that are trivial scalar
/// types. The input/output elements will be frequently copied and hashed.
template <typename InT, typename OutT>
class CyclicReplacerCache {
public:
/// User-provided replacement function & cycle-breaking functions.
/// The cycle-breaking function must not make any more recursive invocations
/// to this cached replacer.
using CycleBreakerFn = std::function<std::optional<OutT>(InT)>;

CyclicReplacerCache() = delete;
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
: cycleBreaker(std::move(cycleBreaker)) {}

/// A possibly unresolved cache entry.
/// If unresolved, the entry must be resolved before it goes out of scope.
struct CacheEntry {
public:
~CacheEntry() { assert(result && "unresovled cache entry"); }

/// Check whether this node was repeated during recursive replacements.
/// This only makes sense to be called after all recursive replacements are
/// completed and the current element has resurfaced to the top of the
/// replacement stack.
bool wasRepeated() const {
// If the top frame includes itself as a dependency, then it must have
// been repeated.
ReplacementFrame &currFrame = cache.replacementStack.back();
size_t currFrameIndex = cache.replacementStack.size() - 1;
return currFrame.dependentFrames.count(currFrameIndex);
}

/// Resolve an unresolved cache entry by providing the result to be stored
/// in the cache.
void resolve(OutT result) {
assert(!this->result && "cache entry already resolved");
cache.finalizeReplacement(element, result);
this->result = std::move(result);
}

/// Get the resolved result if one exists.
const std::optional<OutT> &get() const { return result; }

private:
friend class CyclicReplacerCache;
CacheEntry() = delete;
CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
std::optional<OutT> result = std::nullopt)
: cache(cache), element(std::move(element)), result(result) {}

CyclicReplacerCache<InT, OutT> &cache;
InT element;
std::optional<OutT> result;
};

/// Lookup the cache for a pre-calculated replacement for `element`.
/// If one exists, a resolved CacheEntry will be returned. Otherwise, an
/// unresolved CacheEntry will be returned, and the caller must resolve it
/// with the calculated replacement so it can be registered in the cache for
/// future use.
/// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
/// CacheEntries that are returned must be resolved in reverse order of
/// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
/// the first retrieved CacheEntry must be resolved last. This should be
/// natural when used as a stack / inside recursion.
CacheEntry lookupOrInit(InT element);

private:
/// Register the replacement in the cache and update the replacementStack.
void finalizeReplacement(InT element, OutT result);

CycleBreakerFn cycleBreaker;
DenseMap<InT, OutT> standaloneCache;

struct DependentReplacement {
OutT replacement;
/// The highest replacement frame index that this cache entry is dependent
/// on.
size_t highestDependentFrame;
};
DenseMap<InT, DependentReplacement> dependentCache;

struct ReplacementFrame {
/// The set of elements that is only legal while under this current frame.
/// They need to be removed from the cache when this frame is popped off the
/// replacement stack.
DenseSet<InT> dependingReplacements;
/// The set of frame indices that this current frame's replacement is
/// dependent on, ordered from highest to lowest.
std::set<size_t, std::greater<size_t>> dependentFrames;
};
/// Every element currently in the progress of being replaced pushes a frame
/// onto this stack.
SmallVector<ReplacementFrame> replacementStack;
/// Maps from each input element to its indices on the replacement stack.
DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
/// If set to true, we are currently asking an element to break a cycle. No
/// more recursive invocations is allowed while this is true (the replacement
/// stack can no longer grow).
bool resolvingCycle = false;
};

template <typename InT, typename OutT>
typename CyclicReplacerCache<InT, OutT>::CacheEntry
CyclicReplacerCache<InT, OutT>::lookupOrInit(InT element) {
assert(!resolvingCycle &&
"illegal recursive invocation while breaking cycle");

if (auto it = standaloneCache.find(element); it != standaloneCache.end())
return CacheEntry(*this, element, it->second);

if (auto it = dependentCache.find(element); it != dependentCache.end()) {
// Update the current top frame (the element that invoked this current
// replacement) to include any dependencies the cache entry had.
ReplacementFrame &currFrame = replacementStack.back();
currFrame.dependentFrames.insert(it->second.highestDependentFrame);
return CacheEntry(*this, element, it->second.replacement);
}

auto [it, inserted] = cyclicElementFrame.try_emplace(element);
if (!inserted) {
// This is a repeat of a known element. Try to break cycle here.
resolvingCycle = true;
std::optional<OutT> result = cycleBreaker(element);
resolvingCycle = false;
if (result) {
// Cycle was broken.
size_t dependentFrame = it->second.back();
dependentCache[element] = {*result, dependentFrame};
ReplacementFrame &currFrame = replacementStack.back();
// If this is a repeat, there is no replacement frame to pop. Mark the top
// frame as being dependent on this element.
currFrame.dependentFrames.insert(dependentFrame);

return CacheEntry(*this, element, *result);
}

// Cycle could not be broken.
// A legal setup must ensure at least one element of each cycle can break
// cycles. Under this setup, each element can be seen at most twice before
// the cycle is broken. If we see an element more than twice, we know this
// is an illegal setup.
assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
}

// Otherwise, either this is the first time we see this element, or this
// element could not break this cycle.
it->second.push_back(replacementStack.size());
replacementStack.emplace_back();

return CacheEntry(*this, element);
}

template <typename InT, typename OutT>
void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
OutT result) {
ReplacementFrame &currFrame = replacementStack.back();
// With the conclusion of this replacement frame, the current element is no
// longer a dependent element.
currFrame.dependentFrames.erase(replacementStack.size() - 1);

auto prevLayerIter = ++replacementStack.rbegin();
if (prevLayerIter == replacementStack.rend()) {
// If this is the last frame, there should be zero dependents.
assert(currFrame.dependentFrames.empty() &&
"internal error: top-level dependent replacement");
// Cache standalone result.
standaloneCache[element] = result;
} else if (currFrame.dependentFrames.empty()) {
// Cache standalone result.
standaloneCache[element] = result;
} else {
// Cache dependent result.
size_t highestDependentFrame = *currFrame.dependentFrames.begin();
dependentCache[element] = {result, highestDependentFrame};

// Otherwise, the previous frame inherits the same dependent frames.
prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
currFrame.dependentFrames.end());

// Mark this current replacement as a depending replacement on the closest
// dependent frame.
replacementStack[highestDependentFrame].dependingReplacements.insert(
element);
}

// All depending replacements in the cache must be purged.
for (InT key : currFrame.dependingReplacements)
dependentCache.erase(key);

replacementStack.pop_back();
auto it = cyclicElementFrame.find(element);
it->second.pop_back();
if (it->second.empty())
cyclicElementFrame.erase(it);
}

//===----------------------------------------------------------------------===//
// CachedCyclicReplacer
//===----------------------------------------------------------------------===//

/// A helper class for cases where the input/output types of the replacer
/// function is identical to the types stored in the cache. This class wraps
/// the user-provided replacer function, and can be used in place of the user
/// function.
template <typename InT, typename OutT>
class CachedCyclicReplacer {
public:
using ReplacerFn = std::function<OutT(InT)>;
using CycleBreakerFn =
typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;

CachedCyclicReplacer() = delete;
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
: replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}

OutT operator()(InT element) {
auto cacheEntry = cache.lookupOrInit(element);
if (std::optional<OutT> result = cacheEntry.get())
return *result;

OutT result = replacer(element);
cacheEntry.resolve(result);
return result;
}

private:
ReplacerFn replacer;
CyclicReplacerCache<InT, OutT> cache;
};

} // namespace mlir

#endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H
1 change: 1 addition & 0 deletions mlir/unittests/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRSupportTests
CyclicReplacerCacheTest.cpp
IndentedOstreamTest.cpp
StorageUniquerTest.cpp
)
Expand Down
478 changes: 478 additions & 0 deletions mlir/unittests/Support/CyclicReplacerCacheTest.cpp

Large diffs are not rendered by default.