diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index cd466dceb900f..cfc8a4243e894 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1968,6 +1968,10 @@ LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false); /// Build vector implicit truncation is allowed. LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false); +/// Return true if the value is a constant (+/-)0.0 floating-point value or a +/// splatted vector thereof (with no undefs). +LLVM_ABI bool isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs = false); + /// Return true if \p V is either a integer or FP constant. inline bool isIntOrFPConstant(SDValue V) { return isa(V) || isa(V); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index df353c4d91b1a..59587329493fa 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13018,22 +13018,34 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { return SDValue(); } -// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1)) +// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1)) // -> partial_reduce_*mla(acc, a, b) // -// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) -// -> partial_reduce_*mla(acc, x, C) +// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1)) +// -> partial_reduce_*mla(acc, x, splat(C)) // -// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0)) -// -> partial_reduce_fmla(acc, a, b) +// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1)) +// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b) +// +// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1)) +// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C)) SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDLoc DL(N); auto *Context = DAG.getContext(); SDValue Acc = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - unsigned Opc = Op1->getOpcode(); + + // Handle predication by moving the SELECT into the operand of the MUL. + SDValue Pred; + if (Opc == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) || + isZeroOrZeroSplatFP(Op1->getOperand(2)))) { + Pred = Op1->getOperand(0); + Op1 = Op1->getOperand(1); + Opc = Op1->getOpcode(); + } + if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL) return SDValue(); @@ -13068,6 +13080,19 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue LHSExtOp = LHS->getOperand(0); EVT LHSExtOpVT = LHSExtOp.getValueType(); + // When Pred is non-zero, set Op = select(Pred, Op, splat(0)) and freeze + // OtherOp to keep the same semantics when moving the selects into the MUL + // operands. + auto ApplyPredicate = [&](SDValue &Op, SDValue &OtherOp) { + if (Pred) { + EVT OpVT = Op.getValueType(); + SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT) + : DAG.getConstant(0, DL, OpVT); + Op = DAG.getSelect(DL, OpVT, Pred, Op, Zero); + OtherOp = DAG.getFreeze(OtherOp); + } + }; + // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) APInt C; @@ -13090,8 +13115,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { TLI.getTypeToTransformTo(*Context, LHSExtOpVT))) return SDValue(); - return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, - DAG.getConstant(CTrunc, DL, LHSExtOpVT)); + SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT); + ApplyPredicate(C, LHSExtOp); + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C); } unsigned RHSOpcode = RHS->getOpcode(); @@ -13132,17 +13158,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { TLI.getTypeToTransformTo(*Context, LHSExtOpVT))) return SDValue(); + ApplyPredicate(RHSExtOp, LHSExtOp); return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp); } -// partial.reduce.umla(acc, zext(op), splat(1)) -// -> partial.reduce.umla(acc, op, splat(trunc(1))) -// partial.reduce.smla(acc, sext(op), splat(1)) -// -> partial.reduce.smla(acc, op, splat(trunc(1))) +// partial.reduce.*mla(acc, *ext(op), splat(1)) +// -> partial.reduce.*mla(acc, op, splat(trunc(1))) // partial.reduce.sumla(acc, sext(op), splat(1)) // -> partial.reduce.smla(acc, op, splat(trunc(1))) -// partial.reduce.fmla(acc, fpext(op), splat(1.0)) -// -> partial.reduce.fmla(acc, op, splat(1.0)) +// +// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1)) +// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1))) SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { SDLoc DL(N); SDValue Acc = N->getOperand(0); @@ -13152,7 +13178,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2)) return SDValue(); + SDValue Pred; unsigned Op1Opcode = Op1.getOpcode(); + if (Op1Opcode == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) || + isZeroOrZeroSplatFP(Op1->getOperand(2)))) { + Pred = Op1->getOperand(0); + Op1 = Op1->getOperand(1); + Op1Opcode = Op1->getOpcode(); + } + if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND) return SDValue(); @@ -13181,6 +13215,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { ? DAG.getConstantFP(1, DL, UnextOp1VT) : DAG.getConstant(1, DL, UnextOp1VT); + if (Pred) { + SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? DAG.getConstantFP(0, DL, UnextOp1VT) + : DAG.getConstant(0, DL, UnextOp1VT); + Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero); + } return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, Constant); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index c2b4c19846316..16fdef06d6679 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -12971,6 +12971,11 @@ bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) { return C && C->isZero(); } +bool llvm::isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs) { + ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs); + return C && C->isZero(); +} + HandleSDNode::~HandleSDNode() { DropOperands(); } diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll new file mode 100644 index 0000000000000..24cdd0a852222 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll @@ -0,0 +1,159 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s | FileCheck %s + +target triple = "aarch64" + +define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 { +; CHECK-LABEL: predicate_dot_fixed_length: +; CHECK: // %bb.0: +; CHECK-NEXT: shl v1.16b, v1.16b, #7 +; CHECK-NEXT: cmlt v1.16b, v1.16b, #0 +; CHECK-NEXT: and v1.16b, v1.16b, v3.16b +; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b +; CHECK-NEXT: ret + %ext.1 = sext <16 x i8> %a to <16 x i32> + %ext.2 = sext <16 x i8> %b to <16 x i32> + %mul = mul nsw <16 x i32> %ext.1, %ext.2 + %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer + %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel) + ret <4 x i32> %red +} + +define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 { +; CHECK-LABEL: predicate_dot_by_C_fixed_length: +; CHECK: // %bb.0: +; CHECK-NEXT: shl v1.16b, v1.16b, #7 +; CHECK-NEXT: movi v3.16b, #127 +; CHECK-NEXT: cmlt v1.16b, v1.16b, #0 +; CHECK-NEXT: and v1.16b, v1.16b, v3.16b +; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b +; CHECK-NEXT: ret + %ext.1 = sext <16 x i8> %a to <16 x i32> + %mul = mul nsw <16 x i32> %ext.1, splat(i32 127) + %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer + %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel) + ret <4 x i32> %red +} + +define @predicate_dot_scalable( %acc, %p, %a, %b) #0 { +; CHECK-LABEL: predicate_dot_scalable: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v3.2d, #0000000000000000 +; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %ext.1 = sext %a to + %ext.2 = sext %b to + %mul = mul nsw %ext.1, %ext.2 + %sel = select %p, %mul, zeroinitializer + %red = call @llvm.vector.partial.reduce.add( %acc, %sel) + ret %red +} + +define @predicate_dot_by_C_scalable( %acc, %p, %a) #0 { +; CHECK-LABEL: predicate_dot_by_C_scalable: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %ext.1 = sext %a to + %mul = mul nsw %ext.1, splat(i32 127) + %sel = select %p, %mul, zeroinitializer + %red = call @llvm.vector.partial.reduce.add( %acc, %sel) + ret %red +} + +define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 { +; CHECK-LABEL: predicate_ext_mul_fixed_length: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v3.16b, #1 +; CHECK-NEXT: and v1.16b, v1.16b, v3.16b +; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b +; CHECK-NEXT: ret + %ext = sext <16 x i8> %a to <16 x i32> + %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer + %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel) + ret <4 x i32> %red +} + +define @predicate_ext_mul_scalable( %acc, %p, %a) #0 { +; CHECK-LABEL: predicate_ext_mul_scalable: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1 +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %ext = sext %a to + %sel = select %p, %ext, zeroinitializer + %red = call @llvm.vector.partial.reduce.add( %acc, %sel) + ret %red +} + +define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 { +; CHECK-LABEL: predicated_fdot_fixed_length: +; CHECK: // %bb.0: +; CHECK-NEXT: ushll v1.8h, v1.8b, #0 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: shl v1.8h, v1.8h, #15 +; CHECK-NEXT: cmlt v1.8h, v1.8h, #0 +; CHECK-NEXT: and v1.16b, v1.16b, v3.16b +; CHECK-NEXT: fdot z0.s, z2.h, z1.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %ext.1 = fpext <8 x half> %a to <8 x float> + %ext.2 = fpext <8 x half> %b to <8 x float> + %mul = fmul <8 x float> %ext.1, %ext.2 + %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer + %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel) + ret <4 x float> %red +} + +define @predicated_fdot_scalable( %acc, %p, %a, %b) #1 { +; CHECK-LABEL: predicated_fdot_scalable: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v3.2d, #0000000000000000 +; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h +; CHECK-NEXT: fdot z0.s, z1.h, z2.h +; CHECK-NEXT: ret + %ext.1 = fpext %a to + %ext.2 = fpext %b to + %mul = fmul %ext.1, %ext.2 + %sel = select %p, %mul, zeroinitializer + %red = call @llvm.vector.partial.reduce.fadd( %acc, %sel) + ret %red +} + +define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 { +; CHECK-LABEL: predicated_fpext_fmul_fixed_length: +; CHECK: // %bb.0: +; CHECK-NEXT: ushll v1.8h, v1.8b, #0 +; CHECK-NEXT: movi v3.8h, #60, lsl #8 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: shl v1.8h, v1.8h, #15 +; CHECK-NEXT: cmlt v1.8h, v1.8h, #0 +; CHECK-NEXT: and v1.16b, v1.16b, v3.16b +; CHECK-NEXT: fdot z0.s, z2.h, z1.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %ext = fpext <8 x half> %a to <8 x float> + %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer + %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel) + ret <4 x float> %red +} + +define @predicated_fpext_fmul_scalable( %acc, %p, %a) #1 { +; CHECK-LABEL: predicated_fpext_fmul_scalable: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v2.2d, #0000000000000000 +; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000 +; CHECK-NEXT: fdot z0.s, z1.h, z2.h +; CHECK-NEXT: ret + %ext = fpext %a to + %sel = select %p, %ext, zeroinitializer + %red = call @llvm.vector.partial.reduce.fadd( %acc, %sel) + ret %red +} + +attributes #0 = { nounwind "target-features"="+sve,+dotprod" } +attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" } diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll index 72bf1fa9a8327..d6384a6913efe 100644 --- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll @@ -996,20 +996,31 @@ entry: } define @partial_reduce_select( %a, %b, %m) { -; CHECK-LABEL: partial_reduce_select: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma -; CHECK-NEXT: vsext.vf2 v12, v8 -; CHECK-NEXT: vsext.vf2 v14, v9 -; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma -; CHECK-NEXT: vmv.v.i v8, 0 -; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, mu -; CHECK-NEXT: vwmul.vv v8, v12, v14, v0.t -; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma -; CHECK-NEXT: vadd.vv v8, v11, v8 -; CHECK-NEXT: vadd.vv v9, v9, v10 -; CHECK-NEXT: vadd.vv v8, v9, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: partial_reduce_select: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetvli a0, zero, e16, m2, ta, ma +; NODOT-NEXT: vsext.vf2 v12, v8 +; NODOT-NEXT: vsext.vf2 v14, v9 +; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; NODOT-NEXT: vmv.v.i v8, 0 +; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, mu +; NODOT-NEXT: vwmul.vv v8, v12, v14, v0.t +; NODOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma +; NODOT-NEXT: vadd.vv v8, v11, v8 +; NODOT-NEXT: vadd.vv v9, v9, v10 +; NODOT-NEXT: vadd.vv v8, v9, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: partial_reduce_select: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetvli a0, zero, e8, m1, ta, ma +; DOT-NEXT: vmv.v.i v10, 0 +; DOT-NEXT: vmerge.vvm v10, v10, v9, v0 +; DOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma +; DOT-NEXT: vmv.v.i v9, 0 +; DOT-NEXT: vqdot.vv v9, v8, v10 +; DOT-NEXT: vmv.v.v v8, v9 +; DOT-NEXT: ret entry: %a.sext = sext %a to %b.sext = sext %b to