diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp index 896eab521bfdb..29538d0f9ba1b 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp @@ -435,6 +435,8 @@ bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, Register ExtSrcReg = ExtMI->getOperand(1).getReg(); LLT ExtSrcTy = MRI.getType(ExtSrcReg); LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); + if (ExtSrcTy.getScalarSizeInBits() * 2 > DstTy.getScalarSizeInBits()) + return false; if ((DstTy.getScalarSizeInBits() == 16 && ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) || (DstTy.getScalarSizeInBits() == 32 && @@ -492,7 +494,7 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2; LLT MidScalarLLT = LLT::scalar(MidScalarSize); - Register zeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0); + Register ZeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0); for (unsigned I = 0; I < WorkingRegisters.size(); I++) { // If the number of elements is too small to build an instruction, extend // its size before applying addlv @@ -508,10 +510,10 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, // Generate the {U/S}ADDLV instruction, whose output is always double of the // Src's Scalar size - LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32) + LLT AddlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32) : LLT::fixed_vector(2, 64); - Register addlvReg = - B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]}).getReg(0); + Register AddlvReg = + B.buildInstr(Opc, {AddlvTy}, {WorkingRegisters[I]}).getReg(0); // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or // v2i64 register. @@ -520,26 +522,26 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, // Therefore we have to extract/truncate the the value to the right type if (MidScalarSize == 32 || MidScalarSize == 64) { WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT, - {MidScalarLLT}, {addlvReg, zeroReg}) + {MidScalarLLT}, {AddlvReg, ZeroReg}) .getReg(0); } else { - Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT, - {LLT::scalar(32)}, {addlvReg, zeroReg}) + Register ExtractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT, + {LLT::scalar(32)}, {AddlvReg, ZeroReg}) .getReg(0); WorkingRegisters[I] = - B.buildTrunc({MidScalarLLT}, {extractReg}).getReg(0); + B.buildTrunc({MidScalarLLT}, {ExtractReg}).getReg(0); } } - Register outReg; + Register OutReg; if (WorkingRegisters.size() > 1) { - outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1]) + OutReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1]) .getReg(0); for (unsigned I = 2; I < WorkingRegisters.size(); I++) { - outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I]).getReg(0); + OutReg = B.buildAdd(MidScalarLLT, OutReg, WorkingRegisters[I]).getReg(0); } } else { - outReg = WorkingRegisters[0]; + OutReg = WorkingRegisters[0]; } if (DstTy.getScalarSizeInBits() > MidScalarSize) { @@ -547,9 +549,9 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, // Src's ScalarType B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT, - {DstReg}, {outReg}); + {DstReg}, {OutReg}); } else { - B.buildCopy(DstReg, outReg); + B.buildCopy(DstReg, OutReg); } MI.eraseFromParent(); diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll index 2d0df562b9a4b..12c13e8337e8d 100644 --- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll +++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll @@ -4808,6 +4808,38 @@ define i64 @extract_scalable(<2 x i32> %0) "target-features"="+sve2" { ret i64 %5 } +define i32 @vecreduce_add_from_i21_zero() { +; CHECK-SD-LABEL: vecreduce_add_from_i21_zero: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: mov w0, wzr +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: vecreduce_add_from_i21_zero: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: movi v0.2d, #0000000000000000 +; CHECK-GI-NEXT: addv s0, v0.4s +; CHECK-GI-NEXT: fmov w0, s0 +; CHECK-GI-NEXT: ret +entry: + %0 = zext <4 x i21> zeroinitializer to <4 x i32> + %1 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0) + ret i32 %1 +} + +define i32 @vecreduce_add_from_i21(<4 x i21> %a) { +; CHECK-LABEL: vecreduce_add_from_i21: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v1.4s, #31, msl #16 +; CHECK-NEXT: and v0.16b, v0.16b, v1.16b +; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret +entry: + %0 = zext <4 x i21> %a to <4 x i32> + %1 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0) + ret i32 %1 +} + declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1 immarg) #1 declare i16 @llvm.vector.reduce.add.v32i16(<32 x i16>) declare i16 @llvm.vector.reduce.add.v24i16(<24 x i16>)