Skip to content

Commit

Permalink
[SelectionDAG] Emit calls to __divei4 and friends for division/remain…
Browse files Browse the repository at this point in the history
…der of large integers

Emit calls to __divei4 and friends for divison/remainder of large integers.

This fixes #44994.

The overall RFC is in https://discourse.llvm.org/t/rfc-add-support-for-division-of-large-bitint-builtins-selectiondag-globalisel-clang/60329

The compiler-rt part is in https://reviews.llvm.org/D120327

Differential Revision: https://reviews.llvm.org/D120329
  • Loading branch information
mgehre-amd committed Mar 16, 2022
1 parent f3cbe60 commit 09854f2
Show file tree
Hide file tree
Showing 5 changed files with 1,529 additions and 34 deletions.
14 changes: 14 additions & 0 deletions llvm/include/llvm/IR/RuntimeLibcalls.def
Expand Up @@ -47,6 +47,8 @@ HANDLE_LIBCALL(MUL_I16, "__mulhi3")
HANDLE_LIBCALL(MUL_I32, "__mulsi3")
HANDLE_LIBCALL(MUL_I64, "__muldi3")
HANDLE_LIBCALL(MUL_I128, "__multi3")
HANDLE_LIBCALL(MUL_IEXT, nullptr)

HANDLE_LIBCALL(MULO_I32, "__mulosi4")
HANDLE_LIBCALL(MULO_I64, "__mulodi4")
HANDLE_LIBCALL(MULO_I128, "__muloti4")
Expand All @@ -55,31 +57,43 @@ HANDLE_LIBCALL(SDIV_I16, "__divhi3")
HANDLE_LIBCALL(SDIV_I32, "__divsi3")
HANDLE_LIBCALL(SDIV_I64, "__divdi3")
HANDLE_LIBCALL(SDIV_I128, "__divti3")
HANDLE_LIBCALL(SDIV_IEXT, "__divei4")

HANDLE_LIBCALL(UDIV_I8, "__udivqi3")
HANDLE_LIBCALL(UDIV_I16, "__udivhi3")
HANDLE_LIBCALL(UDIV_I32, "__udivsi3")
HANDLE_LIBCALL(UDIV_I64, "__udivdi3")
HANDLE_LIBCALL(UDIV_I128, "__udivti3")
HANDLE_LIBCALL(UDIV_IEXT, "__udivei4")

HANDLE_LIBCALL(SREM_I8, "__modqi3")
HANDLE_LIBCALL(SREM_I16, "__modhi3")
HANDLE_LIBCALL(SREM_I32, "__modsi3")
HANDLE_LIBCALL(SREM_I64, "__moddi3")
HANDLE_LIBCALL(SREM_I128, "__modti3")
HANDLE_LIBCALL(SREM_IEXT, "__modei4")

HANDLE_LIBCALL(UREM_I8, "__umodqi3")
HANDLE_LIBCALL(UREM_I16, "__umodhi3")
HANDLE_LIBCALL(UREM_I32, "__umodsi3")
HANDLE_LIBCALL(UREM_I64, "__umoddi3")
HANDLE_LIBCALL(UREM_I128, "__umodti3")
HANDLE_LIBCALL(UREM_IEXT, "__umodei4")

HANDLE_LIBCALL(SDIVREM_I8, nullptr)
HANDLE_LIBCALL(SDIVREM_I16, nullptr)
HANDLE_LIBCALL(SDIVREM_I32, nullptr)
HANDLE_LIBCALL(SDIVREM_I64, nullptr)
HANDLE_LIBCALL(SDIVREM_I128, nullptr)
HANDLE_LIBCALL(SDIVREM_IEXT, nullptr)

HANDLE_LIBCALL(UDIVREM_I8, nullptr)
HANDLE_LIBCALL(UDIVREM_I16, nullptr)
HANDLE_LIBCALL(UDIVREM_I32, nullptr)
HANDLE_LIBCALL(UDIVREM_I64, nullptr)
HANDLE_LIBCALL(UDIVREM_I128, nullptr)
HANDLE_LIBCALL(UDIVREM_IEXT, nullptr)

