Skip to content

Commit

Permalink
[LegalizeTypes] Improve splitting for urem/udiv by constant for some …
Browse files Browse the repository at this point in the history
…constants.

For remainder:
If (1 << (Bitwidth / 2)) % Divisor == 1, we can add the high and low halves
together and use a (Bitwidth / 2) urem. If (BitWidth /2) is a legal integer
type, this urem will be expand by DAGCombiner using multiply by magic
constant. We do have to take into account that adding high and low
together can produce a carry, making it a (BitWidth / 2)+1 bit number.
So we need to also add back in the carry from the first addition.

For division:
We can use the above trick to compute the remainder, subtract that
remainder from the dividend, then multiply by the multiplicative
inverse of the Divisor modulo (1 << BitWidth).

This is based on the section "Remainder by Summing Digits" in
Hacker's delight.

The remainder trick is similar to a trick you may have learned for
determining if a decimal number is divisible by 3. You can add all the
digits together and see if the sum is divisible by 3. If you're not sure
if the sum is divisible by 3, you can add its digits together. This
can be repeated until you have a single decimal digit. If that digit
is 3, 6, or 9, then the original number is divisible by 3. This works
because 10 % 3 == 1.

gcc already does this same trick. There are additional tricks gcc
does urem as well as srem, udiv, and sdiv that I plan to add in
future patches.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D130862
  • Loading branch information
topperc committed Sep 12, 2022
1 parent 7eead18 commit 38ffa2b
Show file tree
Hide file tree
Showing 13 changed files with 1,713 additions and 648 deletions.
20 changes: 20 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -4713,6 +4713,26 @@ class TargetLowering : public TargetLoweringBase {
SDValue LL = SDValue(), SDValue LH = SDValue(),
SDValue RL = SDValue(), SDValue RH = SDValue()) const;

/// Attempt to expand an n-bit div/rem/divrem by constant using a n/2-bit
/// urem by constant and other arithmetic ops. The n/2-bit urem by constant
/// will be expanded by DAGCombiner. This is not possible for all constant
/// divisors.
/// \param N Node to expand
/// \param Result A vector that will be filled with the lo and high parts of
/// the results. For *DIVREM, this will be the quotient parts followed
/// by the remainder parts.
/// \param HiLoVT The value type to use for the Lo and Hi parts. Should be
/// half of VT.
/// \param LL Low bits of the LHS of the operation. You can use this
/// parameter if you want to control how low bits are extracted from
/// the LHS.
/// \param LH High bits of the LHS of the operation. See LL for meaning.
/// \returns true if the node has been expanded, false if it has not.
bool expandDIVREMByConstant(SDNode *N, SmallVectorImpl<SDValue> &Result,
EVT HiLoVT, SelectionDAG &DAG,
SDValue LL = SDValue(),
SDValue LH = SDValue()) const;

/// Expand funnel shift.
/// \param N Node to expand
/// \returns The expansion if successful, SDValue() otherwise
Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4427,6 +4427,22 @@ void DAGTypeLegalizer::ExpandIntRes_UDIV(SDNode *N,
return;
}

// Try to expand UDIV by constant.
if (isa<ConstantSDNode>(N->getOperand(1))) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
// Only if the new type is legal.
if (isTypeLegal(NVT)) {
SDValue InL, InH;
GetExpandedInteger(N->getOperand(0), InL, InH);
SmallVector<SDValue> Result;
if (TLI.expandDIVREMByConstant(N, Result, NVT, DAG, InL, InH)) {
Lo = Result[0];
Hi = Result[1];
return;
}
}
}

RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
if (VT == MVT::i16)
LC = RTLIB::UDIV_I16;
Expand Down Expand Up @@ -4454,6 +4470,22 @@ void DAGTypeLegalizer::ExpandIntRes_UREM(SDNode *N,
return;
}

// Try to expand UREM by constant.
if (isa<ConstantSDNode>(N->getOperand(1))) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
// Only if the new type is legal.
if (isTypeLegal(NVT)) {
SDValue InL, InH;
GetExpandedInteger(N->getOperand(0), InL, InH);
SmallVector<SDValue> Result;
if (TLI.expandDIVREMByConstant(N, Result, NVT, DAG, InL, InH)) {
Lo = Result[0];
Hi = Result[1];
return;
}
}
}

RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
if (VT == MVT::i16)
LC = RTLIB::UREM_I16;
Expand Down
146 changes: 146 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7150,6 +7150,152 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
return Ok;
}

