40 changes: 38 additions & 2 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {

unsigned Value::getNumUses() const { return range_size(Val->users()); }

void Value::replaceUsesWithIf(
Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
assert(getType() == OtherV->getType() && "Can't replace with different type");
llvm::Value *OtherVal = OtherV->Val;
Val->replaceUsesWithIf(
OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
if (DstU == nullptr)
return false;
return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
});
}

void Value::replaceAllUsesWith(Value *Other) {
assert(getType() == Other->getType() &&
"Replacing with Value of different type!");
Val->replaceAllUsesWith(Other->Val);
}

#ifndef NDEBUG
std::string Value::getName() const {
std::stringstream SS;
Expand Down Expand Up @@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
return Use(LLVMUse, const_cast<User *>(this), Ctx);
}

#ifndef NDEBUG
void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
assert(Ctx.getValue(Use.getUser()) == this &&
"Use not found in this SBUser's operands!");
}
#endif

bool User::classof(const Value *From) {
switch (From->getSubclassID()) {
#define DEF_VALUE(ID, CLASS)
Expand All @@ -180,6 +206,15 @@ bool User::classof(const Value *From) {
}
}

void User::setOperand(unsigned OperandIdx, Value *Operand) {
assert(isa<llvm::User>(Val) && "No operands!");
cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
}

bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
}

#ifndef NDEBUG
void User::dumpCommonHeader(raw_ostream &OS) const {
Value::dumpCommonHeader(OS);
Expand Down Expand Up @@ -325,10 +360,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();

if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
return It->second.get();
return NewC;
}
if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
unsigned PassthroughArgSize =
(F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
assert(ArgTranslations.size() == F->isVarArg() ? 5 : PassthroughArgSize);
assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));

// Translate arguments to call.
SmallVector<Value *> Args;
Expand Down
17 changes: 12 additions & 5 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1890,11 +1890,18 @@ ParseStatus RISCVAsmParser::parseCSRSystemRegister(OperandVector &Operands) {
if (CE) {
int64_t Imm = CE->getValue();
if (isUInt<12>(Imm)) {
auto SysReg = RISCVSysReg::lookupSysRegByEncoding(Imm);
// Accept an immediate representing a named or un-named Sys Reg
// if the range is valid, regardless of the required features.
Operands.push_back(
RISCVOperand::createSysReg(SysReg ? SysReg->Name : "", S, Imm));
auto Range = RISCVSysReg::lookupSysRegByEncoding(Imm);
// Accept an immediate representing a named Sys Reg if it satisfies the
// the required features.
for (auto &Reg : Range) {
if (Reg.haveRequiredFeatures(STI->getFeatureBits())) {
Operands.push_back(RISCVOperand::createSysReg(Reg.Name, S, Imm));
return ParseStatus::Success;
}
}
// Accept an immediate representing an un-named Sys Reg if the range is
// valid, regardless of the required features.
Operands.push_back(RISCVOperand::createSysReg("", S, Imm));
return ParseStatus::Success;
}
}
Expand Down
13 changes: 8 additions & 5 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,14 @@ void RISCVInstPrinter::printCSRSystemRegister(const MCInst *MI, unsigned OpNo,
const MCSubtargetInfo &STI,
raw_ostream &O) {
unsigned Imm = MI->getOperand(OpNo).getImm();
auto SysReg = RISCVSysReg::lookupSysRegByEncoding(Imm);
if (SysReg && SysReg->haveRequiredFeatures(STI.getFeatureBits()))
markup(O, Markup::Register) << SysReg->Name;
else
markup(O, Markup::Register) << formatImm(Imm);
auto Range = RISCVSysReg::lookupSysRegByEncoding(Imm);
for (auto &Reg : Range) {
if (Reg.haveRequiredFeatures(STI.getFeatureBits())) {
markup(O, Markup::Register) << Reg.Name;
return;
}
}
markup(O, Markup::Register) << formatImm(Imm);
}

void RISCVInstPrinter::printFenceArg(const MCInst *MI, unsigned OpNo,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVSystemOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def SysRegsList : GenericTable {

let PrimaryKey = [ "Encoding" ];
let PrimaryKeyName = "lookupSysRegByEncoding";
let PrimaryKeyReturnRange = true;
}

def lookupSysRegByName : SearchIndex {
Expand Down
28 changes: 13 additions & 15 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ static void handleNoSuspendCoroutine(coro::Shape &Shape) {
}

CoroBegin->eraseFromParent();
Shape.CoroBegin = nullptr;
}

// SimplifySuspendPoint needs to check that there is no calls between
Expand Down Expand Up @@ -1970,9 +1971,17 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
}

/// Remove calls to llvm.coro.end in the original function.
static void removeCoroEnds(const coro::Shape &Shape) {
for (auto *End : Shape.CoroEnds) {
replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr);
static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) {
if (Shape.ABI != coro::ABI::Switch) {
for (auto *End : Shape.CoroEnds) {
replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr);
}
} else {
for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) {
auto &Context = End->getContext();
End->replaceAllUsesWith(ConstantInt::getFalse(Context));
End->eraseFromParent();
}
}
}

