Skip to content

Commit

Permalink
[SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW
Browse files Browse the repository at this point in the history
This patch adds the following member functions:
- User::setOperand()
- User::replaceUsesOfWith()
- Value::replaceAllUsesWith()
- Value::replaceUsesWithIf()
  • Loading branch information
vporpo committed Jul 10, 2024
1 parent 797a2ec commit 044c30d
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 2 deletions.
22 changes: 22 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class Value {
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.
friend class ValueAttorney;

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -284,6 +285,11 @@ class Value {
Type *getType() const { return Val->getType(); }

Context &getContext() const { return Ctx; }

void replaceUsesWithIf(Value *OtherV,
llvm::function_ref<bool(const Use &)> ShouldReplace);
void replaceAllUsesWith(Value *Other);

#ifndef NDEBUG
/// Should crash if there is something wrong with the instruction.
virtual void verify() const = 0;
Expand All @@ -303,6 +309,13 @@ class Value {
#endif
};

/// Helper Attorney-Client class that gives access to the underlying IR.
class ValueAttorney {
private:
static llvm::Value *getValue(const Value *SBV) { return SBV->Val; }
friend class User;
};

/// Argument of a sandboxir::Function.
class Argument : public sandboxir::Value {
Argument(llvm::Argument *Arg, sandboxir::Context &Ctx)
Expand Down Expand Up @@ -349,6 +362,10 @@ class User : public Value {
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()

#ifndef NDEBUG
void verifyUserOfLLVMUse(const llvm::Use &Use) const;
#endif // NDEBUG

public:
/// For isa/dyn_cast.
static bool classof(const Value *From);
Expand Down Expand Up @@ -387,6 +404,11 @@ class User : public Value {
return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
}

virtual void setOperand(unsigned OperandIdx, Value *Operand);
/// Replaces any operands that match \p FromV with \p ToV. Returns whether any
/// operands were replaced.
bool replaceUsesOfWith(Value *FromV, Value *ToV);

#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::User>(Val) && "Expected User!");
Expand Down
44 changes: 42 additions & 2 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {

unsigned Value::getNumUses() const { return range_size(Val->users()); }

void Value::replaceUsesWithIf(
Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
assert(getType() == OtherV->getType() && "Can't replace with different type");
llvm::Value *OtherVal = OtherV->Val;
Val->replaceUsesWithIf(
OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
if (DstU == nullptr)
return false;
return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
});
}

void Value::replaceAllUsesWith(Value *Other) {
assert(getType() == Other->getType() &&
"Replacing with Value of different type!");
Val->replaceAllUsesWith(Other->Val);
}

#ifndef NDEBUG
std::string Value::getName() const {
std::stringstream SS;
Expand Down Expand Up @@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
return Use(LLVMUse, const_cast<User *>(this), Ctx);
}

#ifndef NDEBUG
void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
assert(Ctx.getValue(Use.getUser()) == this &&
"Use not found in this SBUser's operands!");
}
#endif

bool User::classof(const Value *From) {
switch (From->getSubclassID()) {
#define DEF_VALUE(ID, CLASS)
Expand All @@ -180,6 +206,19 @@ bool User::classof(const Value *From) {
}
}

void User::setOperand(unsigned OperandIdx, Value *Operand) {
if (!isa<llvm::User>(Val))
llvm_unreachable("No operands!");
cast<llvm::User>(Val)->setOperand(OperandIdx,
ValueAttorney::getValue(Operand));
}

bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
llvm::Value *FromLLVM = ValueAttorney::getValue(FromV);
llvm::Value *ToLLVM = ValueAttorney::getValue(ToV);
return cast<llvm::User>(Val)->replaceUsesOfWith(FromLLVM, ToLLVM);
}

#ifndef NDEBUG
void User::dumpCommonHeader(raw_ostream &OS) const {
Value::dumpCommonHeader(OS);
Expand Down Expand Up @@ -325,10 +364,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();

if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
return It->second.get();
return NewC;
}
if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
Expand Down
104 changes: 104 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMBBIt = LLVMBB->begin();
Instruction *LLVMI0 = &*LLVMBBIt++;
Instruction *LLVMRet = &*LLVMBBIt++;
Argument *LLVMArg0 = LLVMF.getArg(0);
Argument *LLVMArg1 = LLVMF.getArg(1);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
Expand Down Expand Up @@ -203,6 +206,107 @@ OperandNo: 0
EXPECT_FALSE(I0->hasNUses(0u));
EXPECT_TRUE(I0->hasNUses(1u));
EXPECT_FALSE(I0->hasNUses(2u));

// Check User.setOperand().
Ret->setOperand(0, Arg0);
EXPECT_EQ(Ret->getOperand(0), Arg0);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);

Ret->setOperand(0, Arg1);
EXPECT_EQ(Ret->getOperand(0), Arg1);
EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
}

TEST_F(SandboxIRTest, RUOW) {
parseIR(C, R"IR(
declare void @bar0()
declare void @bar1()
@glob0 = global ptr @bar0
@glob1 = global ptr @bar1
define i32 @foo(i32 %v0, i32 %v1) {
%add0 = add i32 %v0, %v1
%gep1 = getelementptr i8, ptr @glob0, i32 1
%gep2 = getelementptr i8, ptr @glob1, i32 1
ret i32 %add0
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto It = BB.begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *I2 = &*It++;
auto *Ret = &*It++;

bool Replaced;
// Try to replace an operand that doesn't match.
Replaced = I0->replaceUsesOfWith(Ret, Arg1);
EXPECT_FALSE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands differ.
Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg1);
EXPECT_EQ(I0->getOperand(1), Arg1);

// Replace I0 operands when operands are the same.
Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(I0->getOperand(0), Arg0);
EXPECT_EQ(I0->getOperand(1), Arg0);

// Replace Ret operand.
Replaced = Ret->replaceUsesOfWith(I0, Arg0);
EXPECT_TRUE(Replaced);
EXPECT_EQ(Ret->getOperand(0), Arg0);

// Check RAUW on constant.
auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
auto *Glob0Op = Glob0->getOperand(0);
Glob0->replaceUsesOfWith(Glob0Op, Glob1);
EXPECT_EQ(Glob0->getOperand(0), Glob1);
}

TEST_F(SandboxIRTest, RAW_RUWIf) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
%ld0 = load float, ptr %ptr
%ld1 = load float, ptr %ptr
store float %ld0, ptr %ptr
ret void
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
llvm::BasicBlock *LLVMBB0 = &*LLVMF.begin();

Ctx.createFunction(&LLVMF);
auto *BB0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB0));
auto It = BB0->begin();
sandboxir::Instruction *Ld0 = &*It++;
sandboxir::Instruction *Ld1 = &*It++;
sandboxir::Instruction *St0 = &*It++;
// Check RUWIf when the lambda returns false.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
EXPECT_EQ(St0->getOperand(0), Ld0);
// Check RUWIf when the lambda returns true.
Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
EXPECT_EQ(St0->getOperand(0), Ld1);
// Check RAUW.
Ld1->replaceAllUsesWith(Ld0);
EXPECT_EQ(St0->getOperand(0), Ld0);
}

// Check that the operands/users are counted correctly.
Expand Down

0 comments on commit 044c30d

Please sign in to comment.