Skip to content

Commit

Permalink
[MachineCombiner][RISCV] Enable MachineCombiner for RISCV
Browse files Browse the repository at this point in the history
Initial implementation to match basic FP reassociation patterns.

Differential Revision: https://reviews.llvm.org/D135264
  • Loading branch information
asi-sc authored and anton-afanasyev committed Oct 18, 2022
1 parent f12fb91 commit 1978b4d
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 29 deletions.
134 changes: 134 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Expand Up @@ -21,6 +21,7 @@
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/LiveVariables.h"
#include "llvm/CodeGen/MachineCombinerPattern.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
Expand Down Expand Up @@ -1125,6 +1126,127 @@ RISCVInstrInfo::isCopyInstrImpl(const MachineInstr &MI) const {
return None;
}

void RISCVInstrInfo::setSpecialOperandAttr(MachineInstr &OldMI1,
MachineInstr &OldMI2,
MachineInstr &NewMI1,
MachineInstr &NewMI2) const {
uint16_t IntersectedFlags = OldMI1.getFlags() & OldMI2.getFlags();
NewMI1.setFlags(IntersectedFlags);
NewMI2.setFlags(IntersectedFlags);
}

void RISCVInstrInfo::finalizeInsInstrs(
MachineInstr &Root, MachineCombinerPattern &P,
SmallVectorImpl<MachineInstr *> &InsInstrs) const {
int16_t FrmOpIdx =
RISCV::getNamedOperandIdx(Root.getOpcode(), RISCV::OpName::frm);
if (FrmOpIdx < 0) {
assert(all_of(InsInstrs,
[](MachineInstr *MI) {
return RISCV::getNamedOperandIdx(MI->getOpcode(),
RISCV::OpName::frm) < 0;
}) &&
"New instructions require FRM whereas the old one does not have it");
return;
}

const MachineOperand &FRM = Root.getOperand(FrmOpIdx);
MachineFunction &MF = *Root.getMF();

for (auto *NewMI : InsInstrs) {
assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
NewMI->getOpcode(), RISCV::OpName::frm)) ==
NewMI->getNumOperands() &&
"Instruction has unexpected number of operands");
MachineInstrBuilder MIB(MF, NewMI);
MIB.add(FRM);
if (FRM.getImm() == RISCVFPRndMode::DYN)
MIB.addUse(RISCV::FRM, RegState::Implicit);
}
}

static bool isFADD(unsigned Opc) {
switch (Opc) {
default:
return false;
case RISCV::FADD_H:
case RISCV::FADD_S:
case RISCV::FADD_D:
return true;
}
}

static bool isFMUL(unsigned Opc) {
switch (Opc) {
default:
return false;
case RISCV::FMUL_H:
case RISCV::FMUL_S:
case RISCV::FMUL_D:
return true;
}
}

static bool isAssociativeAndCommutativeFPOpcode(unsigned Opc) {
return isFADD(Opc) || isFMUL(Opc);
}

static bool canReassociate(MachineInstr &Root, MachineOperand &MO) {
if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
return false;
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
MachineInstr *MI = MRI.getVRegDef(MO.getReg());
if (!MI || !MRI.hasOneNonDBGUse(MO.getReg()))
return false;

if (MI->getOpcode() != Root.getOpcode())
return false;

if (!Root.getFlag(MachineInstr::MIFlag::FmReassoc) ||
!Root.getFlag(MachineInstr::MIFlag::FmNsz) ||
!MI->getFlag(MachineInstr::MIFlag::FmReassoc) ||
!MI->getFlag(MachineInstr::MIFlag::FmNsz))
return false;

return RISCV::hasEqualFRM(Root, *MI);
}

static bool
getFPReassocPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
bool Added = false;
if (canReassociate(Root, Root.getOperand(1))) {
Patterns.push_back(MachineCombinerPattern::REASSOC_AX_BY);
Patterns.push_back(MachineCombinerPattern::REASSOC_XA_BY);
Added = true;
}
if (canReassociate(Root, Root.getOperand(2))) {
Patterns.push_back(MachineCombinerPattern::REASSOC_AX_YB);
Patterns.push_back(MachineCombinerPattern::REASSOC_XA_YB);
Added = true;
}
return Added;
}

static bool getFPPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
unsigned Opc = Root.getOpcode();
if (isAssociativeAndCommutativeFPOpcode(Opc))
return getFPReassocPatterns(Root, Patterns);
return false;
}

bool RISCVInstrInfo::getMachineCombinerPatterns(
MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
bool DoRegPressureReduce) const {

if (getFPPatterns(Root, Patterns))
return true;

return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
DoRegPressureReduce);
}

bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
StringRef &ErrInfo) const {
MCInstrDesc const &Desc = MI.getDesc();
Expand Down Expand Up @@ -2164,3 +2286,15 @@ bool RISCV::isFaultFirstLoad(const MachineInstr &MI) {
return MI.getNumExplicitDefs() == 2 && MI.modifiesRegister(RISCV::VL) &&
!MI.isInlineAsm();
}

bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) {
int16_t MI1FrmOpIdx =
RISCV::getNamedOperandIdx(MI1.getOpcode(), RISCV::OpName::frm);
int16_t MI2FrmOpIdx =
RISCV::getNamedOperandIdx(MI2.getOpcode(), RISCV::OpName::frm);
if (MI1FrmOpIdx < 0 || MI2FrmOpIdx < 0)
return false;
MachineOperand FrmOp1 = MI1.getOperand(MI1FrmOpIdx);
MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx);
return FrmOp1.getImm() == FrmOp2.getImm();
}
18 changes: 18 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Expand Up @@ -182,6 +182,20 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
MachineBasicBlock::iterator II, const DebugLoc &DL, Register DestReg,
int64_t Amount, MachineInstr::MIFlag Flag = MachineInstr::NoFlags) const;

bool useMachineCombiner() const override { return true; }

void setSpecialOperandAttr(MachineInstr &OldMI1, MachineInstr &OldMI2,
MachineInstr &NewMI1,
MachineInstr &NewMI2) const override;
bool
getMachineCombinerPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns,
bool DoRegPressureReduce) const override;

void
finalizeInsInstrs(MachineInstr &Root, MachineCombinerPattern &P,
SmallVectorImpl<MachineInstr *> &InsInstrs) const override;

protected:
const RISCVSubtarget &STI;
};
Expand All @@ -204,6 +218,10 @@ bool isFaultFirstLoad(const MachineInstr &MI);
// Implemented in RISCVGenInstrInfo.inc
int16_t getNamedOperandIdx(uint16_t Opcode, uint16_t NamedIndex);

// Return true if both input instructions have equal rounding mode. If at least
// one of the instructions does not have rounding mode, false will be returned.
bool hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2);

// Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
static constexpr int64_t VLMaxSentinel = -1LL;
} // namespace RISCV
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
Expand Up @@ -47,6 +47,11 @@ static cl::opt<cl::boolOrDefault>
EnableGlobalMerge("riscv-enable-global-merge", cl::Hidden,
cl::desc("Enable the global merge pass"));

static cl::opt<bool>
EnableMachineCombiner("riscv-enable-machine-combiner",
cl::desc("Enable the machine combiner pass"),
cl::init(true), cl::Hidden);

extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
RegisterTargetMachine<RISCVTargetMachine> X(getTheRISCV32Target());
RegisterTargetMachine<RISCVTargetMachine> Y(getTheRISCV64Target());
Expand Down Expand Up @@ -263,6 +268,8 @@ void RISCVPassConfig::addPreEmitPass2() {

void RISCVPassConfig::addMachineSSAOptimization() {
TargetPassConfig::addMachineSSAOptimization();
if (TM->getOptLevel() == CodeGenOpt::Aggressive && EnableMachineCombiner)
addPass(&MachineCombinerID);

if (TM->getTargetTriple().getArch() == Triple::riscv64)
addPass(createRISCVSExtWRemovalPass());
Expand Down
3 changes: 3 additions & 0 deletions llvm/test/CodeGen/RISCV/O3-pipeline.ll
Expand Up @@ -97,6 +97,9 @@
; CHECK-NEXT: Machine code sinking
; CHECK-NEXT: Peephole Optimizations
; CHECK-NEXT: Remove dead machine instructions
; CHECK-NEXT: Machine Trace Metrics
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
; CHECK-NEXT: Machine InstCombiner
; RV64-NEXT: RISCV sext.w Removal
; CHECK-NEXT: RISCV Pre-RA pseudo instruction expansion pass
; CHECK-NEXT: RISCV Merge Base Offset
Expand Down
86 changes: 86 additions & 0 deletions llvm/test/CodeGen/RISCV/machine-combiner-mir.ll
@@ -0,0 +1,86 @@
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
; RUN: llc -mtriple=riscv64 -mattr=+d -verify-machineinstrs -mcpu=sifive-u74 \
; RUN: -O3 -enable-unsafe-fp-math -riscv-enable-machine-combiner=true \
; RUN: -stop-after machine-combiner < %s | FileCheck %s

define double @test_reassoc_fadd1(double %a0, double %a1, double %a2, double %a3) {
; CHECK-LABEL: name: test_reassoc_fadd1
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $f10_d, $f11_d, $f12_d, $f13_d
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:fpr64 = COPY $f13_d
; CHECK-NEXT: [[COPY1:%[0-9]+]]:fpr64 = COPY $f12_d
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fpr64 = COPY $f11_d
; CHECK-NEXT: [[COPY3:%[0-9]+]]:fpr64 = COPY $f10_d
; CHECK-NEXT: [[FADD_D:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D [[COPY3]], [[COPY2]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D1:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D [[COPY1]], [[COPY]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D2:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D killed [[FADD_D]], killed [[FADD_D1]], 7, implicit $frm
; CHECK-NEXT: $f10_d = COPY [[FADD_D2]]
; CHECK-NEXT: PseudoRET implicit $f10_d
%t0 = fadd nsz reassoc double %a0, %a1
%t1 = fadd nsz reassoc double %t0, %a2
%t2 = fadd nsz reassoc double %t1, %a3
ret double %t2
}

define double @test_reassoc_fmul1(double %a0, double %a1, double %a2, double %a3) {
; CHECK-LABEL: name: test_reassoc_fmul1
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $f10_d, $f11_d, $f12_d, $f13_d
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:fpr64 = COPY $f13_d
; CHECK-NEXT: [[COPY1:%[0-9]+]]:fpr64 = COPY $f12_d
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fpr64 = COPY $f11_d
; CHECK-NEXT: [[COPY3:%[0-9]+]]:fpr64 = COPY $f10_d
; CHECK-NEXT: [[FMUL_D:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FMUL_D [[COPY3]], [[COPY2]], 7, implicit $frm
; CHECK-NEXT: [[FMUL_D1:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FMUL_D [[COPY1]], [[COPY]], 7, implicit $frm
; CHECK-NEXT: [[FMUL_D2:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FMUL_D killed [[FMUL_D]], killed [[FMUL_D1]], 7, implicit $frm
; CHECK-NEXT: $f10_d = COPY [[FMUL_D2]]
; CHECK-NEXT: PseudoRET implicit $f10_d
%t0 = fmul nsz reassoc double %a0, %a1
%t1 = fmul nsz reassoc double %t0, %a2
%t2 = fmul nsz reassoc double %t1, %a3
ret double %t2
}

; Verify flags intersection
define double @test_reassoc_flags1(double %a0, double %a1, double %a2, double %a3) {
; CHECK-LABEL: name: test_reassoc_flags1
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $f10_d, $f11_d, $f12_d, $f13_d
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:fpr64 = COPY $f13_d
; CHECK-NEXT: [[COPY1:%[0-9]+]]:fpr64 = COPY $f12_d
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fpr64 = COPY $f11_d
; CHECK-NEXT: [[COPY3:%[0-9]+]]:fpr64 = COPY $f10_d
; CHECK-NEXT: [[FADD_D:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D [[COPY3]], [[COPY2]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D1:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D [[COPY1]], [[COPY]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D2:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D killed [[FADD_D]], killed [[FADD_D1]], 7, implicit $frm
; CHECK-NEXT: $f10_d = COPY [[FADD_D2]]
; CHECK-NEXT: PseudoRET implicit $f10_d
%t0 = fadd nsz reassoc double %a0, %a1
%t1 = fadd contract nsz reassoc double %t0, %a2
%t2 = fadd nsz reassoc double %t1, %a3
ret double %t2
}

; Verify flags intersection
define double @test_reassoc_flags2(double %a0, double %a1, double %a2, double %a3) {
; CHECK-LABEL: name: test_reassoc_flags2
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $f10_d, $f11_d, $f12_d, $f13_d
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:fpr64 = COPY $f13_d
; CHECK-NEXT: [[COPY1:%[0-9]+]]:fpr64 = COPY $f12_d
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fpr64 = COPY $f11_d
; CHECK-NEXT: [[COPY3:%[0-9]+]]:fpr64 = COPY $f10_d
; CHECK-NEXT: [[FADD_D:%[0-9]+]]:fpr64 = nsz reassoc nofpexcept FADD_D [[COPY3]], [[COPY2]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D1:%[0-9]+]]:fpr64 = nsz contract reassoc nofpexcept FADD_D [[COPY1]], [[COPY]], 7, implicit $frm
; CHECK-NEXT: [[FADD_D2:%[0-9]+]]:fpr64 = nsz contract reassoc nofpexcept FADD_D killed [[FADD_D]], killed [[FADD_D1]], 7, implicit $frm
; CHECK-NEXT: $f10_d = COPY [[FADD_D2]]
; CHECK-NEXT: PseudoRET implicit $f10_d
%t0 = fadd nsz reassoc double %a0, %a1
%t1 = fadd contract nsz reassoc double %t0, %a2
%t2 = fadd contract nsz reassoc double %t1, %a3
ret double %t2
}

0 comments on commit 1978b4d

Please sign in to comment.