Skip to content

Commit

Permalink
[SelectionDAG] Rewrite bfloat16 softening to use the "half promotion"…
Browse files Browse the repository at this point in the history
… path

The main difference is that this preserves intermediate rounding steps,
which the other route doesn't. This aligns bfloat16 more with half
floats, which use this path on most targets.

I didn't understand what the difference was between these softening
approaches when I first added bfloat lowerings, would be nice if we only
had one of them.

Based on @pengfei 's D131502

Differential Revision: https://reviews.llvm.org/D133207
  • Loading branch information
d0k committed Sep 6, 2022
1 parent f2c17a1 commit c349d7f
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 141 deletions.
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2914,6 +2914,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
DAG.getConstant(16, dl,
TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
Op = DAG.getNode(ISD::BITCAST, dl, MVT::f32, Op);
// Add fp_extend in case the output is bigger than f32.
if (Node->getValueType(0) != MVT::f32)
Op = DAG.getNode(ISD::FP_EXTEND, dl, Node->getValueType(0), Op);
Results.push_back(Op);
break;
}
Expand Down
82 changes: 54 additions & 28 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2748,47 +2748,56 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FCOPYSIGN(SDNode *N) {
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
EVT OVT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
SDValue Op0 = GetSoftPromotedHalf(N->getOperand(0));
SDValue Op1 = GetSoftPromotedHalf(N->getOperand(1));
SDValue Op2 = GetSoftPromotedHalf(N->getOperand(2));
SDLoc dl(N);

// Promote to the larger FP type.
Op0 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op0);
Op1 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op1);
Op2 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op2);
auto PromotionOpcode = GetPromotionOpcode(OVT, NVT);
Op0 = DAG.getNode(PromotionOpcode, dl, NVT, Op0);
Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);
Op2 = DAG.getNode(PromotionOpcode, dl, NVT, Op2);

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2);

// Convert back to FP16 as an integer.
return DAG.getNode(ISD::FP_TO_FP16, dl, MVT::i16, Res);
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FPOWI(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
EVT OVT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
SDValue Op0 = GetSoftPromotedHalf(N->getOperand(0));
SDValue Op1 = N->getOperand(1);
SDLoc dl(N);

Op0 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op0);
// Promote to the larger FP type.
Op0 = DAG.getNode(GetPromotionOpcode(OVT, NVT), dl, NVT, Op0);

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1);

// Convert back to FP16 as an integer.
return DAG.getNode(ISD::FP_TO_FP16, dl, MVT::i16, Res);
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
EVT RVT = N->getValueType(0);
EVT SVT = N->getOperand(0).getValueType();

if (N->isStrictFPOpcode()) {
assert(RVT == MVT::f16);
SDValue Res =
DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
{N->getOperand(0), N->getOperand(1)});
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}

return DAG.getNode(ISD::FP_TO_FP16, SDLoc(N), MVT::i16, N->getOperand(0));
return DAG.getNode(GetPromotionOpcode(SVT, RVT), SDLoc(N), MVT::i16,
N->getOperand(0));
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_LOAD(SDNode *N) {
Expand Down Expand Up @@ -2823,47 +2832,51 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_SELECT_CC(SDNode *N) {
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_XINT_TO_FP(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
EVT OVT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
SDLoc dl(N);

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));

// Round the value to the softened type.
return DAG.getNode(ISD::FP_TO_FP16, dl, MVT::i16, Res);
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_UNDEF(SDNode *N) {
return DAG.getUNDEF(MVT::i16);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_UnaryOp(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
EVT OVT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
SDValue Op = GetSoftPromotedHalf(N->getOperand(0));
SDLoc dl(N);

// Promote to the larger FP type.
Op = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op);
Op = DAG.getNode(GetPromotionOpcode(OVT, NVT), dl, NVT, Op);

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op);

// Convert back to FP16 as an integer.
return DAG.getNode(ISD::FP_TO_FP16, dl, MVT::i16, Res);
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BinOp(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
EVT OVT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
SDValue Op0 = GetSoftPromotedHalf(N->getOperand(0));
SDValue Op1 = GetSoftPromotedHalf(N->getOperand(1));
SDLoc dl(N);

// Promote to the larger FP type.
Op0 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op0);
Op1 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op1);
auto PromotionOpcode = GetPromotionOpcode(OVT, NVT);
Op0 = DAG.getNode(PromotionOpcode, dl, NVT, Op0);
Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1);

