Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][GlobalISel] Support udot lowering for vecreduce add #70784

Merged
merged 3 commits into from
Nov 15, 2023

Conversation

chuongg3
Copy link
Contributor

vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
vecreduce_add(ext) -> vecreduce_add(ext)

Vectors of scalar size of 8-bits with element count of multiples of 8

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 31, 2023

@llvm/pr-subscribers-backend-aarch64

Author: None (chuongg3)

Changes

vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
vecreduce_add(ext) -> vecreduce_add(ext)

Vectors of scalar size of 8-bits with element count of multiples of 8


Patch is 25.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70784.diff

4 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+11-1)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrGISel.td (+15)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp (+140)
  • (modified) llvm/test/CodeGen/AArch64/vecreduce-add.ll (+220-113)
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index 017c4523c23a184..e17524b2c55bdd3 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -33,12 +33,22 @@ def fold_global_offset : GICombineRule<
   (apply [{ applyFoldGlobalOffset(*${root}, MRI, B, Observer, ${matchinfo});}])
 >;
 
+let Predicates = [HasDotProd] in {
+def ext_addv_to_udot_addv : GICombineRule<
+  (defs root:$root),
+  (match (wip_match_opcode G_VECREDUCE_ADD):$root,
+         [{ return matchExtAddvToUdotAddv(*${root}, MRI); }]),
+  (apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer); }])
+>;
+}
+
 def AArch64PreLegalizerCombiner: GICombiner<
   "AArch64PreLegalizerCombinerImpl", [all_combines,
                                       fconstant_to_constant,
                                       icmp_redundant_trunc,
                                       fold_global_offset,
-                                      shuffle_to_extract]> {
+                                      shuffle_to_extract,
+                                      ext_addv_to_udot_addv]> {
   let CombineAllMethodName = "tryCombineAllImpl";
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
index 27338bd24393325..1711360779bf74c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
@@ -227,6 +227,18 @@ def G_SMULL : AArch64GenericInstruction {
   let hasSideEffects = 0;
 }
 
+def G_UDOT : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+  let hasSideEffects = 0;
+}
+
+def G_SDOT : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+  let hasSideEffects = 0;
+}
+
 // Generic instruction for the BSP pseudo. It is expanded into BSP, which
 // expands into BSL/BIT/BIF after register allocation.
 def G_BSP : AArch64GenericInstruction {
@@ -270,6 +282,9 @@ def : GINodeEquiv<G_BSP, AArch64bsp>;
 def : GINodeEquiv<G_UMULL, AArch64umull>;
 def : GINodeEquiv<G_SMULL, AArch64smull>;
 
+def : GINodeEquiv<G_UDOT, AArch64udot>;
+def : GINodeEquiv<G_SDOT, AArch64sdot>;
+
 def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>;
 
 def : GINodeEquiv<G_PREFETCH, AArch64Prefetch>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
index d9678bea214dd53..34a59839a99a97c 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
       B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
 }
 
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register DstReg = MI.getOperand(0).getReg();
+  Register MidReg = I1->getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT MidTy = MRI.getType(MidReg);
+  if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+    return false;
+
+  LLT SrcTy;
+  auto I1Opc = I1->getOpcode();
+  if (I1Opc == TargetOpcode::G_MUL) {
+    MachineInstr *ExtMI1 =
+        getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    MachineInstr *ExtMI2 =
+        getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+    LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+    if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+      return false;
+    I1Opc = ExtMI1->getOpcode();
+    SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+  } else
+    SrcTy = MRI.getType(I1->getOperand(1).getReg());
+
+  if (I1Opc != TargetOpcode::G_ZEXT && I1Opc != TargetOpcode::G_SEXT)
+    return false;
+  if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+    return false;
+
+  return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                            MachineIRBuilder &Builder,
+                            GISelChangeObserver &Observer) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register Ext1SrcReg, Ext2SrcReg;
+  unsigned DotOpcode;
+  if (I1->getOpcode() == TargetOpcode::G_MUL) {
+    auto Ext1MI = getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    auto Ext2MI = getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    Ext1SrcReg = Ext1MI->getOperand(1).getReg();
+    Ext2SrcReg = Ext2MI->getOperand(1).getReg();
+    DotOpcode = Ext1MI->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+                                                            : AArch64::G_SDOT;
+  } else if (I1->getOpcode() == TargetOpcode::G_ZEXT ||
+             I1->getOpcode() == TargetOpcode::G_SEXT) {
+    Ext1SrcReg = I1->getOperand(1).getReg();
+    Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+                     ->getOperand(0)
+                     .getReg();
+    DotOpcode = I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+                                                        : AArch64::G_SDOT;
+  } else
+    return;
+
+  LLT SrcTy = MRI.getType(Ext1SrcReg);
+  LLT MidTy;
+  unsigned NumOfVecReduce;
+  if (SrcTy.getNumElements() % 16 == 0) {
+    NumOfVecReduce = SrcTy.getNumElements() / 16;
+    MidTy = LLT::fixed_vector(4, 32);
+  } else if (SrcTy.getNumElements() % 8 == 0) {
+    NumOfVecReduce = SrcTy.getNumElements() / 8;
+    MidTy = LLT::fixed_vector(2, 32);
+  } else
+    return;
+
+  // Handle case where one DOT instruction is needed
+  if (NumOfVecReduce == 1) {
+    auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+    auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+                                  {Zeroes, Ext1SrcReg, Ext2SrcReg});
+    Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+  } else {
+    // Get the number of output vectors needed
+    SmallVector<LLT, 4> DotVecLLT;
+    auto SrcVecNum = SrcTy.getNumElements();
+    while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
+      DotVecLLT.push_back(LLT::fixed_vector(16, 8));
+      SrcVecNum = SrcVecNum - 16;
+    }
+    if (SrcVecNum == 8)
+      DotVecLLT.push_back(LLT::fixed_vector(8, 8));
+
+    // Unmerge the source vectors
+    auto Ext1Unmerge = Builder.buildUnmerge(DotVecLLT, Ext1SrcReg);
+    auto Ext2Unmerge = Builder.buildUnmerge(DotVecLLT, Ext2SrcReg);
+
+    // Build the UDOT instructions
+    SmallVector<Register, 2> DotReg;
+    unsigned NumElements = 0;
+    for (unsigned i = 0; i < DotVecLLT.size(); i++) {
+      LLT ZeroesLLT;
+      // Check if it is 16 or 8 elements. Set Zeroes to the accoridng size
+      if (MRI.getType(Ext1Unmerge.getReg(i)).getNumElements() == 16) {
+        ZeroesLLT = LLT::fixed_vector(4, 32);
+        NumElements += 4;
+      } else {
+        ZeroesLLT = LLT::fixed_vector(2, 32);
+        NumElements += 2;
+      }
+      auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg();
+      DotReg.push_back(Builder
+                           .buildInstr(DotOpcode, {MRI.getType(Zeroes)},
+                                       {Zeroes, Ext1Unmerge.getReg(i),
+                                        Ext2Unmerge.getReg(i)})
+                           ->getOperand(0)
+                           .getReg());
+    }
+
+    // Merge the output
+    // auto a = MI.getOperand(1).getReg().changeNumElements(NumElements);
+    auto ConcatMI =
+        Builder.buildConcatVectors(LLT::fixed_vector(NumElements, 32), DotReg);
+
+    // Put it through a vector reduction
+    Builder.buildVecReduceAdd(MI.getOperand(0).getReg(),
+                              ConcatMI->getOperand(0).getReg());
+  }
+
+  // Erase the dead instructions
+  if (I1->getOpcode() == TargetOpcode::G_MUL) {
+    getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI)->eraseFromParent();
+    getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI)->eraseFromParent();
+  }
+  I1->eraseFromParent();
+  MI.eraseFromParent();
+}
+
 bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
                         CombinerHelper &Helper, GISelChangeObserver &Observer) {
   // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index a88c930d09e9b17..b4b221bf4e46461 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -440,14 +440,10 @@ define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -479,14 +475,10 @@ define i32 @add_v16i8_v16i32_sext(<16 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -514,10 +506,10 @@ define i32 @add_v8i8_v8i32_zext(<8 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -545,10 +537,10 @@ define i32 @add_v8i8_v8i32_sext(<8 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -1560,14 +1552,10 @@ define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_acc_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -1603,14 +1591,10 @@ define i32 @add_v16i8_v16i32_acc_sext(<16 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_acc_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -1642,10 +1626,10 @@ define i32 @add_v8i8_v8i32_acc_zext(<8 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_acc_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -1677,10 +1661,10 @@ define i32 @add_v8i8_v8i32_acc_sext(<8 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_acc_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -2618,6 +2602,152 @@ entry:
   ret i32 %z
 }
 
