Skip to content

Commit

Permalink
[MachineCombiner][RISCV] Add fmadd/fmsub/fnmsub instructions patterns
Browse files Browse the repository at this point in the history
This patch adds tranformation of fmul+fadd/fsub chains to fused multiply
instructions:
  * fmul+fadd->fmadd
  * fmul+fsub->fmsub/fnmsub

We also will try to combine these instructions if the fmul has more than one use
and cannot be deleted. However, removing the dependence between fmul and fadd can
still be profitable, and we rely on machine combiner approximations of scheduling.

Differential Revision: https://reviews.llvm.org/D136764
  • Loading branch information
asi-sc committed Nov 17, 2022
1 parent 50f8eb0 commit b6c7907
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 15 deletions.
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/MachineCombinerPattern.h
Expand Up @@ -169,6 +169,12 @@ enum class MachineCombinerPattern {
FMULv4i32_indexed_OP2,
FMULv8i16_indexed_OP1,
FMULv8i16_indexed_OP2,

// RISCV FMADD, FMSUB, FNMSUB patterns
FMADD_AX,
FMADD_XA,
FMSUB,
FNMSUB,
};

} // end namespace llvm
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/MachineCombiner.cpp
Expand Up @@ -319,6 +319,10 @@ static CombinerObjective getCombinerObjective(MachineCombinerPattern P) {
case MachineCombinerPattern::REASSOC_XMM_AMM_BMM:
case MachineCombinerPattern::SUBADD_OP1:
case MachineCombinerPattern::SUBADD_OP2:
case MachineCombinerPattern::FMADD_AX:
case MachineCombinerPattern::FMADD_XA:
case MachineCombinerPattern::FMSUB:
case MachineCombinerPattern::FNMSUB:
return CombinerObjective::MustReduceDepth;
case MachineCombinerPattern::REASSOC_XY_BCA:
case MachineCombinerPattern::REASSOC_XY_BAC:
Expand Down
174 changes: 168 additions & 6 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Expand Up @@ -26,6 +26,7 @@
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegisterScavenging.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/MC/MCInstBuilder.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -1176,6 +1177,17 @@ static bool isFADD(unsigned Opc) {
}
}

static bool isFSUB(unsigned Opc) {
switch (Opc) {
default:
return false;
case RISCV::FSUB_H:
case RISCV::FSUB_S:
case RISCV::FSUB_D:
return true;
}
}

static bool isFMUL(unsigned Opc) {
switch (Opc) {
default:
Expand Down Expand Up @@ -1211,6 +1223,33 @@ static bool canReassociate(MachineInstr &Root, MachineOperand &MO) {
return RISCV::hasEqualFRM(Root, *MI);
}

static bool canCombineFPFusedMultiply(const MachineInstr &Root,
const MachineOperand &MO,
bool DoRegPressureReduce) {
if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
return false;
const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
MachineInstr *MI = MRI.getVRegDef(MO.getReg());
if (!MI || !isFMUL(MI->getOpcode()))
return false;

if (!Root.getFlag(MachineInstr::MIFlag::FmContract) ||
!MI->getFlag(MachineInstr::MIFlag::FmContract))
return false;

// Try combining even if fmul has more than one use as it eliminates
// dependency between fadd(fsub) and fmul. However, it can extend liveranges
// for fmul operands, so reject the transformation in register pressure
// reduction mode.
if (DoRegPressureReduce && !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
return false;

// Do not combine instructions from different basic blocks.
if (Root.getParent() != MI->getParent())
return false;
return RISCV::hasEqualFRM(Root, *MI);
}

static bool
getFPReassocPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
Expand All @@ -1228,25 +1267,148 @@ getFPReassocPatterns(MachineInstr &Root,
return Added;
}

static bool getFPPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
static bool
getFPFusedMultiplyPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns,
bool DoRegPressureReduce) {
unsigned Opc = Root.getOpcode();
if (isAssociativeAndCommutativeFPOpcode(Opc))
return getFPReassocPatterns(Root, Patterns);
return false;
bool IsFAdd = isFADD(Opc);
if (!IsFAdd && !isFSUB(Opc))
return false;
bool Added = false;
if (canCombineFPFusedMultiply(Root, Root.getOperand(1),
DoRegPressureReduce)) {
Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_AX
: MachineCombinerPattern::FMSUB);
Added = true;
}
if (canCombineFPFusedMultiply(Root, Root.getOperand(2),
DoRegPressureReduce)) {
Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_XA
: MachineCombinerPattern::FNMSUB);
Added = true;
}
return Added;
}

