diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h index 901769c54b6ef..afdc8e3698b2d 100644 --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -71,6 +71,7 @@ FunctionPass *createAArch64PostSelectOptimize(); FunctionPass *createAArch64StackTaggingPass(bool IsOptNone); FunctionPass *createAArch64StackTaggingPreRAPass(); ModulePass *createAArch64GlobalsTaggingPass(); +FunctionPass *createAArch64DotProdMatcherPass(); void initializeAArch64A53Fix835769Pass(PassRegistry&); void initializeAArch64A57FPLoadBalancingPass(PassRegistry&); @@ -108,6 +109,7 @@ void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&); void initializeLDTLSCleanupPass(PassRegistry&); void initializeSMEABIPass(PassRegistry &); void initializeSVEIntrinsicOptsPass(PassRegistry &); +void initializeAArch64DotProdMatcherPass(PassRegistry &); } // end namespace llvm #endif diff --git a/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp new file mode 100644 index 0000000000000..44215efee75c3 --- /dev/null +++ b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp @@ -0,0 +1,486 @@ +//===- AArch64DotProdMatcher - Matches instruction sequences to *DOT ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass recognizes and transforms IR to make use of two relatively simple +// cases that can be implemented by the SDOT and UDOT instructions on AArch64 +// in order to increase vector unit bandwidth. +// +//===----------------------------------------------------------------------===// + +#include "AArch64.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" +#include "Utils/AArch64BaseInfo.h" +#include +#include +#include +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "aarch64-dot-product-matcher" + +#define DOT_ACCUMULATOR_DEPTH (4) + +STATISTIC(NumDOTInstrs, "Number of DOT Instructions generated."); +STATISTIC(NumSimpleDOTReplacements, "Num of simple dot patterns replaced."); +STATISTIC(NumLoopDOTReplacements, "Num of loop dot patterns replaced."); + +struct LoopAccumulate { + Value *RVal; + PHINode *Phi; + Value *IterVals; + Value *Predicate; + Value *Mul; + Value *ValA; + Value *ValB; + VectorType *VTy; + Type *AccTy; + BasicBlock *LoopBlock; + BasicBlock *PHBlock; + bool IsSExt; + + LoopAccumulate(Value *RVal, PHINode *Phi, Value *IterVals, Value *Predicate, + Value *Mul, Value *ValA, Value *ValB, VectorType *VTy, + Type *AccTy, BasicBlock *LoopBlock, BasicBlock *PHBlock, + bool IsSExt) + : RVal(RVal), Phi(Phi), IterVals(IterVals), Predicate(Predicate), + Mul(Mul), ValA(ValA), ValB(ValB), VTy(VTy), AccTy(AccTy), LoopBlock(LoopBlock), + PHBlock(PHBlock), IsSExt(IsSExt) {} +}; + +// Returns true if the instruction in question is an vector integer add +// reduction intrinsic. +static bool isScalableIntegerSumReduction(Instruction &I) { + auto *II = dyn_cast(&I); + return II && + II->getIntrinsicID() == Intrinsic::vector_reduce_add && + isa(II->getOperand(0)->getType()); +} + +// Returns a vector type for a dot product accumulator if the element type and +// extended element type are suitable, or a nullptr if not. +static Type *getAccumulatorType(Type *EltTy, Type *ExtEltTy, ElementCount EC) { + Type *AccEltTy = nullptr; + if (EltTy->isIntegerTy(8) && ExtEltTy->getPrimitiveSizeInBits() <= 32) + AccEltTy = Type::getInt32Ty(EltTy->getContext()); + else if (EltTy->isIntegerTy(16) && ExtEltTy->getPrimitiveSizeInBits() <= 64) + AccEltTy = Type::getInt64Ty(EltTy->getContext()); + + if (AccEltTy) + return VectorType::get(AccEltTy, EC); + + return nullptr; +} + +// Returns either a pair of basic block pointers corresponding to the expected +// two incoming values for the phi, or None if one of the checks failed. +static std::optional> +getPHIIncomingBlocks(PHINode *Phi) { + // Check PHI; we're only expecting the incoming value from within the loop + // and one incoming value from a preheader. + if (Phi->getNumIncomingValues() != 2) + return std::nullopt; + + BasicBlock *PHBlock = Phi->getIncomingBlock(0); + BasicBlock *LoopBlock = Phi->getIncomingBlock(1); + // If this isn't a loop, or if it's a loop with multiple blocks, we bail + // out for now. If needed we can improve this pass later. + if (Phi->getParent() != LoopBlock && Phi->getParent() != PHBlock) + return std::nullopt; + + // Make sure we know which incoming value belongs to the loop + if (PHBlock == Phi->getParent()) + std::swap(LoopBlock, PHBlock); + + // If there's a non-null incoming value from the preheader, bail out for now. + // We may be able to do better in future. + Constant *Const = dyn_cast(Phi->getIncomingValueForBlock(PHBlock)); + if (LoopBlock != Phi->getParent() || !Const || !Const->isNullValue()) + return std::nullopt; + + return std::make_pair(LoopBlock, PHBlock); +} + +static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals, + SmallVectorImpl &Accumulators) { + // Check a possible loop accumulator. + bool IsSExt = false; + + // We only expect the add in the loop to be used by the reduction and by + // the PHI node. + if (!RVal->hasNUses(2) || !is_contained(OldPHI->incoming_values(), RVal)) { + LLVM_DEBUG(dbgs() << "Loop sum operation has more than two uses or isn't " + "used by the accumulating PHI node.\n"); + return false; + } + + // Look through selects with zeroinitializer. Record the predicate so + // we can insert selects for the base values later. + Value *Predicate = nullptr, *Mul = nullptr; + if (!match(IterVals, m_Select(m_Value(Predicate), m_Value(Mul), m_Zero()))) + Mul = IterVals; + + Value *ValA = nullptr, *ValB = nullptr; + // Match the core pattern of element-wise multiplication of extended values. + if (match(Mul, m_OneUse(m_Mul(m_SExt(m_OneUse(m_Value(ValA))), + m_SExt(m_OneUse(m_Value(ValB))))))) + IsSExt = true; + else if (!match(Mul, m_OneUse(m_Mul(m_ZExt(m_OneUse(m_Value(ValA))), + m_ZExt(m_OneUse(m_Value(ValB))))))) { + LLVM_DEBUG(dbgs() << "Couldn't match inner loop multiply: " + << *Mul << "\n"); + return false; + } + + // The same extended value could be used for both operands of the multiply, + // so we just need to check that they have a single user. + Instruction *I = dyn_cast(Mul); + if (!I->getOperand(0)->hasOneUser() || !I->getOperand(1)->hasOneUser()) + return false; + + // Check that the vector type is one packed vector's worth of data. + // TODO: Do we want to allow multiples? + VectorType *ValTy = cast(ValA->getType()); + if (ValTy->getPrimitiveSizeInBits().getKnownMinValue() != + AArch64::SVEBitsPerBlock) { + LLVM_DEBUG(dbgs() << "Vector base size is not a packed representation.\n"); + return false; + } + + // Find the accumulator element type after extension and check that it isn't + // too large; if it is, we might lose data by converting to dot instructions. + // The element count needs to be 1/4th that of the input data, since the + // dot product instructions take four smaller elements and multiply/accumulate + // them into one larger element. + Type *AccTy = getAccumulatorType(ValTy->getElementType(), + Mul->getType()->getScalarType(), + ValTy->getElementCount().divideCoefficientBy(4)); + + if (!AccTy) { + LLVM_DEBUG(dbgs() << "Accumulator element type too wide.\n"); + return false; + } + + // Validate the phi node and retrieve the incoming basic blocks for the + // accumulating loop itself and the preheader. + auto PhiBlocks = getPHIIncomingBlocks(OldPHI); + + if (!PhiBlocks) { + LLVM_DEBUG(dbgs() << "Unable to match PHI node\n"); + return false; + } + + // Everything looks in order, so add it to the list of accumulators to + // transform. + Accumulators.emplace_back(RVal, OldPHI, IterVals, Predicate, Mul, ValA, + ValB, ValTy, AccTy, PhiBlocks->first, + PhiBlocks->second, IsSExt); + return true; +} + +static bool findDOTAccumulatorsInLoop(Value *RVal, + SmallVectorImpl &Accumulators, + unsigned Depth = DOT_ACCUMULATOR_DEPTH) { + // Don't recurse too far. + if (Depth == 0) + return false; + + Value *V1 = nullptr, *V2 = nullptr; + + // Try to match the expected pattern from a sum reduction in + // a vectorized loop. + if (match(RVal, m_Add(m_Value(V1), m_Value(V2)))) { + if (isa(V1) && !isa(V2) && + V1->hasOneUse() && V2->hasOneUse()) + return checkLoopAcc(RVal, cast(V1), V2, Accumulators); + + if (!isa(V1) && isa(V2) && + V1->hasOneUse() && V2->hasOneUse()) + return checkLoopAcc(RVal, cast(V2), V1, Accumulators); + + // Otherwise assume this is an intermediate multi-register reduction + // and recurse to the operands. + return findDOTAccumulatorsInLoop(V1, Accumulators, Depth - 1) && + findDOTAccumulatorsInLoop(V2, Accumulators, Depth - 1); + } + + return false; +} + +namespace { + +class AArch64DotProdMatcher : public FunctionPass { +public: + static char ID; + AArch64DotProdMatcher() : FunctionPass(ID) { + initializeAArch64DotProdMatcherPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + TTI = &getAnalysis().getTTI(F); + + bool Changed = false; + SmallVector Reductions; + for (BasicBlock &Block : F) + // TODO: Support non-scalable dot instructions too. + for (Instruction &I : make_filter_range(Block, + isScalableIntegerSumReduction)) + Reductions.push_back(&I); + + for (auto *Rdx : Reductions) + Changed |= trySimpleDotReplacement(*Rdx) || tryLoopDotReplacement(*Rdx); + + return Changed; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } + + TargetTransformInfo *TTI; + +private: + bool trySimpleDotReplacement(Instruction &I); + bool tryLoopDotReplacement(Instruction &I); +}; + +} // end anonymous namespace + +char AArch64DotProdMatcher::ID = 0; +INITIALIZE_PASS_BEGIN(AArch64DotProdMatcher, DEBUG_TYPE, + "AArch64 Dot Product Instruction Matcher", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(AArch64DotProdMatcher, DEBUG_TYPE, + "AArch64 Dot Product Instruction Matcher", false, false) + +FunctionPass *llvm::createAArch64DotProdMatcherPass() { + return new AArch64DotProdMatcher(); +} + +// The following method looks for a simple pattern of two values being either +// sign or zero extended, multiplied together, then summed. If the types +// match the ones used by the [s|u]dot instructions (groups of 4x8 -> 32, +// groups of 4x16 -> 64) then we can replace the extends and multiply with a +// dot instruction and swap the reduce for one using fewer elements. +// +// +-----------+ +-----------+ +// | ValA | | ValB | +// +-----+-----+ +-----+-----+ +// | | +// | | +// +-----v-----+ +-----v-----+ +// | [S|Z]Ext | | [S|Z]Ext | +// +-----+-----+ +-----+-----+ +// | | +// +--+ +--+ +// | | +// +v---------v+ +// | Mul | +// +-----+-----+ +// | +// | +// +-----v-----+ +// | Reduce(+) | +// +-----------+ +bool AArch64DotProdMatcher::trySimpleDotReplacement(Instruction &I) { + LLVM_DEBUG(dbgs() << "Looking for simple dot reduction: " << I << "\n"); + Value *RVal = I.getOperand(0); + Value *ValA = nullptr, *ValB = nullptr; + bool IsSExt = false; + + if (match(RVal, m_Mul(m_SExt(m_Value(ValA)), m_SExt(m_Value(ValB))))) + IsSExt = true; + else if (!match(RVal, m_Mul(m_ZExt(m_Value(ValA)), m_ZExt(m_Value(ValB))))) { + LLVM_DEBUG(dbgs() << "Unable to match simple dot pattern\n"); + return false; + } + + VectorType *ATy = cast(ValA->getType()); + VectorType *BTy = cast(ValB->getType()); + VectorType *MTy = cast(RVal->getType()); + if (ATy != BTy || !((ATy->getScalarType()->isIntegerTy(8) && + MTy->getScalarType()->isIntegerTy(32)) || + (ATy->getScalarType()->isIntegerTy(16) && + MTy->getScalarType()->isIntegerTy(64)))) { + LLVM_DEBUG(dbgs() << "Unable to match types for simple dot pattern\n"); + return false; + } + + if (TTI->getRegisterBitWidth(TargetTransformInfo::RGK_ScalableVector) != + ATy->getPrimitiveSizeInBits()) + return false; + + // All conditions met, proceed with replacement. + IRBuilder<> Builder(cast(RVal)); + + // Need a new accumulator type. + Type *AccTy = VectorType::get(MTy->getScalarType(), + MTy->getElementCount().divideCoefficientBy(4)); + Value *Zeroes = ConstantAggregateZero::get(AccTy); + + Intrinsic::ID IntID = IsSExt ? Intrinsic::aarch64_sve_sdot : + Intrinsic::aarch64_sve_udot; + Value *DotProd = Builder.CreateIntrinsic(IntID, {AccTy}, + {Zeroes, ValA, ValB}); + Builder.SetInsertPoint(&I); + Value *Reduce = Builder.CreateAddReduce(DotProd); + I.replaceAllUsesWith(Reduce); + NumDOTInstrs++; + NumSimpleDOTReplacements++; + return true; +} + +// This method looks for the following pattern: It starts from a sum +// reduction, but expects to find a vector add operation inside a loop with one +// of the operands being a PHI. The other operand can either be a select +// between zeroes and a multiply, or just the multiply directly. The rest of +// the pattern is the same as the simpler case -- multiply of extends of some +// values. +// +// Replacing this is a little tricky, since we need to replace the PHI node +// and accumulator as well, and potentially add in new selects earlier, but if +// everything checks out then the extend -> multiply -> inner loop add operation +// is replaced by the [s|u]dot instruction. +// +// +-----------+ +// | Zero | +// +-+---------+ +// +-------+ +---------------------+ | +// | | | | +// | +--v------v-+ | +// | | OldPHI | | +// | +--+--------+ | +// | | | +// | | +-----------+ +-----------+| +// | | | ValA | | ValB || +// | | +-----+-----+ +-----+-----+| +// | | | | | +// | | | | | +// | | +-----v-----+ +-----v-----+| +// | | | [S|Z]Ext | | [S|Z]Ext || +// | | +-----+-----+ +-----+-----+| +// | | | | | +// | | +--+ +--+ | +// | | | | | +// | | +v---------v+ | +// | | | Mul | | +// | | +-+---------+ | +// | | | +----------+ +// | | | | +// | | +-v-------v-+ +// | | | Select | +// | | +--+--------+ +// | | | +// | | | +// | | | +// | +--v--------------v---+ +// | | Add | +// | +--+-------+----------+ +// | | | +// +-------+ | +// | +// +-----v-----+ +// | Reduce(+) | +// +-----------+ +bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) { + LLVM_DEBUG(dbgs() << "Looking for Loop DOT Reduction: " << I << "\n"); + Value *RVal = I.getOperand(0); + SmallVector Accumulators; + std::deque RdxVals; + IRBuilder<> Builder(&I); + + // If the loop was interleaved, we may have some intermediate add + // instructions first before we get to the accumulators inside the + // loop. Gather those first then process them. + if (!findDOTAccumulatorsInLoop(RVal, Accumulators)) { + LLVM_DEBUG(dbgs() << "Couldn't find DOT accumulators in the loop\n"); + return false; + } + + // All conditions met, proceed with replacement. + for (auto &Acc : Accumulators) { + Builder.SetInsertPoint(Acc.Phi); + + // Plant new PHI node. + PHINode *DotAcc = Builder.CreatePHI(Acc.AccTy, 2, "dot.accumulate"); + Value *Zeroes = ConstantAggregateZero::get(Acc.AccTy); + DotAcc->addIncoming(Zeroes, Acc.PHBlock); + + // Move to the dot insertion point. + Builder.SetInsertPoint(cast(Acc.RVal)); + + // Need to generate selects for ValA and ValB if there was one before the + // accumulate before. + // Hopefully we can fold away some extra selects (e.g. if the data originally + // came from masked loads with the same predicate). + if (Acc.Predicate) { + Value *Zeroes = ConstantAggregateZero::get(Acc.VTy); + Acc.ValA = Builder.CreateSelect(Acc.Predicate, Acc.ValA, Zeroes); + Acc.ValB = Builder.CreateSelect(Acc.Predicate, Acc.ValB, Zeroes); + } + + // Now plant the dot instruction. + Intrinsic::ID IntID = Acc.IsSExt ? Intrinsic::aarch64_sve_sdot : + Intrinsic::aarch64_sve_udot; + Value *DotProd = Builder.CreateIntrinsic(IntID, {Acc.AccTy}, + {DotAcc, Acc.ValA, Acc.ValB}); + DotAcc->addIncoming(DotProd, Acc.LoopBlock); + + RdxVals.push_back(DotProd); + + NumDOTInstrs++; + } + + assert(!RdxVals.empty() && + "We found accumulators but generated no RdxVals"); + + + Builder.SetInsertPoint(cast(RVal)); + + while (RdxVals.size() > 1) { + RdxVals.push_back(Builder.CreateAdd(RdxVals[0], RdxVals[1])); + // Drop the two RdxVals we just reduced. Sadly, there's no SmallDeque + // with a pop_front_val() convenience method yet. + RdxVals.pop_front(); + RdxVals.pop_front(); + } + + // Plant new reduction. + Builder.SetInsertPoint(&I); + Value *Reduce = Builder.CreateAddReduce(RdxVals.front()); + Value *Trunc = Builder.CreateTrunc(Reduce, I.getType(), "dot.trunc"); + I.replaceAllUsesWith(Trunc); + + + // Delete the original reduction, since it's no longer required + RecursivelyDeleteTriviallyDeadInstructions(&I); + NumLoopDOTReplacements++; + return true; +} + diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp index 3d818c76bd4b7..4a76d2f705a5a 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -165,6 +165,11 @@ static cl::opt cl::desc("Enable SVE intrinsic opts"), cl::init(true)); +static cl::opt +EnableAArch64DotProdMatch("aarch64-enable-dotprodmatch", cl::Hidden, + cl::desc("Enable matching dot product instructions"), + cl::init(true)); + static cl::opt EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix", cl::init(true), cl::Hidden); @@ -246,6 +251,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() { initializeAArch64LowerHomogeneousPrologEpilogPass(*PR); initializeAArch64DAGToDAGISelPass(*PR); initializeAArch64GlobalsTaggingPass(*PR); + initializeAArch64DotProdMatcherPass(*PR); } //===----------------------------------------------------------------------===// @@ -553,6 +559,11 @@ void AArch64PassConfig::addIRPasses() { // ourselves. addPass(createAtomicExpandPass()); + // Make use of SVE intrinsics in place of common vector operations that span + // multiple basic blocks. + if (TM->getOptLevel() != CodeGenOptLevel::None && EnableAArch64DotProdMatch) + addPass(createAArch64DotProdMatcherPass()); + // Expand any SVE vector library calls that we can't code generate directly. if (EnableSVEIntrinsicOpts && TM->getOptLevel() == CodeGenOptLevel::Aggressive) diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt index d97342b0829d8..b89ce94b93122 100644 --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -50,6 +50,7 @@ add_llvm_target(AArch64CodeGen AArch64CondBrTuning.cpp AArch64ConditionalCompares.cpp AArch64DeadRegisterDefinitionsPass.cpp + AArch64DotProdMatcher.cpp AArch64ExpandImm.cpp AArch64ExpandPseudoInsts.cpp AArch64FalkorHWPFFix.cpp diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll index f5c1c3c291cb5..7d196b8579d20 100644 --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -22,6 +22,7 @@ ; CHECK-NEXT: Expand large div/rem ; CHECK-NEXT: Expand large fp convert ; CHECK-NEXT: Expand Atomic instructions +; CHECK-NEXT: AArch64 Dot Product Instruction Matcher ; CHECK-NEXT: SVE intrinsics optimizations ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/AArch64/dotprodmatch.ll b/llvm/test/CodeGen/AArch64/dotprodmatch.ll new file mode 100644 index 0000000000000..a75048351b810 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/dotprodmatch.ll @@ -0,0 +1,684 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt -S -aarch64-dot-product-matcher -instcombine < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define i16 @sve_sdot_loop_i16_to_i32(ptr readonly %a, ptr readonly %b, i32 %N) #0 { +; CHECK-LABEL: define i16 @sve_sdot_loop_i16_to_i32 +; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP11:%.*]] = icmp sgt i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP11]], label [[MIN_ITERS_CHECKED:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: min.iters.checked: +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[PREDICATE_ENTRY:%.*]] = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[PREDICATE:%.*]] = phi [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE:%.*]] = phi [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call @llvm.masked.load.nxv8i16.p0(ptr [[TMP0]], i32 2, [[PREDICATE]], zeroinitializer) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD19:%.*]] = call @llvm.masked.load.nxv8i16.p0(ptr [[TMP1]], i32 2, [[PREDICATE]], zeroinitializer) +; CHECK-NEXT: [[TMP2]] = call @llvm.aarch64.sve.sdot.nxv2i64( [[DOT_ACCUMULATE]], [[WIDE_MASKED_LOAD19]], [[WIDE_MASKED_LOAD]]) +; CHECK-NEXT: [[VS:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[VS_SCALED:%.*]] = shl i64 [[VS]], 3 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]] +; CHECK-NEXT: [[PREDICATE_NEXT]] = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[PREDICATE_NEXT]], i64 0 +; CHECK-NEXT: br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]] +; CHECK: middle.block: +; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64( [[TMP2]]) +; CHECK-NEXT: [[PHITMP201:%.*]] = lshr i64 [[TMP4]], 16 +; CHECK-NEXT: [[PHITMP:%.*]] = trunc i64 [[PHITMP201]] to i16 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: ret i16 [[ACC_0_LCSSA]] +; +entry: + %cmp11 = icmp sgt i32 %N, 0 + br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup + +min.iters.checked: ; preds = %entry + %wide.trip.count = zext i32 %N to i64 + %wide.end.idx.splatinsert = insertelement undef, i64 %wide.trip.count, i32 0 + %wide.end.idx.splat = shufflevector %wide.end.idx.splatinsert, undef, zeroinitializer + %predicate.entry = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 %wide.trip.count) + br label %vector.body + +vector.body: ; preds = %vector.body, %min.iters.checked + %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ] + %predicate = phi [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ] + %vec.phi = phi [ zeroinitializer, %min.iters.checked ], [ %6, %vector.body ] + %0 = getelementptr inbounds i16, ptr %a, i64 %index + %wide.masked.load = call @llvm.masked.load.nxv8i16.p0(ptr %0, i32 2, %predicate, undef) + %1 = sext %wide.masked.load to + %2 = getelementptr inbounds i16, ptr %b, i64 %index + %wide.masked.load19 = call @llvm.masked.load.nxv8i16.p0(ptr %2, i32 2, %predicate, undef) + %3 = sext %wide.masked.load19 to + %4 = mul nsw %3, %1 + %5 = select %predicate, %4, zeroinitializer + %6 = add nsw %vec.phi, %5 + %vs = call i64 @llvm.vscale.i64() + %vs.scaled = mul i64 %vs, 8 + %index.next = add nuw i64 %index, %vs.scaled + %.splatinsert = insertelement undef, i64 %index.next, i32 0 + %.splat = shufflevector %.splatinsert, undef, zeroinitializer + %predicate.next = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 %index.next, i64 %wide.trip.count) + %7 = extractelement %predicate.next, i64 0 + br i1 %7, label %vector.body, label %middle.block + +middle.block: ; preds = %vector.body + %8 = call i32 @llvm.vector.reduce.add.nxv8i32( %6) + %phitmp20 = lshr i32 %8, 16 + %phitmp = trunc i32 %phitmp20 to i16 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %middle.block, %entry + %acc.0.lcssa = phi i16 [ 0, %entry ], [ %phitmp, %middle.block ] + ret i16 %acc.0.lcssa +} + +define dso_local i16 @sve_sdot_loop_i16_to_i32_interleavedx2_scalartail(ptr readonly %a, ptr readonly %b, i32 %N) #0 { +; CHECK-LABEL: define dso_local i16 @sve_sdot_loop_i16_to_i32_interleavedx2_scalartail +; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP9:%.*]] = icmp sgt i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP9]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: for.body.preheader: +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 4 +; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ugt i64 [[TMP1]], [[WIDE_TRIP_COUNT]] +; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER17:%.*]], label [[VECTOR_PH:%.*]] +; CHECK: vector.ph: +; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP1]] +; CHECK-NEXT: [[N_VEC:%.*]] = sub nuw nsw i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]] +; CHECK-NEXT: [[TMP2:%.*]] = tail call i32 @llvm.vscale.i32() +; CHECK-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 3 +; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64 +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE1:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[TMP5]], align 2 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i64 [[TMP4]] +; CHECK-NEXT: [[WIDE_LOAD14:%.*]] = load , ptr [[TMP6]], align 2 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_LOAD15:%.*]] = load , ptr [[TMP7]], align 2 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i64 [[TMP4]] +; CHECK-NEXT: [[WIDE_LOAD16:%.*]] = load , ptr [[TMP8]], align 2 +; CHECK-NEXT: [[TMP9]] = call @llvm.aarch64.sve.sdot.nxv2i64( [[DOT_ACCUMULATE1]], [[WIDE_LOAD15]], [[WIDE_LOAD]]) +; CHECK-NEXT: [[TMP10]] = call @llvm.aarch64.sve.sdot.nxv2i64( [[DOT_ACCUMULATE]], [[WIDE_LOAD16]], [[WIDE_LOAD14]]) +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP1]] +; CHECK-NEXT: [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]] +; CHECK: middle.block: +; CHECK-NEXT: [[TMP12:%.*]] = add [[TMP10]], [[TMP9]] +; CHECK-NEXT: [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64( [[TMP12]]) +; CHECK-NEXT: [[DOT_TRUNC:%.*]] = trunc i64 [[TMP13]] to i32 +; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 +; CHECK-NEXT: [[EXTRACT4:%.*]] = lshr i64 [[TMP13]], 16 +; CHECK-NEXT: [[EXTRACT_T:%.*]] = trunc i64 [[EXTRACT4]] to i16 +; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER17]] +; CHECK: for.body.preheader17: +; CHECK-NEXT: [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: [[ACC_010_PH:%.*]] = phi i32 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[DOT_TRUNC]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.cond.cleanup.loopexit: +; CHECK-NEXT: [[ADD_LCSSA_OFF16:%.*]] = phi i16 [ [[EXTRACT_T]], [[MIDDLE_BLOCK]] ], [ [[EXTRACT_T3:%.*]], [[FOR_BODY]] ] +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[ADD_LCSSA_OFF16]], [[FOR_COND_CLEANUP_LOOPEXIT]] ] +; CHECK-NEXT: ret i16 [[ACC_0_LCSSA]] +; CHECK: for.body: +; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER17]] ] +; CHECK-NEXT: [[ACC_010:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER17]] ] +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDVARS_IV]] +; CHECK-NEXT: [[TMP14:%.*]] = load i16, ptr [[ARRAYIDX]], align 2 +; CHECK-NEXT: [[CONV:%.*]] = sext i16 [[TMP14]] to i32 +; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDVARS_IV]] +; CHECK-NEXT: [[TMP15:%.*]] = load i16, ptr [[ARRAYIDX2]], align 2 +; CHECK-NEXT: [[CONV3:%.*]] = sext i16 [[TMP15]] to i32 +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[CONV3]], [[CONV]] +; CHECK-NEXT: [[ADD]] = add nsw i32 [[MUL]], [[ACC_010]] +; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1 +; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT]] +; CHECK-NEXT: [[EXTRACT2:%.*]] = lshr i32 [[ADD]], 16 +; CHECK-NEXT: [[EXTRACT_T3]] = trunc i32 [[EXTRACT2]] to i16 +; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]] +; +entry: + %cmp9 = icmp sgt i32 %N, 0 + br i1 %cmp9, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + %wide.trip.count = zext i32 %N to i64 + %0 = tail call i64 @llvm.vscale.i64() + %1 = shl nuw nsw i64 %0, 4 + %min.iters.check = icmp ugt i64 %1, %wide.trip.count + br i1 %min.iters.check, label %for.body.preheader17, label %vector.ph + +vector.ph: ; preds = %for.body.preheader + %n.mod.vf = urem i64 %wide.trip.count, %1 + %n.vec = sub nuw nsw i64 %wide.trip.count, %n.mod.vf + %2 = tail call i32 @llvm.vscale.i32() + %3 = shl nuw nsw i32 %2, 3 + %4 = zext i32 %3 to i64 + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ] + %vec.phi = phi [ zeroinitializer, %vector.ph ], [ %15, %vector.body ] + %vec.phi13 = phi [ zeroinitializer, %vector.ph ], [ %16, %vector.body ] + %5 = getelementptr inbounds i16, ptr %a, i64 %index + %wide.load = load , ptr %5, align 2 + %6 = getelementptr inbounds i16, ptr %5, i64 %4 + %wide.load14 = load , ptr %6, align 2 + %7 = sext %wide.load to + %8 = sext %wide.load14 to + %9 = getelementptr inbounds i16, ptr %b, i64 %index + %wide.load15 = load , ptr %9, align 2 + %10 = getelementptr inbounds i16, ptr %9, i64 %4 + %wide.load16 = load , ptr %10, align 2 + %11 = sext %wide.load15 to + %12 = sext %wide.load16 to + %13 = mul nsw %11, %7 + %14 = mul nsw %12, %8 + %15 = add %13, %vec.phi + %16 = add %14, %vec.phi13 + %index.next = add nuw i64 %index, %1 + %17 = icmp eq i64 %index.next, %n.vec + br i1 %17, label %middle.block, label %vector.body + +middle.block: ; preds = %vector.body + %bin.rdx = add %16, %15 + %18 = tail call i32 @llvm.vector.reduce.add.nxv8i32( %bin.rdx) + %cmp.n = icmp eq i64 %n.mod.vf, 0 + br i1 %cmp.n, label %for.cond.cleanup.loopexit, label %for.body.preheader17 + +for.body.preheader17: ; preds = %for.body.preheader, %middle.block + %indvars.iv.ph = phi i64 [ 0, %for.body.preheader ], [ %n.vec, %middle.block ] + %acc.010.ph = phi i32 [ 0, %for.body.preheader ], [ %18, %middle.block ] + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %for.body, %middle.block + %add.lcssa = phi i32 [ %18, %middle.block ], [ %add, %for.body ] + %19 = lshr i32 %add.lcssa, 16 + %20 = trunc i32 %19 to i16 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + %acc.0.lcssa = phi i16 [ 0, %entry ], [ %20, %for.cond.cleanup.loopexit ] + ret i16 %acc.0.lcssa + +for.body: ; preds = %for.body.preheader17, %for.body + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ %indvars.iv.ph, %for.body.preheader17 ] + %acc.010 = phi i32 [ %add, %for.body ], [ %acc.010.ph, %for.body.preheader17 ] + %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv + %21 = load i16, ptr %arrayidx, align 2 + %conv = sext i16 %21 to i32 + %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv + %22 = load i16, ptr %arrayidx2, align 2 + %conv3 = sext i16 %22 to i32 + %mul = mul nsw i32 %conv3, %conv + %add = add nsw i32 %mul, %acc.010 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body +} + +define i16 @sve_udot_loop_i16_to_i32(ptr readonly %a, ptr readonly %b, i32 %N) #0 { +; CHECK-LABEL: define i16 @sve_udot_loop_i16_to_i32 +; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP11_NOT:%.*]] = icmp eq i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP11_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[MIN_ITERS_CHECKED:%.*]] +; CHECK: min.iters.checked: +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[PREDICATE_ENTRY:%.*]] = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[PREDICATE:%.*]] = phi [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE:%.*]] = phi [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call @llvm.masked.load.nxv8i16.p0(ptr [[TMP0]], i32 2, [[PREDICATE]], zeroinitializer) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD19:%.*]] = call @llvm.masked.load.nxv8i16.p0(ptr [[TMP1]], i32 2, [[PREDICATE]], zeroinitializer) +; CHECK-NEXT: [[TMP2]] = call @llvm.aarch64.sve.udot.nxv2i64( [[DOT_ACCUMULATE]], [[WIDE_MASKED_LOAD19]], [[WIDE_MASKED_LOAD]]) +; CHECK-NEXT: [[VS:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[VS_SCALED:%.*]] = shl i64 [[VS]], 3 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]] +; CHECK-NEXT: [[PREDICATE_NEXT]] = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[PREDICATE_NEXT]], i64 0 +; CHECK-NEXT: br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]] +; CHECK: middle.block: +; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64( [[TMP2]]) +; CHECK-NEXT: [[PHITMP201:%.*]] = lshr i64 [[TMP4]], 16 +; CHECK-NEXT: [[PHITMP:%.*]] = trunc i64 [[PHITMP201]] to i16 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: ret i16 [[ACC_0_LCSSA]] +; +entry: + %cmp11 = icmp ugt i32 %N, 0 + br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup + +min.iters.checked: ; preds = %entry + %wide.trip.count = zext i32 %N to i64 + %wide.end.idx.splatinsert = insertelement undef, i64 %wide.trip.count, i32 0 + %wide.end.idx.splat = shufflevector %wide.end.idx.splatinsert, undef, zeroinitializer + %predicate.entry = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 %wide.trip.count) + br label %vector.body + +vector.body: ; preds = %vector.body, %min.iters.checked + %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ] + %predicate = phi [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ] + %vec.phi = phi [ zeroinitializer, %min.iters.checked ], [ %6, %vector.body ] + %0 = getelementptr inbounds i16, ptr %a, i64 %index + %wide.masked.load = call @llvm.masked.load.nxv8i16.p0(ptr %0, i32 2, %predicate, undef) + %1 = zext %wide.masked.load to + %2 = getelementptr inbounds i16, ptr %b, i64 %index + %wide.masked.load19 = call @llvm.masked.load.nxv8i16.p0(ptr %2, i32 2, %predicate, undef) + %3 = zext %wide.masked.load19 to + %4 = mul nsw %3, %1 + %5 = select %predicate, %4, zeroinitializer + %6 = add nsw %vec.phi, %5 + %vs = call i64 @llvm.vscale.i64() + %vs.scaled = mul i64 %vs, 8 + %index.next = add nuw i64 %index, %vs.scaled + %.splatinsert = insertelement undef, i64 %index.next, i32 0 + %.splat = shufflevector %.splatinsert, undef, zeroinitializer + %predicate.next = call @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 %index.next, i64 %wide.trip.count) + %7 = extractelement %predicate.next, i64 0 + br i1 %7, label %vector.body, label %middle.block + +middle.block: ; preds = %vector.body + %8 = call i32 @llvm.vector.reduce.add.nxv8i32( %6) + %phitmp20 = lshr i32 %8, 16 + %phitmp = trunc i32 %phitmp20 to i16 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %middle.block, %entry + %acc.0.lcssa = phi i16 [ 0, %entry ], [ %phitmp, %middle.block ] + ret i16 %acc.0.lcssa +} + +define dso_local i16 @sve_udot_loop_i16_to_i32_interleavedx4_foldedtail(ptr readonly %a, ptr readonly %b, i32 %N) #0 { +; CHECK-LABEL: define dso_local i16 @sve_udot_loop_i16_to_i32_interleavedx4_foldedtail +; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP9:%.*]] = icmp sgt i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP9]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: for.body.preheader: +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 3 +; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 4 +; CHECK-NEXT: [[TMP3:%.*]] = mul nuw nsw i64 [[TMP0]], 24 +; CHECK-NEXT: [[ACTIVE_LANE_MASK_ENTRY:%.*]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK_ENTRY16:%.*]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP2]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK_ENTRY15:%.*]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP1]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK_ENTRY17:%.*]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP3]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.vscale.i32() +; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 3 +; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64 +; CHECK-NEXT: [[TMP7:%.*]] = shl nuw nsw i32 [[TMP4]], 4 +; CHECK-NEXT: [[TMP8:%.*]] = zext i32 [[TMP7]] to i64 +; CHECK-NEXT: [[TMP9:%.*]] = mul nuw nsw i32 [[TMP4]], 24 +; CHECK-NEXT: [[TMP10:%.*]] = zext i32 [[TMP9]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = shl nuw nsw i64 [[TMP0]], 5 +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi [ [[ACTIVE_LANE_MASK_ENTRY]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[ACTIVE_LANE_MASK18:%.*]] = phi [ [[ACTIVE_LANE_MASK_ENTRY15]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT31:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[ACTIVE_LANE_MASK19:%.*]] = phi [ [[ACTIVE_LANE_MASK_ENTRY16]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT32:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[ACTIVE_LANE_MASK20:%.*]] = phi [ [[ACTIVE_LANE_MASK_ENTRY17]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT33:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE3:%.*]] = phi [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP20:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE2:%.*]] = phi [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP21:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE1:%.*]] = phi [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP22:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE:%.*]] = phi [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr [[TMP12]], i32 2, [[ACTIVE_LANE_MASK]], zeroinitializer) +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds i16, ptr [[TMP12]], i64 [[TMP6]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD24:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP13]], i32 2, [[ACTIVE_LANE_MASK18]], zeroinitializer) +; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds i16, ptr [[TMP12]], i64 [[TMP8]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD25:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP14]], i32 2, [[ACTIVE_LANE_MASK19]], zeroinitializer) +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP12]], i64 [[TMP10]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD26:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP15]], i32 2, [[ACTIVE_LANE_MASK20]], zeroinitializer) +; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD27:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr [[TMP16]], i32 2, [[ACTIVE_LANE_MASK]], zeroinitializer) +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds i16, ptr [[TMP16]], i64 [[TMP6]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD28:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP17]], i32 2, [[ACTIVE_LANE_MASK18]], zeroinitializer) +; CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds i16, ptr [[TMP16]], i64 [[TMP8]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD29:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP18]], i32 2, [[ACTIVE_LANE_MASK19]], zeroinitializer) +; CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds i16, ptr [[TMP16]], i64 [[TMP10]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD30:%.*]] = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP19]], i32 2, [[ACTIVE_LANE_MASK20]], zeroinitializer) +; CHECK-NEXT: [[TMP20]] = call @llvm.aarch64.sve.udot.nxv2i64( [[DOT_ACCUMULATE3]], [[WIDE_MASKED_LOAD27]], [[WIDE_MASKED_LOAD]]) +; CHECK-NEXT: [[TMP21]] = call @llvm.aarch64.sve.udot.nxv2i64( [[DOT_ACCUMULATE2]], [[WIDE_MASKED_LOAD28]], [[WIDE_MASKED_LOAD24]]) +; CHECK-NEXT: [[TMP22]] = call @llvm.aarch64.sve.udot.nxv2i64( [[DOT_ACCUMULATE1]], [[WIDE_MASKED_LOAD29]], [[WIDE_MASKED_LOAD25]]) +; CHECK-NEXT: [[TMP23]] = call @llvm.aarch64.sve.udot.nxv2i64( [[DOT_ACCUMULATE]], [[WIDE_MASKED_LOAD30]], [[WIDE_MASKED_LOAD26]]) +; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]] +; CHECK-NEXT: [[TMP24:%.*]] = add i64 [[INDEX_NEXT]], [[TMP1]] +; CHECK-NEXT: [[TMP25:%.*]] = add i64 [[INDEX_NEXT]], [[TMP2]] +; CHECK-NEXT: [[TMP26:%.*]] = add i64 [[INDEX_NEXT]], [[TMP3]] +; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT31]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP24]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT32]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP25]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT33]] = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP26]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[TMP27:%.*]] = extractelement [[ACTIVE_LANE_MASK_NEXT]], i64 0 +; CHECK-NEXT: br i1 [[TMP27]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]] +; CHECK: middle.block: +; CHECK-NEXT: [[TMP28:%.*]] = add [[TMP23]], [[TMP22]] +; CHECK-NEXT: [[TMP29:%.*]] = add [[TMP21]], [[TMP20]] +; CHECK-NEXT: [[TMP30:%.*]] = add [[TMP28]], [[TMP29]] +; CHECK-NEXT: [[TMP31:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64( [[TMP30]]) +; CHECK-NEXT: [[TMP32:%.*]] = lshr i64 [[TMP31]], 16 +; CHECK-NEXT: [[TMP33:%.*]] = trunc i64 [[TMP32]] to i16 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: ret i16 [[ACC_0_LCSSA]] +; +entry: + %cmp9 = icmp sgt i32 %N, 0 + br i1 %cmp9, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + %wide.trip.count = zext i32 %N to i64 + %0 = tail call i64 @llvm.vscale.i64() + %1 = shl nuw nsw i64 %0, 3 + %2 = shl nuw nsw i64 %0, 4 + %3 = mul nuw nsw i64 %0, 24 + %active.lane.mask.entry = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 %wide.trip.count) + %active.lane.mask.entry16 = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %2, i64 %wide.trip.count) + %active.lane.mask.entry15 = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %1, i64 %wide.trip.count) + %active.lane.mask.entry17 = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %3, i64 %wide.trip.count) + %4 = tail call i32 @llvm.vscale.i32() + %5 = shl nuw nsw i32 %4, 3 + %6 = zext i32 %5 to i64 + %7 = shl nuw nsw i32 %4, 4 + %8 = zext i32 %7 to i64 + %9 = mul nuw nsw i32 %4, 24 + %10 = zext i32 %9 to i64 + %11 = shl nuw nsw i64 %0, 5 + br label %vector.body + +vector.body: ; preds = %vector.body, %for.body.preheader + %index = phi i64 [ 0, %for.body.preheader ], [ %index.next, %vector.body ] + %active.lane.mask = phi [ %active.lane.mask.entry, %for.body.preheader ], [ %active.lane.mask.next, %vector.body ] + %active.lane.mask18 = phi [ %active.lane.mask.entry15, %for.body.preheader ], [ %active.lane.mask.next31, %vector.body ] + %active.lane.mask19 = phi [ %active.lane.mask.entry16, %for.body.preheader ], [ %active.lane.mask.next32, %vector.body ] + %active.lane.mask20 = phi [ %active.lane.mask.entry17, %for.body.preheader ], [ %active.lane.mask.next33, %vector.body ] + %vec.phi = phi [ zeroinitializer, %for.body.preheader ], [ %33, %vector.body ] + %vec.phi21 = phi [ zeroinitializer, %for.body.preheader ], [ %35, %vector.body ] + %vec.phi22 = phi [ zeroinitializer, %for.body.preheader ], [ %37, %vector.body ] + %vec.phi23 = phi [ zeroinitializer, %for.body.preheader ], [ %39, %vector.body ] + %12 = getelementptr inbounds i16, ptr %a, i64 %index + %wide.masked.load = tail call @llvm.masked.load.nxv8i16.p0(ptr %12, i32 2, %active.lane.mask, poison) + %13 = getelementptr inbounds i16, ptr %12, i64 %6 + %wide.masked.load24 = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull %13, i32 2, %active.lane.mask18, poison) + %14 = getelementptr inbounds i16, ptr %12, i64 %8 + %wide.masked.load25 = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull %14, i32 2, %active.lane.mask19, poison) + %15 = getelementptr inbounds i16, ptr %12, i64 %10 + %wide.masked.load26 = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull %15, i32 2, %active.lane.mask20, poison) + %16 = zext %wide.masked.load to + %17 = zext %wide.masked.load24 to + %18 = zext %wide.masked.load25 to + %19 = zext %wide.masked.load26 to + %20 = getelementptr inbounds i16, ptr %b, i64 %index + %wide.masked.load27 = tail call @llvm.masked.load.nxv8i16.p0(ptr %20, i32 2, %active.lane.mask, poison) + %21 = getelementptr inbounds i16, ptr %20, i64 %6 + %wide.masked.load28 = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull %21, i32 2, %active.lane.mask18, poison) + %22 = getelementptr inbounds i16, ptr %20, i64 %8 + %wide.masked.load29 = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull %22, i32 2, %active.lane.mask19, poison) + %23 = getelementptr inbounds i16, ptr %20, i64 %10 + %wide.masked.load30 = tail call @llvm.masked.load.nxv8i16.p0(ptr nonnull %23, i32 2, %active.lane.mask20, poison) + %24 = zext %wide.masked.load27 to + %25 = zext %wide.masked.load28 to + %26 = zext %wide.masked.load29 to + %27 = zext %wide.masked.load30 to + %28 = mul nuw nsw %24, %16 + %29 = mul nuw nsw %25, %17 + %30 = mul nuw nsw %26, %18 + %31 = mul nuw nsw %27, %19 + %32 = select %active.lane.mask, %28, zeroinitializer + %33 = add %vec.phi, %32 + %34 = select %active.lane.mask18, %29, zeroinitializer + %35 = add %vec.phi21, %34 + %36 = select %active.lane.mask19, %30, zeroinitializer + %37 = add %vec.phi22, %36 + %38 = select %active.lane.mask20, %31, zeroinitializer + %39 = add %vec.phi23, %38 + %index.next = add i64 %index, %11 + %40 = add i64 %index.next, %1 + %41 = add i64 %index.next, %2 + %42 = add i64 %index.next, %3 + %active.lane.mask.next = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %index.next, i64 %wide.trip.count) + %active.lane.mask.next31 = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %40, i64 %wide.trip.count) + %active.lane.mask.next32 = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %41, i64 %wide.trip.count) + %active.lane.mask.next33 = tail call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %42, i64 %wide.trip.count) + %43 = extractelement %active.lane.mask.next, i64 0 + br i1 %43, label %vector.body, label %middle.block + +middle.block: ; preds = %vector.body + %bin.rdx = add %35, %33 + %bin.rdx34 = add %37, %bin.rdx + %bin.rdx35 = add %39, %bin.rdx34 + %44 = tail call i32 @llvm.vector.reduce.add.nxv8i32( %bin.rdx35) + %45 = lshr i32 %44, 16 + %46 = trunc i32 %45 to i16 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %middle.block, %entry + %acc.0.lcssa = phi i16 [ 0, %entry ], [ %46, %middle.block ] + ret i16 %acc.0.lcssa +} + +define i8 @sve_sdot_loop_i8_to_i16(ptr readonly %a, ptr readonly %b, i32 %N) #0 { +; CHECK-LABEL: define i8 @sve_sdot_loop_i8_to_i16 +; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP11:%.*]] = icmp sgt i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP11]], label [[MIN_ITERS_CHECKED:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: min.iters.checked: +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[PREDICATE_ENTRY:%.*]] = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[PREDICATE:%.*]] = phi [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE:%.*]] = phi [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call @llvm.masked.load.nxv16i8.p0(ptr [[TMP0]], i32 1, [[PREDICATE]], zeroinitializer) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD19:%.*]] = call @llvm.masked.load.nxv16i8.p0(ptr [[TMP1]], i32 1, [[PREDICATE]], zeroinitializer) +; CHECK-NEXT: [[TMP2]] = call @llvm.aarch64.sve.sdot.nxv4i32( [[DOT_ACCUMULATE]], [[WIDE_MASKED_LOAD19]], [[WIDE_MASKED_LOAD]]) +; CHECK-NEXT: [[VS:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[VS_SCALED:%.*]] = shl i64 [[VS]], 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]] +; CHECK-NEXT: [[PREDICATE_NEXT]] = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[PREDICATE_NEXT]], i64 0 +; CHECK-NEXT: br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]] +; CHECK: middle.block: +; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32( [[TMP2]]) +; CHECK-NEXT: [[PHITMP201:%.*]] = lshr i32 [[TMP4]], 8 +; CHECK-NEXT: [[PHITMP:%.*]] = trunc i32 [[PHITMP201]] to i8 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i8 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: ret i8 [[ACC_0_LCSSA]] +; +entry: + %cmp11 = icmp sgt i32 %N, 0 + br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup + +min.iters.checked: ; preds = %entry + %wide.trip.count = zext i32 %N to i64 + %wide.end.idx.splatinsert = insertelement undef, i64 %wide.trip.count, i32 0 + %wide.end.idx.splat = shufflevector %wide.end.idx.splatinsert, undef, zeroinitializer + %predicate.entry = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 %wide.trip.count) + br label %vector.body + +vector.body: ; preds = %vector.body, %min.iters.checked + %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ] + %predicate = phi [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ] + %vec.phi = phi [ zeroinitializer, %min.iters.checked ], [ %6, %vector.body ] + %0 = getelementptr inbounds i8, ptr %a, i64 %index + %wide.masked.load = call @llvm.masked.load.nxv16i8.p0(ptr %0, i32 1, %predicate, undef) + %1 = sext %wide.masked.load to + %2 = getelementptr inbounds i8, i8* %b, i64 %index + %wide.masked.load19 = call @llvm.masked.load.nxv16i8.p0(ptr %2, i32 1, %predicate, undef) + %3 = sext %wide.masked.load19 to + %4 = mul nsw %3, %1 + %5 = select %predicate, %4, zeroinitializer + %6 = add nsw %vec.phi, %5 + %vs = call i64 @llvm.vscale.i64() + %vs.scaled = mul i64 %vs, 16 + %index.next = add nuw i64 %index, %vs.scaled + %.splatinsert = insertelement undef, i64 %index.next, i32 0 + %.splat = shufflevector %.splatinsert, undef, zeroinitializer + %predicate.next = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 %index.next, i64 %wide.trip.count) + %7 = extractelement %predicate.next, i64 0 + br i1 %7, label %vector.body, label %middle.block + +middle.block: ; preds = %vector.body + %8 = call i16 @llvm.vector.reduce.add.nxv16i16( %6) + %phitmp20 = lshr i16 %8, 8 + %phitmp = trunc i16 %phitmp20 to i8 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %middle.block, %entry + %acc.0.lcssa = phi i8 [ 0, %entry ], [ %phitmp, %middle.block ] + ret i8 %acc.0.lcssa +} + +define i8 @sve_udot_loop_i8_to_i16(ptr readonly %a, ptr readonly %b, i32 %N) #0 { +; CHECK-LABEL: define i8 @sve_udot_loop_i8_to_i16 +; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP11:%.*]] = icmp sgt i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP11]], label [[MIN_ITERS_CHECKED:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: min.iters.checked: +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[PREDICATE_ENTRY:%.*]] = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[PREDICATE:%.*]] = phi [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[DOT_ACCUMULATE:%.*]] = phi [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call @llvm.masked.load.nxv16i8.p0(ptr [[TMP0]], i32 1, [[PREDICATE]], undef) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_MASKED_LOAD19:%.*]] = call @llvm.masked.load.nxv16i8.p0(ptr [[TMP1]], i32 1, [[PREDICATE]], undef) +; CHECK-NEXT: [[TMP2]] = call @llvm.aarch64.sve.udot.nxv4i32( [[DOT_ACCUMULATE]], [[WIDE_MASKED_LOAD19]], [[WIDE_MASKED_LOAD]]) +; CHECK-NEXT: [[VS:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[VS_SCALED:%.*]] = shl i64 [[VS]], 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]] +; CHECK-NEXT: [[PREDICATE_NEXT]] = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]]) +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[PREDICATE_NEXT]], i64 0 +; CHECK-NEXT: br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]] +; CHECK: middle.block: +; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32( [[TMP2]]) +; CHECK-NEXT: [[PHITMP201:%.*]] = lshr i32 [[TMP4]], 8 +; CHECK-NEXT: [[PHITMP:%.*]] = trunc i32 [[PHITMP201]] to i8 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i8 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ] +; CHECK-NEXT: ret i8 [[ACC_0_LCSSA]] +; +entry: + %cmp11 = icmp sgt i32 %N, 0 + br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup + +min.iters.checked: ; preds = %entry + %wide.trip.count = zext i32 %N to i64 + %wide.end.idx.splatinsert = insertelement undef, i64 %wide.trip.count, i32 0 + %wide.end.idx.splat = shufflevector %wide.end.idx.splatinsert, undef, zeroinitializer + %predicate.entry = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 %wide.trip.count) + br label %vector.body + +vector.body: ; preds = %vector.body, %min.iters.checked + %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ] + %predicate = phi [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ] + %vec.phi = phi [ zeroinitializer, %min.iters.checked ], [ %5, %vector.body ] + %0 = getelementptr inbounds i8, ptr %a, i64 %index + %wide.masked.load = call @llvm.masked.load.nxv16i8.p0(ptr %0, i32 1, %predicate, undef) + %1 = zext %wide.masked.load to + %2 = getelementptr inbounds i8, i8* %b, i64 %index + %wide.masked.load19 = call @llvm.masked.load.nxv16i8.p0(ptr %2, i32 1, %predicate, undef) + %3 = zext %wide.masked.load19 to + %4 = mul nsw %3, %1 + %5 = add nsw %vec.phi, %4 + %vs = call i64 @llvm.vscale.i64() + %vs.scaled = mul i64 %vs, 16 + %index.next = add nuw i64 %index, %vs.scaled + %.splatinsert = insertelement undef, i64 %index.next, i32 0 + %.splat = shufflevector %.splatinsert, undef, zeroinitializer + %predicate.next = call @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 %index.next, i64 %wide.trip.count) + %6 = extractelement %predicate.next, i64 0 + br i1 %6, label %vector.body, label %middle.block + +middle.block: ; preds = %vector.body + %7 = call i16 @llvm.vector.reduce.add.nxv16i16( %5) + %phitmp20 = lshr i16 %7, 8 + %phitmp = trunc i16 %phitmp20 to i8 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %middle.block, %entry + %acc.0.lcssa = phi i8 [ 0, %entry ], [ %phitmp, %middle.block ] + ret i8 %acc.0.lcssa +} + +define i64 @sve_sdot_i16_to_i64( %a, %b) #0 { +; CHECK-LABEL: define i64 @sve_sdot_i16_to_i64 +; CHECK-SAME: ( [[A:%.*]], [[B:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = call @llvm.aarch64.sve.sdot.nxv2i64( zeroinitializer, [[A]], [[B]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64( [[TMP0]]) +; CHECK-NEXT: ret i64 [[TMP1]] +; +entry: + %exta = sext %a to + %extb = sext %b to + %mul = mul nsw %exta, %extb + %acc = call i64 @llvm.vector.reduce.add.nxv8i64( %mul) + ret i64 %acc +} + +define i32 @sve_udot_i8_to_i32( %a, %b) #0 { +; CHECK-LABEL: define i32 @sve_udot_i8_to_i32 +; CHECK-SAME: ( [[A:%.*]], [[B:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = call @llvm.aarch64.sve.udot.nxv4i32( zeroinitializer, [[A]], [[B]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32( [[TMP0]]) +; CHECK-NEXT: ret i32 [[TMP1]] +; +entry: + %exta = zext %a to + %extb = zext %b to + %mul = mul nsw %exta, %extb + %acc = call i32 @llvm.vector.reduce.add.nxv16i32( %mul) + ret i32 %acc +} + +declare @llvm.masked.load.nxv8i16.p0(ptr, i32, , ) +declare @llvm.masked.load.nxv16i8.p0(ptr, i32, , ) +declare i32 @llvm.vector.reduce.add.nxv8i32() +declare i16 @llvm.vector.reduce.add.nxv16i16() +declare i64 @llvm.vector.reduce.add.nxv8i64() +declare i32 @llvm.vector.reduce.add.nxv16i32() +declare i64 @llvm.vscale.i64() +declare i32 @llvm.vscale.i32() +declare @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64, i64) +declare @llvm.get.active.lane.mask.nxv8i1.i64(i64, i64) + +attributes #0 = { "target-features"="+sve" }