Skip to content

Commit

Permalink
AMDGPU/SDAG: Factor out the fold (add (mul x, y), y) --> mad_[iu]64_[…
Browse files Browse the repository at this point in the history
…iu]32

Refactor to simplify a follow-up change.

No functional change intended. However, there is a rather subtle logic
change: the subsequent combines (e.g. reassociation) are skipped *always*
when one of the operands of the add is a mul, instead of only when
additionally mad64_32 etc. are available. This change makes sense because
the subsequent combines should never apply when one of the operands is a
mul.

Differential Revision: https://reviews.llvm.org/D123833
  • Loading branch information
nhaehnle committed May 2, 2022
1 parent 4f5d525 commit deaa678
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
71 changes: 48 additions & 23 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10661,39 +10661,64 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL,
return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
}

SDValue SITargetLowering::performAddCombine(SDNode *N,
// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z).
SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
DAGCombinerInfo &DCI) const {
assert(N->getOpcode() == ISD::ADD);

SelectionDAG &DAG = DCI.DAG;
EVT VT = N->getValueType(0);
SDLoc SL(N);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

if ((LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL)
&& Subtarget->hasMad64_32() &&
!VT.isVector() && VT.getScalarSizeInBits() > 32 &&
VT.getScalarSizeInBits() <= 64) {
if (LHS.getOpcode() != ISD::MUL)
std::swap(LHS, RHS);
if (VT.isVector())
return SDValue();

SDValue MulLHS = LHS.getOperand(0);
SDValue MulRHS = LHS.getOperand(1);
SDValue AddRHS = RHS;
unsigned NumBits = VT.getScalarSizeInBits();
if (NumBits <= 32 || NumBits > 64)
return SDValue();

// TODO: Maybe restrict if SGPR inputs.
if (numBitsUnsigned(MulLHS, DAG) <= 32 &&
numBitsUnsigned(MulRHS, DAG) <= 32) {
MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32);
MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32);
AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64);
return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false);
}
if (LHS.getOpcode() != ISD::MUL) {
assert(RHS.getOpcode() == ISD::MUL);
std::swap(LHS, RHS);
}

SDValue MulLHS = LHS.getOperand(0);
SDValue MulRHS = LHS.getOperand(1);
SDValue AddRHS = RHS;

// TODO: Maybe restrict if SGPR inputs.
if (numBitsUnsigned(MulLHS, DAG) <= 32 &&
numBitsUnsigned(MulRHS, DAG) <= 32) {
MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32);
MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32);
AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64);
return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false);
}

if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) {
MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32);
MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32);
AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);
return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true);
}

return SDValue();
}

SDValue SITargetLowering::performAddCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
EVT VT = N->getValueType(0);
SDLoc SL(N);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) {
MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32);
MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32);
AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);
return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true);
if (LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL) {
if (Subtarget->hasMad64_32()) {
if (SDValue Folded = tryFoldToMad64_32(N, DCI))
return Folded;
}

return SDValue();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
SDValue reassociateScalarOps(SDNode *N, SelectionDAG &DAG) const;
unsigned getFusedOpcode(const SelectionDAG &DAG,
const SDNode *N0, const SDNode *N1) const;
SDValue tryFoldToMad64_32(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performAddCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performAddCarrySubCarryCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performSubCombine(SDNode *N, DAGCombinerInfo &DCI) const;
Expand Down

0 comments on commit deaa678

Please sign in to comment.