303 changes: 267 additions & 36 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ namespace {
SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
SDValue combineRepeatedFPDivisors(SDNode *N);
SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
Expand All @@ -620,7 +621,10 @@ namespace {
SDValue BuildUDIV(SDNode *N);
SDValue BuildSREMPow2(SDNode *N);
SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
bool KnownNeverZero = false,
bool InexpensiveOnly = false,
std::optional<EVT> OutVT = std::nullopt);
SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
Expand Down Expand Up @@ -4389,12 +4393,12 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {

// fold (mul x, (1 << c)) -> x << c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1) &&
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
SDValue LogBase2 = BuildLogBase2(N1, DL);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
}
}

// fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
Expand Down Expand Up @@ -4916,31 +4920,31 @@ SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
EVT VT = N->getValueType(0);

// fold (udiv x, (1 << c)) -> x >>u c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1)) {
SDValue LogBase2 = BuildLogBase2(N1, DL);
AddToWorklist(LogBase2.getNode());
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) {
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
AddToWorklist(LogBase2.getNode());

EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
AddToWorklist(Trunc.getNode());
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
AddToWorklist(Trunc.getNode());
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
}
}

// fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
if (N1.getOpcode() == ISD::SHL) {
SDValue N10 = N1.getOperand(0);
if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N10)) {
SDValue LogBase2 = BuildLogBase2(N10, DL);
AddToWorklist(LogBase2.getNode());

EVT ADDVT = N1.getOperand(1).getValueType();
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
AddToWorklist(Trunc.getNode());
SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
AddToWorklist(Add.getNode());
return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) {
if (SDValue LogBase2 = BuildLogBase2(N10, DL)) {
AddToWorklist(LogBase2.getNode());

EVT ADDVT = N1.getOperand(1).getValueType();
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
AddToWorklist(Trunc.getNode());
SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
AddToWorklist(Add.getNode());
return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
}
}
}

Expand Down Expand Up @@ -5158,14 +5162,15 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {

// fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
unsigned NumEltBits = VT.getScalarSizeInBits();
SDValue LogBase2 = BuildLogBase2(N1, DL);
SDValue SRLAmt = DAG.getNode(
ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
hasOperation(ISD::SRL, VT)) {
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
unsigned NumEltBits = VT.getScalarSizeInBits();
SDValue SRLAmt = DAG.getNode(
ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
}
}

// If the type twice as wide is legal, transform the mulhu to a wider multiply
Expand Down Expand Up @@ -16328,6 +16333,105 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
return SDValue();
}

// Transform IEEE Floats:
// (fmul C, (uitofp Pow2))
// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
// (fdiv C, (uitofp Pow2))
// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
//
// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
// there is no need for more than an add/sub.
//
// This is valid under the following circumstances:
// 1) We are dealing with IEEE floats
// 2) C is normal
// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
// TODO: Much of this could also be used for generating `ldexp` on targets the
// prefer it.
SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
EVT VT = N->getValueType(0);
SDValue ConstOp, Pow2Op;

int Mantissa = -1;
auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
return false;

ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
Pow2Op = N->getOperand(1 - ConstOpIdx);
if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
(Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
!DAG.computeKnownBits(Pow2Op).isNonNegative()))
return false;

Pow2Op = Pow2Op.getOperand(0);

// TODO(1): We may be able to include undefs.
// TODO(2): We could also handle non-splat vector types.
ConstantFPSDNode *CFP =
isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false);
if (CFP == nullptr)
return false;
const APFloat &APF = CFP->getValueAPF();

// Make sure we have normal/ieee constant.
if (!APF.isNormal() || !APF.isIEEE())
return false;

// `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
// TODO: We could use knownbits to make this bound more precise.
int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();

// Make sure the floats exponent is within the bounds that this transform
// produces bitwise equals value.
int CurExp = ilogb(APF);
// FMul by pow2 will only increase exponent.
int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
// FDiv by pow2 will only decrease exponent.
int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
return false;

// Finally make sure we actually know the mantissa for the float type.
Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
return Mantissa > 0;
};

if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
return SDValue();

if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
return SDValue();

// Get log2 after all other checks have taken place. This is because
// BuildLogBase2 may create a new node.
SDLoc DL(N);
// Get Log2 type with same bitwidth as the float type (VT).
EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits());
if (VT.isVector())
NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT,
VT.getVectorNumElements());

SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
/*InexpensiveOnly*/ true, NewIntVT);
if (!Log2)
return SDValue();

// Perform actual transform.
SDValue MantissaShiftCnt =
DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT));
// TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
// `(X << C1) + (C << C1)`, but that isn't always the case because of the
// cast. We could implement that by handle here to handle the casts.
SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
SDValue ResAsInt =
DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
return ResAsFP;
}

SDValue DAGCombiner::visitFMUL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down Expand Up @@ -16468,6 +16572,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
return Fused;
}

// Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
// able to run.
if (SDValue R = combineFMulOrFDivWithIntPow2(N))
return R;

return SDValue();
}

Expand Down Expand Up @@ -16819,6 +16928,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
}

if (SDValue R = combineFMulOrFDivWithIntPow2(N))
return R;

return SDValue();
}

Expand Down Expand Up @@ -21861,7 +21973,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
if (DAG.isKnownNeverZero(Index))
return DAG.getUNDEF(ScalarVT);