Expand All @@ -1981,18 +1990,6 @@ static void updateCallGraphAfterCoroutineSplit(
const SmallVectorImpl<Function *> &Clones, LazyCallGraph::SCC &C,
LazyCallGraph &CG, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR,
FunctionAnalysisManager &FAM) {
if (!Shape.CoroBegin)
return;

if (Shape.ABI != coro::ABI::Switch)
removeCoroEnds(Shape);
else {
for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) {
auto &Context = End->getContext();
End->replaceAllUsesWith(ConstantInt::getFalse(Context));
End->eraseFromParent();
}
}

if (!Clones.empty()) {
switch (Shape.ABI) {
Expand Down Expand Up @@ -2120,6 +2117,7 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
const coro::Shape Shape =
splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
OptimizeFrame, MaterializableCallback);
removeCoroEndsFromRampFunction(Shape);
updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);

ORE.emit([&]() {
Expand Down
22 changes: 0 additions & 22 deletions llvm/test/MC/AsmParser/altmacro-arg.s

This file was deleted.

123 changes: 123 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMBBIt = LLVMBB->begin();
Instruction *LLVMI0 = &*LLVMBBIt++;
Instruction *LLVMRet = &*LLVMBBIt++;
Argument *LLVMArg0 = LLVMF.getArg(0);
Argument *LLVMArg1 = LLVMF.getArg(1);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
Expand Down Expand Up @@ -203,6 +206,126 @@ OperandNo: 0
EXPECT_FALSE(I0->hasNUses(0u));
EXPECT_TRUE(I0->hasNUses(1u));
EXPECT_FALSE(I0->hasNUses(2u));

// Check User.setOperand().
Ret->setOperand(0, Arg0);
EXPECT_EQ(Ret->getOperand(0), Arg0);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);

Ret->setOperand(0, Arg1);
EXPECT_EQ(Ret->getOperand(0), Arg1);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
}

TEST_F(SandboxIRTest, RUOW) {
parseIR(C, R"IR(
declare void @bar0()
declare void @bar1()
@glob0 = global ptr @bar0
@glob1 = global ptr @bar1
define i32 @foo(i32 %arg0, i32 %arg1) {
%add0 = add i32 %arg0, %arg1
%gep1 = getelementptr i8, ptr @glob0, i32 1
%gep2 = getelementptr i8, ptr @glob1, i32 1
ret i32 %add0
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto It = BB.begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *I2 = &*It++;
auto *Ret = &*It++;

bool Replaced;
// Try to replace an operand that doesn't match.
Replaced = I0->replaceUsesOfWith(Ret, Arg1);
EXPECT_FALSE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands differ.
Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg1);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands are the same.
Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg0);

// Replace Ret operand.
Replaced = Ret->replaceUsesOfWith(I0, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(Ret->getOperand(0), Arg0);

// Check RAUW on constant.
auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
auto *Glob0Op = Glob0->getOperand(0);
Glob0->replaceUsesOfWith(Glob0Op, Glob1);
EXPECT_EQ(Glob0->getOperand(0), Glob1);
}

TEST_F(SandboxIRTest, RAUW_RUWIf) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
%ld0 = load float, ptr %ptr
%ld1 = load float, ptr %ptr
store float %ld0, ptr %ptr
store float %ld0, ptr %ptr
ret void
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
llvm::BasicBlock *LLVMBB = &*LLVMF.begin();

Ctx.createFunction(&LLVMF);
auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
auto It = BB->begin();
sandboxir::Instruction *Ld0 = &*It++;
sandboxir::Instruction *Ld1 = &*It++;
sandboxir::Instruction *St0 = &*It++;
sandboxir::Instruction *St1 = &*It++;
// Check RUWIf when the lambda returns false.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
EXPECT_EQ(St0->getOperand(0), Ld0);
EXPECT_EQ(St1->getOperand(0), Ld0);
// Check RUWIf when the lambda returns true.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
EXPECT_EQ(St0->getOperand(0), Ld1);
EXPECT_EQ(St1->getOperand(0), Ld1);
St0->setOperand(0, Ld0);
St1->setOperand(0, Ld0);
// Check RUWIf user == St0.
Ld0->replaceUsesWithIf(
Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; });
EXPECT_EQ(St0->getOperand(0), Ld1);
EXPECT_EQ(St1->getOperand(0), Ld0);
St0->setOperand(0, Ld0);
// Check RUWIf user == St1.
Ld0->replaceUsesWithIf(
Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; });
EXPECT_EQ(St0->getOperand(0), Ld0);
EXPECT_EQ(St1->getOperand(0), Ld1);
St1->setOperand(0, Ld0);
// Check RAUW.
Ld1->replaceAllUsesWith(Ld0);
EXPECT_EQ(St0->getOperand(0), Ld0);
EXPECT_EQ(St1->getOperand(0), Ld0);
}

// Check that the operands/users are counted correctly.
Expand Down