// Optimize unsigned division or remainder by constants for types twice as large
// as a legal VT.
//
// If (1 << (BitWidth / 2)) % Constant == 1, then the remainder
// can be computed
// as:
// Sum += __builtin_uadd_overflow(Lo, High, &Sum);
// Remainder = Sum % Constant
// This is based on "Remainder by Summing Digits" from Hacker's Delight.
//
// For division, we can compute the remainder, subtract it from the dividend,
// and then multiply by the multiplicative inverse modulo (1 << (BitWidth / 2)).
bool TargetLowering::expandDIVREMByConstant(SDNode *N,
SmallVectorImpl<SDValue> &Result,
EVT HiLoVT, SelectionDAG &DAG,
SDValue LL, SDValue LH) const {
unsigned Opcode = N->getOpcode();
EVT VT = N->getValueType(0);

// TODO: Support signed division/remainder.
if (Opcode == ISD::SREM || Opcode == ISD::SDIV || Opcode == ISD::SDIVREM)
return false;
assert(
(Opcode == ISD::UREM || Opcode == ISD::UDIV || Opcode == ISD::UDIVREM) &&
"Unexpected opcode");

auto *CN = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!CN)
return false;

const APInt &Divisor = CN->getAPIntValue();
unsigned BitWidth = Divisor.getBitWidth();
unsigned HBitWidth = BitWidth / 2;
assert(VT.getScalarSizeInBits() == BitWidth &&
HiLoVT.getScalarSizeInBits() == HBitWidth && "Unexpected VTs");

// Divisor needs to less than (1 << HBitWidth).
APInt HalfMaxPlus1 = APInt::getOneBitSet(BitWidth, HBitWidth);
if (Divisor.uge(HalfMaxPlus1))
return false;

// We depend on the UREM by constant optimization in DAGCombiner that requires
// high multiply.
if (!isOperationLegalOrCustom(ISD::MULHU, HiLoVT) &&
!isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT))
return false;

// Don't expand if optimizing for size.
if (DAG.shouldOptForSize())
return false;

// Early out for 0, 1 or even divisors.
if (Divisor.ule(1) || Divisor[0] == 0)
return false;

SDLoc dl(N);
SDValue Sum;

// If (1 << HBitWidth) % divisor == 1, we can add the two halves together and
// then add in the carry.
// TODO: If we can't split it in half, we might be able to split into 3 or
// more pieces using a smaller bit width.
if (HalfMaxPlus1.urem(Divisor).isOneValue()) {
assert(!LL == !LH && "Expected both input halves or no input halves!");
if (!LL) {
LL = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, N->getOperand(0),
DAG.getIntPtrConstant(0, dl));
LH = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, N->getOperand(0),
DAG.getIntPtrConstant(1, dl));
}

// Use addcarry if we can, otherwise use a compare to detect overflow.
EVT SetCCType =
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), HiLoVT);
if (isOperationLegalOrCustom(ISD::ADDCARRY, HiLoVT)) {
SDVTList VTList = DAG.getVTList(HiLoVT, SetCCType);
Sum = DAG.getNode(ISD::UADDO, dl, VTList, LL, LH);
Sum = DAG.getNode(ISD::ADDCARRY, dl, VTList, Sum,
DAG.getConstant(0, dl, HiLoVT), Sum.getValue(1));
} else {
Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, LL, LH);
SDValue Carry = DAG.getSetCC(dl, SetCCType, Sum, LL, ISD::SETULT);
// If the boolean for the target is 0 or 1, we can add the setcc result
// directly.
if (getBooleanContents(HiLoVT) ==
TargetLoweringBase::ZeroOrOneBooleanContent)
Carry = DAG.getZExtOrTrunc(Carry, dl, HiLoVT);
else
Carry = DAG.getSelect(dl, HiLoVT, Carry, DAG.getConstant(1, dl, HiLoVT),
DAG.getConstant(0, dl, HiLoVT));
Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, Sum, Carry);
}
}

// If we didn't find a sum, we can't do the expansion.
if (!Sum)
return false;

// Perform a HiLoVT urem on the Sum using truncated divisor.
SDValue RemL =
DAG.getNode(ISD::UREM, dl, HiLoVT, Sum,
DAG.getConstant(Divisor.trunc(HBitWidth), dl, HiLoVT));
// High half of the remainder is 0.
SDValue RemH = DAG.getConstant(0, dl, HiLoVT);

// If we only want remainder, we're done.
if (Opcode == ISD::UREM) {
Result.push_back(RemL);
Result.push_back(RemH);
return true;
}

// Otherwise, we need to compute the quotient.

// Join the remainder halves.
SDValue Rem = DAG.getNode(ISD::BUILD_PAIR, dl, VT, RemL, RemH);

// Subtract the remainder from the input.
SDValue In = DAG.getNode(ISD::SUB, dl, VT, N->getOperand(0), Rem);

// Multiply by the multiplicative inverse of the divisor modulo
// (1 << BitWidth).
APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
APInt MulFactor = Divisor.zext(BitWidth + 1);
MulFactor = MulFactor.multiplicativeInverse(Mod);
MulFactor = MulFactor.trunc(BitWidth);

SDValue Quotient =
DAG.getNode(ISD::MUL, dl, VT, In, DAG.getConstant(MulFactor, dl, VT));

// Split the quotient into low and high parts.
SDValue QuotL = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
DAG.getIntPtrConstant(0, dl));
SDValue QuotH = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
DAG.getIntPtrConstant(1, dl));
Result.push_back(QuotL);
Result.push_back(QuotH);
// For DIVREM, also return the remainder parts.
if (Opcode == ISD::UDIVREM) {
Result.push_back(RemL);
Result.push_back(RemH);
}

return true;
}

// Check that (every element of) Z is undef or not an exact multiple of BW.
static bool isNonZeroModBitWidthOrUndef(SDValue Z, unsigned BW) {
return ISD::matchUnaryPredicate(
Expand Down
26 changes: 24 additions & 2 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20417,9 +20417,22 @@ SDValue ARMTargetLowering::LowerDivRem(SDValue Op, SelectionDAG &DAG) const {
"Invalid opcode for Div/Rem lowering");
bool isSigned = (Opcode == ISD::SDIVREM);
EVT VT = Op->getValueType(0);
Type *Ty = VT.getTypeForEVT(*DAG.getContext());
SDLoc dl(Op);

if (VT == MVT::i64 && isa<ConstantSDNode>(Op.getOperand(1))) {
SmallVector<SDValue> Result;
if (expandDIVREMByConstant(Op.getNode(), Result, MVT::i32, DAG)) {
SDValue Res0 =
DAG.getNode(ISD::BUILD_PAIR, dl, VT, Result[0], Result[1]);
SDValue Res1 =
DAG.getNode(ISD::BUILD_PAIR, dl, VT, Result[2], Result[3]);
return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(),
{Res0, Res1});
}
}

Type *Ty = VT.getTypeForEVT(*DAG.getContext());

// If the target has hardware divide, use divide + multiply + subtract:
// div = a / b
// rem = a - b * div
Expand Down Expand Up @@ -20468,11 +20481,20 @@ SDValue ARMTargetLowering::LowerDivRem(SDValue Op, SelectionDAG &DAG) const {
// Lowers REM using divmod helpers
// see RTABI section 4.2/4.3
SDValue ARMTargetLowering::LowerREM(SDNode *N, SelectionDAG &DAG) const {
EVT VT = N->getValueType(0);

if (VT == MVT::i64 && isa<ConstantSDNode>(N->getOperand(1))) {
SmallVector<SDValue> Result;
if (expandDIVREMByConstant(N, Result, MVT::i32, DAG))
return DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), N->getValueType(0),
Result[0], Result[1]);
}

// Build return types (div and rem)
std::vector<Type*> RetTyParams;
Type *RetTyElement;

switch (N->getValueType(0).getSimpleVT().SimpleTy) {
switch (VT.getSimpleVT().SimpleTy) {
default: llvm_unreachable("Unexpected request for libcall!");
case MVT::i8: RetTyElement = Type::getInt8Ty(*DAG.getContext()); break;
case MVT::i16: RetTyElement = Type::getInt16Ty(*DAG.getContext()); break;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29528,6 +29528,12 @@ SDValue X86TargetLowering::LowerWin64_i128OP(SDValue Op, SelectionDAG &DAG) cons
assert(VT.isInteger() && VT.getSizeInBits() == 128 &&
"Unexpected return type for lowering");

if (isa<ConstantSDNode>(Op->getOperand(1))) {
SmallVector<SDValue> Result;
if (expandDIVREMByConstant(Op.getNode(), Result, MVT::i64, DAG))
return DAG.getNode(ISD::BUILD_PAIR, SDLoc(Op), VT, Result[0], Result[1]);
}

RTLIB::Libcall LC;
bool isSigned;
switch (Op->getOpcode()) {
Expand Down
30 changes: 30 additions & 0 deletions llvm/test/CodeGen/ARM/div.ll
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,33 @@ entry:
%tmp1 = urem i64 %a, %b ; <i64> [#uses=1]
ret i64 %tmp1
}

; Make sure we avoid a libcall for some constants.
define i64 @f7(i64 %a) {
; CHECK-SWDIV-LABEL: f7
; CHECK-SWDIV: adc
; CHECK-SWDIV: umull
; CHECK-HWDIV-LABEL: f7
; CHECK-HWDIV: adc
; CHECK-HWDIV: umull
; CHECK-EABI-LABEL: f7
; CHECK-EABI: adc
; CHECK-EABI: umull
%tmp1 = urem i64 %a, 3
ret i64 %tmp1
}

; Make sure we avoid a libcall for some constants.
define i64 @f8(i64 %a) {
; CHECK-SWDIV-LABEL: f8
; CHECK-SWDIV: adc
; CHECK-SWDIV: umull
; CHECK-HWDIV-LABEL: f8
; CHECK-HWDIV: adc
; CHECK-HWDIV: umull
; CHECK-EABI-LABEL: f8
; CHECK-EABI: adc
; CHECK-EABI: umull
%tmp1 = udiv i64 %a, 3
ret i64 %tmp1
}
27 changes: 20 additions & 7 deletions llvm/test/CodeGen/RISCV/div-by-constant.ll
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,26 @@ define i32 @udiv_constant_add(i32 %a) nounwind {
define i64 @udiv64_constant_no_add(i64 %a) nounwind {
; RV32-LABEL: udiv64_constant_no_add:
; RV32: # %bb.0:
; RV32-NEXT: addi sp, sp, -16
; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
; RV32-NEXT: li a2, 5
; RV32-NEXT: li a3, 0
; RV32-NEXT: call __udivdi3@plt
; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
; RV32-NEXT: addi sp, sp, 16
; RV32-NEXT: add a2, a0, a1
; RV32-NEXT: sltu a3, a2, a0
; RV32-NEXT: add a2, a2, a3
; RV32-NEXT: lui a3, 838861
; RV32-NEXT: addi a4, a3, -819
; RV32-NEXT: mulhu a5, a2, a4
; RV32-NEXT: srli a6, a5, 2
; RV32-NEXT: andi a5, a5, -4
; RV32-NEXT: add a5, a5, a6
; RV32-NEXT: sub a2, a2, a5
; RV32-NEXT: sub a5, a0, a2
; RV32-NEXT: addi a3, a3, -820
; RV32-NEXT: mul a3, a5, a3
; RV32-NEXT: mulhu a6, a5, a4
; RV32-NEXT: add a3, a6, a3
; RV32-NEXT: sltu a0, a0, a2
; RV32-NEXT: sub a0, a1, a0
; RV32-NEXT: mul a0, a0, a4
; RV32-NEXT: add a1, a3, a0
; RV32-NEXT: mul a0, a5, a4
; RV32-NEXT: ret
;
; RV64-LABEL: udiv64_constant_no_add:
Expand Down
27 changes: 20 additions & 7 deletions llvm/test/CodeGen/RISCV/div.ll
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,26 @@ define i64 @udiv64_constant(i64 %a) nounwind {
;
; RV32IM-LABEL: udiv64_constant:
; RV32IM: # %bb.0:
; RV32IM-NEXT: addi sp, sp, -16
; RV32IM-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
; RV32IM-NEXT: li a2, 5
; RV32IM-NEXT: li a3, 0
; RV32IM-NEXT: call __udivdi3@plt
; RV32IM-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
; RV32IM-NEXT: addi sp, sp, 16
; RV32IM-NEXT: add a2, a0, a1
; RV32IM-NEXT: sltu a3, a2, a0
; RV32IM-NEXT: add a2, a2, a3
; RV32IM-NEXT: lui a3, 838861
; RV32IM-NEXT: addi a4, a3, -819
; RV32IM-NEXT: mulhu a5, a2, a4
; RV32IM-NEXT: srli a6, a5, 2
; RV32IM-NEXT: andi a5, a5, -4
; RV32IM-NEXT: add a5, a5, a6
; RV32IM-NEXT: sub a2, a2, a5
; RV32IM-NEXT: sub a5, a0, a2
; RV32IM-NEXT: addi a3, a3, -820
; RV32IM-NEXT: mul a3, a5, a3
; RV32IM-NEXT: mulhu a6, a5, a4
; RV32IM-NEXT: add a3, a6, a3
; RV32IM-NEXT: sltu a0, a0, a2
; RV32IM-NEXT: sub a0, a1, a0
; RV32IM-NEXT: mul a0, a0, a4
; RV32IM-NEXT: add a1, a3, a0
; RV32IM-NEXT: mul a0, a5, a4
; RV32IM-NEXT: ret
;
; RV64I-LABEL: udiv64_constant:
Expand Down
Loading

0 comments on commit 38ffa2b

Please sign in to comment.