Skip to content

Commit

Permalink
[ARM] Combine fadd into fcmla
Browse files Browse the repository at this point in the history
This is the MVE equivalent of https://reviews.llvm.org/D146407. It adds a
target combine for fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), c, d), pushing
the fadd into the operands of the fcmla, which can help simplify away some
additions.

Differential Revision: https://reviews.llvm.org/D147200
  • Loading branch information
davemgreen committed Apr 5, 2023
1 parent a5e1a93 commit b4df2b2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
42 changes: 41 additions & 1 deletion llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -16884,6 +16884,46 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags);
}

static SDValue PerformFADDVCMLACombine(SDNode *N, SelectionDAG &DAG) {
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
EVT VT = N->getValueType(0);
SDLoc DL(N);

if (!N->getFlags().hasAllowReassociation())
return SDValue();

// Combine fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), b, c)
auto ReassocComplex = [&](SDValue A, SDValue B) {
if (A.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
return SDValue();
unsigned Opc = A.getConstantOperandVal(0);
if (Opc != Intrinsic::arm_mve_vcmlaq)
return SDValue();
SDValue VCMLA = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, VT, A.getOperand(0), A.getOperand(1),
DAG.getNode(ISD::FADD, DL, VT, A.getOperand(2), B, N->getFlags()),
A.getOperand(3), A.getOperand(4));
VCMLA->setFlags(A->getFlags());
return VCMLA;
};
if (SDValue R = ReassocComplex(LHS, RHS))
return R;
if (SDValue R = ReassocComplex(RHS, LHS))
return R;

return SDValue();
}

static SDValue PerformFADDCombine(SDNode *N, SelectionDAG &DAG,
const ARMSubtarget *Subtarget) {
if (SDValue S = PerformFAddVSelectCombine(N, DAG, Subtarget))
return S;
if (SDValue S = PerformFADDVCMLACombine(N, DAG))
return S;
return SDValue();
}

/// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
/// can replace combinations of VCVT (integer to floating-point) and VDIV
/// when the VDIV has a constant operand that is a power of 2.
Expand Down Expand Up @@ -18771,7 +18811,7 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_TO_UINT:
return PerformVCVTCombine(N, DCI.DAG, Subtarget);
case ISD::FADD:
return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget);
return PerformFADDCombine(N, DCI.DAG, Subtarget);
case ISD::FDIV:
return PerformVDIVCombine(N, DCI.DAG, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
Expand Down
Expand Up @@ -391,16 +391,16 @@ define <4 x float> @mul_addequal(<4 x float> %a, <4 x float> %b, <4 x float> %c)
; CHECK-LABEL: mul_addequal:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov d0, r0, r1
; CHECK-NEXT: mov r1, sp
; CHECK-NEXT: vldrw.u32 q2, [r1]
; CHECK-NEXT: vmov d1, r2, r3
; CHECK-NEXT: add r0, sp, #16
; CHECK-NEXT: vcmul.f32 q3, q0, q2, #0
; CHECK-NEXT: mov r0, sp
; CHECK-NEXT: add r1, sp, #16
; CHECK-NEXT: vldrw.u32 q1, [r0]
; CHECK-NEXT: vcmla.f32 q3, q0, q2, #90
; CHECK-NEXT: vadd.f32 q0, q3, q1
; CHECK-NEXT: vmov r0, r1, d0
; CHECK-NEXT: vmov r2, r3, d1
; CHECK-NEXT: vmov d1, r2, r3
; CHECK-NEXT: vldrw.u32 q2, [r1]
; CHECK-NEXT: vcmul.f32 q3, q0, q1, #0
; CHECK-NEXT: vadd.f32 q2, q3, q2
; CHECK-NEXT: vcmla.f32 q2, q0, q1, #90
; CHECK-NEXT: vmov r0, r1, d4
; CHECK-NEXT: vmov r2, r3, d5
; CHECK-NEXT: bx lr
entry:
%strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>
Expand Down
16 changes: 4 additions & 12 deletions llvm/test/CodeGen/Thumb2/mve-vcmla.ll
Expand Up @@ -10,9 +10,7 @@ declare <4 x float> @llvm.arm.mve.vcmulq.v4f32(i32, <4 x float>, <4 x float>)
define arm_aapcs_vfpcc <4 x float> @reassoc_f32x4(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
; CHECK-LABEL: reassoc_f32x4:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov.i32 q3, #0x0
; CHECK-NEXT: vcmla.f32 q3, q1, q2, #0
; CHECK-NEXT: vadd.f32 q0, q3, q0
; CHECK-NEXT: vcmla.f32 q0, q1, q2, #0
; CHECK-NEXT: bx lr
entry:
%d = tail call <4 x float> @llvm.arm.mve.vcmlaq.v4f32(i32 0, <4 x float> zeroinitializer, <4 x float> %b, <4 x float> %c)
Expand All @@ -23,9 +21,7 @@ entry:
define arm_aapcs_vfpcc <4 x float> @reassoc_c_f32x4(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
; CHECK-LABEL: reassoc_c_f32x4:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov.i32 q3, #0x0
; CHECK-NEXT: vcmla.f32 q3, q1, q2, #90
; CHECK-NEXT: vadd.f32 q0, q0, q3
; CHECK-NEXT: vcmla.f32 q0, q1, q2, #90
; CHECK-NEXT: bx lr
entry:
%d = tail call <4 x float> @llvm.arm.mve.vcmlaq.v4f32(i32 1, <4 x float> zeroinitializer, <4 x float> %b, <4 x float> %c)
Expand All @@ -36,9 +32,7 @@ entry:
define arm_aapcs_vfpcc <8 x half> @reassoc_f16x4(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
; CHECK-LABEL: reassoc_f16x4:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov.i32 q3, #0x0
; CHECK-NEXT: vcmla.f16 q3, q1, q2, #180
; CHECK-NEXT: vadd.f16 q0, q3, q0
; CHECK-NEXT: vcmla.f16 q0, q1, q2, #180
; CHECK-NEXT: bx lr
entry:
%d = tail call <8 x half> @llvm.arm.mve.vcmlaq.v8f16(i32 2, <8 x half> zeroinitializer, <8 x half> %b, <8 x half> %c)
Expand All @@ -49,9 +43,7 @@ entry:
define arm_aapcs_vfpcc <8 x half> @reassoc_c_f16x4(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
; CHECK-LABEL: reassoc_c_f16x4:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov.i32 q3, #0x0
; CHECK-NEXT: vcmla.f16 q3, q1, q2, #270
; CHECK-NEXT: vadd.f16 q0, q0, q3
; CHECK-NEXT: vcmla.f16 q0, q1, q2, #270
; CHECK-NEXT: bx lr
entry:
%d = tail call <8 x half> @llvm.arm.mve.vcmlaq.v8f16(i32 3, <8 x half> zeroinitializer, <8 x half> %b, <8 x half> %c)
Expand Down

0 comments on commit b4df2b2

Please sign in to comment.