diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 181ff00184b36b..5f964c08b28ecb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10448,9 +10448,10 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled, // Fold sext/zext of index into index type. bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index, - bool Scaled, SelectionDAG &DAG) { + bool Scaled, bool Signed, SelectionDAG &DAG) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // It's always safe to look through zero extends. if (Index.getOpcode() == ISD::ZERO_EXTEND) { SDValue Op = Index.getOperand(0); MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED); @@ -10460,7 +10461,8 @@ bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index, } } - if (Index.getOpcode() == ISD::SIGN_EXTEND) { + // It's only safe to look through sign extends when Index is signed. + if (Index.getOpcode() == ISD::SIGN_EXTEND && Signed) { SDValue Op = Index.getOperand(0); MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED); if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { @@ -10493,7 +10495,8 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); } - if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) { + if (refineIndexType(MSC, Index, MSC->isIndexScaled(), MSC->isIndexSigned(), + DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter( DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, @@ -10589,7 +10592,8 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { MGT->getExtensionType()); } - if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) { + if (refineIndexType(MGT, Index, MGT->isIndexScaled(), MGT->isIndexSigned(), + DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, Ops, diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll index f4b4a033c3433e..82beb94dfcac22 100644 --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -399,14 +399,12 @@ define @masked_gather_nxv4i32_u8_offsets(i32* %base, %data } -; TODO: The generated code is wrong because we're replicating offset[31] across -; offset[32:63] even though the IR has explicitly zero'd those bits. define @masked_gather_nxv4i32_u32s8_offsets(i32* %base, %offsets, %mask) #0 { ; CHECK-LABEL: masked_gather_nxv4i32_u32s8_offsets: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p1.s ; CHECK-NEXT: sxtb z0.s, p1/m, z0.s -; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0, z0.s, sxtw #2] +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0, z0.s, uxtw #2] ; CHECK-NEXT: ret %offsets.sext = sext %offsets to %offsets.sext.zext = zext %offsets.sext to @@ -482,14 +480,12 @@ define void @masked_scatter_nxv4i32_u8_offsets(i32* %base, %of ret void } -; TODO: The generated code is wrong because we're replicating offset[31] across -; offset[32:63] even though the IR has explicitly zero'd those bits. define void @masked_scatter_nxv4i32_u32s8_offsets(i32* %base, %offsets, %mask, %data) #0 { ; CHECK-LABEL: masked_scatter_nxv4i32_u32s8_offsets: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p1.s ; CHECK-NEXT: sxtb z0.s, p1/m, z0.s -; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2] +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] ; CHECK-NEXT: ret %offsets.sext = sext %offsets to %offsets.sext.zext = zext %offsets.sext to