static bool getFPPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns,
bool DoRegPressureReduce) {
bool Added = getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
if (isAssociativeAndCommutativeFPOpcode(Root.getOpcode()))
Added |= getFPReassocPatterns(Root, Patterns);
return Added;
}

bool RISCVInstrInfo::getMachineCombinerPatterns(
MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
bool DoRegPressureReduce) const {

if (getFPPatterns(Root, Patterns))
if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
return true;

return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
DoRegPressureReduce);
}

static unsigned getFPFusedMultiplyOpcode(unsigned RootOpc,
MachineCombinerPattern Pattern) {
switch (RootOpc) {
default:
llvm_unreachable("Unexpected opcode");
case RISCV::FADD_H:
return RISCV::FMADD_H;
case RISCV::FADD_S:
return RISCV::FMADD_S;
case RISCV::FADD_D:
return RISCV::FMADD_D;
case RISCV::FSUB_H:
return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_H
: RISCV::FNMSUB_H;
case RISCV::FSUB_S:
return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_S
: RISCV::FNMSUB_S;
case RISCV::FSUB_D:
return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_D
: RISCV::FNMSUB_D;
}
}

static unsigned getAddendOperandIdx(MachineCombinerPattern Pattern) {
switch (Pattern) {
default:
llvm_unreachable("Unexpected pattern");
case MachineCombinerPattern::FMADD_AX:
case MachineCombinerPattern::FMSUB:
return 2;
case MachineCombinerPattern::FMADD_XA:
case MachineCombinerPattern::FNMSUB:
return 1;
}
}

static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
MachineCombinerPattern Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs) {
MachineFunction *MF = Root.getMF();
MachineRegisterInfo &MRI = MF->getRegInfo();
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();

MachineOperand &Mul1 = Prev.getOperand(1);
MachineOperand &Mul2 = Prev.getOperand(2);
MachineOperand &Dst = Root.getOperand(0);
MachineOperand &Addend = Root.getOperand(getAddendOperandIdx(Pattern));

Register DstReg = Dst.getReg();
unsigned FusedOpc = getFPFusedMultiplyOpcode(Root.getOpcode(), Pattern);
auto IntersectedFlags = Root.getFlags() & Prev.getFlags();
DebugLoc MergedLoc =
DILocation::getMergedLocation(Root.getDebugLoc(), Prev.getDebugLoc());

MachineInstrBuilder MIB =
BuildMI(*MF, MergedLoc, TII->get(FusedOpc), DstReg)
.addReg(Mul1.getReg(), getKillRegState(Mul1.isKill()))
.addReg(Mul2.getReg(), getKillRegState(Mul2.isKill()))
.addReg(Addend.getReg(), getKillRegState(Addend.isKill()))
.setMIFlags(IntersectedFlags);

// Mul operands are not killed anymore.
Mul1.setIsKill(false);
Mul2.setIsKill(false);

InsInstrs.push_back(MIB);
if (MRI.hasOneNonDBGUse(Prev.getOperand(0).getReg()))
DelInstrs.push_back(&Prev);
DelInstrs.push_back(&Root);
}

void RISCVInstrInfo::genAlternativeCodeSequence(
MachineInstr &Root, MachineCombinerPattern Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
switch (Pattern) {
default:
TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs,
DelInstrs, InstrIdxForVirtReg);
return;
case MachineCombinerPattern::FMADD_AX:
case MachineCombinerPattern::FMSUB: {
MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(1).getReg());
combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
return;
}
case MachineCombinerPattern::FMADD_XA:
case MachineCombinerPattern::FNMSUB: {
MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(2).getReg());
combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
return;
}
}
}

bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
StringRef &ErrInfo) const {
MCInstrDesc const &Desc = MI.getDesc();
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Expand Up @@ -196,6 +196,12 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
finalizeInsInstrs(MachineInstr &Root, MachineCombinerPattern &P,
SmallVectorImpl<MachineInstr *> &InsInstrs) const override;

void genAlternativeCodeSequence(
MachineInstr &Root, MachineCombinerPattern Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;

protected:
const RISCVSubtarget &STI;
};
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/RISCV/machine-combiner-mir.ll
Expand Up @@ -95,8 +95,8 @@ define double @test_fmadd(double %a0, double %a1, double %a2) {
; CHECK-NEXT: [[COPY1:%[0-9]+]]:fpr64 = COPY $f11_d
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fpr64 = COPY $f10_d
; CHECK-NEXT: [[FMUL_D:%[0-9]+]]:fpr64 = contract nofpexcept FMUL_D [[COPY2]], [[COPY1]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D:%[0-9]+]]:fpr64 = contract nofpexcept FADD_D [[FMUL_D]], [[COPY]], 7, implicit $frm
; CHECK-NEXT: [[FDIV_D:%[0-9]+]]:fpr64 = nofpexcept FDIV_D killed [[FADD_D]], [[FMUL_D]], 7, implicit $frm
; CHECK-NEXT: [[FMADD_D:%[0-9]+]]:fpr64 = contract nofpexcept FMADD_D [[COPY2]], [[COPY1]], [[COPY]], 7, implicit $frm
; CHECK-NEXT: [[FDIV_D:%[0-9]+]]:fpr64 = nofpexcept FDIV_D killed [[FMADD_D]], [[FMUL_D]], 7, implicit $frm
; CHECK-NEXT: $f10_d = COPY [[FDIV_D]]
; CHECK-NEXT: PseudoRET implicit $f10_d
%t0 = fmul contract double %a0, %a1
Expand Down
13 changes: 6 additions & 7 deletions llvm/test/CodeGen/RISCV/machine-combiner.ll
Expand Up @@ -188,10 +188,9 @@ define double @test_reassoc_fadd_flags_2(double %a0, double %a1, double %a2, dou
define double @test_fmadd1(double %a0, double %a1, double %a2, double %a3) {
; CHECK-LABEL: test_fmadd1:
; CHECK: # %bb.0:
; CHECK-NEXT: fmul.d ft0, fa0, fa1
; CHECK-NEXT: fadd.d ft1, ft0, fa2
; CHECK-NEXT: fadd.d ft0, fa3, ft0
; CHECK-NEXT: fadd.d fa0, ft1, ft0
; CHECK-NEXT: fmadd.d ft0, fa0, fa1, fa2
; CHECK-NEXT: fmadd.d ft1, fa0, fa1, fa3
; CHECK-NEXT: fadd.d fa0, ft0, ft1
; CHECK-NEXT: ret
%t0 = fmul contract double %a0, %a1
%t1 = fadd contract double %t0, %a2
Expand All @@ -204,7 +203,7 @@ define double @test_fmadd2(double %a0, double %a1, double %a2) {
; CHECK-LABEL: test_fmadd2:
; CHECK: # %bb.0:
; CHECK-NEXT: fmul.d ft0, fa0, fa1
; CHECK-NEXT: fadd.d ft1, ft0, fa2
; CHECK-NEXT: fmadd.d ft1, fa0, fa1, fa2
; CHECK-NEXT: fdiv.d fa0, ft1, ft0
; CHECK-NEXT: ret
%t0 = fmul contract double %a0, %a1
Expand All @@ -217,7 +216,7 @@ define double @test_fmsub(double %a0, double %a1, double %a2) {
; CHECK-LABEL: test_fmsub:
; CHECK: # %bb.0:
; CHECK-NEXT: fmul.d ft0, fa0, fa1
; CHECK-NEXT: fsub.d ft1, ft0, fa2
; CHECK-NEXT: fmsub.d ft1, fa0, fa1, fa2
; CHECK-NEXT: fdiv.d fa0, ft1, ft0
; CHECK-NEXT: ret
%t0 = fmul contract double %a0, %a1
Expand All @@ -230,7 +229,7 @@ define double @test_fnmsub(double %a0, double %a1, double %a2) {
; CHECK-LABEL: test_fnmsub:
; CHECK: # %bb.0:
; CHECK-NEXT: fmul.d ft0, fa0, fa1
; CHECK-NEXT: fsub.d ft1, fa2, ft0
; CHECK-NEXT: fnmsub.d ft1, fa0, fa1, fa2
; CHECK-NEXT: fdiv.d fa0, ft1, ft0
; CHECK-NEXT: ret
%t0 = fmul contract double %a0, %a1
Expand Down

0 comments on commit b6c7907

Please sign in to comment.