diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h new file mode 100644 index 0000000000000..2c4ba30f6fd05 --- /dev/null +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h @@ -0,0 +1,77 @@ +//===- InstrMaps.h ----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H +#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/SandboxIR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm::sandboxir { + +/// Maps the original instructions to the vectorized instrs and the reverse. +/// For now an original instr can only map to a single vector. +class InstrMaps { + /// A map from the original values that got combined into vectors, to the + /// vector value(s). + DenseMap OrigToVectorMap; + /// A map from the vector value to a map of the original value to its lane. + /// Please note that for constant vectors, there may multiple original values + /// with the same lane, as they may be coming from vectorizing different + /// original values. + DenseMap> VectorToOrigLaneMap; + +public: + /// \Returns the vector value that we got from vectorizing \p Orig, or + /// nullptr if not found. + Value *getVectorForOrig(Value *Orig) const { + auto It = OrigToVectorMap.find(Orig); + return It != OrigToVectorMap.end() ? It->second : nullptr; + } + /// \Returns the lane of \p Orig before it got vectorized into \p Vec, or + /// nullopt if not found. + std::optional getOrigLane(Value *Vec, Value *Orig) const { + auto It1 = VectorToOrigLaneMap.find(Vec); + if (It1 == VectorToOrigLaneMap.end()) + return std::nullopt; + const auto &OrigToLaneMap = It1->second; + auto It2 = OrigToLaneMap.find(Orig); + if (It2 == OrigToLaneMap.end()) + return std::nullopt; + return It2->second; + } + /// Update the map to reflect that \p Origs got vectorized into \p Vec. + void registerVector(ArrayRef Origs, Value *Vec) { + auto &OrigToLaneMap = VectorToOrigLaneMap[Vec]; + for (auto [Lane, Orig] : enumerate(Origs)) { + auto Pair = OrigToVectorMap.try_emplace(Orig, Vec); + assert(Pair.second && "Orig already exists in the map!"); + OrigToLaneMap[Orig] = Lane; + } + } + void clear() { + OrigToVectorMap.clear(); + VectorToOrigLaneMap.clear(); + } +#ifndef NDEBUG + void print(raw_ostream &OS) const { + OS << "OrigToVectorMap:\n"; + for (auto [Orig, Vec] : OrigToVectorMap) + OS << *Orig << " : " << *Vec << "\n"; + } + LLVM_DUMP_METHOD void dump() const; +#endif +}; +} // namespace llvm::sandboxir + +#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 233cf82a1b3df..c03e7a10397ad 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -23,10 +23,12 @@ namespace llvm::sandboxir { class LegalityAnalysis; class Value; +class InstrMaps; enum class LegalityResultID { - Pack, ///> Collect scalar values. - Widen, ///> Vectorize by combining scalars to a vector. + Pack, ///> Collect scalar values. + Widen, ///> Vectorize by combining scalars to a vector. + DiamondReuse, ///> Don't generate new code, reuse existing vector. }; /// The reason for vectorizing or not vectorizing. @@ -50,6 +52,8 @@ struct ToStr { return "Pack"; case LegalityResultID::Widen: return "Widen"; + case LegalityResultID::DiamondReuse: + return "DiamondReuse"; } llvm_unreachable("Unknown LegalityResultID enum"); } @@ -137,6 +141,19 @@ class Widen final : public LegalityResult { } }; +class DiamondReuse final : public LegalityResult { + friend class LegalityAnalysis; + Value *Vec; + DiamondReuse(Value *Vec) + : LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {} + +public: + static bool classof(const LegalityResult *From) { + return From->getSubclassID() == LegalityResultID::DiamondReuse; + } + Value *getVector() const { return Vec; } +}; + class Pack final : public LegalityResultWithReason { Pack(ResultReason Reason) : LegalityResultWithReason(LegalityResultID::Pack, Reason) {} @@ -148,6 +165,59 @@ class Pack final : public LegalityResultWithReason { } }; +/// Describes how to collect the values needed by each lane. +class CollectDescr { +public: + /// Describes how to get a value element. If the value is a vector then it + /// also provides the index to extract it from. + class ExtractElementDescr { + Value *V; + /// The index in `V` that the value can be extracted from. + /// This is nullopt if we need to use `V` as a whole. + std::optional ExtractIdx; + + public: + ExtractElementDescr(Value *V, int ExtractIdx) + : V(V), ExtractIdx(ExtractIdx) {} + ExtractElementDescr(Value *V) : V(V), ExtractIdx(std::nullopt) {} + Value *getValue() const { return V; } + bool needsExtract() const { return ExtractIdx.has_value(); } + int getExtractIdx() const { return *ExtractIdx; } + }; + + using DescrVecT = SmallVector; + DescrVecT Descrs; + +public: + CollectDescr(SmallVectorImpl &&Descrs) + : Descrs(std::move(Descrs)) {} + /// If all elements come from a single vector input, then return that vector + /// and whether we need a shuffle to get them in order. + std::optional> getSingleInput() const { + const auto &Descr0 = *Descrs.begin(); + Value *V0 = Descr0.getValue(); + if (!Descr0.needsExtract()) + return std::nullopt; + bool NeedsShuffle = Descr0.getExtractIdx() != 0; + int Lane = 1; + for (const auto &Descr : drop_begin(Descrs)) { + if (!Descr.needsExtract()) + return std::nullopt; + if (Descr.getValue() != V0) + return std::nullopt; + if (Descr.getExtractIdx() != Lane++) + NeedsShuffle = true; + } + return std::make_pair(V0, NeedsShuffle); + } + bool hasVectorInputs() const { + return any_of(Descrs, [](const auto &D) { return D.needsExtract(); }); + } + const SmallVector &getDescrs() const { + return Descrs; + } +}; + /// Performs the legality analysis and returns a LegalityResult object. class LegalityAnalysis { Scheduler Sched; @@ -160,11 +230,17 @@ class LegalityAnalysis { ScalarEvolution &SE; const DataLayout &DL; + InstrMaps &IMaps; + + /// Finds how we can collect the values in \p Bndl from the vectorized or + /// non-vectorized code. It returns a map of the value we should extract from + /// and the corresponding shuffle mask we need to use. + CollectDescr getHowToCollectValues(ArrayRef Bndl) const; public: LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL, - Context &Ctx) - : Sched(AA, Ctx), SE(SE), DL(DL) {} + Context &Ctx, InstrMaps &IMaps) + : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {} /// A LegalityResult factory. template ResultT &createLegalityResult(ArgsT... Args) { @@ -177,7 +253,7 @@ class LegalityAnalysis { // TODO: Try to remove the SkipScheduling argument by refactoring the tests. const LegalityResult &canVectorize(ArrayRef Bndl, bool SkipScheduling = false); - void clear() { Sched.clear(); } + void clear(); }; } // namespace llvm::sandboxir diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h index 1a53ca6e06f5f..69cea3c4c7b53 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h @@ -18,6 +18,7 @@ #include "llvm/SandboxIR/Pass.h" #include "llvm/SandboxIR/PassManager.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" namespace llvm::sandboxir { @@ -26,6 +27,8 @@ class BottomUpVec final : public FunctionPass { bool Change = false; std::unique_ptr Legality; DenseSet DeadInstrCandidates; + /// Maps scalars to vectors. + InstrMaps IMaps; /// Creates and returns a vector instruction that replaces the instructions in /// \p Bndl. \p Operands are the already vectorized operands. diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt index d769d5100afd2..6a025652f92f8 100644 --- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt +++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt @@ -4,6 +4,7 @@ add_llvm_component_library(LLVMVectorize LoopVectorizationLegality.cpp LoopVectorize.cpp SandboxVectorizer/DependencyGraph.cpp + SandboxVectorizer/InstrMaps.cpp SandboxVectorizer/Interval.cpp SandboxVectorizer/Legality.cpp SandboxVectorizer/Passes/BottomUpVec.cpp diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp new file mode 100644 index 0000000000000..4df4829a04c41 --- /dev/null +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/InstrMaps.cpp @@ -0,0 +1,21 @@ +//===- InstructionMaps.cpp - Maps scalars to vectors and reverse ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" +#include "llvm/Support/Debug.h" + +namespace llvm::sandboxir { + +#ifndef NDEBUG +void InstrMaps::dump() const { + print(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + +} // namespace llvm::sandboxir diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index 8c6deeb7df249..f8149c5bc6636 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -12,6 +12,7 @@ #include "llvm/SandboxIR/Utils.h" #include "llvm/SandboxIR/Value.h" #include "llvm/Support/Debug.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" namespace llvm::sandboxir { @@ -184,6 +185,22 @@ static void dumpBndl(ArrayRef Bndl) { } #endif // NDEBUG +CollectDescr +LegalityAnalysis::getHowToCollectValues(ArrayRef Bndl) const { + SmallVector Vec; + Vec.reserve(Bndl.size()); + for (auto [Lane, V] : enumerate(Bndl)) { + if (auto *VecOp = IMaps.getVectorForOrig(V)) { + // If there is a vector containing `V`, then get the lane it came from. + std::optional ExtractIdxOpt = IMaps.getOrigLane(VecOp, V); + Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1); + } else { + Vec.emplace_back(V); + } + } + return CollectDescr(std::move(Vec)); +} + const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, bool SkipScheduling) { // If Bndl contains values other than instructions, we need to Pack. @@ -193,11 +210,21 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, return createLegalityResult(ResultReason::NotInstructions); } + auto CollectDescrs = getHowToCollectValues(Bndl); + if (CollectDescrs.hasVectorInputs()) { + if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) { + auto [Vec, NeedsShuffle] = *ValueShuffleOpt; + if (!NeedsShuffle) + return createLegalityResult(Vec); + llvm_unreachable("TODO: Unimplemented"); + } else { + llvm_unreachable("TODO: Unimplemented"); + } + } + if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) return createLegalityResult(*ReasonOpt); - // TODO: Check for existing vectors containing values in Bndl. - if (!SkipScheduling) { // TODO: Try to remove the IBndl vector. SmallVector IBndl; @@ -210,4 +237,9 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, return createLegalityResult(); } + +void LegalityAnalysis::clear() { + Sched.clear(); + IMaps.clear(); +} } // namespace llvm::sandboxir diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index d44199609838d..6b2032be53560 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -56,103 +56,114 @@ getInsertPointAfterInstrs(ArrayRef Instrs) { Value *BottomUpVec::createVectorInstr(ArrayRef Bndl, ArrayRef Operands) { - Change = true; - assert(all_of(Bndl, [](auto *V) { return isa(V); }) && - "Expect Instructions!"); - auto &Ctx = Bndl[0]->getContext(); + auto CreateVectorInstr = [](ArrayRef Bndl, + ArrayRef Operands) -> Value * { + assert(all_of(Bndl, [](auto *V) { return isa(V); }) && + "Expect Instructions!"); + auto &Ctx = Bndl[0]->getContext(); - Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); - auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); + Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); + auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); - BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); + BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); - auto Opcode = cast(Bndl[0])->getOpcode(); - switch (Opcode) { - case Instruction::Opcode::ZExt: - case Instruction::Opcode::SExt: - case Instruction::Opcode::FPToUI: - case Instruction::Opcode::FPToSI: - case Instruction::Opcode::FPExt: - case Instruction::Opcode::PtrToInt: - case Instruction::Opcode::IntToPtr: - case Instruction::Opcode::SIToFP: - case Instruction::Opcode::UIToFP: - case Instruction::Opcode::Trunc: - case Instruction::Opcode::FPTrunc: - case Instruction::Opcode::BitCast: { - assert(Operands.size() == 1u && "Casts are unary!"); - return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast"); - } - case Instruction::Opcode::FCmp: - case Instruction::Opcode::ICmp: { - auto Pred = cast(Bndl[0])->getPredicate(); - assert(all_of(drop_begin(Bndl), - [Pred](auto *SBV) { - return cast(SBV)->getPredicate() == Pred; - }) && - "Expected same predicate across bundle."); - return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, - "VCmp"); - } - case Instruction::Opcode::Select: { - return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, - Ctx, "Vec"); - } - case Instruction::Opcode::FNeg: { - auto *UOp0 = cast(Bndl[0]); - auto OpC = UOp0->getOpcode(); - return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt, - Ctx, "Vec"); - } - case Instruction::Opcode::Add: - case Instruction::Opcode::FAdd: - case Instruction::Opcode::Sub: - case Instruction::Opcode::FSub: - case Instruction::Opcode::Mul: - case Instruction::Opcode::FMul: - case Instruction::Opcode::UDiv: - case Instruction::Opcode::SDiv: - case Instruction::Opcode::FDiv: - case Instruction::Opcode::URem: - case Instruction::Opcode::SRem: - case Instruction::Opcode::FRem: - case Instruction::Opcode::Shl: - case Instruction::Opcode::LShr: - case Instruction::Opcode::AShr: - case Instruction::Opcode::And: - case Instruction::Opcode::Or: - case Instruction::Opcode::Xor: { - auto *BinOp0 = cast(Bndl[0]); - auto *LHS = Operands[0]; - auto *RHS = Operands[1]; - return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS, - BinOp0, WhereIt, Ctx, "Vec"); - } - case Instruction::Opcode::Load: { - auto *Ld0 = cast(Bndl[0]); - Value *Ptr = Ld0->getPointerOperand(); - return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL"); - } - case Instruction::Opcode::Store: { - auto Align = cast(Bndl[0])->getAlign(); - Value *Val = Operands[0]; - Value *Ptr = Operands[1]; - return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); - } - case Instruction::Opcode::Br: - case Instruction::Opcode::Ret: - case Instruction::Opcode::PHI: - case Instruction::Opcode::AddrSpaceCast: - case Instruction::Opcode::Call: - case Instruction::Opcode::GetElementPtr: - llvm_unreachable("Unimplemented"); - break; - default: - llvm_unreachable("Unimplemented"); - break; + auto Opcode = cast(Bndl[0])->getOpcode(); + switch (Opcode) { + case Instruction::Opcode::ZExt: + case Instruction::Opcode::SExt: + case Instruction::Opcode::FPToUI: + case Instruction::Opcode::FPToSI: + case Instruction::Opcode::FPExt: + case Instruction::Opcode::PtrToInt: + case Instruction::Opcode::IntToPtr: + case Instruction::Opcode::SIToFP: + case Instruction::Opcode::UIToFP: + case Instruction::Opcode::Trunc: + case Instruction::Opcode::FPTrunc: + case Instruction::Opcode::BitCast: { + assert(Operands.size() == 1u && "Casts are unary!"); + return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, + "VCast"); + } + case Instruction::Opcode::FCmp: + case Instruction::Opcode::ICmp: { + auto Pred = cast(Bndl[0])->getPredicate(); + assert(all_of(drop_begin(Bndl), + [Pred](auto *SBV) { + return cast(SBV)->getPredicate() == Pred; + }) && + "Expected same predicate across bundle."); + return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, + "VCmp"); + } + case Instruction::Opcode::Select: { + return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, + Ctx, "Vec"); + } + case Instruction::Opcode::FNeg: { + auto *UOp0 = cast(Bndl[0]); + auto OpC = UOp0->getOpcode(); + return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, + WhereIt, Ctx, "Vec"); + } + case Instruction::Opcode::Add: + case Instruction::Opcode::FAdd: + case Instruction::Opcode::Sub: + case Instruction::Opcode::FSub: + case Instruction::Opcode::Mul: + case Instruction::Opcode::FMul: + case Instruction::Opcode::UDiv: + case Instruction::Opcode::SDiv: + case Instruction::Opcode::FDiv: + case Instruction::Opcode::URem: + case Instruction::Opcode::SRem: + case Instruction::Opcode::FRem: + case Instruction::Opcode::Shl: + case Instruction::Opcode::LShr: + case Instruction::Opcode::AShr: + case Instruction::Opcode::And: + case Instruction::Opcode::Or: + case Instruction::Opcode::Xor: { + auto *BinOp0 = cast(Bndl[0]); + auto *LHS = Operands[0]; + auto *RHS = Operands[1]; + return BinaryOperator::createWithCopiedFlags( + BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec"); + } + case Instruction::Opcode::Load: { + auto *Ld0 = cast(Bndl[0]); + Value *Ptr = Ld0->getPointerOperand(); + return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, + "VecL"); + } + case Instruction::Opcode::Store: { + auto Align = cast(Bndl[0])->getAlign(); + Value *Val = Operands[0]; + Value *Ptr = Operands[1]; + return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); + } + case Instruction::Opcode::Br: + case Instruction::Opcode::Ret: + case Instruction::Opcode::PHI: + case Instruction::Opcode::AddrSpaceCast: + case Instruction::Opcode::Call: + case Instruction::Opcode::GetElementPtr: + llvm_unreachable("Unimplemented"); + break; + default: + llvm_unreachable("Unimplemented"); + break; + } + llvm_unreachable("Missing switch case!"); + // TODO: Propagate debug info. + }; + + auto *VecI = CreateVectorInstr(Bndl, Operands); + if (VecI != nullptr) { + Change = true; + IMaps.registerVector(Bndl, VecI); } - llvm_unreachable("Missing switch case!"); - // TODO: Propagate debug info. + return VecI; } void BottomUpVec::tryEraseDeadInstrs() { @@ -280,6 +291,10 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl, unsigned Depth) { collectPotentiallyDeadInstrs(Bndl); break; } + case LegalityResultID::DiamondReuse: { + NewVec = cast(LegalityRes).getVector(); + break; + } case LegalityResultID::Pack: { // If we can't vectorize the seeds then just return. if (Depth == 0) @@ -300,9 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef Bndl) { } bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { + IMaps.clear(); Legality = std::make_unique( A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), - F.getContext()); + F.getContext(), IMaps); Change = false; const auto &DL = F.getParent()->getDataLayout(); unsigned VecRegBits = diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index d34c8f88e4b3c..7bc6e5ac3d760 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -201,3 +201,23 @@ define void @pack_vectors(ptr %ptr, ptr %ptr2) { store float %ld1, ptr %ptr1 ret void } + +define void @diamond(ptr %ptr) { +; CHECK-LABEL: define void @diamond( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VECL]] +; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %ptr0 + %ld1 = load float, ptr %ptr1 + %sub0 = fsub float %ld0, %ld0 + %sub1 = fsub float %ld1, %ld1 + store float %sub0, ptr %ptr0 + store float %sub1, ptr %ptr1 + ret void +} diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt index df689767b7724..bbfbcc730a4cb 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(SandboxVectorizerTests DependencyGraphTest.cpp + InstrMapsTest.cpp IntervalTest.cpp LegalityTest.cpp SchedulerTest.cpp diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp new file mode 100644 index 0000000000000..bcfb8db7f8674 --- /dev/null +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp @@ -0,0 +1,78 @@ +//===- InstrMapsTest.cpp --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/SandboxIR/Function.h" +#include "llvm/SandboxIR/Instruction.h" +#include "llvm/Support/SourceMgr.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace llvm; + +struct InstrMapsTest : public testing::Test { + LLVMContext C; + std::unique_ptr M; + + void parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("InstrMapsTest", errs()); + } +}; + +TEST_F(InstrMapsTest, Basic) { + parseIR(C, R"IR( +define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) { + %add0 = add i8 %v0, %v0 + %add1 = add i8 %v1, %v1 + %add2 = add i8 %v2, %v2 + %add3 = add i8 %v3, %v3 + %vadd0 = add <2 x i8> %vec, %vec + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + + auto *Add0 = cast(&*It++); + auto *Add1 = cast(&*It++); + auto *Add2 = cast(&*It++); + auto *Add3 = cast(&*It++); + auto *VAdd0 = cast(&*It++); + [[maybe_unused]] auto *Ret = cast(&*It++); + + sandboxir::InstrMaps IMaps; + // Check with empty IMaps. + EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr); + EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr); + EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0)); + // Check with 1 match. + IMaps.registerVector({Add0, Add1}, VAdd0); + EXPECT_EQ(IMaps.getVectorForOrig(Add0), VAdd0); + EXPECT_EQ(IMaps.getVectorForOrig(Add1), VAdd0); + EXPECT_FALSE(IMaps.getOrigLane(VAdd0, VAdd0)); // Bad Orig value + EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0)); // Bad Vector value + EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add0), 0); + EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1); + // Check when the same vector maps to different original values (which is + // common for vector constants). + IMaps.registerVector({Add2, Add3}, VAdd0); + EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add2), 0); + EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add3), 1); + // Check when we register for a second time. +#ifndef NDEBUG + EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*"); +#endif // NDEBUG +} diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index b5e2c302f5901..2e90462a633c1 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -18,6 +18,7 @@ #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Instruction.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" #include "gtest/gtest.h" using namespace llvm; @@ -110,7 +111,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *CmpSLT = cast(&*It++); auto *CmpSGT = cast(&*It++); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); + llvm::sandboxir::InstrMaps IMaps; + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); const auto &Result = Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); @@ -228,7 +230,8 @@ define void @foo(ptr %ptr) { auto *St0 = cast(&*It++); auto *St1 = cast(&*It++); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); + llvm::sandboxir::InstrMaps IMaps; + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); { // Can vectorize St0,St1. const auto &Result = Legality.canVectorize({St0, St1}); @@ -263,7 +266,8 @@ define void @foo() { }; sandboxir::Context Ctx(C); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); + llvm::sandboxir::InstrMaps IMaps; + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen")); EXPECT_TRUE(Matches(Legality.createLegalityResult( @@ -283,3 +287,68 @@ define void @foo() { "Pack Reason: DiffWrapFlags")); } #endif // NDEBUG + +TEST_F(LegalityTest, CollectDescr) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr float, ptr %ptr, i32 0 + %gep1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %gep0 + %ld1 = load float, ptr %gep1 + %vld = load <4 x float>, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + getAnalyses(*LLVMF); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + [[maybe_unused]] auto *Gep0 = cast(&*It++); + [[maybe_unused]] auto *Gep1 = cast(&*It++); + auto *Ld0 = cast(&*It++); + [[maybe_unused]] auto *Ld1 = cast(&*It++); + auto *VLd = cast(&*It++); + + sandboxir::CollectDescr::DescrVecT Descrs; + using EEDescr = sandboxir::CollectDescr::ExtractElementDescr; + + { + // Check single input, no shuffle. + Descrs.push_back(EEDescr(VLd, 0)); + Descrs.push_back(EEDescr(VLd, 1)); + sandboxir::CollectDescr CD(std::move(Descrs)); + EXPECT_TRUE(CD.getSingleInput()); + EXPECT_EQ(CD.getSingleInput()->first, VLd); + EXPECT_EQ(CD.getSingleInput()->second, false); + EXPECT_TRUE(CD.hasVectorInputs()); + } + { + // Check single input, shuffle. + Descrs.push_back(EEDescr(VLd, 1)); + Descrs.push_back(EEDescr(VLd, 0)); + sandboxir::CollectDescr CD(std::move(Descrs)); + EXPECT_TRUE(CD.getSingleInput()); + EXPECT_EQ(CD.getSingleInput()->first, VLd); + EXPECT_EQ(CD.getSingleInput()->second, true); + EXPECT_TRUE(CD.hasVectorInputs()); + } + { + // Check multiple inputs. + Descrs.push_back(EEDescr(Ld0)); + Descrs.push_back(EEDescr(VLd, 0)); + Descrs.push_back(EEDescr(VLd, 1)); + sandboxir::CollectDescr CD(std::move(Descrs)); + EXPECT_FALSE(CD.getSingleInput()); + EXPECT_TRUE(CD.hasVectorInputs()); + } + { + // Check multiple inputs only scalars. + Descrs.push_back(EEDescr(Ld0)); + Descrs.push_back(EEDescr(Ld1)); + sandboxir::CollectDescr CD(std::move(Descrs)); + EXPECT_FALSE(CD.getSingleInput()); + EXPECT_FALSE(CD.hasVectorInputs()); + } +}