Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LV][SVE] Recognize potential DOT sequences and use a wider VF #69587

Closed

Conversation

huntergr-arm
Copy link
Collaborator

This patch extends the LoopVectorize cost model to identify when
a extend->multiply->accumulate chain is suitable for the UDOT/SDOT
instructions in AArch64 (SVE in particular) and will ignore the
extension when determining desirable VFs.

This patch extends the LoopVectorize cost model to identify when
a extend->multiply->accumulate chain is suitable for the UDOT/SDOT
instructions in AArch64 (SVE in particular) and will ignore the
extension when determining desirable VFs.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-aarch64

Author: Graham Hunter (huntergr-arm)

Changes

This patch extends the LoopVectorize cost model to identify when
a extend->multiply->accumulate chain is suitable for the UDOT/SDOT
instructions in AArch64 (SVE in particular) and will ignore the
extension when determining desirable VFs.


Patch is 37.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69587.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+11)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+38)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll (+485)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..b11c325f31c5ccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -772,6 +772,10 @@ class TargetTransformInfo {
   /// Return true if the target supports masked expand load.
   bool isLegalMaskedExpandLoad(Type *DataType) const;
 
+  /// Returns true if the types are legal for DOT product instructions on
+  /// the target (extend->multiply->accumulate)
+  bool isLegalDotProd(Type *DataType, Type *ExtType) const;
+
   /// Return true if this is an alternating opcode pattern that can be lowered
   /// to a single instruction on the target. In X86 this is for the addsub
   /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1787,6 +1791,7 @@ class TargetTransformInfo::Concept {
                                            Align Alignment) = 0;
   virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
   virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalDotProd(Type *DataType, Type *ExtType) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
                                const SmallBitVector &OpcodeMask) const = 0;
@@ -2267,6 +2272,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isLegalMaskedExpandLoad(Type *DataType) override {
     return Impl.isLegalMaskedExpandLoad(DataType);
   }
+  bool isLegalDotProd(Type *DataType, Type *ExtType) override {
+    return Impl.isLegalDotProd(DataType, ExtType);
+  }
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const override {
     return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..01f5af17a6f4814 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -302,6 +302,8 @@ class TargetTransformInfoImplBase {
 
   bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
 
+  bool isLegalDotProd(Type *DataType, Type *ExtType) const { return false; }
+
   bool enableOrderedReductions() const { return false; }
 
   bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..fbbf8c3f5e34217 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -492,6 +492,10 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
   return TTIImpl->isLegalMaskedExpandLoad(DataType);
 }
 
+bool TargetTransformInfo::isLegalDotProd(Type *DataType, Type *ExtType) const {
+  return TTIImpl->isLegalDotProd(DataType, ExtType);
+}
+
 bool TargetTransformInfo::enableOrderedReductions() const {
   return TTIImpl->enableOrderedReductions();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..6be8f2867ec1a7f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -151,6 +151,17 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  // TODO: NEON should be able to support this after... 8.3 or so?
+  // Need to make sure that the input type is either i8 or i16, and that
+  // the extended type is at most the accumulator type of the dot product
+  // instructions so that we don't lose data.
+  bool isLegalDotProd(Type *DataType, Type *ExtType) const {
+    return ST->hasSVE() && ((DataType->isIntegerTy(8) &&
+                             ExtType->getPrimitiveSizeInBits() <= 32) ||
+                            (DataType->isIntegerTy(16) &&
+                             ExtType->getPrimitiveSizeInBits() <= 64));
+  }
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index aa435b0d47aa599..3b585cd221eda42 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -80,6 +80,7 @@
 #include "llvm/Analysis/CodeMetrics.h"
 #include "llvm/Analysis/DemandedBits.h"
 #include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/IVDescriptors.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopAnalysisManager.h"
 #include "llvm/Analysis/LoopInfo.h"
@@ -1921,6 +1922,9 @@ class LoopVectorizationCostModel {
 
   /// All element types found in the loop.
   SmallPtrSet<Type *, 16> ElementTypesInLoop;
+
+  /// Extends used as part of a dot-product chain; these are 'free'.
+  SmallPtrSet<Value *, 2> DotExtends;
 };
 } // end namespace llvm
 
@@ -5580,6 +5584,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() {
 }
 
 void LoopVectorizationCostModel::collectElementTypesForWidening() {
+  using namespace llvm::PatternMatch;
   ElementTypesInLoop.clear();
   // For each block.
   for (BasicBlock *BB : TheLoop->blocks()) {
@@ -5607,6 +5612,34 @@ void LoopVectorizationCostModel::collectElementTypesForWidening() {
                                       RdxDesc.getRecurrenceType(),
                                       TargetTransformInfo::ReductionFlags()))
           continue;
+        // DOT Prod proto...
+        if (RdxDesc.getRecurrenceKind() == RecurKind::Add) {
+          Instruction *Sum = RdxDesc.getLoopExitInstr();
+          Value *Accum = Legal->getReductionVars().find(PN)->first;
+
+          if (!Accum->hasOneUse() || !Sum->hasNUses(2))
+            continue;
+
+          Value *Step = (Sum->getOperand(0) == Accum) ? Sum->getOperand(1)
+                                                      : Sum->getOperand(0);
+          Value *ValA = nullptr, *ValB = nullptr;
+
+          if (match(Step,
+                    m_OneUse(m_Mul(m_ZExtOrSExt(m_OneUse(m_Value(ValA))),
+                                   m_ZExtOrSExt(m_OneUse(m_Value(ValB)))))) &&
+              (ValA->getType() == ValB->getType()) &&
+              TTI.isLegalDotProd(ValA->getType(), Step->getType())) {
+            Instruction *I = cast<Instruction>(Step);
+
+            // Make sure the extends are only used by the multiply.
+            if (I->getOperand(0)->hasOneUser() &&
+                I->getOperand(1)->hasOneUser()) {
+              DotExtends.insert(I->getOperand(0));
+              DotExtends.insert(I->getOperand(1));
+              continue;
+            }
+          }
+        }
         T = RdxDesc.getRecurrenceType();
       }
 
@@ -7351,6 +7384,11 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
         CCH = ComputeCCH(Load);
     }
 
