Skip to content

Commit

Permalink
[AArch64][GlobalISel] Port some AArch64 target specific MUL combines …
Browse files Browse the repository at this point in the history
…from SDAG.

These do things like turn a multiply of a pow-2+1 into a shift and and add,
which is a common pattern that pops up, and is universally better than expensive
madd instructions with a constant.

I've added check lines to an existing codegen test since the code being ported
is almost identical, however the mul by negative pow2 constant tests don't generate
the same code because we're missing some generic G_MUL combines still.

Differential Revision: https://reviews.llvm.org/D91125
  • Loading branch information
aemerson committed Nov 11, 2020
1 parent 881b4d2 commit 2262393
Show file tree
Hide file tree
Showing 3 changed files with 593 additions and 112 deletions.
11 changes: 10 additions & 1 deletion llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def extractvecelt_pairwise_add : GICombineRule<
(apply [{ applyExtractVecEltPairwiseAdd(*${root}, MRI, B, ${matchinfo}); }])
>;

def mul_const_matchdata : GIDefMatchData<"std::function<void(MachineIRBuilder&, Register)>">;
def mul_const : GICombineRule<
(defs root:$root, mul_const_matchdata:$matchinfo),
(match (wip_match_opcode G_MUL):$root,
[{ return matchAArch64MulConstCombine(*${root}, MRI, ${matchinfo}); }]),
(apply [{ applyAArch64MulConstCombine(*${root}, MRI, B, ${matchinfo}); }])
>;

// Post-legalization combines which should happen at all optimization levels.
// (E.g. ones that facilitate matching for the selector) For example, matching
// pseudos.
Expand All @@ -128,6 +136,7 @@ def AArch64PostLegalizerCombinerHelper
sext_trunc_sextload,
hoist_logic_op_with_same_opcode_hands,
redundant_and, xor_of_and_with_same_reg,
extractvecelt_pairwise_add, redundant_or]> {
extractvecelt_pairwise_add, redundant_or,
mul_const]> {
let DisableRuleOption = "aarch64postlegalizercombiner-disable-rule";
}
133 changes: 133 additions & 0 deletions llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -104,6 +105,138 @@ bool applyExtractVecEltPairwiseAdd(
return true;
}

static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
// TODO: check if extended build vector as well.
unsigned Opc = MRI.getVRegDef(R)->getOpcode();
return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
}

static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
// TODO: check if extended build vector as well.
return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
}

bool matchAArch64MulConstCombine(
MachineInstr &MI, MachineRegisterInfo &MRI,
std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
assert(MI.getOpcode() == TargetOpcode::G_MUL);
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
Register Dst = MI.getOperand(0).getReg();
const LLT Ty = MRI.getType(LHS);

// The below optimizations require a constant RHS.
auto Const = getConstantVRegValWithLookThrough(RHS, MRI);
if (!Const)
return false;

const APInt &ConstValue = APInt(Ty.getSizeInBits(), Const->Value, true);
// The following code is ported from AArch64ISelLowering.
// Multiplication of a power of two plus/minus one can be done more
// cheaply as as shift+add/sub. For now, this is true unilaterally. If
// future CPUs have a cheaper MADD instruction, this may need to be
// gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
// 64-bit is 5 cycles, so this is always a win.
// More aggressively, some multiplications N0 * C can be lowered to
// shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
// e.g. 6=3*2=(2+1)*2.
// TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
// which equals to (1+2)*16-(1+2).
// TrailingZeroes is used to test if the mul can be lowered to
// shift+add+shift.
unsigned TrailingZeroes = ConstValue.countTrailingZeros();
if (TrailingZeroes) {
// Conservatively do not lower to shift+add+shift if the mul might be
// folded into smul or umul.
if (MRI.hasOneNonDBGUse(LHS) &&
(isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
return false;
// Conservatively do not lower to shift+add+shift if the mul might be
// folded into madd or msub.
if (MRI.hasOneNonDBGUse(Dst)) {
MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
if (UseMI.getOpcode() == TargetOpcode::G_ADD ||
UseMI.getOpcode() == TargetOpcode::G_SUB)
return false;
}
}
// Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
// and shift+add+shift.
APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);

unsigned ShiftAmt, AddSubOpc;
// Is the shifted value the LHS operand of the add/sub?
bool ShiftValUseIsLHS = true;
// Do we need to negate the result?
bool NegateResult = false;

if (ConstValue.isNonNegative()) {
// (mul x, 2^N + 1) => (add (shl x, N), x)
// (mul x, 2^N - 1) => (sub (shl x, N), x)
// (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
APInt SCVMinus1 = ShiftedConstValue - 1;
APInt CVPlus1 = ConstValue + 1;
if (SCVMinus1.isPowerOf2()) {
ShiftAmt = SCVMinus1.logBase2();
AddSubOpc = TargetOpcode::G_ADD;
} else if (CVPlus1.isPowerOf2()) {
ShiftAmt = CVPlus1.logBase2();
AddSubOpc = TargetOpcode::G_SUB;
} else
return false;
} else {
// (mul x, -(2^N - 1)) => (sub x, (shl x, N))
// (mul x, -(2^N + 1)) => - (add (shl x, N), x)
APInt CVNegPlus1 = -ConstValue + 1;
APInt CVNegMinus1 = -ConstValue - 1;
if (CVNegPlus1.isPowerOf2()) {
ShiftAmt = CVNegPlus1.logBase2();
AddSubOpc = TargetOpcode::G_SUB;
ShiftValUseIsLHS = false;
} else if (CVNegMinus1.isPowerOf2()) {
ShiftAmt = CVNegMinus1.logBase2();
AddSubOpc = TargetOpcode::G_ADD;
NegateResult = true;
} else
return false;
}

if (NegateResult && TrailingZeroes)
return false;

ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
auto ShiftedVal = B.buildShl(Ty, LHS, Shift);

Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
assert(!(NegateResult && TrailingZeroes) &&
"NegateResult and TrailingZeroes cannot both be true for now.");
// Negate the result.
if (NegateResult) {
B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
return;
}
// Shift the result.
if (TrailingZeroes) {
B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
return;
}
B.buildCopy(DstReg, Res.getReg(0));
};
return true;
}

bool applyAArch64MulConstCombine(
MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
B.setInstrAndDebugLoc(MI);
ApplyFn(B, MI.getOperand(0).getReg());
MI.eraseFromParent();
return true;
}

#define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
#include "AArch64GenPostLegalizeGICombiner.inc"
#undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
Expand Down
Loading

0 comments on commit 2262393

Please sign in to comment.