// Check if the result type doesn't match the inserted element type.
// Check if the result type doesn't match the inserted element type.
// The inserted element and extracted element may have mismatched bitwidth.
// As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
SDValue InOp = VecOp.getOperand(0);
Expand Down Expand Up @@ -27142,10 +27254,129 @@ SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
return SDValue();
}

// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
//
// Returns the node that represents `Log2(Op)`. This may create a new node. If
// we are unable to compute `Log2(Op)` its return `SDValue()`.
//
// All nodes will be created at `DL` and the output will be of type `VT`.
//
// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
// `AssumeNonZero` if this function should simply assume (not require proving
// `Op` is non-zero).
static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
SDValue Op, unsigned Depth,
bool AssumeNonZero) {
assert(VT.isInteger() && "Only integer types are supported!");

auto PeekThroughCastsAndTrunc = [](SDValue V) {
while (true) {
switch (V.getOpcode()) {
case ISD::TRUNCATE:
case ISD::ZERO_EXTEND:
V = V.getOperand(0);
break;
default:
return V;
}
}
};

if (VT.isScalableVector())
return SDValue();

Op = PeekThroughCastsAndTrunc(Op);

// Helper for determining whether a value is a power-2 constant scalar or a
// vector of such elements.
SmallVector<APInt> Pow2Constants;
auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
if (C->isZero() || C->isOpaque())
return false;
// TODO: We may also be able to support negative powers of 2 here.
if (C->getAPIntValue().isPowerOf2()) {
Pow2Constants.emplace_back(C->getAPIntValue());
return true;
}
return false;
};

if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) {
if (!VT.isVector())
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
// We need to create a build vector
SmallVector<SDValue> Log2Ops;
for (const APInt &Pow2 : Pow2Constants)
Log2Ops.emplace_back(
DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType()));
return DAG.getBuildVector(VT, DL, Log2Ops);
}

if (Depth >= DAG.MaxRecursionDepth)
return SDValue();

auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
EVT CurVT = ToCast.getValueType();
ToCast = PeekThroughCastsAndTrunc(ToCast);
if (NewVT == CurVT)
return ToCast;

if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
return DAG.getBitcast(NewVT, ToCast);

return DAG.getZExtOrTrunc(ToCast, DL, NewVT);
};

// log2(X << Y) -> log2(X) + Y
if (Op.getOpcode() == ISD::SHL) {
// 1 << Y and X nuw/nsw << Y are all non-zero.
if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0)))
if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0),
Depth + 1, AssumeNonZero))
return DAG.getNode(ISD::ADD, DL, VT, LogX,
CastToVT(VT, Op.getOperand(1)));
}

// c ? X : Y -> c ? Log2(X) : Log2(Y)
if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
Op.hasOneUse()) {
if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
Depth + 1, AssumeNonZero))
if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
Depth + 1, AssumeNonZero))
return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
}

// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
// log2(umax(X, Y)) -> umax(log2(X), log2(Y))
if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
Op.hasOneUse()) {
// Use AssumeNonZero as false here. Otherwise we can hit case where
// log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
if (SDValue LogX =
takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1,
/*AssumeNonZero*/ false))
if (SDValue LogY =
takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1,
/*AssumeNonZero*/ false))
return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY);
}

return SDValue();
}

/// Determines the LogBase2 value for a non-null input value using the
/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
EVT VT = V.getValueType();
SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
bool KnownNonZero, bool InexpensiveOnly,
std::optional<EVT> OutVT) {
EVT VT = OutVT ? *OutVT : V.getValueType();
SDValue InexpensiveLogBase2 =
takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero);
if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V))
return InexpensiveLogBase2;

SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22445,6 +22445,24 @@ bool X86TargetLowering::isXAndYEqZeroPreferableToXAndYEqY(ISD::CondCode Cond,
return !VT.isVector() || Cond != ISD::CondCode::SETEQ;
}

bool X86TargetLowering::optimizeFMulOrFDivAsShiftAddBitcast(
SDNode *N, SDValue, SDValue IntPow2) const {
if (N->getOpcode() == ISD::FDIV)
return true;

EVT FPVT = N->getValueType(0);
EVT IntVT = IntPow2.getValueType();

// This indicates a non-free bitcast.
// TODO: This is probably overly conservative as we will need to scale the
// integer vector anyways for the int->fp cast.
if (FPVT.isVector() &&
FPVT.getScalarSizeInBits() != IntVT.getScalarSizeInBits())
return false;

return true;
}

/// Check if replacement of SQRT with RSQRT should be disabled.
bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1808,6 +1808,9 @@ namespace llvm {
const SDLoc &dl, SelectionDAG &DAG,
SDValue &X86CC) const;

bool optimizeFMulOrFDivAsShiftAddBitcast(SDNode *N, SDValue FPConst,
SDValue IntPow2) const override;

/// Check if replacement of SQRT with RSQRT should be disabled.
bool isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const override;

Expand Down
876 changes: 130 additions & 746 deletions llvm/test/CodeGen/AMDGPU/fold-int-pow2-with-fmul-or-fdiv.ll

Large diffs are not rendered by default.

1,403 changes: 213 additions & 1,190 deletions llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll

Large diffs are not rendered by default.