Expand Up
@@ -611,6 +611,7 @@ namespace {
SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
SDValue combineRepeatedFPDivisors(SDNode *N);
SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
Expand All
@@ -620,7 +621,10 @@ namespace {
SDValue BuildUDIV(SDNode *N);
SDValue BuildSREMPow2(SDNode *N);
SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
bool KnownNeverZero = false,
bool InexpensiveOnly = false,
std::optional<EVT> OutVT = std::nullopt);
SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
Expand Down
Expand Up
@@ -4389,12 +4393,12 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
// fold (mul x, (1 << c)) -> x << c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1) &&
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
SDValue LogBase2 = BuildLogBase2(N1, DL);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
}
}
// fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
Expand Down
Expand Up
@@ -4916,31 +4920,31 @@ SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
EVT VT = N->getValueType(0);
// fold (udiv x, (1 << c)) -> x >>u c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1)) {
SDValue LogBase2 = BuildLogBase2(N1, DL);
AddToWorklist(LogBase2.getNode());
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) {
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
AddToWorklist(LogBase2.getNode());
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
AddToWorklist(Trunc.getNode());
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
AddToWorklist(Trunc.getNode());
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
}
}
// fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
if (N1.getOpcode() == ISD::SHL) {
SDValue N10 = N1.getOperand(0);
if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo (N10)) {
SDValue LogBase2 = BuildLogBase2(N10, DL );
AddToWorklist(LogBase2.getNode());
EVT ADDVT = N1.getOperand(1).getValueType( );
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT );
AddToWorklist(Trunc .getNode() );
SDValue Add = DAG .getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc );
AddToWorklist(Add .getNode() );
return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) {
if (SDValue LogBase2 = BuildLogBase2 (N10, DL )) {
AddToWorklist(LogBase2.getNode() );
EVT ADDVT = N1.getOperand(1).getValueType();
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT );
AddToWorklist(Trunc.getNode() );
SDValue Add = DAG .getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc );
AddToWorklist(Add .getNode() );
return DAG .getNode(ISD::SRL, DL, VT, N0, Add );
}
}
}
Expand Down
Expand Up
@@ -5158,14 +5162,15 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
// fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
unsigned NumEltBits = VT.getScalarSizeInBits();
SDValue LogBase2 = BuildLogBase2(N1, DL);
SDValue SRLAmt = DAG.getNode(
ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
hasOperation(ISD::SRL, VT)) {
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
unsigned NumEltBits = VT.getScalarSizeInBits();
SDValue SRLAmt = DAG.getNode(
ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
}
}
// If the type twice as wide is legal, transform the mulhu to a wider multiply
Expand Down
Expand Up
@@ -16328,6 +16333,105 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
return SDValue();
}
// Transform IEEE Floats:
// (fmul C, (uitofp Pow2))
// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
// (fdiv C, (uitofp Pow2))
// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
//
// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
// there is no need for more than an add/sub.
//
// This is valid under the following circumstances:
// 1) We are dealing with IEEE floats
// 2) C is normal
// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
// TODO: Much of this could also be used for generating `ldexp` on targets the
// prefer it.
SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
EVT VT = N->getValueType(0);
SDValue ConstOp, Pow2Op;
int Mantissa = -1;
auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
return false;
ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
Pow2Op = N->getOperand(1 - ConstOpIdx);
if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
(Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
!DAG.computeKnownBits(Pow2Op).isNonNegative()))
return false;
Pow2Op = Pow2Op.getOperand(0);
// TODO(1): We may be able to include undefs.
// TODO(2): We could also handle non-splat vector types.
ConstantFPSDNode *CFP =
isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false);
if (CFP == nullptr)
return false;
const APFloat &APF = CFP->getValueAPF();
// Make sure we have normal/ieee constant.
if (!APF.isNormal() || !APF.isIEEE())
return false;
// `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
// TODO: We could use knownbits to make this bound more precise.
int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
// Make sure the floats exponent is within the bounds that this transform
// produces bitwise equals value.
int CurExp = ilogb(APF);
// FMul by pow2 will only increase exponent.
int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
// FDiv by pow2 will only decrease exponent.
int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
return false;
// Finally make sure we actually know the mantissa for the float type.
Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
return Mantissa > 0;
};
if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
return SDValue();
if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
return SDValue();
// Get log2 after all other checks have taken place. This is because
// BuildLogBase2 may create a new node.
SDLoc DL(N);
// Get Log2 type with same bitwidth as the float type (VT).
EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits());
if (VT.isVector())
NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT,
VT.getVectorNumElements());
SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
/*InexpensiveOnly*/ true, NewIntVT);
if (!Log2)
return SDValue();
// Perform actual transform.
SDValue MantissaShiftCnt =
DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT));
// TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
// `(X << C1) + (C << C1)`, but that isn't always the case because of the
// cast. We could implement that by handle here to handle the casts.
SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
SDValue ResAsInt =
DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
return ResAsFP;
}
SDValue DAGCombiner::visitFMUL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down
Expand Up
@@ -16468,6 +16572,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
return Fused;
}
// Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
// able to run.
if (SDValue R = combineFMulOrFDivWithIntPow2(N))
return R;
return SDValue();
}
Expand Down
Expand Up
@@ -16819,6 +16928,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
}
if (SDValue R = combineFMulOrFDivWithIntPow2(N))
return R;
return SDValue();
}
Expand Down
Expand Up
@@ -21861,7 +21973,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
if (DAG.isKnownNeverZero(Index))
return DAG.getUNDEF(ScalarVT);
// Check if the result type doesn't match the inserted element type.
// Check if the result type doesn't match the inserted element type.
// The inserted element and extracted element may have mismatched bitwidth.
// As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
SDValue InOp = VecOp.getOperand(0);
Expand Down
Expand Up
@@ -27142,10 +27254,129 @@ SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
return SDValue();
}
// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
//
// Returns the node that represents `Log2(Op)`. This may create a new node. If
// we are unable to compute `Log2(Op)` its return `SDValue()`.
//
// All nodes will be created at `DL` and the output will be of type `VT`.
//
// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
// `AssumeNonZero` if this function should simply assume (not require proving
// `Op` is non-zero).
static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
SDValue Op, unsigned Depth,
bool AssumeNonZero) {
assert(VT.isInteger() && "Only integer types are supported!");
auto PeekThroughCastsAndTrunc = [](SDValue V) {
while (true) {
switch (V.getOpcode()) {
case ISD::TRUNCATE:
case ISD::ZERO_EXTEND:
V = V.getOperand(0);
break;
default:
return V;
}
}
};
if (VT.isScalableVector())
return SDValue();
Op = PeekThroughCastsAndTrunc(Op);
// Helper for determining whether a value is a power-2 constant scalar or a
// vector of such elements.
SmallVector<APInt> Pow2Constants;
auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
if (C->isZero() || C->isOpaque())
return false;
// TODO: We may also be able to support negative powers of 2 here.
if (C->getAPIntValue().isPowerOf2()) {
Pow2Constants.emplace_back(C->getAPIntValue());
return true;
}
return false;
};
if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) {
if (!VT.isVector())
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
// We need to create a build vector
SmallVector<SDValue> Log2Ops;
for (const APInt &Pow2 : Pow2Constants)
Log2Ops.emplace_back(
DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType()));
return DAG.getBuildVector(VT, DL, Log2Ops);
}
if (Depth >= DAG.MaxRecursionDepth)
return SDValue();
auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
EVT CurVT = ToCast.getValueType();
ToCast = PeekThroughCastsAndTrunc(ToCast);
if (NewVT == CurVT)
return ToCast;
if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
return DAG.getBitcast(NewVT, ToCast);
return DAG.getZExtOrTrunc(ToCast, DL, NewVT);
};
// log2(X << Y) -> log2(X) + Y
if (Op.getOpcode() == ISD::SHL) {
// 1 << Y and X nuw/nsw << Y are all non-zero.
if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0)))
if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0),
Depth + 1, AssumeNonZero))
return DAG.getNode(ISD::ADD, DL, VT, LogX,
CastToVT(VT, Op.getOperand(1)));
}
// c ? X : Y -> c ? Log2(X) : Log2(Y)
if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
Op.hasOneUse()) {
if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
Depth + 1, AssumeNonZero))
if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
Depth + 1, AssumeNonZero))
return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
}
// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
// log2(umax(X, Y)) -> umax(log2(X), log2(Y))
if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
Op.hasOneUse()) {
// Use AssumeNonZero as false here. Otherwise we can hit case where
// log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
if (SDValue LogX =
takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1,
/*AssumeNonZero*/ false))
if (SDValue LogY =
takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1,
/*AssumeNonZero*/ false))
return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY);
}
return SDValue();
}
/// Determines the LogBase2 value for a non-null input value using the
/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
EVT VT = V.getValueType();
SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
bool KnownNonZero, bool InexpensiveOnly,
std::optional<EVT> OutVT) {
EVT VT = OutVT ? *OutVT : V.getValueType();
SDValue InexpensiveLogBase2 =
takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero);
if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V))
return InexpensiveLogBase2;
SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
Expand Down