Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW #98410

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class Value {
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -284,6 +285,11 @@ class Value {
Type *getType() const { return Val->getType(); }

Context &getContext() const { return Ctx; }

void replaceUsesWithIf(Value *OtherV,
llvm::function_ref<bool(const Use &)> ShouldReplace);
void replaceAllUsesWith(Value *Other);

#ifndef NDEBUG
/// Should crash if there is something wrong with the instruction.
virtual void verify() const = 0;
Expand Down Expand Up @@ -349,6 +355,10 @@ class User : public Value {
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()

#ifndef NDEBUG
void verifyUserOfLLVMUse(const llvm::Use &Use) const;
#endif // NDEBUG

public:
/// For isa/dyn_cast.
static bool classof(const Value *From);
Expand Down Expand Up @@ -387,6 +397,11 @@ class User : public Value {
return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
}

virtual void setOperand(unsigned OperandIdx, Value *Operand);
/// Replaces any operands that match \p FromV with \p ToV. Returns whether any
/// operands were replaced.
bool replaceUsesOfWith(Value *FromV, Value *ToV);

#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::User>(Val) && "Expected User!");
Expand Down
40 changes: 38 additions & 2 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {

unsigned Value::getNumUses() const { return range_size(Val->users()); }

void Value::replaceUsesWithIf(
Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
assert(getType() == OtherV->getType() && "Can't replace with different type");
llvm::Value *OtherVal = OtherV->Val;
Val->replaceUsesWithIf(
OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
if (DstU == nullptr)
return false;
return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
});
}

void Value::replaceAllUsesWith(Value *Other) {
assert(getType() == Other->getType() &&
"Replacing with Value of different type!");
Val->replaceAllUsesWith(Other->Val);
}

#ifndef NDEBUG
std::string Value::getName() const {
std::stringstream SS;
Expand Down Expand Up @@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
return Use(LLVMUse, const_cast<User *>(this), Ctx);
}

#ifndef NDEBUG
void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
assert(Ctx.getValue(Use.getUser()) == this &&
"Use not found in this SBUser's operands!");
}
#endif

bool User::classof(const Value *From) {
switch (From->getSubclassID()) {
#define DEF_VALUE(ID, CLASS)
Expand All @@ -180,6 +206,15 @@ bool User::classof(const Value *From) {
}
}

void User::setOperand(unsigned OperandIdx, Value *Operand) {
assert(isa<llvm::User>(Val) && "No operands!");
cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
}

bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
}

#ifndef NDEBUG
void User::dumpCommonHeader(raw_ostream &OS) const {
Value::dumpCommonHeader(OS);
Expand Down Expand Up @@ -325,10 +360,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();

if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
return It->second.get();
return NewC;
}
if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
Expand Down
104 changes: 104 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMBBIt = LLVMBB->begin();
Instruction *LLVMI0 = &*LLVMBBIt++;
Instruction *LLVMRet = &*LLVMBBIt++;
Argument *LLVMArg0 = LLVMF.getArg(0);
Argument *LLVMArg1 = LLVMF.getArg(1);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
Expand Down Expand Up @@ -203,6 +206,107 @@ OperandNo: 0
EXPECT_FALSE(I0->hasNUses(0u));
EXPECT_TRUE(I0->hasNUses(1u));
EXPECT_FALSE(I0->hasNUses(2u));

// Check User.setOperand().
Ret->setOperand(0, Arg0);
EXPECT_EQ(Ret->getOperand(0), Arg0);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);

Ret->setOperand(0, Arg1);
EXPECT_EQ(Ret->getOperand(0), Arg1);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
}

TEST_F(SandboxIRTest, RUOW) {
parseIR(C, R"IR(
declare void @bar0()
declare void @bar1()

@glob0 = global ptr @bar0
@glob1 = global ptr @bar1

define i32 @foo(i32 %v0, i32 %v1) {
vporpo marked this conversation as resolved.
Show resolved Hide resolved
%add0 = add i32 %v0, %v1
%gep1 = getelementptr i8, ptr @glob0, i32 1
%gep2 = getelementptr i8, ptr @glob1, i32 1
ret i32 %add0
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto It = BB.begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *I2 = &*It++;
auto *Ret = &*It++;

bool Replaced;
// Try to replace an operand that doesn't match.
Replaced = I0->replaceUsesOfWith(Ret, Arg1);
EXPECT_FALSE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands differ.
Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg1);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands are the same.
Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg0);

// Replace Ret operand.
Replaced = Ret->replaceUsesOfWith(I0, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(Ret->getOperand(0), Arg0);

// Check RAUW on constant.
auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
auto *Glob0Op = Glob0->getOperand(0);
Glob0->replaceUsesOfWith(Glob0Op, Glob1);
EXPECT_EQ(Glob0->getOperand(0), Glob1);
}

TEST_F(SandboxIRTest, RAW_RUWIf) {
vporpo marked this conversation as resolved.
Show resolved Hide resolved
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
%ld0 = load float, ptr %ptr
%ld1 = load float, ptr %ptr
store float %ld0, ptr %ptr
ret void
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
llvm::BasicBlock *LLVMBB0 = &*LLVMF.begin();

Ctx.createFunction(&LLVMF);
auto *BB0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB0));
auto It = BB0->begin();
sandboxir::Instruction *Ld0 = &*It++;
sandboxir::Instruction *Ld1 = &*It++;
sandboxir::Instruction *St0 = &*It++;
// Check RUWIf when the lambda returns false.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
EXPECT_EQ(St0->getOperand(0), Ld0);
// Check RUWIf when the lambda returns true.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
vporpo marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(St0->getOperand(0), Ld1);
// Check RAUW.
Ld1->replaceAllUsesWith(Ld0);
EXPECT_EQ(St0->getOperand(0), Ld0);
}

// Check that the operands/users are counted correctly.
Expand Down
Loading