diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7211607fee528..038c23b5e8d50 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21002,6 +21002,12 @@ static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG, Store->getMemOperand()); } +bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) { + return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) || + (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) || + (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32); +} + static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -21043,16 +21049,16 @@ static SDValue performSTORECombine(SDNode *N, if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST)) return Store; - if (ST->isTruncatingStore()) + if (ST->isTruncatingStore()) { + EVT StoreVT = ST->getMemoryVT(); + if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT)) + return SDValue(); if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(ST->getOperand(1), DAG, Subtarget)) { - EVT StoreVT = ST->getMemoryVT(); - if ((ValueVT == MVT::nxv8i16 && StoreVT == MVT::nxv8i8) || - (ValueVT == MVT::nxv4i32 && StoreVT == MVT::nxv4i16) || - (ValueVT == MVT::nxv2i64 && StoreVT == MVT::nxv2i32)) - return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(), - StoreVT, ST->getMemOperand()); + return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(), + StoreVT, ST->getMemOperand()); } + } return SDValue(); } @@ -21098,6 +21104,19 @@ static SDValue performMSTORECombine(SDNode *N, } } + if (MST->isTruncatingStore()) { + EVT ValueVT = Value->getValueType(0); + EVT MemVT = MST->getMemoryVT(); + if (!isHalvingTruncateOfLegalScalableType(ValueVT, MemVT)) + return SDValue(); + if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) { + return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(), + MST->getOffset(), MST->getMask(), + MST->getMemoryVT(), MST->getMemOperand(), + MST->getAddressingMode(), true); + } + } + return SDValue(); } diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll index a913177623df9..0afd11d098a00 100644 --- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll +++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll @@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){ store %3, ptr %4, align 1 ret void } + +define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, %mask) { ; preds = %vector.body, %vector.ph +; CHECK-LABEL: masked_store_rshrnb: +; CHECK: // %bb.0: +; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0] +; CHECK-NEXT: rshrnb z0.b, z0.h, #6 +; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2] +; CHECK-NEXT: ret + %wide.masked.load = tail call @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, %mask, poison) + %1 = add %wide.masked.load, trunc ( shufflevector ( insertelement ( poison, i32 32, i64 0), poison, zeroinitializer) to ) + %2 = lshr %1, trunc ( shufflevector ( insertelement ( poison, i32 6, i64 0), poison, zeroinitializer) to ) + %3 = trunc %2 to + %4 = getelementptr inbounds i8, ptr %dst, i64 %index + tail call void @llvm.masked.store.nxv8i8.p0( %3, ptr %4, i32 1, %mask) + ret void +} + +declare void @llvm.masked.store.nxv8i8.p0(, ptr, i32, ) +declare @llvm.masked.load.nxv8i16.p0(ptr, i32, , )