Skip to content

Commit

Permalink
[RISCV] Move narrowIndex to be a DAG combine over target independent …
Browse files Browse the repository at this point in the history
…nodes

In D154687, we added a transform to narrow indexed load/store indices of the
form (shl (zext), C).  We can move this into a generic transform over the
target independent nodes instead, and pick up the fixed vector cases with no
additional work required.  This is an alternative to D158163.

Performing this transform points out that we weren't eliminating zero_extends
via the the generic DAG combine.  Adjust the (existing) callbacks so that we
do.

This change *removes* the existing transform on the target specific intrinsic
nodes.  If anyone has a use case this impacts, please speak up.

Note: Reviewed as part of a stack of changes in PR# 66405.
  • Loading branch information
preames committed Sep 15, 2023
1 parent 2ff9175 commit 37aa07a
Show file tree
Hide file tree
Showing 8 changed files with 375 additions and 371 deletions.
70 changes: 45 additions & 25 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11632,21 +11632,24 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
// zero-extended their indices, \p narrowIndex tries to narrow the type of index
// operand if it is matched to pattern (shl (zext x to ty), C) and bits(x) + C <
// bits(ty).
static SDValue narrowIndex(SDValue N, SelectionDAG &DAG) {
static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &DAG) {
if (isIndexTypeSigned(IndexType))
return false;

if (N.getOpcode() != ISD::SHL || !N->hasOneUse())
return SDValue();
return false;

SDValue N0 = N.getOperand(0);
if (N0.getOpcode() != ISD::ZERO_EXTEND &&
N0.getOpcode() != RISCVISD::VZEXT_VL)
return SDValue();
return false;;
if (!N0->hasOneUse())
return SDValue();
return false;;

APInt ShAmt;
SDValue N1 = N.getOperand(1);
if (!ISD::isConstantSplatVector(N1.getNode(), ShAmt))
return SDValue();
return false;;

SDLoc DL(N);
SDValue Src = N0.getOperand(0);
Expand All @@ -11658,14 +11661,15 @@ static SDValue narrowIndex(SDValue N, SelectionDAG &DAG) {

// Skip if NewElen is not narrower than the original extended type.
if (NewElen >= N0.getValueType().getScalarSizeInBits())
return SDValue();
return false;

EVT NewEltVT = EVT::getIntegerVT(*DAG.getContext(), NewElen);
EVT NewVT = SrcVT.changeVectorElementType(NewEltVT);

SDValue NewExt = DAG.getNode(N0->getOpcode(), DL, NewVT, N0->ops());
SDValue NewShAmtVec = DAG.getConstant(ShAmtV, DL, NewVT);
return DAG.getNode(ISD::SHL, DL, NewVT, NewExt, NewShAmtVec);
N = DAG.getNode(ISD::SHL, DL, NewVT, NewExt, NewShAmtVec);
return true;
}

// Replace (seteq (i64 (and X, 0xffffffff)), C1) with
Expand Down Expand Up @@ -13883,6 +13887,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
MGN->getBasePtr(), Index, ScaleOp},
MGN->getMemOperand(), IndexType, MGN->getExtensionType());

if (narrowIndex(Index, IndexType, DAG))
return DAG.getMaskedGather(
N->getVTList(), MGN->getMemoryVT(), DL,
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
MGN->getBasePtr(), Index, ScaleOp},
MGN->getMemOperand(), IndexType, MGN->getExtensionType());
break;
}
case ISD::MSCATTER:{
Expand All @@ -13900,6 +13911,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
{MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
Index, ScaleOp},
MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());

if (narrowIndex(Index, IndexType, DAG))
return DAG.getMaskedScatter(
N->getVTList(), MSN->getMemoryVT(), DL,
{MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
Index, ScaleOp},
MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
break;
}
case ISD::VP_GATHER: {
Expand All @@ -13917,6 +13935,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
ScaleOp, VPGN->getMask(),
VPGN->getVectorLength()},
VPGN->getMemOperand(), IndexType);

if (narrowIndex(Index, IndexType, DAG))
return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
{VPGN->getChain(), VPGN->getBasePtr(), Index,
ScaleOp, VPGN->getMask(),
VPGN->getVectorLength()},
VPGN->getMemOperand(), IndexType);

break;
}
case ISD::VP_SCATTER: {
Expand All @@ -13934,6 +13960,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
VPSN->getBasePtr(), Index, ScaleOp,
VPSN->getMask(), VPSN->getVectorLength()},
VPSN->getMemOperand(), IndexType);