+define i32 @test_udot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v8i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    umull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    addv s0, v2.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_udot_v8i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    udot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_udot_v8i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = zext <8 x i8> %a to <8 x i32>
+  %1 = zext <8 x i8> %b to <8 x i32>
+  %2 = mul nuw nsw <8 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+  ret i32 %3
+}
+
+define i32 @test_udot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v16i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT:    umull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT:    umull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT:    umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    umlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT:    addv s0, v0.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_udot_v16i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    udot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT:    addv s0, v2.4s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_udot_v16i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = zext <16 x i8> %a to <16 x i32>
+  %1 = zext <16 x i8> %b to <16 x i32>
+  %2 = mul nuw nsw <16 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  ret i32 %3
+}
+
+define i32 @test_sdot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v8i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    smull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    addv s0, v2.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_sdot_v8i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    sdot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_sdot_v8i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = sext <8 x i8> %a to <8 x i32>
+  %1 = sext <8 x i8> %b to <8 x i32>
+  %2 = mul nuw nsw <8 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+  ret i32 %3
+}
+
+define i32 @test_sdot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v16i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT:    smull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT:    smull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT:    smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    smlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT:    addv s0, v0.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_sdot_v16i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    sdot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT:    addv s0, v2.4s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_sdot_v16i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = sext <16 x i8> %a to <16 x i32>
+  %1 = sext <16 x i8> %b to <16 x i32>
+  %2 = mul nuw nsw <16 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  ret i32 %3
+}
+
 define zeroext i16 @add_pair_v8i16_v8i16(<8 x i16> %x, <8 x i16> %y) {
 ; CHECK-BASE-LABEL: add_pair_v8i16_v8i16:
 ; CHECK-BASE:       // %bb.0: // %entry
@@ -2990,22 +3120,13 @@ define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
 ;
 ; CHECK-GI-LABEL: add_pair_v16i8_v16i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v2.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-GI-NEXT:    ushll2 v1.8h, v1....
[truncated]

Copy link
Contributor

@aemerson aemerson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize the predicates would work in the combiner tablegen files. Can you add a RUN line for GISel without +dotprod to check it doesn't produce udot in those cases? I suggest you precommit that test change and then rebase this patch on top of it.

@chuongg3 chuongg3 force-pushed the GlobalISel_VECREDUCE_ADD_UDOT branch from 7a53eae to 5c03471 Compare November 1, 2023 16:51
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test look good to me, if you wanted to push those separately and didn't want to add some more. It might be worth making sure there are tests for 24x vector types and 48x vector types, if there are not already.

// Get the number of output vectors needed
SmallVector<LLT, 4> DotVecLLT;
auto SrcVecNum = SrcTy.getNumElements();
while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have tests for 24x vector types?


// Erase the dead instructions
if (I1->getOpcode() == TargetOpcode::G_MUL) {
getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI)->eraseFromParent();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if these have extra uses?

I1Opc = ExtMI1->getOpcode();
SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
} else
SrcTy = MRI.getType(I1->getOperand(1).getReg());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The braces are unbalanced.

