Skip to content

Commit

Permalink
[RISCV] Allow fractional LMUL for reduction start value
Browse files Browse the repository at this point in the history
For reductions, we need to put the start value into a source vector. For fractional LMULs, we can perform the operation at the original LMUL.  For LMUL > 1, we eventually want to use a scalar insert, but that's outside the scope of this patch.

Differential Revision: https://reviews.llvm.org/D139747
  • Loading branch information
preames committed Dec 12, 2022
1 parent 81084bf commit a4b45c2
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 244 deletions.
12 changes: 11 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -5814,9 +5814,16 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue,
const MVT M1VT = getLMUL1VT(VecVT);
const MVT XLenVT = Subtarget.getXLenVT();

// The reduction needs an LMUL1 input; do the splat at either LMUL1
// or the original VT if fractional.
auto InnerVT = VecVT.bitsLE(M1VT) ? VecVT : M1VT;
SDValue InitialSplat =
lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
M1VT, DL, DAG, Subtarget);
InnerVT, DL, DAG, Subtarget);
if (M1VT != InnerVT)
InitialSplat = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, M1VT,
DAG.getUNDEF(M1VT),
InitialSplat, DAG.getConstant(0, DL, XLenVT));
SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
InitialSplat, Mask, VL);
Expand Down Expand Up @@ -8014,6 +8021,9 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) {
return SDValue();

SDValue ScalarV = Reduce.getOperand(2);
if (ScalarV.getOpcode() == ISD::INSERT_SUBVECTOR &&
ScalarV.getOperand(0)->isUndef())
ScalarV = ScalarV.getOperand(1);

// Make sure that ScalarV is a splat with VL=1.
if (ScalarV.getOpcode() != RISCVISD::VFMV_S_F_VL &&
Expand Down
12 changes: 6 additions & 6 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll
Expand Up @@ -9,7 +9,7 @@ declare half @llvm.vp.reduce.fadd.v2f16(half, <2 x half>, <2 x i1>, i32)
define half @vpreduce_fadd_v2f16(half %s, <2 x half> %v, <2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_fadd_v2f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma
; CHECK-NEXT: vfredusum.vs v9, v8, v9, v0.t
Expand All @@ -22,7 +22,7 @@ define half @vpreduce_fadd_v2f16(half %s, <2 x half> %v, <2 x i1> %m, i32 zeroex
define half @vpreduce_ord_fadd_v2f16(half %s, <2 x half> %v, <2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_ord_fadd_v2f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma
; CHECK-NEXT: vfredosum.vs v9, v8, v9, v0.t
Expand All @@ -37,7 +37,7 @@ declare half @llvm.vp.reduce.fadd.v4f16(half, <4 x half>, <4 x i1>, i32)
define half @vpreduce_fadd_v4f16(half %s, <4 x half> %v, <4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_fadd_v4f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, tu, ma
; CHECK-NEXT: vfredusum.vs v9, v8, v9, v0.t
Expand All @@ -50,7 +50,7 @@ define half @vpreduce_fadd_v4f16(half %s, <4 x half> %v, <4 x i1> %m, i32 zeroex
define half @vpreduce_ord_fadd_v4f16(half %s, <4 x half> %v, <4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_ord_fadd_v4f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, tu, ma
; CHECK-NEXT: vfredosum.vs v9, v8, v9, v0.t
Expand All @@ -65,7 +65,7 @@ declare float @llvm.vp.reduce.fadd.v2f32(float, <2 x float>, <2 x i1>, i32)
define float @vpreduce_fadd_v2f32(float %s, <2 x float> %v, <2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_fadd_v2f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, tu, ma
; CHECK-NEXT: vfredusum.vs v9, v8, v9, v0.t
Expand All @@ -78,7 +78,7 @@ define float @vpreduce_fadd_v2f32(float %s, <2 x float> %v, <2 x i1> %m, i32 zer
define float @vpreduce_ord_fadd_v2f32(float %s, <2 x float> %v, <2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_ord_fadd_v2f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, tu, ma
; CHECK-NEXT: vfredosum.vs v9, v8, v9, v0.t
Expand Down
35 changes: 16 additions & 19 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp.ll
Expand Up @@ -317,11 +317,10 @@ define float @vreduce_fwadd_v1f32(<1 x half>* %x, float %s) {
define float @vreduce_ord_fwadd_v1f32(<1 x half>* %x, float %s) {
; CHECK-LABEL: vreduce_ord_fwadd_v1f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
; CHECK-NEXT: vfwredosum.vs v8, v8, v9
; CHECK-NEXT: vsetivli zero, 0, e32, m1, ta, ma
; CHECK-NEXT: vfmv.f.s fa0, v8
Expand Down Expand Up @@ -365,11 +364,10 @@ define float @vreduce_ord_fadd_v2f32(<2 x float>* %x, float %s) {
define float @vreduce_fwadd_v2f32(<2 x half>* %x, float %s) {
; CHECK-LABEL: vreduce_fwadd_v2f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
; CHECK-NEXT: vfwredusum.vs v8, v8, v9
; CHECK-NEXT: vsetivli zero, 0, e32, m1, ta, ma
; CHECK-NEXT: vfmv.f.s fa0, v8
Expand All @@ -383,11 +381,10 @@ define float @vreduce_fwadd_v2f32(<2 x half>* %x, float %s) {
define float @vreduce_ord_fwadd_v2f32(<2 x half>* %x, float %s) {
; CHECK-LABEL: vreduce_ord_fwadd_v2f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vfmv.s.f v9, fa0
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
; CHECK-NEXT: vfwredosum.vs v8, v8, v9
; CHECK-NEXT: vsetivli zero, 0, e32, m1, ta, ma
; CHECK-NEXT: vfmv.f.s fa0, v8
Expand Down Expand Up @@ -1185,7 +1182,7 @@ define half @vreduce_fmin_v2f16(<2 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI68_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI68_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vfredmin.vs v8, v8, v9
Expand All @@ -1205,7 +1202,7 @@ define half @vreduce_fmin_v4f16(<4 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI69_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI69_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vfredmin.vs v8, v8, v9
Expand All @@ -1223,7 +1220,7 @@ define half @vreduce_fmin_v4f16_nonans(<4 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI70_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI70_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vfredmin.vs v8, v8, v9
Expand All @@ -1241,7 +1238,7 @@ define half @vreduce_fmin_v4f16_nonans_noinfs(<4 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI71_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI71_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vfredmin.vs v8, v8, v9
Expand Down Expand Up @@ -1285,7 +1282,7 @@ define float @vreduce_fmin_v2f32(<2 x float>* %x) {
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI73_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI73_0)
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
; CHECK-NEXT: vlse32.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vfredmin.vs v8, v8, v9
Expand Down Expand Up @@ -1490,7 +1487,7 @@ define half @vreduce_fmax_v2f16(<2 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI83_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI83_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vfredmax.vs v8, v8, v9
Expand All @@ -1510,7 +1507,7 @@ define half @vreduce_fmax_v4f16(<4 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI84_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI84_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vfredmax.vs v8, v8, v9
Expand All @@ -1528,7 +1525,7 @@ define half @vreduce_fmax_v4f16_nonans(<4 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI85_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI85_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vfredmax.vs v8, v8, v9
Expand All @@ -1546,7 +1543,7 @@ define half @vreduce_fmax_v4f16_nonans_noinfs(<4 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI86_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI86_0)
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
; CHECK-NEXT: vlse16.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vfredmax.vs v8, v8, v9
Expand Down Expand Up @@ -1590,7 +1587,7 @@ define float @vreduce_fmax_v2f32(<2 x float>* %x) {
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a0, %hi(.LCPI88_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI88_0)
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
; CHECK-NEXT: vlse32.v v9, (a0), zero
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vfredmax.vs v8, v8, v9
Expand Down

0 comments on commit a4b45c2

Please sign in to comment.