diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h index 7e0b88ae7197d..a52368b5704a1 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h @@ -25,7 +25,11 @@ namespace llvm::sandboxir { class BottomUpVec final : public FunctionPass { bool Change = false; std::unique_ptr Legality; - void vectorizeRec(ArrayRef Bndl); + + /// Creates and returns a vector instruction that replaces the instructions in + /// \p Bndl. \p Operands are the already vectorized operands. + Value *createVectorInstr(ArrayRef Bndl, ArrayRef Operands); + Value *vectorizeRec(ArrayRef Bndl); void tryVectorize(ArrayRef Seeds); // The PM containing the pipeline of region passes. diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h index 8b64ec58da345..85229150de2b6 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h @@ -66,6 +66,40 @@ class VecUtils { } return true; } + + /// \Returns the number of vector lanes of \p Ty or 1 if not a vector. + /// NOTE: It asserts that \p Ty is a fixed vector type. + static unsigned getNumLanes(Type *Ty) { + assert(!isa(Ty) && "Expect scalar or fixed vector"); + if (auto *FixedVecTy = dyn_cast(Ty)) + return FixedVecTy->getNumElements(); + return 1u; + } + + /// \Returns the expected vector lanes of \p V or 1 if not a vector. + /// NOTE: It asserts that \p V is a fixed vector. + static unsigned getNumLanes(Value *V) { + return VecUtils::getNumLanes(Utils::getExpectedType(V)); + } + + /// \Returns the total number of lanes across all values in \p Bndl. + static unsigned getNumLanes(ArrayRef Bndl) { + unsigned Lanes = 0; + for (Value *V : Bndl) + Lanes += getNumLanes(V); + return Lanes; + } + + /// \Returns . + /// It works for both scalar and vector \p ElemTy. + static Type *getWideType(Type *ElemTy, unsigned NumElts) { + if (ElemTy->isVectorTy()) { + auto *VecTy = cast(ElemTy); + ElemTy = VecTy->getElementType(); + NumElts = VecTy->getNumElements() * NumElts; + } + return FixedVectorType::get(ElemTy, NumElts); + } }; } // namespace llvm::sandboxir diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 005d2241430ff..8b36ce57e2ae8 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -7,12 +7,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" - #include "llvm/ADT/SmallVector.h" #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Instruction.h" #include "llvm/SandboxIR/Module.h" +#include "llvm/SandboxIR/Utils.h" #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" namespace llvm::sandboxir { @@ -40,15 +41,149 @@ static SmallVector getOperand(ArrayRef Bndl, return Operands; } -void BottomUpVec::vectorizeRec(ArrayRef Bndl) { +static BasicBlock::iterator +getInsertPointAfterInstrs(ArrayRef Instrs) { + // TODO: Use the VecUtils function for getting the bottom instr once it lands. + auto *BotI = cast( + *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { + return cast(V1)->comesBefore(cast(V2)); + })); + // If Bndl contains Arguments or Constants, use the beginning of the BB. + return std::next(BotI->getIterator()); +} + +Value *BottomUpVec::createVectorInstr(ArrayRef Bndl, + ArrayRef Operands) { + assert(all_of(Bndl, [](auto *V) { return isa(V); }) && + "Expect Instructions!"); + auto &Ctx = Bndl[0]->getContext(); + + Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); + auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); + + BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); + + auto Opcode = cast(Bndl[0])->getOpcode(); + switch (Opcode) { + case Instruction::Opcode::ZExt: + case Instruction::Opcode::SExt: + case Instruction::Opcode::FPToUI: + case Instruction::Opcode::FPToSI: + case Instruction::Opcode::FPExt: + case Instruction::Opcode::PtrToInt: + case Instruction::Opcode::IntToPtr: + case Instruction::Opcode::SIToFP: + case Instruction::Opcode::UIToFP: + case Instruction::Opcode::Trunc: + case Instruction::Opcode::FPTrunc: + case Instruction::Opcode::BitCast: { + assert(Operands.size() == 1u && "Casts are unary!"); + return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast"); + } + case Instruction::Opcode::FCmp: + case Instruction::Opcode::ICmp: { + auto Pred = cast(Bndl[0])->getPredicate(); + assert(all_of(drop_begin(Bndl), + [Pred](auto *SBV) { + return cast(SBV)->getPredicate() == Pred; + }) && + "Expected same predicate across bundle."); + return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, + "VCmp"); + } + case Instruction::Opcode::Select: { + return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, + Ctx, "Vec"); + } + case Instruction::Opcode::FNeg: { + auto *UOp0 = cast(Bndl[0]); + auto OpC = UOp0->getOpcode(); + return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt, + Ctx, "Vec"); + } + case Instruction::Opcode::Add: + case Instruction::Opcode::FAdd: + case Instruction::Opcode::Sub: + case Instruction::Opcode::FSub: + case Instruction::Opcode::Mul: + case Instruction::Opcode::FMul: + case Instruction::Opcode::UDiv: + case Instruction::Opcode::SDiv: + case Instruction::Opcode::FDiv: + case Instruction::Opcode::URem: + case Instruction::Opcode::SRem: + case Instruction::Opcode::FRem: + case Instruction::Opcode::Shl: + case Instruction::Opcode::LShr: + case Instruction::Opcode::AShr: + case Instruction::Opcode::And: + case Instruction::Opcode::Or: + case Instruction::Opcode::Xor: { + auto *BinOp0 = cast(Bndl[0]); + auto *LHS = Operands[0]; + auto *RHS = Operands[1]; + return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS, + BinOp0, WhereIt, Ctx, "Vec"); + } + case Instruction::Opcode::Load: { + auto *Ld0 = cast(Bndl[0]); + Value *Ptr = Ld0->getPointerOperand(); + return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL"); + } + case Instruction::Opcode::Store: { + auto Align = cast(Bndl[0])->getAlign(); + Value *Val = Operands[0]; + Value *Ptr = Operands[1]; + return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); + } + case Instruction::Opcode::Br: + case Instruction::Opcode::Ret: + case Instruction::Opcode::PHI: + case Instruction::Opcode::AddrSpaceCast: + case Instruction::Opcode::Call: + case Instruction::Opcode::GetElementPtr: + llvm_unreachable("Unimplemented"); + break; + default: + llvm_unreachable("Unimplemented"); + break; + } + llvm_unreachable("Missing switch case!"); + // TODO: Propagate debug info. +} + +Value *BottomUpVec::vectorizeRec(ArrayRef Bndl) { + Value *NewVec = nullptr; const auto &LegalityRes = Legality->canVectorize(Bndl); switch (LegalityRes.getSubclassID()) { case LegalityResultID::Widen: { auto *I = cast(Bndl[0]); - for (auto OpIdx : seq(I->getNumOperands())) { - auto OperandBndl = getOperand(Bndl, OpIdx); - vectorizeRec(OperandBndl); + SmallVector VecOperands; + switch (I->getOpcode()) { + case Instruction::Opcode::Load: + // Don't recurse towards the pointer operand. + VecOperands.push_back(cast(I)->getPointerOperand()); + break; + case Instruction::Opcode::Store: { + // Don't recurse towards the pointer operand. + auto *VecOp = vectorizeRec(getOperand(Bndl, 0)); + VecOperands.push_back(VecOp); + VecOperands.push_back(cast(I)->getPointerOperand()); + break; + } + default: + // Visit all operands. + for (auto OpIdx : seq(I->getNumOperands())) { + auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx)); + VecOperands.push_back(VecOp); + } + break; } + NewVec = createVectorInstr(Bndl, VecOperands); + + // TODO: Notify DAG/Scheduler about new instruction + + // TODO: Collect potentially dead instructions. break; } case LegalityResultID::Pack: { @@ -56,6 +191,7 @@ void BottomUpVec::vectorizeRec(ArrayRef Bndl) { llvm_unreachable("Unimplemented"); } } + return NewVec; } void BottomUpVec::tryVectorize(ArrayRef Bndl) { vectorizeRec(Bndl); } diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll new file mode 100644 index 0000000000000..2b9aac93b7485 --- /dev/null +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -0,0 +1,88 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=sandbox-vectorizer -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s + +define void @store_load(ptr %ptr) { +; CHECK-LABEL: define void @store_load( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: store float [[LD0]], ptr [[PTR0]], align 4 +; CHECK-NEXT: store float [[LD1]], ptr [[PTR1]], align 4 +; CHECK-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + store float %ld0, ptr %ptr0 + store float %ld1, ptr %ptr1 + ret void +} + + +define void @store_fpext_load(ptr %ptr) { +; CHECK-LABEL: define void @store_fpext_load( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[PTRD0:%.*]] = getelementptr double, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTRD1:%.*]] = getelementptr double, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[FPEXT0:%.*]] = fpext float [[LD0]] to double +; CHECK-NEXT: [[FPEXT1:%.*]] = fpext float [[LD1]] to double +; CHECK-NEXT: [[VCAST:%.*]] = fpext <2 x float> [[VECL]] to <2 x double> +; CHECK-NEXT: store double [[FPEXT0]], ptr [[PTRD0]], align 8 +; CHECK-NEXT: store double [[FPEXT1]], ptr [[PTRD1]], align 8 +; CHECK-NEXT: store <2 x double> [[VCAST]], ptr [[PTRD0]], align 8 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ptrd0 = getelementptr double, ptr %ptr, i32 0 + %ptrd1 = getelementptr double, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + %fpext0 = fpext float %ld0 to double + %fpext1 = fpext float %ld1 to double + store double %fpext0, ptr %ptrd0 + store double %fpext1, ptr %ptrd1 + ret void +} + +; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check + +; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check + +define void @store_fneg_load(ptr %ptr) { +; CHECK-LABEL: define void @store_fneg_load( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[FNEG0:%.*]] = fneg float [[LD0]] +; CHECK-NEXT: [[FNEG1:%.*]] = fneg float [[LD1]] +; CHECK-NEXT: [[VEC:%.*]] = fneg <2 x float> [[VECL]] +; CHECK-NEXT: store float [[FNEG0]], ptr [[PTR0]], align 4 +; CHECK-NEXT: store float [[FNEG1]], ptr [[PTR1]], align 4 +; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + %fneg0 = fneg float %ld0 + %fneg1 = fneg float %ld1 + store float %fneg0, ptr %ptr0 + store float %fneg1, ptr %ptr1 + ret void +} + diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp index 75f72ce23fbaa..654fd7dfe1776 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp @@ -368,3 +368,45 @@ define void @foo(ptr %ptr) { EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V3L0, V2L2, SE, DL)); EXPECT_FALSE(sandboxir::VecUtils::areConsecutive(V2L1, L0, SE, DL)); } + +TEST_F(VecUtilsTest, GetNumLanes) { + parseIR(R"IR( +define <4 x float> @foo(float %v, <2 x float> %v2, <4 x float> %ret, ptr %ptr) { + store float %v, ptr %ptr + store <2 x float> %v2, ptr %ptr + ret <4 x float> %ret +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + + auto It = BB.begin(); + auto *S0 = cast(&*It++); + auto *S1 = cast(&*It++); + auto *Ret = cast(&*It++); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S0->getValueOperand()->getType()), + 1u); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S0), 1); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S1->getValueOperand()->getType()), + 2u); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(S1), 2); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Ret->getReturnValue()->getType()), + 4u); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Ret), 4); + + SmallVector Bndl({S0, S1, Ret}); + EXPECT_EQ(sandboxir::VecUtils::getNumLanes(Bndl), 7u); +} + +TEST_F(VecUtilsTest, GetWideType) { + sandboxir::Context Ctx(C); + + auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx); + auto *Int32X4Ty = sandboxir::FixedVectorType::get(Int32Ty, 4); + EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32Ty, 4), Int32X4Ty); + auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8); + EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty); +}