Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,10 @@ LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
/// Build vector implicit truncation is allowed.
LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);

/// Return true if the value is a constant (+/-)0.0 floating-point value or a
/// splatted vector thereof (with no undefs).
LLVM_ABI bool isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs = false);

/// Return true if \p V is either a integer or FP constant.
inline bool isIntOrFPConstant(SDValue V) {
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
Expand Down
68 changes: 54 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13018,22 +13018,34 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
return SDValue();
}

// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, splat(C))
//
// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
// -> partial_reduce_fmla(acc, a, b)
// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
//
// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);

unsigned Opc = Op1->getOpcode();

// Handle predication by moving the SELECT into the operand of the MUL.
SDValue Pred;
if (Opc == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
Pred = Op1->getOperand(0);
Op1 = Op1->getOperand(1);
Opc = Op1->getOpcode();
}

if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();

Expand Down Expand Up @@ -13068,6 +13080,19 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();

// When Pred is non-zero, set Op = select(Pred, Op, splat(0)) and freeze
// OtherOp to keep the same semantics when moving the selects into the MUL
// operands.
auto ApplyPredicate = [&](SDValue &Op, SDValue &OtherOp) {
if (Pred) {
EVT OpVT = Op.getValueType();
SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
: DAG.getConstant(0, DL, OpVT);
Op = DAG.getSelect(DL, OpVT, Pred, Op, Zero);
OtherOp = DAG.getFreeze(OtherOp);
}
};

// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
Expand All @@ -13090,8 +13115,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
ApplyPredicate(C, LHSExtOp);
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
}

unsigned RHSOpcode = RHS->getOpcode();
Expand Down Expand Up @@ -13132,17 +13158,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

ApplyPredicate(RHSExtOp, LHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}

// partial.reduce.umla(acc, zext(op), splat(1))
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
// partial.reduce.smla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.*mla(acc, *ext(op), splat(1))
// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.fmla(acc, fpext(op), splat(1.0))
// -> partial.reduce.fmla(acc, op, splat(1.0))
//
// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
Expand All @@ -13152,7 +13178,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();

SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
if (Op1Opcode == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
Pred = Op1->getOperand(0);
Op1 = Op1->getOperand(1);
Op1Opcode = Op1->getOpcode();
}

if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();

Expand Down Expand Up @@ -13181,6 +13215,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);

if (Pred) {
SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
? DAG.getConstantFP(0, DL, UnextOp1VT)
: DAG.getConstant(0, DL, UnextOp1VT);
Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
}
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
Constant);
}
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12971,6 +12971,11 @@ bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
return C && C->isZero();
}

bool llvm::isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs) {
ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
return C && C->isZero();
}

HandleSDNode::~HandleSDNode() {
DropOperands();
}
Expand Down
159 changes: 159 additions & 0 deletions llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s | FileCheck %s

target triple = "aarch64"

define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
; CHECK-LABEL: predicate_dot_fixed_length:
; CHECK: // %bb.0:
; CHECK-NEXT: shl v1.16b, v1.16b, #7
; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
; CHECK-NEXT: ret
%ext.1 = sext <16 x i8> %a to <16 x i32>
%ext.2 = sext <16 x i8> %b to <16 x i32>
%mul = mul nsw <16 x i32> %ext.1, %ext.2
%sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
%red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
ret <4 x i32> %red
}

define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
; CHECK-LABEL: predicate_dot_by_C_fixed_length:
; CHECK: // %bb.0:
; CHECK-NEXT: shl v1.16b, v1.16b, #7
; CHECK-NEXT: movi v3.16b, #127
; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
; CHECK-NEXT: ret
%ext.1 = sext <16 x i8> %a to <16 x i32>
%mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
%sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
%red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
ret <4 x i32> %red
}

define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
; CHECK-LABEL: predicate_dot_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v3.2d, #0000000000000000
; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
%ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
%sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
%red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
ret <vscale x 4 x i32> %red
}

define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
; CHECK-LABEL: predicate_dot_by_C_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
%ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
%sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
%red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
ret <vscale x 4 x i32> %red
}

define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
; CHECK-LABEL: predicate_ext_mul_fixed_length:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v3.16b, #1
; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
; CHECK-NEXT: ret
%ext = sext <16 x i8> %a to <16 x i32>
%sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
%red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
ret <4 x i32> %red
}

define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
; CHECK-LABEL: predicate_ext_mul_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
%ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
%red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
ret <vscale x 4 x i32> %red
}

define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
; CHECK-LABEL: predicated_fdot_fixed_length:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: shl v1.8h, v1.8h, #15
; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
; CHECK-NEXT: fdot z0.s, z2.h, z1.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%ext.1 = fpext <8 x half> %a to <8 x float>
%ext.2 = fpext <8 x half> %b to <8 x float>
%mul = fmul <8 x float> %ext.1, %ext.2
%sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
%red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
ret <4 x float> %red
}

define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
; CHECK-LABEL: predicated_fdot_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v3.2d, #0000000000000000
; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
%ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
%ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
%mul = fmul <vscale x 8 x float> %ext.1, %ext.2
%sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
%red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
ret <vscale x 4 x float> %red
}

define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: movi v3.8h, #60, lsl #8
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: shl v1.8h, v1.8h, #15
; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
; CHECK-NEXT: fdot z0.s, z2.h, z1.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%ext = fpext <8 x half> %a to <8 x float>
%sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
%red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
ret <4 x float> %red
}

define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
; CHECK-LABEL: predicated_fpext_fmul_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
%ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
%sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
%red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
ret <vscale x 4 x float> %red
}

attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
39 changes: 25 additions & 14 deletions llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
Original file line number Diff line number Diff line change
Expand Up @@ -996,20 +996,31 @@ entry:
}

define <vscale x 2 x i32> @partial_reduce_select(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b, <vscale x 8 x i1> %m) {
; CHECK-LABEL: partial_reduce_select:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
; CHECK-NEXT: vsext.vf2 v12, v8
; CHECK-NEXT: vsext.vf2 v14, v9
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, mu
; CHECK-NEXT: vwmul.vv v8, v12, v14, v0.t
; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v11, v8
; CHECK-NEXT: vadd.vv v9, v9, v10
; CHECK-NEXT: vadd.vv v8, v9, v8
; CHECK-NEXT: ret
; NODOT-LABEL: partial_reduce_select:
; NODOT: # %bb.0: # %entry
; NODOT-NEXT: vsetvli a0, zero, e16, m2, ta, ma
; NODOT-NEXT: vsext.vf2 v12, v8
; NODOT-NEXT: vsext.vf2 v14, v9
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
; NODOT-NEXT: vmv.v.i v8, 0
; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, mu
; NODOT-NEXT: vwmul.vv v8, v12, v14, v0.t
; NODOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma
; NODOT-NEXT: vadd.vv v8, v11, v8
; NODOT-NEXT: vadd.vv v9, v9, v10
; NODOT-NEXT: vadd.vv v8, v9, v8
; NODOT-NEXT: ret
;
; DOT-LABEL: partial_reduce_select:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetvli a0, zero, e8, m1, ta, ma
; DOT-NEXT: vmv.v.i v10, 0
; DOT-NEXT: vmerge.vvm v10, v10, v9, v0
; DOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v9, 0
; DOT-NEXT: vqdot.vv v9, v8, v10
; DOT-NEXT: vmv.v.v v8, v9
; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.sext = sext <vscale x 8 x i8> %b to <vscale x 8 x i32>
Expand Down
Loading