Skip to content

Commit

Permalink
[GISel] Rework trunc/shl combine in a generic trunc/shift combine
Browse files Browse the repository at this point in the history
This combine only handled left shifts, but now it can handle right shifts as well. It handles right shifts conservatively and only truncates them to the size returned by TLI.

AMDGPU benefits from always lowering shifts to 32 bits for instance, but AArch64 would rather keep them at 64 bits.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D136319
  • Loading branch information
Pierre-vh committed Dec 9, 2022
1 parent 3e5f54d commit 3612d9e
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 116 deletions.
17 changes: 11 additions & 6 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Expand Up @@ -406,12 +406,17 @@ class CombinerHelper {
void applyCombineTruncOfExt(MachineInstr &MI,
std::pair<Register, unsigned> &MatchInfo);

/// Transform trunc (shl x, K) to shl (trunc x),
/// K => K < VT.getScalarSizeInBits().
bool matchCombineTruncOfShl(MachineInstr &MI,
std::pair<Register, Register> &MatchInfo);
void applyCombineTruncOfShl(MachineInstr &MI,
std::pair<Register, Register> &MatchInfo);
/// Transform trunc (shl x, K) to shl (trunc x), K
/// if K < VT.getScalarSizeInBits().
///
/// Transforms trunc ([al]shr x, K) to (trunc ([al]shr (MidVT (trunc x)), K))
/// if K <= (MidVT.getScalarSizeInBits() - VT.getScalarSizeInBits())
/// MidVT is obtained by finding a legal type between the trunc's src and dst
/// types.
bool matchCombineTruncOfShift(MachineInstr &MI,
std::pair<MachineInstr *, LLT> &MatchInfo);
void applyCombineTruncOfShift(MachineInstr &MI,
std::pair<MachineInstr *, LLT> &MatchInfo);

/// Transform G_MUL(x, -1) to G_SUB(0, x)
void applyCombineMulByNegativeOne(MachineInstr &MI);
Expand Down
16 changes: 9 additions & 7 deletions llvm/include/llvm/Target/GlobalISel/Combine.td
Expand Up @@ -642,13 +642,15 @@ def trunc_ext_fold: GICombineRule <
(apply [{ Helper.applyCombineTruncOfExt(*${root}, ${matchinfo}); }])
>;

// Fold trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits().
def trunc_shl_matchinfo : GIDefMatchData<"std::pair<Register, Register>">;
def trunc_shl: GICombineRule <
(defs root:$root, trunc_shl_matchinfo:$matchinfo),
// Under certain conditions, transform:
// trunc (shl x, K) -> shl (trunc x), K//
// trunc ([al]shr x, K) -> (trunc ([al]shr (trunc x), K))
def trunc_shift_matchinfo : GIDefMatchData<"std::pair<MachineInstr*, LLT>">;
def trunc_shift: GICombineRule <
(defs root:$root, trunc_shift_matchinfo:$matchinfo),
(match (wip_match_opcode G_TRUNC):$root,
[{ return Helper.matchCombineTruncOfShl(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyCombineTruncOfShl(*${root}, ${matchinfo}); }])
[{ return Helper.matchCombineTruncOfShift(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyCombineTruncOfShift(*${root}, ${matchinfo}); }])
>;

// Transform (mul x, -1) -> (sub 0, x)
Expand Down Expand Up @@ -1076,7 +1078,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
known_bits_simplifications, ext_ext_fold,
not_cmp_fold, opt_brcond_by_inverting_cond,
unmerge_merge, unmerge_cst, unmerge_dead_to_trunc,
unmerge_zext_to_zext, merge_unmerge, trunc_ext_fold, trunc_shl,
unmerge_zext_to_zext, merge_unmerge, trunc_ext_fold, trunc_shift,
const_combines, xor_of_and_with_same_reg, ptr_add_with_zero,
shift_immed_chain, shift_of_shifted_logic_chain, load_or_combine,
truncstore_merge, div_rem_to_divrem, funnel_shift_combines,
Expand Down
125 changes: 95 additions & 30 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Expand Up @@ -2266,44 +2266,109 @@ void CombinerHelper::applyCombineTruncOfExt(
MI.eraseFromParent();
}

bool CombinerHelper::matchCombineTruncOfShl(
MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
LLT DstTy = MRI.getType(DstReg);
Register ShiftSrc;
Register ShiftAmt;

if (MRI.hasOneNonDBGUse(SrcReg) &&
mi_match(SrcReg, MRI, m_GShl(m_Reg(ShiftSrc), m_Reg(ShiftAmt))) &&
isLegalOrBeforeLegalizer(
{TargetOpcode::G_SHL,
{DstTy, getTargetLowering().getPreferredShiftAmountTy(DstTy)}})) {
KnownBits Known = KB->getKnownBits(ShiftAmt);
unsigned Size = DstTy.getSizeInBits();
if (Known.countMaxActiveBits() <= Log2_32(Size)) {
MatchInfo = std::make_pair(ShiftSrc, ShiftAmt);
return true;
}
}
return false;
static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
const unsigned TruncSize = TruncTy.getScalarSizeInBits();

// ShiftTy > 32 > TruncTy -> 32
if (ShiftSize > 32 && TruncSize < 32)
return ShiftTy.changeElementSize(32);

// TODO: We could also reduce to 16 bits, but that's more target-dependent.
// Some targets like it, some don't, some only like it under certain
// conditions/processor versions, etc.
// A TL hook might be needed for this.

// Don't combine
return ShiftTy;
}

void CombinerHelper::applyCombineTruncOfShl(
MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
bool CombinerHelper::matchCombineTruncOfShift(
MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();

if (!MRI.hasOneNonDBGUse(SrcReg))
return false;

LLT SrcTy = MRI.getType(SrcReg);
LLT DstTy = MRI.getType(DstReg);
MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);

Register ShiftSrc = MatchInfo.first;
Register ShiftAmt = MatchInfo.second;
MachineInstr *SrcMI = getDefIgnoringCopies(SrcReg, MRI);
const auto &TL = getTargetLowering();

LLT NewShiftTy;
switch (SrcMI->getOpcode()) {
default:
return false;
case TargetOpcode::G_SHL: {
NewShiftTy = DstTy;

// Make sure new shift amount is legal.
KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
if (Known.getMaxValue().uge(NewShiftTy.getScalarSizeInBits()))
return false;
break;
}
case TargetOpcode::G_LSHR:
case TargetOpcode::G_ASHR: {
// For right shifts, we conservatively do not do the transform if the TRUNC
// has any STORE users. The reason is that if we change the type of the
// shift, we may break the truncstore combine.
//
// TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
for (auto &User : MRI.use_instructions(DstReg))
if (User.getOpcode() == TargetOpcode::G_STORE)
return false;

NewShiftTy = getMidVTForTruncRightShiftCombine(SrcTy, DstTy);
if (NewShiftTy == SrcTy)
return false;

// Make sure we won't lose information by truncating the high bits.
KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
if (Known.getMaxValue().ugt(NewShiftTy.getScalarSizeInBits() -
DstTy.getScalarSizeInBits()))
return false;
break;
}
}

if (!isLegalOrBeforeLegalizer(
{SrcMI->getOpcode(),
{NewShiftTy, TL.getPreferredShiftAmountTy(NewShiftTy)}}))
return false;

MatchInfo = std::make_pair(SrcMI, NewShiftTy);
return true;
}

void CombinerHelper::applyCombineTruncOfShift(
MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
Builder.setInstrAndDebugLoc(MI);
auto TruncShiftSrc = Builder.buildTrunc(DstTy, ShiftSrc);
Builder.buildShl(DstReg, TruncShiftSrc, ShiftAmt, SrcMI->getFlags());
MI.eraseFromParent();

MachineInstr *ShiftMI = MatchInfo.first;
LLT NewShiftTy = MatchInfo.second;

Register Dst = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(Dst);

Register ShiftAmt = ShiftMI->getOperand(2).getReg();
Register ShiftSrc = ShiftMI->getOperand(1).getReg();
ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0);

Register NewShift =
Builder
.buildInstr(ShiftMI->getOpcode(), {NewShiftTy}, {ShiftSrc, ShiftAmt})
.getReg(0);

if (NewShiftTy == DstTy)
replaceRegWith(MRI, Dst, NewShift);
else
Builder.buildTrunc(Dst, NewShift);

eraseInst(MI);
}

bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) {
Expand Down

0 comments on commit 3612d9e

Please sign in to comment.