diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2a2953c359984..8059592623713 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17053,14 +17053,28 @@ static SDValue performSVEAndCombine(SDNode *N, uint64_t ExtVal = C->getZExtValue(); + auto MaskAndTypeMatch = [ExtVal](EVT VT) -> bool { + return ((ExtVal == 0xFF && VT == MVT::i8) || + (ExtVal == 0xFFFF && VT == MVT::i16) || + (ExtVal == 0xFFFFFFFF && VT == MVT::i32)); + }; + // If the mask is fully covered by the unpack, we don't need to push // a new AND onto the operand EVT EltTy = UnpkOp->getValueType(0).getVectorElementType(); - if ((ExtVal == 0xFF && EltTy == MVT::i8) || - (ExtVal == 0xFFFF && EltTy == MVT::i16) || - (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32)) + if (MaskAndTypeMatch(EltTy)) return Src; + // If this is 'and (uunpklo/hi (extload MemTy -> ExtTy)), mask', then check + // to see if the mask is all-ones of size MemTy. + auto MaskedLoadOp = dyn_cast(UnpkOp); + if (MaskedLoadOp && (MaskedLoadOp->getExtensionType() == ISD::ZEXTLOAD || + MaskedLoadOp->getExtensionType() == ISD::EXTLOAD)) { + EVT EltTy = MaskedLoadOp->getMemoryVT().getVectorElementType(); + if (MaskAndTypeMatch(EltTy)) + return Src; + } + // Truncate to prevent a DUP with an over wide constant APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits()); diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll index 55bd3833f611c..c495e983818f7 100644 --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll @@ -52,10 +52,9 @@ define @masked_ld1b_i8_zext_i32( *%base, < define @masked_ld1b_nxv8i8_zext_i32( *%a, %mask) { ; CHECK-LABEL: masked_ld1b_nxv8i8_zext_i32: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1b { z0.h }, p0/z, [x0] -; CHECK-NEXT: uunpkhi z1.s, z0.h -; CHECK-NEXT: and z0.h, z0.h, #0xff -; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ld1b { z1.h }, p0/z, [x0] +; CHECK-NEXT: uunpklo z0.s, z1.h +; CHECK-NEXT: uunpkhi z1.s, z1.h ; CHECK-NEXT: ret %wide.masked.load = call @llvm.masked.load.nxv8i8.p0(ptr %a, i32 1, %mask, poison) %res = zext %wide.masked.load to @@ -125,10 +124,9 @@ define @masked_ld1b_i8_zext( *%base, @masked_ld1b_nxv4i8_zext_i64( *%a, %mask) { ; CHECK-LABEL: masked_ld1b_nxv4i8_zext_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1b { z0.s }, p0/z, [x0] -; CHECK-NEXT: uunpkhi z1.d, z0.s -; CHECK-NEXT: and z0.s, z0.s, #0xff -; CHECK-NEXT: uunpklo z0.d, z0.s +; CHECK-NEXT: ld1b { z1.s }, p0/z, [x0] +; CHECK-NEXT: uunpklo z0.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s ; CHECK-NEXT: ret %wide.masked.load = call @llvm.masked.load.nxv4i8.p0(ptr %a, i32 1, %mask, poison) %res = zext %wide.masked.load to @@ -186,10 +184,9 @@ define @masked_ld1h_i16_zext( *%base, @masked_ld1h_nxv4i16_zext( *%a, %mask) { ; CHECK-LABEL: masked_ld1h_nxv4i16_zext: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1h { z0.s }, p0/z, [x0] -; CHECK-NEXT: uunpkhi z1.d, z0.s -; CHECK-NEXT: and z0.s, z0.s, #0xffff -; CHECK-NEXT: uunpklo z0.d, z0.s +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x0] +; CHECK-NEXT: uunpklo z0.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s ; CHECK-NEXT: ret %wide.masked.load = call @llvm.masked.load.nxv4i16.p0(ptr %a, i32 1, %mask, poison) %res = zext %wide.masked.load to