if (narrowIndex(Index, IndexType, DAG))
return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
{VPSN->getChain(), VPSN->getValue(),
VPSN->getBasePtr(), Index, ScaleOp,
VPSN->getMask(), VPSN->getVectorLength()},
VPSN->getMemOperand(), IndexType);
break;
}
case RISCVISD::SRA_VL:
Expand Down Expand Up @@ -14238,23 +14271,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getConstant(-1, DL, VT);
return DAG.getConstant(0, DL, VT);
}
case Intrinsic::riscv_vloxei:
case Intrinsic::riscv_vloxei_mask:
case Intrinsic::riscv_vluxei:
case Intrinsic::riscv_vluxei_mask:
case Intrinsic::riscv_vsoxei:
case Intrinsic::riscv_vsoxei_mask:
case Intrinsic::riscv_vsuxei:
case Intrinsic::riscv_vsuxei_mask:
if (SDValue V = narrowIndex(N->getOperand(4), DAG)) {
SmallVector<SDValue, 8> Ops(N->ops());
Ops[4] = V;
const auto *MemSD = cast<MemIntrinsicSDNode>(N);
return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(),
Ops, MemSD->getMemoryVT(),
MemSD->getMemOperand());
}
return SDValue();
}
}
case ISD::BITCAST: {
Expand Down Expand Up @@ -17692,7 +17708,11 @@ Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic(

bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
EVT DataVT) const {
return false;
// We have indexed loads for all legal index types. Indices are always
// zero extended
return Extend.getOpcode() == ISD::ZERO_EXTEND &&
isTypeLegal(Extend.getValueType()) &&
isTypeLegal(Extend.getOperand(0).getValueType());
}

bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
Expand Down
132 changes: 66 additions & 66 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1716,21 +1716,19 @@ define <8 x i16> @mgather_baseidx_sext_v8i8_v8i16(ptr %base, <8 x i8> %idxs, <8
define <8 x i16> @mgather_baseidx_zext_v8i8_v8i16(ptr %base, <8 x i8> %idxs, <8 x i1> %m, <8 x i16> %passthru) {
; RV32-LABEL: mgather_baseidx_zext_v8i8_v8i16:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32-NEXT: vzext.vf4 v10, v8
; RV32-NEXT: vadd.vv v10, v10, v10
; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV32-NEXT: vwaddu.vv v10, v8, v8
; RV32-NEXT: vsetvli zero, zero, e16, m1, ta, mu
; RV32-NEXT: vluxei32.v v9, (a0), v10, v0.t
; RV32-NEXT: vluxei16.v v9, (a0), v10, v0.t
; RV32-NEXT: vmv.v.v v8, v9
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_baseidx_zext_v8i8_v8i16:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64V-NEXT: vzext.vf8 v12, v8
; RV64V-NEXT: vadd.vv v12, v12, v12
; RV64V-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV64V-NEXT: vwaddu.vv v10, v8, v8
; RV64V-NEXT: vsetvli zero, zero, e16, m1, ta, mu
; RV64V-NEXT: vluxei64.v v9, (a0), v12, v0.t
; RV64V-NEXT: vluxei16.v v9, (a0), v10, v0.t
; RV64V-NEXT: vmv.v.v v8, v9
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -2793,20 +2791,21 @@ define <8 x i32> @mgather_baseidx_sext_v8i8_v8i32(ptr %base, <8 x i8> %idxs, <8
define <8 x i32> @mgather_baseidx_zext_v8i8_v8i32(ptr %base, <8 x i8> %idxs, <8 x i1> %m, <8 x i32> %passthru) {
; RV32-LABEL: mgather_baseidx_zext_v8i8_v8i32:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; RV32-NEXT: vzext.vf4 v12, v8
; RV32-NEXT: vsll.vi v8, v12, 2
; RV32-NEXT: vluxei32.v v10, (a0), v8, v0.t
; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV32-NEXT: vzext.vf2 v9, v8
; RV32-NEXT: vsll.vi v8, v9, 2
; RV32-NEXT: vsetvli zero, zero, e32, m2, ta, mu
; RV32-NEXT: vluxei16.v v10, (a0), v8, v0.t
; RV32-NEXT: vmv.v.v v8, v10
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_baseidx_zext_v8i8_v8i32:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64V-NEXT: vzext.vf8 v12, v8
; RV64V-NEXT: vsll.vi v12, v12, 2
; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV64V-NEXT: vzext.vf2 v9, v8
; RV64V-NEXT: vsll.vi v8, v9, 2
; RV64V-NEXT: vsetvli zero, zero, e32, m2, ta, mu
; RV64V-NEXT: vluxei64.v v10, (a0), v12, v0.t
; RV64V-NEXT: vluxei16.v v10, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v10
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -3264,11 +3263,10 @@ define <8 x i32> @mgather_baseidx_zext_v8i16_v8i32(ptr %base, <8 x i16> %idxs, <
;
; RV64V-LABEL: mgather_baseidx_zext_v8i16_v8i32:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64V-NEXT: vzext.vf4 v12, v8
; RV64V-NEXT: vsll.vi v12, v12, 2
; RV64V-NEXT: vsetvli zero, zero, e32, m2, ta, mu
; RV64V-NEXT: vluxei64.v v10, (a0), v12, v0.t
; RV64V-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; RV64V-NEXT: vzext.vf2 v12, v8
; RV64V-NEXT: vsll.vi v8, v12, 2
; RV64V-NEXT: vluxei32.v v10, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v10
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -4772,20 +4770,21 @@ define <8 x i64> @mgather_baseidx_sext_v8i8_v8i64(ptr %base, <8 x i8> %idxs, <8
define <8 x i64> @mgather_baseidx_zext_v8i8_v8i64(ptr %base, <8 x i8> %idxs, <8 x i1> %m, <8 x i64> %passthru) {
; RV32V-LABEL: mgather_baseidx_zext_v8i8_v8i64:
; RV32V: # %bb.0:
; RV32V-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32V-NEXT: vzext.vf4 v10, v8
; RV32V-NEXT: vsll.vi v8, v10, 3
; RV32V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV32V-NEXT: vzext.vf2 v9, v8
; RV32V-NEXT: vsll.vi v8, v9, 3
; RV32V-NEXT: vsetvli zero, zero, e64, m4, ta, mu
; RV32V-NEXT: vluxei32.v v12, (a0), v8, v0.t
; RV32V-NEXT: vluxei16.v v12, (a0), v8, v0.t
; RV32V-NEXT: vmv.v.v v8, v12
; RV32V-NEXT: ret
;
; RV64V-LABEL: mgather_baseidx_zext_v8i8_v8i64:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, mu
; RV64V-NEXT: vzext.vf8 v16, v8
; RV64V-NEXT: vsll.vi v8, v16, 3
; RV64V-NEXT: vluxei64.v v12, (a0), v8, v0.t
; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV64V-NEXT: vzext.vf2 v9, v8
; RV64V-NEXT: vsll.vi v8, v9, 3
; RV64V-NEXT: vsetvli zero, zero, e64, m4, ta, mu
; RV64V-NEXT: vluxei16.v v12, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v12
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -5616,10 +5615,11 @@ define <8 x i64> @mgather_baseidx_zext_v8i16_v8i64(ptr %base, <8 x i16> %idxs, <
;
; RV64V-LABEL: mgather_baseidx_zext_v8i16_v8i64:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, mu
; RV64V-NEXT: vzext.vf4 v16, v8
; RV64V-NEXT: vsll.vi v8, v16, 3
; RV64V-NEXT: vluxei64.v v12, (a0), v8, v0.t
; RV64V-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV64V-NEXT: vzext.vf2 v10, v8
; RV64V-NEXT: vsll.vi v8, v10, 3
; RV64V-NEXT: vsetvli zero, zero, e64, m4, ta, mu
; RV64V-NEXT: vluxei32.v v12, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v12
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -7645,21 +7645,19 @@ define <8 x half> @mgather_baseidx_sext_v8i8_v8f16(ptr %base, <8 x i8> %idxs, <8
define <8 x half> @mgather_baseidx_zext_v8i8_v8f16(ptr %base, <8 x i8> %idxs, <8 x i1> %m, <8 x half> %passthru) {
; RV32-LABEL: mgather_baseidx_zext_v8i8_v8f16:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32-NEXT: vzext.vf4 v10, v8
; RV32-NEXT: vadd.vv v10, v10, v10
; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV32-NEXT: vwaddu.vv v10, v8, v8
; RV32-NEXT: vsetvli zero, zero, e16, m1, ta, mu
; RV32-NEXT: vluxei32.v v9, (a0), v10, v0.t
; RV32-NEXT: vluxei16.v v9, (a0), v10, v0.t
; RV32-NEXT: vmv.v.v v8, v9
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_baseidx_zext_v8i8_v8f16:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64V-NEXT: vzext.vf8 v12, v8
; RV64V-NEXT: vadd.vv v12, v12, v12
; RV64V-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV64V-NEXT: vwaddu.vv v10, v8, v8
; RV64V-NEXT: vsetvli zero, zero, e16, m1, ta, mu
; RV64V-NEXT: vluxei64.v v9, (a0), v12, v0.t
; RV64V-NEXT: vluxei16.v v9, (a0), v10, v0.t
; RV64V-NEXT: vmv.v.v v8, v9
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -8596,20 +8594,21 @@ define <8 x float> @mgather_baseidx_sext_v8i8_v8f32(ptr %base, <8 x i8> %idxs, <
define <8 x float> @mgather_baseidx_zext_v8i8_v8f32(ptr %base, <8 x i8> %idxs, <8 x i1> %m, <8 x float> %passthru) {
; RV32-LABEL: mgather_baseidx_zext_v8i8_v8f32:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; RV32-NEXT: vzext.vf4 v12, v8
; RV32-NEXT: vsll.vi v8, v12, 2
; RV32-NEXT: vluxei32.v v10, (a0), v8, v0.t
; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV32-NEXT: vzext.vf2 v9, v8
; RV32-NEXT: vsll.vi v8, v9, 2
; RV32-NEXT: vsetvli zero, zero, e32, m2, ta, mu
; RV32-NEXT: vluxei16.v v10, (a0), v8, v0.t
; RV32-NEXT: vmv.v.v v8, v10
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_baseidx_zext_v8i8_v8f32:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64V-NEXT: vzext.vf8 v12, v8
; RV64V-NEXT: vsll.vi v12, v12, 2
; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV64V-NEXT: vzext.vf2 v9, v8
; RV64V-NEXT: vsll.vi v8, v9, 2
; RV64V-NEXT: vsetvli zero, zero, e32, m2, ta, mu
; RV64V-NEXT: vluxei64.v v10, (a0), v12, v0.t
; RV64V-NEXT: vluxei16.v v10, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v10
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -9067,11 +9066,10 @@ define <8 x float> @mgather_baseidx_zext_v8i16_v8f32(ptr %base, <8 x i16> %idxs,
;
; RV64V-LABEL: mgather_baseidx_zext_v8i16_v8f32:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64V-NEXT: vzext.vf4 v12, v8
; RV64V-NEXT: vsll.vi v12, v12, 2
; RV64V-NEXT: vsetvli zero, zero, e32, m2, ta, mu
; RV64V-NEXT: vluxei64.v v10, (a0), v12, v0.t
; RV64V-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; RV64V-NEXT: vzext.vf2 v12, v8
; RV64V-NEXT: vsll.vi v8, v12, 2
; RV64V-NEXT: vluxei32.v v10, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v10
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -10334,20 +10332,21 @@ define <8 x double> @mgather_baseidx_sext_v8i8_v8f64(ptr %base, <8 x i8> %idxs,
define <8 x double> @mgather_baseidx_zext_v8i8_v8f64(ptr %base, <8 x i8> %idxs, <8 x i1> %m, <8 x double> %passthru) {
; RV32V-LABEL: mgather_baseidx_zext_v8i8_v8f64:
; RV32V: # %bb.0:
; RV32V-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32V-NEXT: vzext.vf4 v10, v8
; RV32V-NEXT: vsll.vi v8, v10, 3
; RV32V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV32V-NEXT: vzext.vf2 v9, v8
; RV32V-NEXT: vsll.vi v8, v9, 3
; RV32V-NEXT: vsetvli zero, zero, e64, m4, ta, mu
; RV32V-NEXT: vluxei32.v v12, (a0), v8, v0.t
; RV32V-NEXT: vluxei16.v v12, (a0), v8, v0.t
; RV32V-NEXT: vmv.v.v v8, v12
; RV32V-NEXT: ret
;
; RV64V-LABEL: mgather_baseidx_zext_v8i8_v8f64:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, mu
; RV64V-NEXT: vzext.vf8 v16, v8
; RV64V-NEXT: vsll.vi v8, v16, 3
; RV64V-NEXT: vluxei64.v v12, (a0), v8, v0.t
; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV64V-NEXT: vzext.vf2 v9, v8
; RV64V-NEXT: vsll.vi v8, v9, 3
; RV64V-NEXT: vsetvli zero, zero, e64, m4, ta, mu
; RV64V-NEXT: vluxei16.v v12, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v12
; RV64V-NEXT: ret
;
Expand Down Expand Up @@ -11001,10 +11000,11 @@ define <8 x double> @mgather_baseidx_zext_v8i16_v8f64(ptr %base, <8 x i16> %idxs
;
; RV64V-LABEL: mgather_baseidx_zext_v8i16_v8f64:
; RV64V: # %bb.0:
; RV64V-NEXT: vsetivli zero, 8, e64, m4, ta, mu
; RV64V-NEXT: vzext.vf4 v16, v8
; RV64V-NEXT: vsll.vi v8, v16, 3
; RV64V-NEXT: vluxei64.v v12, (a0), v8, v0.t
; RV64V-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV64V-NEXT: vzext.vf2 v10, v8
; RV64V-NEXT: vsll.vi v8, v10, 3
; RV64V-NEXT: vsetvli zero, zero, e64, m4, ta, mu
; RV64V-NEXT: vluxei32.v v12, (a0), v8, v0.t
; RV64V-NEXT: vmv.v.v v8, v12
; RV64V-NEXT: ret
;
Expand Down

0 comments on commit 37aa07a

Please sign in to comment.