-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAGCombiner] Fold select into partial.reduce.add operands. #167857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAGCombiner] Fold select into partial.reduce.add operands. #167857
Conversation
|
@llvm/pr-subscribers-backend-risc-v @llvm/pr-subscribers-llvm-selectiondag Author: Sander de Smalen (sdesmalen-arm) ChangesThis generates more optimal codegen when using partial reductions with predication. Full diff: https://github.com/llvm/llvm-project/pull/167857.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..c21b79225323e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
return SDValue();
}
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
//
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
-
unsigned Opc = Op1->getOpcode();
+
+ // Handle predication by moving the SELECT into the operand of the MUL.
+ SDValue Pred;
+ if (Opc == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Opc = Op1->getOpcode();
+ }
+
if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
@@ -13068,6 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ // Return 'select(P, Op, splat(0))' if P is nonzero,
+ // or 'P' otherwise.
+ auto tryPredicate = [&](SDValue P, SDValue Op) {
+ if (!P)
+ return Op;
+ EVT OpVT = Op.getValueType();
+ SDValue Zero = OpVT.isFloatingPoint()
+ ? DAG.getConstantFP(0.0, DL, OpVT)
+ : DAG.getConstant(0, DL, OpVT);
+ return DAG.getSelect(DL, OpVT, P, Op, Zero);
+ };
+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
@@ -13090,8 +13117,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ SDValue Constant =
+ tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ Constant);
}
unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13161,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ RHSExtOp = tryPredicate(Pred, RHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -13152,7 +13181,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
+ SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
+ if (Op1Opcode == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Op1Opcode = Op1->getOpcode();
+ }
+
if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
@@ -13181,6 +13221,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);
+ if (Pred) {
+ SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(0, DL, UnextOp1VT)
+ : DAG.getConstant(0, DL, UnextOp1VT);
+ Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+ }
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
Constant);
}
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: movi v3.16b, #127
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.16b, #1
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: movi v3.8h, #60, lsl #8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
|
|
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesThis generates more optimal codegen when using partial reductions with predication. Full diff: https://github.com/llvm/llvm-project/pull/167857.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..c21b79225323e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
return SDValue();
}
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
//
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
-
unsigned Opc = Op1->getOpcode();
+
+ // Handle predication by moving the SELECT into the operand of the MUL.
+ SDValue Pred;
+ if (Opc == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Opc = Op1->getOpcode();
+ }
+
if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
@@ -13068,6 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ // Return 'select(P, Op, splat(0))' if P is nonzero,
+ // or 'P' otherwise.
+ auto tryPredicate = [&](SDValue P, SDValue Op) {
+ if (!P)
+ return Op;
+ EVT OpVT = Op.getValueType();
+ SDValue Zero = OpVT.isFloatingPoint()
+ ? DAG.getConstantFP(0.0, DL, OpVT)
+ : DAG.getConstant(0, DL, OpVT);
+ return DAG.getSelect(DL, OpVT, P, Op, Zero);
+ };
+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
@@ -13090,8 +13117,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ SDValue Constant =
+ tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ Constant);
}
unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13161,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ RHSExtOp = tryPredicate(Pred, RHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -13152,7 +13181,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
+ SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
+ if (Op1Opcode == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Op1Opcode = Op1->getOpcode();
+ }
+
if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
@@ -13181,6 +13221,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);
+ if (Pred) {
+ SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(0, DL, UnextOp1VT)
+ : DAG.getConstant(0, DL, UnextOp1VT);
+ Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+ }
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
Constant);
}
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: movi v3.16b, #127
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.16b, #1
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: movi v3.8h, #60, lsl #8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
14e935c to
a126fe7
Compare
RISC-V test coverage for #167857
|
I've precommitted a RISC-V test case here 2a53949, can you rebase this PR on top of it? |
This generates more optimal codegen when using partial reductions with predication. partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1)) -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b) partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1)) -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
a126fe7 to
7506ebc
Compare
| if (Opc == ISD::VSELECT) { | ||
| APInt C; | ||
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | ||
| !C.isZero()) | ||
| return SDValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (Opc == ISD::VSELECT) { | |
| APInt C; | |
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | |
| !C.isZero()) | |
| return SDValue(); | |
| if (Opc == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) { |
| if (Op1Opcode == ISD::VSELECT) { | ||
| APInt C; | ||
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | ||
| !C.isZero()) | ||
| return SDValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (Op1Opcode == ISD::VSELECT) { | |
| APInt C; | |
| if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) || | |
| !C.isZero()) | |
| return SDValue(); | |
| if (Op1Opcode == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems isNullOrNullSplat does not recognise FP constants, so I've added a similar FP variant (similar to what was done for isOneOrOneSplatFP)
…-partial-reduce-add' into predicated-partial-reduce-add
lukel97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
fhahn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
This generates more optimal codegen when using partial reductions with predication.