Skip to content

Commit

Permalink
[AArch64] Add Machine InstCombiner patterns for FMUL indexed variant
Browse files Browse the repository at this point in the history
This patch adds DUP+FMUL => FMUL_indexed pattern to InstCombiner.
FMUL_indexed is normally selected during instruction selection, but it
does not work in cases when VDUP and VMUL are in different basic
blocks.

Differential Revision: https://reviews.llvm.org/D99662
  • Loading branch information
asavonic committed Nov 9, 2021
1 parent 0076957 commit b702276
Show file tree
Hide file tree
Showing 4 changed files with 825 additions and 3 deletions.
13 changes: 12 additions & 1 deletion llvm/include/llvm/CodeGen/MachineCombinerPattern.h
Expand Up @@ -153,7 +153,18 @@ enum class MachineCombinerPattern {
FMLSv4f32_OP1,
FMLSv4f32_OP2,
FMLSv4i32_indexed_OP1,
FMLSv4i32_indexed_OP2
FMLSv4i32_indexed_OP2,

FMULv2i32_indexed_OP1,
FMULv2i32_indexed_OP2,
FMULv2i64_indexed_OP1,
FMULv2i64_indexed_OP2,
FMULv4i16_indexed_OP1,
FMULv4i16_indexed_OP2,
FMULv4i32_indexed_OP1,
FMULv4i32_indexed_OP2,
FMULv8i16_indexed_OP1,
FMULv8i16_indexed_OP2,
};

} // end namespace llvm
Expand Down
140 changes: 139 additions & 1 deletion llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Expand Up @@ -4917,6 +4917,55 @@ static bool getFMAPatterns(MachineInstr &Root,
return Found;
}

static bool getFMULPatterns(MachineInstr &Root,
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
MachineBasicBlock &MBB = *Root.getParent();
bool Found = false;

auto Match = [&](unsigned Opcode, int Operand,
MachineCombinerPattern Pattern) -> bool {
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
MachineOperand &MO = Root.getOperand(Operand);
MachineInstr *MI = nullptr;
if (MO.isReg() && Register::isVirtualRegister(MO.getReg()))
MI = MRI.getUniqueVRegDef(MO.getReg());
if (MI && MI->getOpcode() == Opcode) {
Patterns.push_back(Pattern);
return true;
}
return false;
};

typedef MachineCombinerPattern MCP;

switch (Root.getOpcode()) {
default:
return false;
case AArch64::FMULv2f32:
Found = Match(AArch64::DUPv2i32lane, 1, MCP::FMULv2i32_indexed_OP1);
Found |= Match(AArch64::DUPv2i32lane, 2, MCP::FMULv2i32_indexed_OP2);
break;
case AArch64::FMULv2f64:
Found = Match(AArch64::DUPv2i64lane, 1, MCP::FMULv2i64_indexed_OP1);
Found |= Match(AArch64::DUPv2i64lane, 2, MCP::FMULv2i64_indexed_OP2);
break;
case AArch64::FMULv4f16:
Found = Match(AArch64::DUPv4i16lane, 1, MCP::FMULv4i16_indexed_OP1);
Found |= Match(AArch64::DUPv4i16lane, 2, MCP::FMULv4i16_indexed_OP2);
break;
case AArch64::FMULv4f32:
Found = Match(AArch64::DUPv4i32lane, 1, MCP::FMULv4i32_indexed_OP1);
Found |= Match(AArch64::DUPv4i32lane, 2, MCP::FMULv4i32_indexed_OP2);
break;
case AArch64::FMULv8f16:
Found = Match(AArch64::DUPv8i16lane, 1, MCP::FMULv8i16_indexed_OP1);
Found |= Match(AArch64::DUPv8i16lane, 2, MCP::FMULv8i16_indexed_OP2);
break;
}

return Found;
}

/// Return true when a code sequence can improve throughput. It
/// should be called only for instructions in loops.
/// \param Pattern - combiner pattern
Expand Down Expand Up @@ -4980,6 +5029,16 @@ bool AArch64InstrInfo::isThroughputPattern(
case MachineCombinerPattern::FMLSv2f64_OP2:
case MachineCombinerPattern::FMLSv4i32_indexed_OP2:
case MachineCombinerPattern::FMLSv4f32_OP2:
case MachineCombinerPattern::FMULv2i32_indexed_OP1:
case MachineCombinerPattern::FMULv2i32_indexed_OP2:
case MachineCombinerPattern::FMULv2i64_indexed_OP1:
case MachineCombinerPattern::FMULv2i64_indexed_OP2:
case MachineCombinerPattern::FMULv4i16_indexed_OP1:
case MachineCombinerPattern::FMULv4i16_indexed_OP2:
case MachineCombinerPattern::FMULv4i32_indexed_OP1:
case MachineCombinerPattern::FMULv4i32_indexed_OP2:
case MachineCombinerPattern::FMULv8i16_indexed_OP1:
case MachineCombinerPattern::FMULv8i16_indexed_OP2:
case MachineCombinerPattern::MULADDv8i8_OP1:
case MachineCombinerPattern::MULADDv8i8_OP2:
case MachineCombinerPattern::MULADDv16i8_OP1:
Expand Down Expand Up @@ -5036,6 +5095,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
if (getMaddPatterns(Root, Patterns))
return true;
// Floating point patterns
if (getFMULPatterns(Root, Patterns))
return true;
if (getFMAPatterns(Root, Patterns))
return true;

Expand Down Expand Up @@ -5124,6 +5185,42 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI,
return MUL;
}

/// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane)
static MachineInstr *
genIndexedMultiply(MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs,
unsigned IdxDupOp, unsigned MulOpc,
const TargetRegisterClass *RC, MachineRegisterInfo &MRI) {
assert(((IdxDupOp == 1) || (IdxDupOp == 2)) &&
"Invalid index of FMUL operand");

MachineFunction &MF = *Root.getMF();
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();

MachineInstr *Dup =
MF.getRegInfo().getUniqueVRegDef(Root.getOperand(IdxDupOp).getReg());

Register DupSrcReg = Dup->getOperand(1).getReg();
MRI.clearKillFlags(DupSrcReg);
MRI.constrainRegClass(DupSrcReg, RC);

unsigned DupSrcLane = Dup->getOperand(2).getImm();

unsigned IdxMulOp = IdxDupOp == 1 ? 2 : 1;
MachineOperand &MulOp = Root.getOperand(IdxMulOp);

Register ResultReg = Root.getOperand(0).getReg();

MachineInstrBuilder MIB;
MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MulOpc), ResultReg)
.add(MulOp)
.addReg(DupSrcReg)
.addImm(DupSrcLane);

InsInstrs.push_back(MIB);
return &Root;
}

/// genFusedMultiplyAcc - Helper to generate fused multiply accumulate
/// instructions.
///
Expand Down Expand Up @@ -6082,12 +6179,53 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
}
break;
}
case MachineCombinerPattern::FMULv2i32_indexed_OP1:
case MachineCombinerPattern::FMULv2i32_indexed_OP2: {
unsigned IdxDupOp =
(Pattern == MachineCombinerPattern::FMULv2i32_indexed_OP1) ? 1 : 2;
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv2i32_indexed,
&AArch64::FPR128RegClass, MRI);
break;
}
case MachineCombinerPattern::FMULv2i64_indexed_OP1:
case MachineCombinerPattern::FMULv2i64_indexed_OP2: {
unsigned IdxDupOp =
(Pattern == MachineCombinerPattern::FMULv2i64_indexed_OP1) ? 1 : 2;
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv2i64_indexed,
&AArch64::FPR128RegClass, MRI);
break;
}
case MachineCombinerPattern::FMULv4i16_indexed_OP1:
case MachineCombinerPattern::FMULv4i16_indexed_OP2: {
unsigned IdxDupOp =
(Pattern == MachineCombinerPattern::FMULv4i16_indexed_OP1) ? 1 : 2;
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv4i16_indexed,
&AArch64::FPR128_loRegClass, MRI);
break;
}
case MachineCombinerPattern::FMULv4i32_indexed_OP1:
case MachineCombinerPattern::FMULv4i32_indexed_OP2: {
unsigned IdxDupOp =
(Pattern == MachineCombinerPattern::FMULv4i32_indexed_OP1) ? 1 : 2;
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv4i32_indexed,
&AArch64::FPR128RegClass, MRI);
break;
}
case MachineCombinerPattern::FMULv8i16_indexed_OP1:
case MachineCombinerPattern::FMULv8i16_indexed_OP2: {
unsigned IdxDupOp =
(Pattern == MachineCombinerPattern::FMULv8i16_indexed_OP1) ? 1 : 2;
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv8i16_indexed,
&AArch64::FPR128_loRegClass, MRI);
break;
}
} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion
// FIXME: This assertion fails in CodeGen/AArch64/tailmerging_in_mbp.ll and
// CodeGen/AArch64/urem-seteq-nonzero.ll.
// assert(MUL && "MUL was never set");
DelInstrs.push_back(MUL);
if (MUL)
DelInstrs.push_back(MUL);
DelInstrs.push_back(&Root);
}

Expand Down
128 changes: 127 additions & 1 deletion llvm/test/CodeGen/AArch64/arm64-fma-combines.ll
@@ -1,4 +1,5 @@
; RUN: llc < %s -O=3 -mtriple=arm64-apple-ios -mcpu=cyclone -enable-unsafe-fp-math | FileCheck %s
; RUN: llc < %s -O=3 -mtriple=arm64-apple-ios -mcpu=cyclone -mattr=+fullfp16 -enable-unsafe-fp-math -verify-machineinstrs | FileCheck %s

