Skip to content

Commit

Permalink
[LegalizeTypes][VP] Add splitting support for vp.reduction.*
Browse files Browse the repository at this point in the history
Split vp.reduction.* intrinsics by splitting the vector to reduce in
two halves, perform the reduction operation in each one of them and
accumulate the results of both operations.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D117469
  • Loading branch information
victor-eds committed Jan 18, 2022
1 parent c154f39 commit fd1dce3
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 7 deletions.
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void GetSplitVector(SDValue Op, SDValue &Lo, SDValue &Hi);
void SetSplitVector(SDValue Op, SDValue Lo, SDValue Hi);

/// Split mask operator of a VP intrinsic.
std::pair<SDValue, SDValue> SplitMask(SDValue Mask);

// Helper function for incrementing the pointer when splitting
// memory operations
void IncrementPointer(MemSDNode *N, EVT MemVT, MachinePointerInfo &MPI,
Expand Down Expand Up @@ -864,6 +867,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_VECREDUCE_SEQ(SDNode *N);
SDValue SplitVecOp_VP_REDUCE(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_UnaryOp(SDNode *N);
SDValue SplitVecOp_TruncateHelper(SDNode *N);

Expand Down
62 changes: 55 additions & 7 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,16 @@ void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
}
}

std::pair<SDValue, SDValue> DAGTypeLegalizer::SplitMask(SDValue Mask) {
SDValue MaskLo, MaskHi;
EVT MaskVT = Mask.getValueType();
if (getTypeAction(MaskVT) == TargetLowering::TypeSplitVector)
GetSplitVector(Mask, MaskLo, MaskHi);
else
std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, SDLoc(Mask));
return std::make_pair(MaskLo, MaskHi);
}

void DAGTypeLegalizer::SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi) {
SDValue LHSLo, LHSHi;
GetSplitVector(N->getOperand(0), LHSLo, LHSHi);
Expand All @@ -1135,14 +1145,8 @@ void DAGTypeLegalizer::SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi) {
assert(N->getNumOperands() == 4 && "Unexpected number of operands!");
assert(N->isVPOpcode() && "Expected VP opcode");

// Split the mask.
SDValue MaskLo, MaskHi;
SDValue Mask = N->getOperand(2);
EVT MaskVT = Mask.getValueType();
if (getTypeAction(MaskVT) == TargetLowering::TypeSplitVector)
GetSplitVector(Mask, MaskLo, MaskHi);
else
std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, SDLoc(Mask));
std::tie(MaskLo, MaskHi) = SplitMask(N->getOperand(2));

SDValue EVLLo, EVLHi;
std::tie(EVLLo, EVLHi) =
Expand Down Expand Up @@ -2342,6 +2346,23 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VECREDUCE_SEQ_FMUL:
Res = SplitVecOp_VECREDUCE_SEQ(N);
break;
case ISD::VP_REDUCE_FADD:
case ISD::VP_REDUCE_SEQ_FADD:
case ISD::VP_REDUCE_FMUL:
case ISD::VP_REDUCE_SEQ_FMUL:
case ISD::VP_REDUCE_ADD:
case ISD::VP_REDUCE_MUL:
case ISD::VP_REDUCE_AND:
case ISD::VP_REDUCE_OR:
case ISD::VP_REDUCE_XOR:
case ISD::VP_REDUCE_SMAX:
case ISD::VP_REDUCE_SMIN:
case ISD::VP_REDUCE_UMAX:
case ISD::VP_REDUCE_UMIN:
case ISD::VP_REDUCE_FMAX:
case ISD::VP_REDUCE_FMIN:
Res = SplitVecOp_VP_REDUCE(N, OpNo);
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -2438,6 +2459,33 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE_SEQ(SDNode *N) {
return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, Hi, Flags);
}