// Convert back to FP16 as an integer.
return DAG.getNode(ISD::FP_TO_FP16, dl, MVT::i16, Res);
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfRes_VECREDUCE(SDNode *N) {
Expand Down Expand Up @@ -2947,22 +2960,27 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FCOPYSIGN(SDNode *N,
unsigned OpNo) {
assert(OpNo == 1 && "Only Operand 1 must need promotion here");
SDValue Op1 = N->getOperand(1);
EVT RVT = Op1.getValueType();
SDLoc dl(N);

EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op1.getValueType());

Op1 = GetSoftPromotedHalf(Op1);
Op1 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op1);
Op1 = DAG.getNode(GetPromotionOpcode(RVT, NVT), dl, NVT, Op1);

return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), N->getOperand(0),
Op1);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
EVT RVT = N->getValueType(0);
bool IsStrict = N->isStrictFPOpcode();
SDValue Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));
SDValue Op = N->getOperand(IsStrict ? 1 : 0);
EVT SVT = Op.getValueType();
Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));

if (IsStrict) {
assert(SVT == MVT::f16);
SDValue Res =
DAG.getNode(ISD::STRICT_FP16_TO_FP, SDLoc(N),
{N->getValueType(0), MVT::Other}, {N->getOperand(0), Op});
Expand All @@ -2971,31 +2989,35 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
return SDValue();
}

return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0), Op);
return DAG.getNode(GetPromotionOpcode(SVT, RVT), SDLoc(N), RVT, Op);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT(SDNode *N) {
EVT RVT = N->getValueType(0);
SDValue Op = N->getOperand(0);
EVT SVT = Op.getValueType();
SDLoc dl(N);

EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType());

Op = GetSoftPromotedHalf(Op);

SDValue Res = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op);
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);

return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Res);
}

SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT_SAT(SDNode *N) {
EVT RVT = N->getValueType(0);
SDValue Op = N->getOperand(0);
EVT SVT = Op.getValueType();
SDLoc dl(N);

EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType());

Op = GetSoftPromotedHalf(Op);

SDValue Res = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op);
SDValue Res = DAG.getNode(GetPromotionOpcode(SVT, RVT), dl, NVT, Op);

return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Res,
N->getOperand(1));
Expand All @@ -3008,14 +3030,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_SELECT_CC(SDNode *N,
SDValue Op1 = N->getOperand(1);
SDLoc dl(N);

EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op0.getValueType());
EVT SVT = Op0.getValueType();
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), SVT);

Op0 = GetSoftPromotedHalf(Op0);
Op1 = GetSoftPromotedHalf(Op1);

// Promote to the larger FP type.
Op0 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op0);
Op1 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op1);
auto PromotionOpcode = GetPromotionOpcode(SVT, NVT);
Op0 = DAG.getNode(PromotionOpcode, dl, NVT, Op0);
Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);

return DAG.getNode(ISD::SELECT_CC, SDLoc(N), N->getValueType(0), Op0, Op1,
N->getOperand(2), N->getOperand(3), N->getOperand(4));
Expand All @@ -3027,14 +3051,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_SETCC(SDNode *N) {
ISD::CondCode CCCode = cast<CondCodeSDNode>(N->getOperand(2))->get();
SDLoc dl(N);

EVT SVT = Op0.getValueType();
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), Op0.getValueType());

Op0 = GetSoftPromotedHalf(Op0);
Op1 = GetSoftPromotedHalf(Op1);

// Promote to the larger FP type.
Op0 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op0);
Op1 = DAG.getNode(ISD::FP16_TO_FP, dl, NVT, Op1);
auto PromotionOpcode = GetPromotionOpcode(SVT, NVT);
Op0 = DAG.getNode(PromotionOpcode, dl, NVT, Op0);
Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);

return DAG.getSetCC(SDLoc(N), N->getValueType(0), Op0, Op1, CCCode);
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ void TargetLoweringBase::computeRegisterProperties(
NumRegistersForVT[MVT::bf16] = NumRegistersForVT[MVT::f32];
RegisterTypeForVT[MVT::bf16] = RegisterTypeForVT[MVT::f32];
TransformToType[MVT::bf16] = MVT::f32;
ValueTypeActions.setTypeAction(MVT::bf16, TypePromoteFloat);
ValueTypeActions.setTypeAction(MVT::bf16, TypeSoftPromoteHalf);
}

// Loop over all of the vector value types to see which need transformations.
Expand Down
Loading

0 comments on commit c349d7f

Please sign in to comment.