diff --git a/llvm/include/llvm/IR/VectorBuilder.h b/llvm/include/llvm/IR/VectorBuilder.h new file mode 100644 index 00000000000000..a24338adf3c5d2 --- /dev/null +++ b/llvm/include/llvm/IR/VectorBuilder.h @@ -0,0 +1,99 @@ +//===- llvm/VectorBuilder.h - Builder for VP Intrinsics ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the VectorBuilder class, which is used as a convenient way +// to create VP intrinsics as if they were LLVM instructions with a consistent +// and simplified interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_VECTORBUILDER_H +#define LLVM_IR_VECTORBUILDER_H + +#include +#include +#include +#include + +namespace llvm { + +class VectorBuilder { +public: + enum class Behavior { + // Abort if the requested VP intrinsic could not be created. + // This is useful for strict consistency. + ReportAndAbort = 0, + + // Return a default-initialized value if the requested VP intrinsic could + // not be created. + // This is useful for a defensive fallback to non-VP code. + SilentlyReturnNone = 1, + }; + +private: + IRBuilder<> &Builder; + Behavior ErrorHandling; + + // Explicit mask parameter. + Value *Mask; + // Explicit vector length parameter. + Value *ExplicitVectorLength; + // Compile-time vector length. + ElementCount StaticVectorLength; + + // Get mask/evl value handles for the current configuration. + Value &requestMask(); + Value &requestEVL(); + + void handleError(const char *ErrorMsg) const; + template + RetType returnWithError(const char *ErrorMsg) const { + handleError(ErrorMsg); + return RetType(); + } + +public: + VectorBuilder(IRBuilder<> &Builder, + Behavior ErrorHandling = Behavior::ReportAndAbort) + : Builder(Builder), ErrorHandling(ErrorHandling), Mask(nullptr), + ExplicitVectorLength(nullptr), + StaticVectorLength(ElementCount::getFixed(0)) {} + + Module &getModule() const; + LLVMContext &getContext() const { return Builder.getContext(); } + + // All-true mask for the currently configured explicit vector length. + Value *getAllTrueMask(); + + VectorBuilder &setMask(Value *NewMask) { + Mask = NewMask; + return *this; + } + VectorBuilder &setEVL(Value *NewExplicitVectorLength) { + ExplicitVectorLength = NewExplicitVectorLength; + return *this; + } + VectorBuilder &setStaticVL(unsigned NewFixedVL) { + StaticVectorLength = ElementCount::getFixed(NewFixedVL); + return *this; + } + // TODO: setStaticVL(ElementCount) for scalable types. + + // Emit a VP intrinsic call that mimics a regular instruction. + // This operation behaves according to the VectorBuilderBehavior. + // \p Opcode The functional instruction opcode of the emitted intrinsic. + // \p ReturnTy The return type of the operation. + // \p VecOpArray The operand list. + Value *createVectorInstruction(unsigned Opcode, Type *ReturnTy, + ArrayRef VecOpArray, + const Twine &Name = Twine()); +}; + +} // namespace llvm + +#endif // LLVM_IR_VECTORBUILDER_H diff --git a/llvm/include/llvm/module.modulemap b/llvm/include/llvm/module.modulemap index 2e6bf791330d1d..c94bece5c0985c 100644 --- a/llvm/include/llvm/module.modulemap +++ b/llvm/include/llvm/module.modulemap @@ -254,6 +254,7 @@ module LLVM_intrinsic_gen { module IR_InstrTypes { header "IR/InstrTypes.h" export * } module IR_Instructions { header "IR/Instructions.h" export * } module IR_TypeFinder { header "IR/TypeFinder.h" export * } + module IR_VectorBuilder { header "IR/VectorBuilder.h" export * } // Intrinsics.h diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt index 2c09e57c3f970d..3e542e4622fbd0 100644 --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -61,6 +61,7 @@ add_llvm_component_library(LLVMCore User.cpp Value.cpp ValueSymbolTable.cpp + VectorBuilder.cpp Verifier.cpp ADDITIONAL_HEADER_DIRS diff --git a/llvm/lib/IR/VectorBuilder.cpp b/llvm/lib/IR/VectorBuilder.cpp new file mode 100644 index 00000000000000..34d943a9b10b48 --- /dev/null +++ b/llvm/lib/IR/VectorBuilder.cpp @@ -0,0 +1,103 @@ +//===- VectorBuilder.cpp - Builder for VP Intrinsics ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the VectorBuilder class, which is used as a convenient +// way to create VP intrinsics as if they were LLVM instructions with a +// consistent and simplified interface. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +namespace llvm { + +void VectorBuilder::handleError(const char *ErrorMsg) const { + if (ErrorHandling == Behavior::SilentlyReturnNone) + return; + report_fatal_error(ErrorMsg); +} + +Module &VectorBuilder::getModule() const { + return *Builder.GetInsertBlock()->getModule(); +} + +Value *VectorBuilder::getAllTrueMask() { + auto *BoolTy = Builder.getInt1Ty(); + auto *MaskTy = VectorType::get(BoolTy, StaticVectorLength); + return ConstantInt::getAllOnesValue(MaskTy); +} + +Value &VectorBuilder::requestMask() { + if (Mask) + return *Mask; + + return *getAllTrueMask(); +} + +Value &VectorBuilder::requestEVL() { + if (ExplicitVectorLength) + return *ExplicitVectorLength; + + assert(!StaticVectorLength.isScalable() && "TODO vscale lowering"); + auto *IntTy = Builder.getInt32Ty(); + return *ConstantInt::get(IntTy, StaticVectorLength.getFixedValue()); +} + +Value *VectorBuilder::createVectorInstruction(unsigned Opcode, Type *ReturnTy, + ArrayRef InstOpArray, + const Twine &Name) { + auto VPID = VPIntrinsic::getForOpcode(Opcode); + if (VPID == Intrinsic::not_intrinsic) + return returnWithError("No VPIntrinsic for this opcode"); + + auto MaskPosOpt = VPIntrinsic::getMaskParamPos(VPID); + auto VLenPosOpt = VPIntrinsic::getVectorLengthParamPos(VPID); + size_t NumInstParams = InstOpArray.size(); + size_t NumVPParams = + NumInstParams + MaskPosOpt.hasValue() + VLenPosOpt.hasValue(); + + SmallVector IntrinParams; + + // Whether the mask and vlen parameter are at the end of the parameter list. + bool TrailingMaskAndVLen = + std::min(MaskPosOpt.getValueOr(NumInstParams), + VLenPosOpt.getValueOr(NumInstParams)) >= NumInstParams; + + if (TrailingMaskAndVLen) { + // Fast path for trailing mask, vector length. + IntrinParams.append(InstOpArray.begin(), InstOpArray.end()); + IntrinParams.resize(NumVPParams); + } else { + IntrinParams.resize(NumVPParams); + // Insert mask and evl operands in between the instruction operands. + for (size_t VPParamIdx = 0, ParamIdx = 0; VPParamIdx < NumVPParams; + ++VPParamIdx) { + if ((MaskPosOpt && MaskPosOpt.getValueOr(NumVPParams) == VPParamIdx) || + (VLenPosOpt && VLenPosOpt.getValueOr(NumVPParams) == VPParamIdx)) + continue; + assert(ParamIdx < NumInstParams); + IntrinParams[VPParamIdx] = InstOpArray[ParamIdx++]; + } + } + + if (MaskPosOpt.hasValue()) + IntrinParams[*MaskPosOpt] = &requestMask(); + if (VLenPosOpt.hasValue()) + IntrinParams[*VLenPosOpt] = &requestEVL(); + + auto *VPDecl = VPIntrinsic::getDeclarationForParams(&getModule(), VPID, + ReturnTy, IntrinParams); + return Builder.CreateCall(VPDecl, IntrinParams, Name); +} + +} // namespace llvm diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt index 42fa677b247211..e9f3d5598d142c 100644 --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -41,6 +41,7 @@ add_llvm_unittest(IRTests ValueHandleTest.cpp ValueMapTest.cpp ValueTest.cpp + VectorBuilderTest.cpp VectorTypesTest.cpp VerifierTest.cpp VPIntrinsicTest.cpp diff --git a/llvm/unittests/IR/VectorBuilderTest.cpp b/llvm/unittests/IR/VectorBuilderTest.cpp new file mode 100644 index 00000000000000..82ce045ab4b053 --- /dev/null +++ b/llvm/unittests/IR/VectorBuilderTest.cpp @@ -0,0 +1,280 @@ +//===--------- VectorBuilderTest.cpp - VectorBuilder unit tests -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/VectorBuilder.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +static unsigned VectorNumElements = 8; + +class VectorBuilderTest : public testing::Test { +protected: + LLVMContext Context; + + VectorBuilderTest() : Context() {} + + std::unique_ptr createBuilderModule(Function *&Func, BasicBlock *&BB, + Value *&Mask, Value *&EVL) { + auto Mod = std::make_unique("TestModule", Context); + auto *Int32Ty = Type::getInt32Ty(Context); + auto *Mask8Ty = + FixedVectorType::get(Type::getInt1Ty(Context), VectorNumElements); + auto *VoidFuncTy = + FunctionType::get(Type::getVoidTy(Context), {Mask8Ty, Int32Ty}, false); + Func = + Function::Create(VoidFuncTy, GlobalValue::ExternalLinkage, "bla", *Mod); + Mask = Func->getArg(0); + EVL = Func->getArg(1); + BB = BasicBlock::Create(Context, "entry", Func); + + return Mod; + } +}; + +/// Check that creating binary arithmetic VP intrinsics works. +TEST_F(VectorBuilderTest, TestCreateBinaryInstructions) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + VectorBuilder VBuild(Builder); + VBuild.setMask(Mask).setEVL(EVL); + + auto *FloatVecTy = + FixedVectorType::get(Type::getFloatTy(Context), VectorNumElements); + auto *IntVecTy = + FixedVectorType::get(Type::getInt32Ty(Context), VectorNumElements); + +#define HANDLE_BINARY_INST(NUM, OPCODE, INSTCLASS) \ + { \ + auto VPID = VPIntrinsic::getForOpcode(Instruction::OPCODE); \ + bool IsFP = (#INSTCLASS)[0] == 'F'; \ + auto *ValueTy = IsFP ? FloatVecTy : IntVecTy; \ + Value *Op = UndefValue::get(ValueTy); \ + auto *I = VBuild.createVectorInstruction(Instruction::OPCODE, ValueTy, \ + {Op, Op}); \ + ASSERT_TRUE(isa(I)); \ + auto *VPIntrin = cast(I); \ + ASSERT_EQ(VPIntrin->getIntrinsicID(), VPID); \ + ASSERT_EQ(VPIntrin->getMaskParam(), Mask); \ + ASSERT_EQ(VPIntrin->getVectorLengthParam(), EVL); \ + } +#include "llvm/IR/Instruction.def" +} + +static bool isAllTrueMask(Value *Val, unsigned NumElements) { + auto *ConstMask = dyn_cast(Val); + if (!ConstMask) + return false; + + // Structure check. + if (!ConstMask->isAllOnesValue()) + return false; + + // Type check. + auto *MaskVecTy = cast(ConstMask->getType()); + if (MaskVecTy->getNumElements() != NumElements) + return false; + + return MaskVecTy->getElementType()->isIntegerTy(1); +} + +/// Check that creating binary arithmetic VP intrinsics works. +TEST_F(VectorBuilderTest, TestCreateBinaryInstructions_FixedVector_NoMask) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + VectorBuilder VBuild(Builder); + VBuild.setEVL(EVL).setStaticVL(VectorNumElements); + + auto *FloatVecTy = + FixedVectorType::get(Type::getFloatTy(Context), VectorNumElements); + auto *IntVecTy = + FixedVectorType::get(Type::getInt32Ty(Context), VectorNumElements); + +#define HANDLE_BINARY_INST(NUM, OPCODE, INSTCLASS) \ + { \ + auto VPID = VPIntrinsic::getForOpcode(Instruction::OPCODE); \ + bool IsFP = (#INSTCLASS)[0] == 'F'; \ + Type *ValueTy = IsFP ? FloatVecTy : IntVecTy; \ + Value *Op = UndefValue::get(ValueTy); \ + auto *I = VBuild.createVectorInstruction(Instruction::OPCODE, ValueTy, \ + {Op, Op}); \ + ASSERT_TRUE(isa(I)); \ + auto *VPIntrin = cast(I); \ + ASSERT_EQ(VPIntrin->getIntrinsicID(), VPID); \ + ASSERT_TRUE(isAllTrueMask(VPIntrin->getMaskParam(), VectorNumElements)); \ + ASSERT_EQ(VPIntrin->getVectorLengthParam(), EVL); \ + } +#include "llvm/IR/Instruction.def" +} + +static bool isLegalConstEVL(Value *Val, unsigned ExpectedEVL) { + auto *ConstEVL = dyn_cast(Val); + if (!ConstEVL) + return false; + + // Value check. + if (ConstEVL->getZExtValue() != ExpectedEVL) + return false; + + // Type check. + return ConstEVL->getType()->isIntegerTy(32); +} + +/// Check that creating binary arithmetic VP intrinsics works. +TEST_F(VectorBuilderTest, TestCreateBinaryInstructions_FixedVector_NoEVL) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + VectorBuilder VBuild(Builder); + VBuild.setMask(Mask).setStaticVL(VectorNumElements); + + auto *FloatVecTy = + FixedVectorType::get(Type::getFloatTy(Context), VectorNumElements); + auto *IntVecTy = + FixedVectorType::get(Type::getInt32Ty(Context), VectorNumElements); + +#define HANDLE_BINARY_INST(NUM, OPCODE, INSTCLASS) \ + { \ + auto VPID = VPIntrinsic::getForOpcode(Instruction::OPCODE); \ + bool IsFP = (#INSTCLASS)[0] == 'F'; \ + Type *ValueTy = IsFP ? FloatVecTy : IntVecTy; \ + Value *Op = UndefValue::get(ValueTy); \ + auto *I = VBuild.createVectorInstruction(Instruction::OPCODE, ValueTy, \ + {Op, Op}); \ + ASSERT_TRUE(isa(I)); \ + auto *VPIntrin = cast(I); \ + ASSERT_EQ(VPIntrin->getIntrinsicID(), VPID); \ + ASSERT_EQ(VPIntrin->getMaskParam(), Mask); \ + ASSERT_TRUE( \ + isLegalConstEVL(VPIntrin->getVectorLengthParam(), VectorNumElements)); \ + } +#include "llvm/IR/Instruction.def" +} + +/// Check that creating binary arithmetic VP intrinsics works. +TEST_F(VectorBuilderTest, + TestCreateBinaryInstructions_FixedVector_NoMask_NoEVL) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + VectorBuilder VBuild(Builder); + VBuild.setStaticVL(VectorNumElements); + + auto *FloatVecTy = + FixedVectorType::get(Type::getFloatTy(Context), VectorNumElements); + auto *IntVecTy = + FixedVectorType::get(Type::getInt32Ty(Context), VectorNumElements); + +#define HANDLE_BINARY_INST(NUM, OPCODE, INSTCLASS) \ + { \ + auto VPID = VPIntrinsic::getForOpcode(Instruction::OPCODE); \ + bool IsFP = (#INSTCLASS)[0] == 'F'; \ + Type *ValueTy = IsFP ? FloatVecTy : IntVecTy; \ + Value *Op = UndefValue::get(ValueTy); \ + auto *I = VBuild.createVectorInstruction(Instruction::OPCODE, ValueTy, \ + {Op, Op}); \ + ASSERT_TRUE(isa(I)); \ + auto *VPIntrin = cast(I); \ + ASSERT_EQ(VPIntrin->getIntrinsicID(), VPID); \ + ASSERT_TRUE(isAllTrueMask(VPIntrin->getMaskParam(), VectorNumElements)); \ + ASSERT_TRUE( \ + isLegalConstEVL(VPIntrin->getVectorLengthParam(), VectorNumElements)); \ + } +#include "llvm/IR/Instruction.def" +} +/// Check that creating vp.load/vp.store works. +TEST_F(VectorBuilderTest, TestCreateLoadStore) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + VectorBuilder VBuild(Builder); + VBuild.setMask(Mask).setEVL(EVL); + + auto *FloatVecTy = + FixedVectorType::get(Type::getFloatTy(Context), VectorNumElements); + auto *FloatVecPtrTy = FloatVecTy->getPointerTo(); + + Value *FloatVecPtr = UndefValue::get(FloatVecPtrTy); + Value *FloatVec = UndefValue::get(FloatVecTy); + + // vp.load + auto LoadVPID = VPIntrinsic::getForOpcode(Instruction::Load); + auto *LoadIntrin = VBuild.createVectorInstruction(Instruction::Load, + FloatVecTy, {FloatVecPtr}); + ASSERT_TRUE(isa(LoadIntrin)); + auto *VPLoad = cast(LoadIntrin); + ASSERT_EQ(VPLoad->getIntrinsicID(), LoadVPID); + ASSERT_EQ(VPLoad->getMemoryPointerParam(), FloatVecPtr); + + // vp.store + auto *VoidTy = Builder.getVoidTy(); + auto StoreVPID = VPIntrinsic::getForOpcode(Instruction::Store); + auto *StoreIntrin = VBuild.createVectorInstruction(Instruction::Store, VoidTy, + {FloatVec, FloatVecPtr}); + ASSERT_TRUE(isa(LoadIntrin)); + auto *VPStore = cast(StoreIntrin); + ASSERT_EQ(VPStore->getIntrinsicID(), StoreVPID); + ASSERT_EQ(VPStore->getMemoryPointerParam(), FloatVecPtr); + ASSERT_EQ(VPStore->getMemoryDataParam(), FloatVec); +} + +/// Check that the SilentlyReturnNone error handling mode works. +TEST_F(VectorBuilderTest, TestFail_SilentlyReturnNone) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + auto *VoidTy = Builder.getVoidTy(); + VectorBuilder VBuild(Builder, VectorBuilder::Behavior::SilentlyReturnNone); + VBuild.setMask(Mask).setEVL(EVL); + auto *Val = VBuild.createVectorInstruction(Instruction::Br, VoidTy, {}); + ASSERT_EQ(Val, nullptr); +} + +/// Check that the ReportAndFail error handling mode aborts as advertised. +TEST_F(VectorBuilderTest, TestFail_ReportAndAbort) { + Function *F; + BasicBlock *BB; + Value *Mask, *EVL; + auto Mod = createBuilderModule(F, BB, Mask, EVL); + + IRBuilder<> Builder(BB); + auto *VoidTy = Builder.getVoidTy(); + VectorBuilder VBuild(Builder, VectorBuilder::Behavior::ReportAndAbort); + VBuild.setMask(Mask).setEVL(EVL); + ASSERT_DEATH({ VBuild.createVectorInstruction(Instruction::Br, VoidTy, {}); }, + "No VPIntrinsic for this opcode"); +} + +} // end anonymous namespace