Skip to content

Commit

Permalink
[AArch64] Emit FNMADD instead of FNEG(FMADD)
Browse files Browse the repository at this point in the history
Emit FNMADD instead of FNEG(FMADD) for optimization levels
above Oz when fast-math flags (nsz+contract) permit it.

Differential Revision: https://reviews.llvm.org/D149260
  • Loading branch information
MDevereau committed May 10, 2023
1 parent d526e2e commit 004bf17
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/MachineCombinerPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ enum class MachineCombinerPattern {

// X86 VNNI
DPWSSD,

FNMADD,
};

} // end namespace llvm
Expand Down
81 changes: 81 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5409,6 +5409,39 @@ static bool getFMULPatterns(MachineInstr &Root,
return Found;
}

static bool getFNEGPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
unsigned Opc = Root.getOpcode();
MachineBasicBlock &MBB = *Root.getParent();
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();

auto Match = [&](unsigned Opcode, MachineCombinerPattern Pattern) -> bool {
MachineOperand &MO = Root.getOperand(1);
MachineInstr *MI = MRI.getUniqueVRegDef(MO.getReg());
if (MI != nullptr && MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()) &&
(MI->getOpcode() == Opcode) &&
Root.getFlag(MachineInstr::MIFlag::FmContract) &&
Root.getFlag(MachineInstr::MIFlag::FmNsz) &&
MI->getFlag(MachineInstr::MIFlag::FmContract) &&
MI->getFlag(MachineInstr::MIFlag::FmNsz)) {
Patterns.push_back(Pattern);
return true;
}
return false;
};

switch (Opc) {
default:
break;
case AArch64::FNEGDr:
return Match(AArch64::FMADDDrrr, MachineCombinerPattern::FNMADD);
case AArch64::FNEGSr:
return Match(AArch64::FMADDSrrr, MachineCombinerPattern::FNMADD);
}

return false;
}

/// Return true when a code sequence can improve throughput. It
/// should be called only for instructions in loops.
/// \param Pattern - combiner pattern
Expand Down Expand Up @@ -5578,6 +5611,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
return true;
if (getFMAPatterns(Root, Patterns))
return true;
if (getFNEGPatterns(Root, Patterns))
return true;

// Other patterns
if (getMiscPatterns(Root, Patterns))
Expand Down Expand Up @@ -5668,6 +5703,47 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI,
return MUL;
}

static MachineInstr *
genFNegatedMAD(MachineFunction &MF, MachineRegisterInfo &MRI,
const TargetInstrInfo *TII, MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs) {
MachineInstr *MAD = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());

unsigned Opc = 0;
const TargetRegisterClass *RC = MRI.getRegClass(MAD->getOperand(0).getReg());
if (AArch64::FPR32RegClass.hasSubClassEq(RC))
Opc = AArch64::FNMADDSrrr;
else if (AArch64::FPR64RegClass.hasSubClassEq(RC))
Opc = AArch64::FNMADDDrrr;
else
return nullptr;

Register ResultReg = Root.getOperand(0).getReg();
Register SrcReg0 = MAD->getOperand(1).getReg();
Register SrcReg1 = MAD->getOperand(2).getReg();
Register SrcReg2 = MAD->getOperand(3).getReg();
bool Src0IsKill = MAD->getOperand(1).isKill();
bool Src1IsKill = MAD->getOperand(2).isKill();
bool Src2IsKill = MAD->getOperand(3).isKill();
if (ResultReg.isVirtual())
MRI.constrainRegClass(ResultReg, RC);
if (SrcReg0.isVirtual())
MRI.constrainRegClass(SrcReg0, RC);
if (SrcReg1.isVirtual())
MRI.constrainRegClass(SrcReg1, RC);
if (SrcReg2.isVirtual())
MRI.constrainRegClass(SrcReg2, RC);

MachineInstrBuilder MIB =
BuildMI(MF, MIMetadata(Root), TII->get(Opc), ResultReg)
.addReg(SrcReg0, getKillRegState(Src0IsKill))
.addReg(SrcReg1, getKillRegState(Src1IsKill))
.addReg(SrcReg2, getKillRegState(Src2IsKill));
InsInstrs.push_back(MIB);

return MAD;
}

/// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane)
static MachineInstr *
genIndexedMultiply(MachineInstr &Root,
Expand Down Expand Up @@ -6800,6 +6876,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
&AArch64::FPR128_loRegClass, MRI);
break;
}
case MachineCombinerPattern::FNMADD: {
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
break;
}

} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion
if (MUL)
Expand Down
153 changes: 153 additions & 0 deletions llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 -verify-machineinstrs | FileCheck %s

define void @fnmaddd(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: fnmaddd:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr d0, [x1]
; CHECK-NEXT: ldr d1, [x0]
; CHECK-NEXT: ldr d2, [x2]
; CHECK-NEXT: fnmadd d0, d0, d1, d2
; CHECK-NEXT: str d0, [x0]
; CHECK-NEXT: ret
entry:
%0 = load double, ptr %a, align 8
%1 = load double, ptr %b, align 8
%mul = fmul fast double %1, %0
%2 = load double, ptr %c, align 8
%add = fadd fast double %mul, %2
%fneg = fneg fast double %add
store double %fneg, ptr %a, align 8
ret void
}

