diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h new file mode 100644 index 0000000000000..78c1c0e4c0464 --- /dev/null +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -0,0 +1,62 @@ +//===- Legality.h -----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Legality checks for the Sandbox Vectorizer. +// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H +#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H + +#include "llvm/SandboxIR/SandboxIR.h" + +namespace llvm::sandboxir { + +class LegalityAnalysis; + +enum class LegalityResultID { + Widen, ///> Vectorize by combining scalars to a vector. +}; + +/// The legality outcome is represented by a class rather than an enum class +/// because in some cases the legality checks are expensive and look for a +/// particular instruction that can be passed along to the vectorizer to avoid +/// repeating the same expensive computation. +class LegalityResult { +protected: + LegalityResultID ID; + /// Only Legality can create LegalityResults. + LegalityResult(LegalityResultID ID) : ID(ID) {} + friend class LegalityAnalysis; + +public: + LegalityResultID getSubclassID() const { return ID; } +}; + +class Widen final : public LegalityResult { + friend class LegalityAnalysis; + Widen() : LegalityResult(LegalityResultID::Widen) {} + +public: + static bool classof(const LegalityResult *From) { + return From->getSubclassID() == LegalityResultID::Widen; + } +}; + +/// Performs the legality analysis and returns a LegalityResult object. +class LegalityAnalysis { +public: + LegalityAnalysis() = default; + LegalityResult canVectorize(ArrayRef Bndl) { + // TODO: For now everything is legal. + return Widen(); + } +}; + +} // namespace llvm::sandboxir + +#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h index 5b3d1a50aa1ec..99582e3e0e023 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h @@ -12,11 +12,18 @@ #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H +#include "llvm/ADT/ArrayRef.h" #include "llvm/SandboxIR/Pass.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" namespace llvm::sandboxir { class BottomUpVec final : public FunctionPass { + bool Change = false; + LegalityAnalysis Legality; + void vectorizeRec(ArrayRef Bndl); + void tryVectorize(ArrayRef Seeds); public: BottomUpVec() : FunctionPass("bottom-up-vec") {} diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index c4870b70fd52d..0c44d05f0474d 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -7,7 +7,58 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" +#include "llvm/ADT/SmallVector.h" using namespace llvm::sandboxir; -bool BottomUpVec::runOnFunction(Function &F) { return false; } +namespace llvm::sandboxir { +// TODO: This is a temporary function that returns some seeds. +// Replace this with SeedCollector's function when it lands. +static llvm::SmallVector collectSeeds(BasicBlock &BB) { + llvm::SmallVector Seeds; + for (auto &I : BB) + if (auto *SI = llvm::dyn_cast(&I)) + Seeds.push_back(SI); + return Seeds; +} + +static SmallVector getOperand(ArrayRef Bndl, + unsigned OpIdx) { + SmallVector Operands; + for (Value *BndlV : Bndl) { + auto *BndlI = cast(BndlV); + Operands.push_back(BndlI->getOperand(OpIdx)); + } + return Operands; +} + +} // namespace llvm::sandboxir + +void BottomUpVec::vectorizeRec(ArrayRef Bndl) { + auto LegalityRes = Legality.canVectorize(Bndl); + switch (LegalityRes.getSubclassID()) { + case LegalityResultID::Widen: { + auto *I = cast(Bndl[0]); + for (auto OpIdx : seq(I->getNumOperands())) { + auto OperandBndl = getOperand(Bndl, OpIdx); + vectorizeRec(OperandBndl); + } + break; + } + } +} + +void BottomUpVec::tryVectorize(ArrayRef Bndl) { vectorizeRec(Bndl); } + +bool BottomUpVec::runOnFunction(Function &F) { + Change = false; + // TODO: Start from innermost BBs first + for (auto &BB : F) { + // TODO: Replace with proper SeedCollector function. + auto Seeds = collectSeeds(BB); + // TODO: Slice Seeds into smaller chunks. + if (Seeds.size() >= 2) + tryVectorize(Seeds); + } + return Change; +} diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt index 488c9c2344b56..2c7bf7d7e8754 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt @@ -9,4 +9,5 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(SandboxVectorizerTests DependencyGraphTest.cpp + LegalityTest.cpp ) diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp new file mode 100644 index 0000000000000..a136be41ae363 --- /dev/null +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -0,0 +1,56 @@ +//===- LegalityTest.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +struct LegalityTest : public testing::Test { + LLVMContext C; + std::unique_ptr M; + + void parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("LegalityTest", errs()); + } +}; + +TEST_F(LegalityTest, Legality) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr float, ptr %ptr, i32 0 + %gep1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %gep0 + %ld1 = load float, ptr %gep0 + store float %ld0, ptr %gep0 + store float %ld1, ptr %gep1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + [[maybe_unused]] auto *Gep0 = cast(&*It++); + [[maybe_unused]] auto *Gep1 = cast(&*It++); + [[maybe_unused]] auto *Ld0 = cast(&*It++); + [[maybe_unused]] auto *Ld1 = cast(&*It++); + auto *St0 = cast(&*It++); + auto *St1 = cast(&*It++); + + sandboxir::LegalityAnalysis Legality; + auto Result = Legality.canVectorize({St0, St1}); + EXPECT_TRUE(isa(Result)); +}