SDValue DAGTypeLegalizer::SplitVecOp_VP_REDUCE(SDNode *N, unsigned OpNo) {
assert(N->isVPOpcode() && "Expected VP opcode");
assert(OpNo == 1 && "Can only split reduce vector operand");

unsigned Opc = N->getOpcode();
EVT ResVT = N->getValueType(0);
SDValue Lo, Hi;
SDLoc dl(N);

SDValue VecOp = N->getOperand(OpNo);
EVT VecVT = VecOp.getValueType();
assert(VecVT.isVector() && "Can only split reduce vector operand");
GetSplitVector(VecOp, Lo, Hi);

SDValue MaskLo, MaskHi;
std::tie(MaskLo, MaskHi) = SplitMask(N->getOperand(2));

SDValue EVLLo, EVLHi;
std::tie(EVLLo, EVLHi) = DAG.SplitEVL(N->getOperand(3), VecVT, dl);

const SDNodeFlags Flags = N->getFlags();

SDValue ResLo =
DAG.getNode(Opc, dl, ResVT, {N->getOperand(0), Lo, MaskLo, EVLLo}, Flags);
return DAG.getNode(Opc, dl, ResVT, {ResLo, Hi, MaskHi, EVLHi}, Flags);
}

SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) {
// The result has a legal vector type, but the input needs splitting.
EVT ResVT = N->getValueType(0);
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,31 +373,46 @@ ISD::NodeType ISD::getVecReduceBaseOpcode(unsigned VecReduceOpcode) {
llvm_unreachable("Expected VECREDUCE opcode");
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VP_REDUCE_FADD:
case ISD::VP_REDUCE_SEQ_FADD:
return ISD::FADD;
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_SEQ_FMUL:
case ISD::VP_REDUCE_FMUL:
case ISD::VP_REDUCE_SEQ_FMUL:
return ISD::FMUL;
case ISD::VECREDUCE_ADD:
case ISD::VP_REDUCE_ADD:
return ISD::ADD;
case ISD::VECREDUCE_MUL:
case ISD::VP_REDUCE_MUL:
return ISD::MUL;
case ISD::VECREDUCE_AND:
case ISD::VP_REDUCE_AND:
return ISD::AND;
case ISD::VECREDUCE_OR:
case ISD::VP_REDUCE_OR:
return ISD::OR;
case ISD::VECREDUCE_XOR:
case ISD::VP_REDUCE_XOR:
return ISD::XOR;
case ISD::VECREDUCE_SMAX:
case ISD::VP_REDUCE_SMAX:
return ISD::SMAX;
case ISD::VECREDUCE_SMIN:
case ISD::VP_REDUCE_SMIN:
return ISD::SMIN;
case ISD::VECREDUCE_UMAX:
case ISD::VP_REDUCE_UMAX:
return ISD::UMAX;
case ISD::VECREDUCE_UMIN:
case ISD::VP_REDUCE_UMIN:
return ISD::UMIN;
case ISD::VECREDUCE_FMAX:
case ISD::VP_REDUCE_FMAX:
return ISD::FMAXNUM;
case ISD::VECREDUCE_FMIN:
case ISD::VP_REDUCE_FMIN:
return ISD::FMINNUM;
}
}
Expand Down
66 changes: 66 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,72 @@ define float @vpreduce_ord_fadd_v4f32(float %s, <4 x float> %v, <4 x i1> %m, i32
ret float %r
}

declare float @llvm.vp.reduce.fadd.v64f32(float, <64 x float>, <64 x i1>, i32)

define float @vpreduce_fadd_v64f32(float %s, <64 x float> %v, <64 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_fadd_v64f32:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a2, a0, -32
; CHECK-NEXT: li a1, 0
; CHECK-NEXT: bltu a0, a2, .LBB8_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a1, a2
; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, mu
; CHECK-NEXT: li a2, 32
; CHECK-NEXT: vslidedown.vi v24, v0, 4
; CHECK-NEXT: bltu a0, a2, .LBB8_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: .LBB8_4:
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vfmv.s.f v25, fa0
; CHECK-NEXT: vsetvli zero, a0, e32, m8, tu, mu
; CHECK-NEXT: vfredusum.vs v25, v8, v25, v0.t
; CHECK-NEXT: vfmv.f.s ft0, v25
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vfmv.s.f v8, ft0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vfredusum.vs v8, v16, v8, v0.t
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%r = call reassoc float @llvm.vp.reduce.fadd.v64f32(float %s, <64 x float> %v, <64 x i1> %m, i32 %evl)
ret float %r
}