; Don't combine: No flags
define void @fnmaddd_no_fast(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: fnmaddd_no_fast:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr d0, [x0]
; CHECK-NEXT: ldr d1, [x1]
; CHECK-NEXT: fmul d0, d1, d0
; CHECK-NEXT: ldr d1, [x2]
; CHECK-NEXT: fadd d0, d0, d1
; CHECK-NEXT: fneg d0, d0
; CHECK-NEXT: str d0, [x0]
; CHECK-NEXT: ret
entry:
%0 = load double, ptr %a, align 8
%1 = load double, ptr %b, align 8
%mul = fmul double %1, %0
%2 = load double, ptr %c, align 8
%add = fadd double %mul, %2
%fneg = fneg double %add
store double %fneg, ptr %a, align 8
ret void
}

define void @fnmadds(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: fnmadds:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr s0, [x1]
; CHECK-NEXT: ldr s1, [x0]
; CHECK-NEXT: ldr s2, [x2]
; CHECK-NEXT: fnmadd s0, s0, s1, s2
; CHECK-NEXT: str s0, [x0]
; CHECK-NEXT: ret
entry:
%0 = load float, ptr %a, align 4
%1 = load float, ptr %b, align 4
%mul = fmul fast float %1, %0
%2 = load float, ptr %c, align 4
%add = fadd fast float %mul, %2
%fneg = fneg fast float %add
store float %fneg, ptr %a, align 4
ret void
}

define void @fnmadds_nsz_contract(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: fnmadds_nsz_contract:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr s0, [x1]
; CHECK-NEXT: ldr s1, [x0]
; CHECK-NEXT: ldr s2, [x2]
; CHECK-NEXT: fnmadd s0, s0, s1, s2
; CHECK-NEXT: str s0, [x0]
; CHECK-NEXT: ret
entry:
%0 = load float, ptr %a, align 4
%1 = load float, ptr %b, align 4
%mul = fmul contract nsz float %1, %0
%2 = load float, ptr %c, align 4
%add = fadd contract nsz float %mul, %2
%fneg = fneg contract nsz float %add
store float %fneg, ptr %a, align 4
ret void
}

; Don't combine: Missing nsz
define void @fnmadds_contract(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: fnmadds_contract:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr s0, [x1]
; CHECK-NEXT: ldr s1, [x0]
; CHECK-NEXT: ldr s2, [x2]
; CHECK-NEXT: fmadd s0, s0, s1, s2
; CHECK-NEXT: fneg s0, s0
; CHECK-NEXT: str s0, [x0]
; CHECK-NEXT: ret
entry:
%0 = load float, ptr %a, align 4
%1 = load float, ptr %b, align 4
%mul = fmul contract float %1, %0
%2 = load float, ptr %c, align 4
%add = fadd contract float %mul, %2
%fneg = fneg contract float %add
store float %fneg, ptr %a, align 4
ret void
}

; Don't combine: Missing contract
define void @fnmadds_nsz(ptr %a, ptr %b, ptr %c) {
; CHECK-LABEL: fnmadds_nsz:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr s0, [x0]
; CHECK-NEXT: ldr s1, [x1]
; CHECK-NEXT: fmul s0, s1, s0
; CHECK-NEXT: ldr s1, [x2]
; CHECK-NEXT: fadd s0, s0, s1
; CHECK-NEXT: fneg s0, s0
; CHECK-NEXT: str s0, [x0]
; CHECK-NEXT: ret
entry:
%0 = load float, ptr %a, align 4
%1 = load float, ptr %b, align 4
%mul = fmul nsz float %1, %0
%2 = load float, ptr %c, align 4
%add = fadd nsz float %mul, %2
%fneg = fneg nsz float %add
store float %fneg, ptr %a, align 4
ret void
}

define void @fnmaddd_two_uses(ptr %a, ptr %b, ptr %c, ptr %d) {
; CHECK-LABEL: fnmaddd_two_uses:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr d0, [x1]
; CHECK-NEXT: ldr d1, [x0]
; CHECK-NEXT: ldr d2, [x2]
; CHECK-NEXT: fmadd d0, d0, d1, d2
; CHECK-NEXT: fneg d1, d0
; CHECK-NEXT: str d1, [x0]
; CHECK-NEXT: str d0, [x3]
; CHECK-NEXT: ret
entry:
%0 = load double, ptr %a, align 8
%1 = load double, ptr %b, align 8
%mul = fmul fast double %1, %0
%2 = load double, ptr %c, align 8
%add = fadd fast double %mul, %2
%fneg1 = fneg fast double %add
store double %fneg1, ptr %a, align 8
store double %add, ptr %d, align 8
ret void
}

0 comments on commit 004bf17

Please sign in to comment.