From 52e72eecc9dcbfcfb819210866f63060eef41306 Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Wed, 10 Jul 2024 15:41:18 -0700 Subject: [PATCH 1/2] [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW This patch adds the following member functions: - User::setOperand() - User::replaceUsesOfWith() - Value::replaceAllUsesWith() - Value::replaceUsesWithIf() --- llvm/include/llvm/SandboxIR/SandboxIR.h | 15 +++ llvm/lib/SandboxIR/SandboxIR.cpp | 40 +++++++- llvm/unittests/SandboxIR/SandboxIRTest.cpp | 104 +++++++++++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 8e87470ee1e5c..317884fe07681 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -200,6 +200,7 @@ class Value { llvm::Value *Val = nullptr; friend class Context; // For getting `Val`. + friend class User; // For getting `Val`. /// All values point to the context. Context &Ctx; @@ -284,6 +285,11 @@ class Value { Type *getType() const { return Val->getType(); } Context &getContext() const { return Ctx; } + + void replaceUsesWithIf(Value *OtherV, + llvm::function_ref ShouldReplace); + void replaceAllUsesWith(Value *Other); + #ifndef NDEBUG /// Should crash if there is something wrong with the instruction. virtual void verify() const = 0; @@ -349,6 +355,10 @@ class User : public Value { virtual unsigned getUseOperandNo(const Use &Use) const = 0; friend unsigned Use::getOperandNo() const; // For getUseOperandNo() +#ifndef NDEBUG + void verifyUserOfLLVMUse(const llvm::Use &Use) const; +#endif // NDEBUG + public: /// For isa/dyn_cast. static bool classof(const Value *From); @@ -387,6 +397,11 @@ class User : public Value { return isa(Val) ? cast(Val)->getNumOperands() : 0; } + virtual void setOperand(unsigned OperandIdx, Value *Operand); + /// Replaces any operands that match \p FromV with \p ToV. Returns whether any + /// operands were replaced. + bool replaceUsesOfWith(Value *FromV, Value *ToV); + #ifndef NDEBUG void verify() const override { assert(isa(Val) && "Expected User!"); diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 160d807738a3c..b374cbff13696 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() { unsigned Value::getNumUses() const { return range_size(Val->users()); } +void Value::replaceUsesWithIf( + Value *OtherV, llvm::function_ref ShouldReplace) { + assert(getType() == OtherV->getType() && "Can't replace with different type"); + llvm::Value *OtherVal = OtherV->Val; + Val->replaceUsesWithIf( + OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool { + User *DstU = cast_or_null(Ctx.getValue(LLVMUse.getUser())); + if (DstU == nullptr) + return false; + return ShouldReplace(Use(&LLVMUse, DstU, Ctx)); + }); +} + +void Value::replaceAllUsesWith(Value *Other) { + assert(getType() == Other->getType() && + "Replacing with Value of different type!"); + Val->replaceAllUsesWith(Other->Val); +} + #ifndef NDEBUG std::string Value::getName() const { std::stringstream SS; @@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const { return Use(LLVMUse, const_cast(this), Ctx); } +#ifndef NDEBUG +void User::verifyUserOfLLVMUse(const llvm::Use &Use) const { + assert(Ctx.getValue(Use.getUser()) == this && + "Use not found in this SBUser's operands!"); +} +#endif + bool User::classof(const Value *From) { switch (From->getSubclassID()) { #define DEF_VALUE(ID, CLASS) @@ -180,6 +206,15 @@ bool User::classof(const Value *From) { } } +void User::setOperand(unsigned OperandIdx, Value *Operand) { + assert(isa(Val) && "No operands!"); + cast(Val)->setOperand(OperandIdx, Operand->Val); +} + +bool User::replaceUsesOfWith(Value *FromV, Value *ToV) { + return cast(Val)->replaceUsesOfWith(FromV->Val, ToV->Val); +} + #ifndef NDEBUG void User::dumpCommonHeader(raw_ostream &OS) const { Value::dumpCommonHeader(OS); @@ -325,10 +360,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { return It->second.get(); if (auto *C = dyn_cast(LLVMV)) { + It->second = std::unique_ptr(new Constant(C, *this)); + auto *NewC = It->second.get(); for (llvm::Value *COp : C->operands()) getOrCreateValueInternal(COp, C); - It->second = std::unique_ptr(new Constant(C, *this)); - return It->second.get(); + return NewC; } if (auto *Arg = dyn_cast(LLVMV)) { It->second = std::unique_ptr(new Argument(Arg, *this)); diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 72e81bf640350..16e537efba5de 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) { BasicBlock *LLVMBB = &*LLVMF.begin(); auto LLVMBBIt = LLVMBB->begin(); Instruction *LLVMI0 = &*LLVMBBIt++; + Instruction *LLVMRet = &*LLVMBBIt++; + Argument *LLVMArg0 = LLVMF.getArg(0); + Argument *LLVMArg1 = LLVMF.getArg(1); auto &F = *Ctx.createFunction(&LLVMF); auto &BB = *F.begin(); @@ -203,6 +206,107 @@ OperandNo: 0 EXPECT_FALSE(I0->hasNUses(0u)); EXPECT_TRUE(I0->hasNUses(1u)); EXPECT_FALSE(I0->hasNUses(2u)); + + // Check User.setOperand(). + Ret->setOperand(0, Arg0); + EXPECT_EQ(Ret->getOperand(0), Arg0); + EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0); + EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0); + + Ret->setOperand(0, Arg1); + EXPECT_EQ(Ret->getOperand(0), Arg1); + EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1); + EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1); +} + +TEST_F(SandboxIRTest, RUOW) { + parseIR(C, R"IR( +declare void @bar0() +declare void @bar1() + +@glob0 = global ptr @bar0 +@glob1 = global ptr @bar1 + +define i32 @foo(i32 %v0, i32 %v1) { + %add0 = add i32 %v0, %v1 + %gep1 = getelementptr i8, ptr @glob0, i32 1 + %gep2 = getelementptr i8, ptr @glob1, i32 1 + ret i32 %add0 +} +)IR"); + llvm::Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + auto *Arg0 = F.getArg(0); + auto *Arg1 = F.getArg(1); + auto It = BB.begin(); + auto *I0 = &*It++; + auto *I1 = &*It++; + auto *I2 = &*It++; + auto *Ret = &*It++; + + bool Replaced; + // Try to replace an operand that doesn't match. + Replaced = I0->replaceUsesOfWith(Ret, Arg1); + EXPECT_FALSE(Replaced); + EXPECT_EQ(I0->getOperand(0), Arg0); + EXPECT_EQ(I0->getOperand(1), Arg1); + + // Replace I0 operands when operands differ. + Replaced = I0->replaceUsesOfWith(Arg0, Arg1); + EXPECT_TRUE(Replaced); + EXPECT_EQ(I0->getOperand(0), Arg1); + EXPECT_EQ(I0->getOperand(1), Arg1); + + // Replace I0 operands when operands are the same. + Replaced = I0->replaceUsesOfWith(Arg1, Arg0); + EXPECT_TRUE(Replaced); + EXPECT_EQ(I0->getOperand(0), Arg0); + EXPECT_EQ(I0->getOperand(1), Arg0); + + // Replace Ret operand. + Replaced = Ret->replaceUsesOfWith(I0, Arg0); + EXPECT_TRUE(Replaced); + EXPECT_EQ(Ret->getOperand(0), Arg0); + + // Check RAUW on constant. + auto *Glob0 = cast(I1->getOperand(0)); + auto *Glob1 = cast(I2->getOperand(0)); + auto *Glob0Op = Glob0->getOperand(0); + Glob0->replaceUsesOfWith(Glob0Op, Glob1); + EXPECT_EQ(Glob0->getOperand(0), Glob1); +} + +TEST_F(SandboxIRTest, RAW_RUWIf) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %ld0 = load float, ptr %ptr + %ld1 = load float, ptr %ptr + store float %ld0, ptr %ptr + ret void +} +)IR"); + llvm::Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + llvm::BasicBlock *LLVMBB0 = &*LLVMF.begin(); + + Ctx.createFunction(&LLVMF); + auto *BB0 = cast(Ctx.getValue(LLVMBB0)); + auto It = BB0->begin(); + sandboxir::Instruction *Ld0 = &*It++; + sandboxir::Instruction *Ld1 = &*It++; + sandboxir::Instruction *St0 = &*It++; + // Check RUWIf when the lambda returns false. + Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; }); + EXPECT_EQ(St0->getOperand(0), Ld0); + // Check RUWIf when the lambda returns true. + Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; }); + EXPECT_EQ(St0->getOperand(0), Ld1); + // Check RAUW. + Ld1->replaceAllUsesWith(Ld0); + EXPECT_EQ(St0->getOperand(0), Ld0); } // Check that the operands/users are counted correctly. From abe63b20b600b8fc76c029acd60bd158bc746c42 Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Thu, 11 Jul 2024 13:28:15 -0700 Subject: [PATCH 2/2] fixup! [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW --- llvm/unittests/SandboxIR/SandboxIRTest.cpp | 31 +++++++++++++++++----- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 16e537efba5de..98c0052d878d8 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -227,8 +227,8 @@ declare void @bar1() @glob0 = global ptr @bar0 @glob1 = global ptr @bar1 -define i32 @foo(i32 %v0, i32 %v1) { - %add0 = add i32 %v0, %v1 +define i32 @foo(i32 %arg0, i32 %arg1) { + %add0 = add i32 %arg0, %arg1 %gep1 = getelementptr i8, ptr @glob0, i32 1 %gep2 = getelementptr i8, ptr @glob1, i32 1 ret i32 %add0 @@ -279,34 +279,53 @@ define i32 @foo(i32 %v0, i32 %v1) { EXPECT_EQ(Glob0->getOperand(0), Glob1); } -TEST_F(SandboxIRTest, RAW_RUWIf) { +TEST_F(SandboxIRTest, RAUW_RUWIf) { parseIR(C, R"IR( define void @foo(ptr %ptr) { %ld0 = load float, ptr %ptr %ld1 = load float, ptr %ptr store float %ld0, ptr %ptr + store float %ld0, ptr %ptr ret void } )IR"); llvm::Function &LLVMF = *M->getFunction("foo"); sandboxir::Context Ctx(C); - llvm::BasicBlock *LLVMBB0 = &*LLVMF.begin(); + llvm::BasicBlock *LLVMBB = &*LLVMF.begin(); Ctx.createFunction(&LLVMF); - auto *BB0 = cast(Ctx.getValue(LLVMBB0)); - auto It = BB0->begin(); + auto *BB = cast(Ctx.getValue(LLVMBB)); + auto It = BB->begin(); sandboxir::Instruction *Ld0 = &*It++; sandboxir::Instruction *Ld1 = &*It++; sandboxir::Instruction *St0 = &*It++; + sandboxir::Instruction *St1 = &*It++; // Check RUWIf when the lambda returns false. Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; }); EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); // Check RUWIf when the lambda returns true. Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; }); EXPECT_EQ(St0->getOperand(0), Ld1); + EXPECT_EQ(St1->getOperand(0), Ld1); + St0->setOperand(0, Ld0); + St1->setOperand(0, Ld0); + // Check RUWIf user == St0. + Ld0->replaceUsesWithIf( + Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; }); + EXPECT_EQ(St0->getOperand(0), Ld1); + EXPECT_EQ(St1->getOperand(0), Ld0); + St0->setOperand(0, Ld0); + // Check RUWIf user == St1. + Ld0->replaceUsesWithIf( + Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; }); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld1); + St1->setOperand(0, Ld0); // Check RAUW. Ld1->replaceAllUsesWith(Ld0); EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); } // Check that the operands/users are counted correctly.