define float @vpreduce_ord_fadd_v64f32(float %s, <64 x float> %v, <64 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_ord_fadd_v64f32:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a2, a0, -32
; CHECK-NEXT: li a1, 0
; CHECK-NEXT: bltu a0, a2, .LBB9_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a1, a2
; CHECK-NEXT: .LBB9_2:
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, mu
; CHECK-NEXT: li a2, 32
; CHECK-NEXT: vslidedown.vi v24, v0, 4
; CHECK-NEXT: bltu a0, a2, .LBB9_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: .LBB9_4:
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vfmv.s.f v25, fa0
; CHECK-NEXT: vsetvli zero, a0, e32, m8, tu, mu
; CHECK-NEXT: vfredosum.vs v25, v8, v25, v0.t
; CHECK-NEXT: vfmv.f.s ft0, v25
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vfmv.s.f v8, ft0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vfredosum.vs v8, v16, v8, v0.t
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%r = call float @llvm.vp.reduce.fadd.v64f32(float %s, <64 x float> %v, <64 x i1> %m, i32 %evl)
ret float %r
}

declare double @llvm.vp.reduce.fadd.v2f64(double, <2 x double>, <2 x i1>, i32)

define double @vpreduce_fadd_v2f64(double %s, <2 x double> %v, <2 x i1> %m, i32 zeroext %evl) {
Expand Down
34 changes: 34 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,40 @@ define signext i32 @vpreduce_xor_v4i32(i32 signext %s, <4 x i32> %v, <4 x i1> %m
ret i32 %r
}

declare i32 @llvm.vp.reduce.xor.v64i32(i32, <64 x i32>, <64 x i1>, i32)

define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_xor_v64i32:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a3, a1, -32
; CHECK-NEXT: li a2, 0
; CHECK-NEXT: bltu a1, a3, .LBB48_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a2, a3
; CHECK-NEXT: .LBB48_2:
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, mu
; CHECK-NEXT: li a3, 32
; CHECK-NEXT: vslidedown.vi v24, v0, 4
; CHECK-NEXT: bltu a1, a3, .LBB48_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: .LBB48_4:
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vmv.s.x v25, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, tu, mu
; CHECK-NEXT: vredxor.vs v25, v8, v25, v0.t
; CHECK-NEXT: vmv.x.s a0, v25
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vmv.s.x v8, a0
; CHECK-NEXT: vsetvli zero, a2, e32, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vredxor.vs v8, v16, v8, v0.t
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
%r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl)
ret i32 %r
}

declare i64 @llvm.vp.reduce.add.v2i64(i64, <2 x i64>, <2 x i1>, i32)

