-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][GlobalISel] Support udot lowering for vecreduce add #70784
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: None (chuongg3) Changesvecreduce_add(mul(ext, ext)) -> vecreduce_add(udot) Vectors of scalar size of 8-bits with element count of multiples of 8 Patch is 25.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70784.diff 4 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index 017c4523c23a184..e17524b2c55bdd3 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -33,12 +33,22 @@ def fold_global_offset : GICombineRule<
(apply [{ applyFoldGlobalOffset(*${root}, MRI, B, Observer, ${matchinfo});}])
>;
+let Predicates = [HasDotProd] in {
+def ext_addv_to_udot_addv : GICombineRule<
+ (defs root:$root),
+ (match (wip_match_opcode G_VECREDUCE_ADD):$root,
+ [{ return matchExtAddvToUdotAddv(*${root}, MRI); }]),
+ (apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer); }])
+>;
+}
+
def AArch64PreLegalizerCombiner: GICombiner<
"AArch64PreLegalizerCombinerImpl", [all_combines,
fconstant_to_constant,
icmp_redundant_trunc,
fold_global_offset,
- shuffle_to_extract]> {
+ shuffle_to_extract,
+ ext_addv_to_udot_addv]> {
let CombineAllMethodName = "tryCombineAllImpl";
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
index 27338bd24393325..1711360779bf74c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
@@ -227,6 +227,18 @@ def G_SMULL : AArch64GenericInstruction {
let hasSideEffects = 0;
}
+def G_UDOT : AArch64GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+ let hasSideEffects = 0;
+}
+
+def G_SDOT : AArch64GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+ let hasSideEffects = 0;
+}
+
// Generic instruction for the BSP pseudo. It is expanded into BSP, which
// expands into BSL/BIT/BIF after register allocation.
def G_BSP : AArch64GenericInstruction {
@@ -270,6 +282,9 @@ def : GINodeEquiv<G_BSP, AArch64bsp>;
def : GINodeEquiv<G_UMULL, AArch64umull>;
def : GINodeEquiv<G_SMULL, AArch64smull>;
+def : GINodeEquiv<G_UDOT, AArch64udot>;
+def : GINodeEquiv<G_SDOT, AArch64sdot>;
+
def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>;
def : GINodeEquiv<G_PREFETCH, AArch64Prefetch>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
index d9678bea214dd53..34a59839a99a97c 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
}
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register DstReg = MI.getOperand(0).getReg();
+ Register MidReg = I1->getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT MidTy = MRI.getType(MidReg);
+ if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+ return false;
+
+ LLT SrcTy;
+ auto I1Opc = I1->getOpcode();
+ if (I1Opc == TargetOpcode::G_MUL) {
+ MachineInstr *ExtMI1 =
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ MachineInstr *ExtMI2 =
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+ LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+ if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+ return false;
+ I1Opc = ExtMI1->getOpcode();
+ SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+ } else
+ SrcTy = MRI.getType(I1->getOperand(1).getReg());
+
+ if (I1Opc != TargetOpcode::G_ZEXT && I1Opc != TargetOpcode::G_SEXT)
+ return false;
+ if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+ return false;
+
+ return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &Builder,
+ GISelChangeObserver &Observer) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register Ext1SrcReg, Ext2SrcReg;
+ unsigned DotOpcode;
+ if (I1->getOpcode() == TargetOpcode::G_MUL) {
+ auto Ext1MI = getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ auto Ext2MI = getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ Ext1SrcReg = Ext1MI->getOperand(1).getReg();
+ Ext2SrcReg = Ext2MI->getOperand(1).getReg();
+ DotOpcode = Ext1MI->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+ : AArch64::G_SDOT;
+ } else if (I1->getOpcode() == TargetOpcode::G_ZEXT ||
+ I1->getOpcode() == TargetOpcode::G_SEXT) {
+ Ext1SrcReg = I1->getOperand(1).getReg();
+ Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+ ->getOperand(0)
+ .getReg();
+ DotOpcode = I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+ : AArch64::G_SDOT;
+ } else
+ return;
+
+ LLT SrcTy = MRI.getType(Ext1SrcReg);
+ LLT MidTy;
+ unsigned NumOfVecReduce;
+ if (SrcTy.getNumElements() % 16 == 0) {
+ NumOfVecReduce = SrcTy.getNumElements() / 16;
+ MidTy = LLT::fixed_vector(4, 32);
+ } else if (SrcTy.getNumElements() % 8 == 0) {
+ NumOfVecReduce = SrcTy.getNumElements() / 8;
+ MidTy = LLT::fixed_vector(2, 32);
+ } else
+ return;
+
+ // Handle case where one DOT instruction is needed
+ if (NumOfVecReduce == 1) {
+ auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+ auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+ {Zeroes, Ext1SrcReg, Ext2SrcReg});
+ Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+ } else {
+ // Get the number of output vectors needed
+ SmallVector<LLT, 4> DotVecLLT;
+ auto SrcVecNum = SrcTy.getNumElements();
+ while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
+ DotVecLLT.push_back(LLT::fixed_vector(16, 8));
+ SrcVecNum = SrcVecNum - 16;
+ }
+ if (SrcVecNum == 8)
+ DotVecLLT.push_back(LLT::fixed_vector(8, 8));
+
+ // Unmerge the source vectors
+ auto Ext1Unmerge = Builder.buildUnmerge(DotVecLLT, Ext1SrcReg);
+ auto Ext2Unmerge = Builder.buildUnmerge(DotVecLLT, Ext2SrcReg);
+
+ // Build the UDOT instructions
+ SmallVector<Register, 2> DotReg;
+ unsigned NumElements = 0;
+ for (unsigned i = 0; i < DotVecLLT.size(); i++) {
+ LLT ZeroesLLT;
+ // Check if it is 16 or 8 elements. Set Zeroes to the accoridng size
+ if (MRI.getType(Ext1Unmerge.getReg(i)).getNumElements() == 16) {
+ ZeroesLLT = LLT::fixed_vector(4, 32);
+ NumElements += 4;
+ } else {
+ ZeroesLLT = LLT::fixed_vector(2, 32);
+ NumElements += 2;
+ }
+ auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg();
+ DotReg.push_back(Builder
+ .buildInstr(DotOpcode, {MRI.getType(Zeroes)},
+ {Zeroes, Ext1Unmerge.getReg(i),
+ Ext2Unmerge.getReg(i)})
+ ->getOperand(0)
+ .getReg());
+ }
+
+ // Merge the output
+ // auto a = MI.getOperand(1).getReg().changeNumElements(NumElements);
+ auto ConcatMI =
+ Builder.buildConcatVectors(LLT::fixed_vector(NumElements, 32), DotReg);
+
+ // Put it through a vector reduction
+ Builder.buildVecReduceAdd(MI.getOperand(0).getReg(),
+ ConcatMI->getOperand(0).getReg());
+ }
+
+ // Erase the dead instructions
+ if (I1->getOpcode() == TargetOpcode::G_MUL) {
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI)->eraseFromParent();
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI)->eraseFromParent();
+ }
+ I1->eraseFromParent();
+ MI.eraseFromParent();
+}
+
bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
CombinerHelper &Helper, GISelChangeObserver &Observer) {
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index a88c930d09e9b17..b4b221bf4e46461 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -440,14 +440,10 @@ define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -479,14 +475,10 @@ define i32 @add_v16i8_v16i32_sext(<16 x i8> %x) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -514,10 +506,10 @@ define i32 @add_v8i8_v8i32_zext(<8 x i8> %x) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -545,10 +537,10 @@ define i32 @add_v8i8_v8i32_sext(<8 x i8> %x) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -1560,14 +1552,10 @@ define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_acc_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -1603,14 +1591,10 @@ define i32 @add_v16i8_v16i32_acc_sext(<16 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_acc_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -1642,10 +1626,10 @@ define i32 @add_v8i8_v8i32_acc_zext(<8 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_acc_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -1677,10 +1661,10 @@ define i32 @add_v8i8_v8i32_acc_sext(<8 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_acc_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -2618,6 +2602,152 @@ entry:
ret i32 %z
}
+define i32 @test_udot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v8i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT: umull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: addv s0, v2.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_udot_v8i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_udot_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = zext <8 x i8> %a to <8 x i32>
+ %1 = zext <8 x i8> %b to <8 x i32>
+ %2 = mul nuw nsw <8 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+ ret i32 %3
+}
+
+define i32 @test_udot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v16i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT: umull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT: umull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT: umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: umlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT: addv s0, v0.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_udot_v16i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT: addv s0, v2.4s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_udot_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = zext <16 x i8> %a to <16 x i32>
+ %1 = zext <16 x i8> %b to <16 x i32>
+ %2 = mul nuw nsw <16 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+ ret i32 %3
+}
+
+define i32 @test_sdot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v8i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: sshll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT: smull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: addv s0, v2.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_sdot_v8i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_sdot_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = sext <8 x i8> %a to <8 x i32>
+ %1 = sext <8 x i8> %b to <8 x i32>
+ %2 = mul nuw nsw <8 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+ ret i32 %3
+}
+
+define i32 @test_sdot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v16i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT: smull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT: smull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT: smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: smlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT: addv s0, v0.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_sdot_v16i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT: addv s0, v2.4s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_sdot_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = sext <16 x i8> %a to <16 x i32>
+ %1 = sext <16 x i8> %b to <16 x i32>
+ %2 = mul nuw nsw <16 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+ ret i32 %3
+}
+
define zeroext i16 @add_pair_v8i16_v8i16(<8 x i16> %x, <8 x i16> %y) {
; CHECK-BASE-LABEL: add_pair_v8i16_v8i16:
; CHECK-BASE: // %bb.0: // %entry
@@ -2990,22 +3120,13 @@ define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
;
; CHECK-GI-LABEL: add_pair_v16i8_v16i32_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-GI-NEXT: ushll2 v1.8h, v1....
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't realize the predicates would work in the combiner tablegen files. Can you add a RUN line for GISel without +dotprod to check it doesn't produce udot in those cases? I suggest you precommit that test change and then rebase this patch on top of it.
7a53eae
to
5c03471
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test look good to me, if you wanted to push those separately and didn't want to add some more. It might be worth making sure there are tests for 24x vector types and 48x vector types, if there are not already.
// Get the number of output vectors needed | ||
SmallVector<LLT, 4> DotVecLLT; | ||
auto SrcVecNum = SrcTy.getNumElements(); | ||
while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have tests for 24x vector types?
|
||
// Erase the dead instructions | ||
if (I1->getOpcode() == TargetOpcode::G_MUL) { | ||
getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI)->eraseFromParent(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if these have extra uses?
I1Opc = ExtMI1->getOpcode(); | ||
SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg()); | ||
} else | ||
SrcTy = MRI.getType(I1->getOperand(1).getReg()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The braces are unbalanced.
NumOfVecReduce = SrcTy.getNumElements() / 8; | ||
MidTy = LLT::fixed_vector(2, 32); | ||
} else | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The braces are unbalanced.
5c03471
to
9cef73c
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
9cef73c
to
e2a32e3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the comment below this LGTM, if no one else has any comments.
->getOperand(0) | ||
.getReg()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that ->getOperand(0).getReg() can be .getReg(0) from a MIBuilder. Here and 2 other places below.
e2a32e3
to
5259733
Compare
vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot) vecreduce_add(ext) -> vecreduce_add(ext) Vectors of scalar size of 8-bits with element count of multiples of 8
5259733
to
3b8d5a9
Compare
…0784) vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot) vecreduce_add(ext) -> vecreduce_add(ext) Vectors of scalar size of 8-bits with element count of multiples of 8
vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
vecreduce_add(ext) -> vecreduce_add(ext)
Vectors of scalar size of 8-bits with element count of multiples of 8