diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp index 56025aa5c45fb..7b0215535a92c 100644 --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -111,7 +111,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, SmallVector ScalarArgTypes; std::string ScalarName; Function *FuncToReplace = nullptr; - if (auto *CI = dyn_cast(&I)) { + auto *CI = dyn_cast(&I); + if (CI) { FuncToReplace = CI->getCalledFunction(); Intrinsic::ID IID = FuncToReplace->getIntrinsicID(); assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic"); @@ -168,12 +169,36 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, if (!OptInfo) return false; + // There is no guarantee that the vectorized instructions followed the VFABI + // specification when being created, this is why we need to add extra check to + // make sure that the operands of the vector function obtained via VFABI match + // the operands of the original vector instruction. + if (CI) { + for (auto VFParam : OptInfo->Shape.Parameters) { + if (VFParam.ParamKind == VFParamKind::GlobalPredicate) + continue; + + // tryDemangleForVFABI must return valid ParamPos, otherwise it could be + // a bug in the VFABI parser. + assert(VFParam.ParamPos < CI->arg_size() && + "ParamPos has invalid range."); + Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType(); + if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName + << ". Wrong type at index " << VFParam.ParamPos + << ": " << *OrigTy << "\n"); + return false; + } + } + } + FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy); if (!VectorFTy) return false; Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy, VD->getVectorFnName(), FuncToReplace); + replaceWithTLIFunction(I, *OptInfo, TLIFunc); LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName << "` with call to `" << TLIFunc->getName() << "`.\n"); @@ -220,6 +245,9 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F, const TargetLibraryInfo &TLI = AM.getResult(F); auto Changed = runImpl(TLI, F); if (Changed) { + LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: " + << NumCallsReplaced << "\n"); + PreservedAnalyses PA; PA.preserveSet(); PA.preserve(); diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt index 847430bf17697..e7505f2633d92 100644 --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ set(ANALYSIS_TEST_SOURCES PluginInlineAdvisorAnalysisTest.cpp PluginInlineOrderAnalysisTest.cpp ProfileSummaryInfoTest.cpp + ReplaceWithVecLibTest.cpp ScalarEvolutionTest.cpp VectorFunctionABITest.cpp SparsePropagation.cpp diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp new file mode 100644 index 0000000000000..a1f0a4a894c8d --- /dev/null +++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp @@ -0,0 +1,113 @@ +//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ReplaceWithVeclib.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +/// NOTE: Assertions must be enabled for these tests to run. +#ifndef NDEBUG + +namespace { + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("ReplaceWithVecLibTest", errs()); + return Mod; +} + +/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This +/// allows checking that the pass won't crash when the function to replace (from +/// the input IR) does not match the replacement function (derived from the +/// VecDesc mapping). +/// +class ReplaceWithVecLibTest : public ::testing::Test { + + std::string getLastLine(std::string Out) { + // remove any trailing '\n' + if (!Out.empty() && *(Out.cend() - 1) == '\n') + Out.pop_back(); + + size_t LastNL = Out.find_last_of('\n'); + return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1); + } + +protected: + LLVMContext Ctx; + + /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib + /// pass. The pass should not crash even when the replacement function + /// (derived from the \p VD mapping) does not match the function to be + /// replaced (from the input \p IR). + /// + /// \returns the last line of the standard error to be compared for + /// correctness. + std::string run(const VecDesc &VD, const char *IR) { + // Create TLII and register it with FAM so it's preserved when + // ReplaceWithVeclib pass runs. + TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple()); + TLII.addVectorizableFunctions({VD}); + FunctionAnalysisManager FAM; + FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); }); + + // Register and run the pass on the 'foo' function from the input IR. + FunctionPassManager FPM; + FPM.addPass(ReplaceWithVeclib()); + std::unique_ptr M = parseIR(Ctx, IR); + PassBuilder PB; + PB.registerFunctionAnalyses(FAM); + + // Enable debugging and capture std error + llvm::DebugFlag = true; + testing::internal::CaptureStderr(); + FPM.run(*M->getFunction("foo"), FAM); + return getLastLine(testing::internal::GetCapturedStderr()); + } +}; + +} // end anonymous namespace + +static const char *IR = R"IR( +define @foo( %in){ + %call = call @llvm.powi.f32.i32( %in, i32 3) + ret %call +} + +declare @llvm.powi.f32.i32(, i32) #0 +)IR"; + +// The VFABI prefix in TLI describes signature which is matching the powi +// intrinsic declaration. +TEST_F(ReplaceWithVecLibTest, TestValidMapping) { + VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi", + ElementCount::getScalable(4), /*Masked*/ true, + "_ZGVsMxvu"}; + EXPECT_EQ(run(CorrectVD, IR), + "Instructions replaced with vector libraries: 1"); +} + +// The VFABI prefix in TLI describes signature which is not matching the powi +// intrinsic declaration. +TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) { + VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi", + ElementCount::getScalable(4), /*Masked*/ true, + "_ZGVsMxvv"}; + EXPECT_EQ(run(IncorrectVD, IR), + "replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong " + "type at index 1: i32"); +} +#endif \ No newline at end of file