+    // Extensions used in dot product calculations are 'free', since the
+    // dot instruction performs that operation internally before multiplying
+    if (DotExtends.contains(I))
+      return 0;
+
     // We optimize the truncation of induction variables having constant
     // integer steps. The cost of these truncations is the same as the scalar
     // operation.
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll
new file mode 100644
index 000000000000000..2014bb18b11b104
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll
@@ -0,0 +1,485 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt -passes=loop-vectorize,simplifycfg,instcombine -force-vector-interleave=1 -prefer-predicate-over-epilogue=predicate-dont-vectorize -S < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+;; For SVE, we want to make sure that we 'maximize bandwidth' during loop
+;; vectorization of IR patterns that roughly match an SDOT or UDOT instruction.
+;; Normally, <vscale x 8 x i32> wouldn't be considered since it takes up two
+;; registers, but since the *DOT instructions transform a given number of
+;; narrower input values into a smaller number of wider accumulation values, we
+;; won't actually use any additional registers for this case.
+;;
+;; This file just tests that the loop vectorizer sets up for the
+;; AArch64DotProdMatcher pass. To do so, it will need to identify extends
+;; that will be folded away by using the DOT instructions. For the first
+;; example below, the vectorized loop will use <vscale x 8 x i16> extended to
+;; <vscale x 8 x i32>. Normally the loop vectorizer would pick a VF of 4 since
+;; the i32 is the widest type, but since that will be folded away we want to
+;; pick a VF of 8 to maximize the number of i16s processed per iteration.
+;;
+;; The backend pass will then match this and plant a DOT intrinsic with
+;; 2 <vscale x 8 x i16>s as input and one <vscale x 2 x i64> as output.
+;;
+;; If the extend would exceed the capacity of the DOT instruction (basically
+;; if i8s were extended to i64s), then we can't perform the second part of
+;; the transformation. We then wouldn't want to perform the first part either.
+;; We also want to stop the transform if there was another use of one of the
+;; values in the chain that would be folded into the DOT instruction, since
+;; the intermediate values would never exist in a register for reuse.
+
+define i16 @sdot_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i16 @sdot_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP3]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 8 x i16> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP5]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 8 x i16> [[WIDE_MASKED_LOAD1]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 8 x i32> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i32> [[TMP7]], <vscale x 8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i32 [[TMP13]], 16
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i32 [[PHITMP]] to i16
+; CHECK-NEXT:    ret i16 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = sext i16 %0 to i32
+  %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv
+  %1 = load i16, i16* %arrayidx2, align 2
+  %conv3 = sext i16 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %add = add nsw i32 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i32 %add, 16
+  %phitmp14 = trunc i32 %phitmp to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i16 %phitmp14
+}
+
+;; Similar to the above check, but for a zext instead of a sext.
+
+define i16 @udot_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i16 @udot_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP3]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <vscale x 8 x i16> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP5]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <vscale x 8 x i16> [[WIDE_MASKED_LOAD1]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nuw nsw <vscale x 8 x i32> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i32> [[TMP7]], <vscale x 8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i32 [[TMP13]], 16
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i32 [[PHITMP]] to i16
+; CHECK-NEXT:    ret i16 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = zext i16 %0 to i32
+  %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv
+  %1 = load i16, i16* %arrayidx2, align 2
+  %conv3 = zext i16 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %add = add nsw i32 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i32 %add, 16
+  %phitmp14 = trunc i32 %phitmp to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i16 %phitmp14
+}
+
+;; In this case we don't want to use the maximum bandwidth since the accumulator
+;; type (i64) is wider than it would be in the sdot instruction for i8 inputs
+;; (i32).
+
+define i8 @sdot_xform_too_wide(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i8 @sdot_xform_too_wide
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 2 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8.p0(ptr [[TMP3]], i32 2, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i8> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 2 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 2 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8.p0(ptr [[TMP5]], i32 2, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i8> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 2 x i8> [[WIDE_MASKED_LOAD1]] to <vscale x 2 x i64>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 2 x i64> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i64> [[TMP7]], <vscale x 2 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 2 x i64> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64...
[truncated]

@huntergr-arm
Copy link
Collaborator Author

This patch sets up for a target-specific pass (#69583) to generate SVE SDOT/UDOT instructions when appropriate.

This is some downstream code I wrote 4 years ago that we want to upstream -- suggestions on alternative approaches welcome.

@LittleMeepo
Copy link
Contributor

I also achieved a similar function by adding recipe to LoopVectorize: But I think the method of directly generating aarch64 intrinsic in LoopVectorize can only be used as a local temporary solution.

Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if requesting changes is a little harsh but given I'd like us to try a different path (https://discourse.llvm.org/t/rfc-is-a-more-expressive-way-to-represent-reductions-useful) I'm assuming this will significant affect this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants