Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAG: Fix chain mismanagement in SoftenFloatRes_FP_EXTEND #74406

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
SDTIntToFPOp, [SDNPHasChain]>;
def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
SDTIntToFPOp, [SDNPHasChain]>;

def strict_f16_to_fp : SDNode<"ISD::STRICT_FP16_TO_FP",
SDTIntToFPOp, [SDNPHasChain]>;
def strict_fp_to_f16 : SDNode<"ISD::STRICT_FP_TO_FP16",
SDTFPToIntOp, [SDNPHasChain]>;

def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;

Expand Down Expand Up @@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
[(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
(setcc node:$lhs, node:$rhs, node:$pred)]>;

def any_f16_to_fp : PatFrags<(ops node:$src),
[(f16_to_fp node:$src),
(strict_f16_to_fp node:$src)]>;
def any_fp_to_f16 : PatFrags<(ops node:$src),
[(fp_to_f16 node:$src),
(strict_fp_to_f16 node:$src)]>;

multiclass binary_atomic_op_ord {
def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
(!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
Expand Down
94 changes: 92 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,11 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
Op = GetPromotedFloat(Op);
// If the promotion did the FP_EXTEND to the destination type for us,
// there's nothing left to do here.
if (Op.getValueType() == N->getValueType(0))
if (Op.getValueType() == N->getValueType(0)) {
if (IsStrict)
ReplaceValueWith(SDValue(N, 1), Chain);
return BitConvertToInteger(Op);
}
}

// There's only a libcall for f16 -> f32 and shifting is only valid for bf16
Expand All @@ -541,8 +544,10 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
}
}

if (Op.getValueType() == MVT::bf16)
if (Op.getValueType() == MVT::bf16) {
// FIXME: Need ReplaceValueWith on chain in strict case
return SoftenFloatRes_BF16_TO_FP(N);
}

RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0));
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!");
Expand Down Expand Up @@ -2181,6 +2186,24 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
report_fatal_error("Attempt at an invalid promotion-related conversion");
}

static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
if (OpVT == MVT::f16)
return ISD::STRICT_FP16_TO_FP;

if (RetVT == MVT::f16)
return ISD::STRICT_FP_TO_FP16;

if (OpVT == MVT::bf16) {
// TODO: return ISD::STRICT_BF16_TO_FP;
}

if (RetVT == MVT::bf16) {
// TODO: return ISD::STRICT_FP_TO_BF16;
}

report_fatal_error("Attempt at an invalid promotion-related conversion");
}

bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
SDValue R = SDValue();
Expand Down Expand Up @@ -2214,6 +2237,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::FP_TO_UINT_SAT:
R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
case ISD::FP_EXTEND: R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
case ISD::STRICT_FP_EXTEND:
R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
break;
case ISD::SELECT_CC: R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
case ISD::SETCC: R = PromoteFloatOp_SETCC(N, OpNo); break;
case ISD::STORE: R = PromoteFloatOp_STORE(N, OpNo); break;
Expand Down Expand Up @@ -2276,6 +2302,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
}

SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
unsigned OpNo) {
assert(OpNo == 1 && "Promoting unpromotable operand");

SDValue Op = GetPromotedFloat(N->getOperand(1));
EVT VT = N->getValueType(0);

// Desired VT is same as promoted type. Use promoted float directly.
if (VT == Op->getValueType(0)) {
ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
return Op;
}

// Else, extend the promoted float value to the desired VT.
SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
N->getOperand(0), Op);
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}

// Promote the float operands used for comparison. The true- and false-
// operands have the same type as the result and are promoted, if needed, by
// PromoteFloatRes_SELECT_CC
Expand Down Expand Up @@ -2393,12 +2439,16 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
case ISD::FFREXP: R = PromoteFloatRes_FFREXP(N); break;

case ISD::FP_ROUND: R = PromoteFloatRes_FP_ROUND(N); break;
case ISD::STRICT_FP_ROUND:
R = PromoteFloatRes_STRICT_FP_ROUND(N);
break;
case ISD::LOAD: R = PromoteFloatRes_LOAD(N); break;
case ISD::SELECT: R = PromoteFloatRes_SELECT(N); break;
case ISD::SELECT_CC: R = PromoteFloatRes_SELECT_CC(N); break;

case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP: R = PromoteFloatRes_XINT_TO_FP(N); break;
case ISD::STRICT_SINT_TO_FP: R = PromoteFloatRes_STRICT_XINT_TO_FP(N); break;
case ISD::UNDEF: R = PromoteFloatRes_UNDEF(N); break;
case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
case ISD::VECREDUCE_FADD:
Expand Down Expand Up @@ -2598,6 +2648,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
}

// Explicit operation to reduce precision. Reduce the value to half precision
// and promote it back to the legal type.
SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
SDLoc DL(N);

SDValue Chain = N->getOperand(0);
SDValue Op = N->getOperand(1);
EVT VT = N->getValueType(0);
EVT OpVT = Op->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());

// Round promoted float to desired precision
SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
DAG.getVTList(IVT, MVT::Other), Chain, Op);
// Promote it back to the legal output type
SDValue Res =
DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}

SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
LoadSDNode *L = cast<LoadSDNode>(N);
EVT VT = N->getValueType(0);
Expand Down Expand Up @@ -2651,6 +2724,23 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_XINT_TO_FP(SDNode *N) {
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)));
}

// Construct a SDNode that transforms the SINT or UINT operand to the promoted
// float type.
SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_XINT_TO_FP(SDNode *N) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDVTList NVTs = DAG.getVTList(NVT, MVT::Other);

SDValue NV = DAG.getNode(N->getOpcode(), DL, NVTs, N->getOperand(0), N->getOperand(1));

// Round the value to the desired precision (that of the source type).
SDValue Rounded = DAG.getNode(ISD::STRICT_FP_ROUND, DL, N->getVTList(), NV.getValue(1), NV,
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
return DAG.getNode(
ISD::STRICT_FP_EXTEND, DL, NVTs, Rounded.getValue(1), Rounded.getValue(0));
}

SDValue DAGTypeLegalizer::PromoteFloatRes_UNDEF(SDNode *N) {
return DAG.getUNDEF(TLI.getTypeToTransformTo(*DAG.getContext(),
N->getValueType(0)));
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::FP_TO_FP16:
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
break;

case ISD::STRICT_FP_TO_FP16:
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
break;
case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;

case ISD::AND:
Expand Down Expand Up @@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
}

SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDLoc dl(N);

SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
N->getOperand(0), N->getOperand(1));
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}

SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDLoc dl(N);
Expand Down Expand Up @@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::FP16_TO_FP:
case ISD::VP_UINT_TO_FP:
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
case ISD::STRICT_FP16_TO_FP:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just doesn't look goo imo. Why not add PromoteIntOp_STRICT_FP16_TO_FP? Or rename PromoteIntOp_STRICT_UINT_TO_FP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The non-strict path does the same thing, so I'm just copying that. I don't see much point in renaming it

case ISD::STRICT_UINT_TO_FP: Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
case ISD::ZERO_EXTEND: Res = PromoteIntOp_ZERO_EXTEND(N); break;
case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
SDValue PromoteIntRes_XRINT(SDNode *N);
SDValue PromoteIntRes_FREEZE(SDNode *N);
SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
Expand Down Expand Up @@ -698,20 +699,23 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteFloatRes_ExpOp(SDNode *N);
SDValue PromoteFloatRes_FFREXP(SDNode *N);
SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
SDValue PromoteFloatRes_LOAD(SDNode *N);
SDValue PromoteFloatRes_SELECT(SDNode *N);
SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
SDValue PromoteFloatRes_UnaryOp(SDNode *N);
SDValue PromoteFloatRes_UNDEF(SDNode *N);
SDValue BitcastToInt_ATOMIC_SWAP(SDNode *N);
SDValue PromoteFloatRes_XINT_TO_FP(SDNode *N);
SDValue PromoteFloatRes_STRICT_XINT_TO_FP(SDNode *N);
SDValue PromoteFloatRes_VECREDUCE(SDNode *N);
SDValue PromoteFloatRes_VECREDUCE_SEQ(SDNode *N);

bool PromoteFloatOperand(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
Expand Down
Loading
Loading