define signext i64 @vpreduce_add_v2i64(i64 signext %s, <2 x i64> %v, <2 x i1> %m, i32 zeroext %evl) {
Expand Down
35 changes: 35 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-mask-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,41 @@ define signext i1 @vpreduce_and_v16i1(i1 signext %s, <16 x i1> %v, <16 x i1> %m,
ret i1 %r
}

declare i1 @llvm.vp.reduce.and.v256i1(i1, <256 x i1>, <256 x i1>, i32)

define signext i1 @vpreduce_and_v256i1(i1 signext %s, <256 x i1> %v, <256 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_and_v256i1:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a2, a1, -128
; CHECK-NEXT: vmv1r.v v11, v0
; CHECK-NEXT: li a3, 0
; CHECK-NEXT: bltu a1, a2, .LBB13_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: .LBB13_2:
; CHECK-NEXT: vsetvli zero, a3, e8, m8, ta, mu
; CHECK-NEXT: vmnand.mm v8, v8, v8
; CHECK-NEXT: vmv1r.v v0, v10
; CHECK-NEXT: vcpop.m a2, v8, v0.t
; CHECK-NEXT: li a3, 128
; CHECK-NEXT: seqz a2, a2
; CHECK-NEXT: bltu a1, a3, .LBB13_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: li a1, 128
; CHECK-NEXT: .LBB13_4:
; CHECK-NEXT: vsetvli zero, a1, e8, m8, ta, mu
; CHECK-NEXT: vmnand.mm v8, v11, v11
; CHECK-NEXT: vmv1r.v v0, v9
; CHECK-NEXT: vcpop.m a1, v8, v0.t
; CHECK-NEXT: seqz a1, a1
; CHECK-NEXT: and a0, a1, a0
; CHECK-NEXT: and a0, a2, a0
; CHECK-NEXT: neg a0, a0
; CHECK-NEXT: ret
%r = call i1 @llvm.vp.reduce.and.v256i1(i1 %s, <256 x i1> %v, <256 x i1> %m, i32 %evl)
ret i1 %r
}

declare i1 @llvm.vp.reduce.or.v16i1(i1, <16 x i1>, <16 x i1>, i32)

define signext i1 @vpreduce_or_v16i1(i1 signext %s, <16 x i1> %v, <16 x i1> %m, i32 zeroext %evl) {
Expand Down
72 changes: 72 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,78 @@ define half @vpreduce_ord_fadd_nxv4f16(half %s, <vscale x 4 x half> %v, <vscale
ret half %r
}

declare half @llvm.vp.reduce.fadd.nxv64f16(half, <vscale x 64 x half>, <vscale x 64 x i1>, i32)

define half @vpreduce_fadd_nxv64f16(half %s, <vscale x 64 x half> %v, <vscale x 64 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_fadd_nxv64f16:
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a2, vlenb
; CHECK-NEXT: srli a1, a2, 1
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
; CHECK-NEXT: slli a2, a2, 2
; CHECK-NEXT: vfmv.s.f v25, fa0
; CHECK-NEXT: mv a3, a0
; CHECK-NEXT: bltu a0, a2, .LBB6_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: .LBB6_2:
; CHECK-NEXT: li a4, 0
; CHECK-NEXT: vsetvli a5, zero, e8, m1, ta, mu
; CHECK-NEXT: vslidedown.vx v24, v0, a1
; CHECK-NEXT: vsetvli zero, a3, e16, m8, tu, mu
; CHECK-NEXT: vfredusum.vs v25, v8, v25, v0.t
; CHECK-NEXT: vfmv.f.s ft0, v25
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
; CHECK-NEXT: sub a1, a0, a2
; CHECK-NEXT: vfmv.s.f v8, ft0
; CHECK-NEXT: bltu a0, a1, .LBB6_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a4, a1
; CHECK-NEXT: .LBB6_4:
; CHECK-NEXT: vsetvli zero, a4, e16, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vfredusum.vs v8, v16, v8, v0.t
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%r = call reassoc half @llvm.vp.reduce.fadd.nxv64f16(half %s, <vscale x 64 x half> %v, <vscale x 64 x i1> %m, i32 %evl)
ret half %r
}

define half @vpreduce_ord_fadd_nxv64f16(half %s, <vscale x 64 x half> %v, <vscale x 64 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpreduce_ord_fadd_nxv64f16:
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a2, vlenb
; CHECK-NEXT: srli a1, a2, 1
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
; CHECK-NEXT: slli a2, a2, 2
; CHECK-NEXT: vfmv.s.f v25, fa0
; CHECK-NEXT: mv a3, a0
; CHECK-NEXT: bltu a0, a2, .LBB7_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: .LBB7_2:
; CHECK-NEXT: li a4, 0
; CHECK-NEXT: vsetvli a5, zero, e8, m1, ta, mu
; CHECK-NEXT: vslidedown.vx v24, v0, a1
; CHECK-NEXT: vsetvli zero, a3, e16, m8, tu, mu
; CHECK-NEXT: vfredosum.vs v25, v8, v25, v0.t
; CHECK-NEXT: vfmv.f.s ft0, v25
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
; CHECK-NEXT: sub a1, a0, a2
; CHECK-NEXT: vfmv.s.f v8, ft0
; CHECK-NEXT: bltu a0, a1, .LBB7_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a4, a1
; CHECK-NEXT: .LBB7_4:
; CHECK-NEXT: vsetvli zero, a4, e16, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vfredosum.vs v8, v16, v8, v0.t
; CHECK-NEXT: vfmv.f.s fa0, v8
; CHECK-NEXT: ret
%r = call half @llvm.vp.reduce.fadd.nxv64f16(half %s, <vscale x 64 x half> %v, <vscale x 64 x i1> %m, i32 %evl)
ret half %r
}

declare float @llvm.vp.reduce.fadd.nxv1f32(float, <vscale x 1 x float>, <vscale x 1 x i1>, i32)

define float @vpreduce_fadd_nxv1f32(float %s, <vscale x 1 x float> %v, <vscale x 1 x i1> %m, i32 zeroext %evl) {
Expand Down
Loading

0 comments on commit fd1dce3

Please sign in to comment.