HANDLE_LIBCALL(NEG_I32, "__negsi2")
HANDLE_LIBCALL(NEG_I64, "__negdi2")
HANDLE_LIBCALL(CTLZ_I32, "__clzsi2")
Expand Down
67 changes: 33 additions & 34 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Expand Up @@ -141,12 +141,10 @@ class SelectionDAGLegalize {
RTLIB::Libcall Call_F128,
RTLIB::Libcall Call_PPCF128,
SmallVectorImpl<SDValue> &Results);
SDValue ExpandIntLibCall(SDNode *Node, bool isSigned,
RTLIB::Libcall Call_I8,
RTLIB::Libcall Call_I16,
RTLIB::Libcall Call_I32,
RTLIB::Libcall Call_I64,
RTLIB::Libcall Call_I128);
SDValue ExpandIntLibCall(SDNode *Node, bool isSigned, RTLIB::Libcall Call_I8,
RTLIB::Libcall Call_I16, RTLIB::Libcall Call_I32,
RTLIB::Libcall Call_I64, RTLIB::Libcall Call_I128,
RTLIB::Libcall Call_IEXT);
void ExpandArgFPLibCall(SDNode *Node,
RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
Expand Down Expand Up @@ -2105,15 +2103,17 @@ void SelectionDAGLegalize::ExpandFPLibCall(SDNode* Node,
ExpandFPLibCall(Node, LC, Results);
}

SDValue SelectionDAGLegalize::ExpandIntLibCall(SDNode* Node, bool isSigned,
RTLIB::Libcall Call_I8,
RTLIB::Libcall Call_I16,
RTLIB::Libcall Call_I32,
RTLIB::Libcall Call_I64,
RTLIB::Libcall Call_I128) {
SDValue SelectionDAGLegalize::ExpandIntLibCall(
SDNode *Node, bool isSigned, RTLIB::Libcall Call_I8,
RTLIB::Libcall Call_I16, RTLIB::Libcall Call_I32, RTLIB::Libcall Call_I64,
RTLIB::Libcall Call_I128, RTLIB::Libcall Call_IEXT) {
RTLIB::Libcall LC;
switch (Node->getSimpleValueType(0).SimpleTy) {
default: llvm_unreachable("Unexpected request for libcall!");

default:
LC = Call_IEXT;
break;

case MVT::i8: LC = Call_I8; break;
case MVT::i16: LC = Call_I16; break;
case MVT::i32: LC = Call_I32; break;
Expand Down Expand Up @@ -2148,7 +2148,11 @@ SelectionDAGLegalize::ExpandDivRemLibCall(SDNode *Node,

RTLIB::Libcall LC;
switch (Node->getSimpleValueType(0).SimpleTy) {
default: llvm_unreachable("Unexpected request for libcall!");

default:
LC = isSigned ? RTLIB::SDIVREM_IEXT : RTLIB::UDIVREM_IEXT;
break;

case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
Expand Down Expand Up @@ -4319,39 +4323,34 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
RTLIB::SUB_PPCF128, Results);
break;
case ISD::SREM:
Results.push_back(ExpandIntLibCall(Node, true,
RTLIB::SREM_I8,
RTLIB::SREM_I16, RTLIB::SREM_I32,
RTLIB::SREM_I64, RTLIB::SREM_I128));
Results.push_back(ExpandIntLibCall(
Node, true, RTLIB::SREM_I8, RTLIB::SREM_I16, RTLIB::SREM_I32,
RTLIB::SREM_I64, RTLIB::SREM_I128, RTLIB::SREM_IEXT));
break;
case ISD::UREM:
Results.push_back(ExpandIntLibCall(Node, false,
RTLIB::UREM_I8,
RTLIB::UREM_I16, RTLIB::UREM_I32,
RTLIB::UREM_I64, RTLIB::UREM_I128));
Results.push_back(ExpandIntLibCall(
Node, false, RTLIB::UREM_I8, RTLIB::UREM_I16, RTLIB::UREM_I32,
RTLIB::UREM_I64, RTLIB::UREM_I128, RTLIB::UREM_IEXT));
break;
case ISD::SDIV:
Results.push_back(ExpandIntLibCall(Node, true,
RTLIB::SDIV_I8,
RTLIB::SDIV_I16, RTLIB::SDIV_I32,
RTLIB::SDIV_I64, RTLIB::SDIV_I128));
Results.push_back(ExpandIntLibCall(
Node, true, RTLIB::SDIV_I8, RTLIB::SDIV_I16, RTLIB::SDIV_I32,
RTLIB::SDIV_I64, RTLIB::SDIV_I128, RTLIB::SDIV_IEXT));
break;
case ISD::UDIV:
Results.push_back(ExpandIntLibCall(Node, false,
RTLIB::UDIV_I8,
RTLIB::UDIV_I16, RTLIB::UDIV_I32,
RTLIB::UDIV_I64, RTLIB::UDIV_I128));
Results.push_back(ExpandIntLibCall(
Node, false, RTLIB::UDIV_I8, RTLIB::UDIV_I16, RTLIB::UDIV_I32,
RTLIB::UDIV_I64, RTLIB::UDIV_I128, RTLIB::UDIV_IEXT));
break;
case ISD::SDIVREM:
case ISD::UDIVREM:
// Expand into divrem libcall
ExpandDivRemLibCall(Node, Results);
break;
case ISD::MUL:
Results.push_back(ExpandIntLibCall(Node, false,
RTLIB::MUL_I8,
RTLIB::MUL_I16, RTLIB::MUL_I32,
RTLIB::MUL_I64, RTLIB::MUL_I128));
Results.push_back(ExpandIntLibCall(
Node, false, RTLIB::MUL_I8, RTLIB::MUL_I16, RTLIB::MUL_I32,
RTLIB::MUL_I64, RTLIB::MUL_I128, RTLIB::MUL_IEXT));
break;
case ISD::CTLZ_ZERO_UNDEF:
switch (Node->getSimpleValueType(0).SimpleTy) {
Expand Down
96 changes: 96 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Expand Up @@ -3914,6 +3914,70 @@ void DAGTypeLegalizer::ExpandIntRes_SADDSUBO(SDNode *Node,
ReplaceValueWith(SDValue(Node, 1), Ovf);
}

// Emit a call to __udivei4 and friends which require
// the arguments be based on the stack
// and extra argument that contains the number of bits of the operands.
// Returns the result of the call operation.
static SDValue ExpandExtIntRes_DIVREM(const TargetLowering &TLI,
const RTLIB::Libcall &LC,
SelectionDAG &DAG, SDNode *N,
const SDLoc &DL, const EVT &VT) {

SDValue InChain = DAG.getEntryNode();

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;

// The signature of __udivei4 is
// void __udivei4(unsigned int *quo, unsigned int *a, unsigned int *b,
// unsigned int bits)
EVT ArgVT = N->op_begin()->getValueType();
assert(ArgVT.isInteger() && ArgVT.getSizeInBits() > 128 &&
"Unexpected argument type for lowering");
Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());

SDValue Output = DAG.CreateStackTemporary(ArgVT);
Entry.Node = Output;
Entry.Ty = ArgTy->getPointerTo();
Entry.IsSExt = false;
Entry.IsZExt = false;
Args.push_back(Entry);

for (const llvm::SDUse &Op : N->ops()) {
SDValue StackPtr = DAG.CreateStackTemporary(ArgVT);
InChain = DAG.getStore(InChain, DL, Op, StackPtr, MachinePointerInfo());
Entry.Node = StackPtr;
Entry.Ty = ArgTy->getPointerTo();
Entry.IsSExt = false;
Entry.IsZExt = false;
Args.push_back(Entry);
}

int Bits = N->getOperand(0)
.getValueType()
.getTypeForEVT(*DAG.getContext())
->getIntegerBitWidth();
Entry.Node = DAG.getConstant(Bits, DL, TLI.getPointerTy(DAG.getDataLayout()));
Entry.Ty = Type::getInt32Ty(*DAG.getContext());
Entry.IsSExt = false;
Entry.IsZExt = true;
Args.push_back(Entry);

SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
TLI.getPointerTy(DAG.getDataLayout()));

TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL)
.setChain(InChain)
.setLibCallee(TLI.getLibcallCallingConv(LC),
Type::getVoidTy(*DAG.getContext()), Callee, std::move(Args))
.setDiscardResult();

SDValue Chain = TLI.LowerCallTo(CLI).second;

return DAG.getLoad(ArgVT, DL, Chain, Output, MachinePointerInfo());
}

void DAGTypeLegalizer::ExpandIntRes_SDIV(SDNode *N,
SDValue &Lo, SDValue &Hi) {
EVT VT = N->getValueType(0);
Expand All @@ -3935,6 +3999,14 @@ void DAGTypeLegalizer::ExpandIntRes_SDIV(SDNode *N,
LC = RTLIB::SDIV_I64;
else if (VT == MVT::i128)
LC = RTLIB::SDIV_I128;

else {
SDValue Result =
ExpandExtIntRes_DIVREM(TLI, RTLIB::SDIV_IEXT, DAG, N, dl, VT);
SplitInteger(Result, Lo, Hi);
return;
}

assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported SDIV!");

TargetLowering::MakeLibCallOptions CallOptions;
Expand Down Expand Up @@ -4126,6 +4198,14 @@ void DAGTypeLegalizer::ExpandIntRes_SREM(SDNode *N,
LC = RTLIB::SREM_I64;
else if (VT == MVT::i128)
LC = RTLIB::SREM_I128;

else {
SDValue Result =
ExpandExtIntRes_DIVREM(TLI, RTLIB::SREM_IEXT, DAG, N, dl, VT);
SplitInteger(Result, Lo, Hi);
return;
}

assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported SREM!");

TargetLowering::MakeLibCallOptions CallOptions;
Expand Down Expand Up @@ -4301,6 +4381,14 @@ void DAGTypeLegalizer::ExpandIntRes_UDIV(SDNode *N,
LC = RTLIB::UDIV_I64;
else if (VT == MVT::i128)
LC = RTLIB::UDIV_I128;

else {
SDValue Result =
ExpandExtIntRes_DIVREM(TLI, RTLIB::UDIV_IEXT, DAG, N, dl, VT);
SplitInteger(Result, Lo, Hi);
return;
}

assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported UDIV!");

TargetLowering::MakeLibCallOptions CallOptions;
Expand Down Expand Up @@ -4328,6 +4416,14 @@ void DAGTypeLegalizer::ExpandIntRes_UREM(SDNode *N,
LC = RTLIB::UREM_I64;
else if (VT == MVT::i128)
LC = RTLIB::UREM_I128;

else {
SDValue Result =
ExpandExtIntRes_DIVREM(TLI, RTLIB::UREM_IEXT, DAG, N, dl, VT);
SplitInteger(Result, Lo, Hi);
return;
}

assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported UREM!");

TargetLowering::MakeLibCallOptions CallOptions;
Expand Down

5 comments on commit 09854f2

@bridiver
Copy link

@bridiver bridiver commented on 09854f2 Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this re-enable support for BigInt > 128? a6cabd9 @mgehre-amd ?

@mgehre-amd
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_BitInt > 128 bit to/from float conversions still crash the backend, so we cannot re-enable the support yet.

@bridiver
Copy link

@bridiver bridiver commented on 09854f2 Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgehre-amd would it be possible to add a runtime flag to support >128bit? It's very important for anyone doing crypto work (like us) and the alternatives are not great. The removal of 256bit support forced us to scramble for an alternative and having a runtime flag would have made things much easier for us.

@mgehre-amd
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I opened https://reviews.llvm.org/D127287 to add a flag for enabling larger bitints

@bridiver
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much @mgehre-amd!

Please sign in to comment.