define void @foo_2d(double* %src) {
; CHECK-LABEL: %entry
; CHECK: fmul {{d[0-9]+}}, {{d[0-9]+}}, {{d[0-9]+}}
Expand Down Expand Up @@ -134,3 +135,128 @@ for.body: ; preds = %for.body, %entry
for.end: ; preds = %for.body
ret void
}

define void @indexed_2s(<2 x float> %shuf, <2 x float> %add,
<2 x float>* %pmul, <2 x float>* %pret) {
; CHECK-LABEL: %entry
; CHECK: for.body
; CHECK: fmla.2s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
;
entry:
%shuffle = shufflevector <2 x float> %shuf, <2 x float> undef, <2 x i32> zeroinitializer
br label %for.body

for.body:
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
%pmul_i = getelementptr inbounds <2 x float>, <2 x float>* %pmul, i64 %i
%pret_i = getelementptr inbounds <2 x float>, <2 x float>* %pret, i64 %i

%mul_i = load <2 x float>, <2 x float>* %pmul_i

%mul = fmul fast <2 x float> %mul_i, %shuffle
%muladd = fadd fast <2 x float> %mul, %add

store <2 x float> %muladd, <2 x float>* %pret_i, align 16
%inext = add i64 %i, 1
br label %for.body
}

define void @indexed_2d(<2 x double> %shuf, <2 x double> %add,
<2 x double>* %pmul, <2 x double>* %pret) {
; CHECK-LABEL: %entry
; CHECK: for.body
; CHECK: fmla.2d {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
;
entry:
%shuffle = shufflevector <2 x double> %shuf, <2 x double> undef, <2 x i32> zeroinitializer
br label %for.body

for.body:
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
%pmul_i = getelementptr inbounds <2 x double>, <2 x double>* %pmul, i64 %i
%pret_i = getelementptr inbounds <2 x double>, <2 x double>* %pret, i64 %i

%mul_i = load <2 x double>, <2 x double>* %pmul_i

%mul = fmul fast <2 x double> %mul_i, %shuffle
%muladd = fadd fast <2 x double> %mul, %add

store <2 x double> %muladd, <2 x double>* %pret_i, align 16
%inext = add i64 %i, 1
br label %for.body
}

define void @indexed_4s(<4 x float> %shuf, <4 x float> %add,
<4 x float>* %pmul, <4 x float>* %pret) {
; CHECK-LABEL: %entry
; CHECK: for.body
; CHECK: fmla.4s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
;
entry:
%shuffle = shufflevector <4 x float> %shuf, <4 x float> undef, <4 x i32> zeroinitializer
br label %for.body

for.body:
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
%pmul_i = getelementptr inbounds <4 x float>, <4 x float>* %pmul, i64 %i
%pret_i = getelementptr inbounds <4 x float>, <4 x float>* %pret, i64 %i

%mul_i = load <4 x float>, <4 x float>* %pmul_i

%mul = fmul fast <4 x float> %mul_i, %shuffle
%muladd = fadd fast <4 x float> %mul, %add

store <4 x float> %muladd, <4 x float>* %pret_i, align 16
%inext = add i64 %i, 1
br label %for.body
}

define void @indexed_4h(<4 x half> %shuf, <4 x half> %add,
<4 x half>* %pmul, <4 x half>* %pret) {
; CHECK-LABEL: %entry
; CHECK: for.body
; CHECK: fmla.4h {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
;
entry:
%shuffle = shufflevector <4 x half> %shuf, <4 x half> undef, <4 x i32> zeroinitializer
br label %for.body

for.body:
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
%pmul_i = getelementptr inbounds <4 x half>, <4 x half>* %pmul, i64 %i
%pret_i = getelementptr inbounds <4 x half>, <4 x half>* %pret, i64 %i

%mul_i = load <4 x half>, <4 x half>* %pmul_i

%mul = fmul fast <4 x half> %mul_i, %shuffle
%muladd = fadd fast <4 x half> %mul, %add

store <4 x half> %muladd, <4 x half>* %pret_i, align 16
%inext = add i64 %i, 1
br label %for.body
}

define void @indexed_8h(<8 x half> %shuf, <8 x half> %add,
<8 x half>* %pmul, <8 x half>* %pret) {
; CHECK-LABEL: %entry
; CHECK: for.body
; CHECK: fmla.8h {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
;
entry:
%shuffle = shufflevector <8 x half> %shuf, <8 x half> undef, <8 x i32> zeroinitializer
br label %for.body

for.body:
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
%pmul_i = getelementptr inbounds <8 x half>, <8 x half>* %pmul, i64 %i
%pret_i = getelementptr inbounds <8 x half>, <8 x half>* %pret, i64 %i

%mul_i = load <8 x half>, <8 x half>* %pmul_i

%mul = fmul fast <8 x half> %mul_i, %shuffle
%muladd = fadd fast <8 x half> %mul, %add

store <8 x half> %muladd, <8 x half>* %pret_i, align 16
%inext = add i64 %i, 1
br label %for.body
}

0 comments on commit b702276

Please sign in to comment.