NumOfVecReduce = SrcTy.getNumElements() / 8;
MidTy = LLT::fixed_vector(2, 32);
} else
return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The braces are unbalanced.

@chuongg3 chuongg3 force-pushed the GlobalISel_VECREDUCE_ADD_UDOT branch from 5c03471 to 9cef73c Compare November 9, 2023 10:36
Copy link

github-actions bot commented Nov 9, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@chuongg3 chuongg3 force-pushed the GlobalISel_VECREDUCE_ADD_UDOT branch from 9cef73c to e2a32e3 Compare November 10, 2023 09:43
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the comment below this LGTM, if no one else has any comments.

Comment on lines 367 to 368
->getOperand(0)
.getReg());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that ->getOperand(0).getReg() can be .getReg(0) from a MIBuilder. Here and 2 other places below.

@chuongg3 chuongg3 force-pushed the GlobalISel_VECREDUCE_ADD_UDOT branch from e2a32e3 to 5259733 Compare November 14, 2023 11:17
vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
vecreduce_add(ext) -> vecreduce_add(ext)

Vectors of scalar size of 8-bits with element count of multiples of 8
@chuongg3 chuongg3 force-pushed the GlobalISel_VECREDUCE_ADD_UDOT branch from 5259733 to 3b8d5a9 Compare November 15, 2023 10:57
@chuongg3 chuongg3 merged commit 692fbd6 into llvm:main Nov 15, 2023
2 of 3 checks passed
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…0784)

vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot) 
vecreduce_add(ext) -> vecreduce_add(ext)

Vectors of scalar size of 8-bits with element count of multiples of 8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants