Skip to content

Commit

Permalink
Recommit "[RISCV] Implement support for bf16 truncate/extend on hard …
Browse files Browse the repository at this point in the history
…FP targets"

Without the changes from D153598.

Original commit message:

For the same reasons as D151284, this requires custom lowering of the
truncate libcall on hard float ABIs (the normal libcall code path is
used on soft ABIs).

The extend operation is implemented by a shift just as in the standard
legalisation, but needs to be custom lowered because i32 isn't a legal
type on RV64.

This patch aims to make the minimal changes that result in correct
codegen for the bfloat.ll tests.

Differential Revision: https://reviews.llvm.org/D151663
  • Loading branch information
asb authored and topperc committed Jun 24, 2023
1 parent f2d16b3 commit 9291249
Show file tree
Hide file tree
Showing 2 changed files with 507 additions and 8 deletions.
49 changes: 43 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom);
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom);
setOperationAction(ISD::FP_TO_BF16, MVT::f32,
Subtarget.isSoftFPABI() ? LibCall : Custom);

if (Subtarget.hasStdExtZfa())
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
Expand Down Expand Up @@ -461,6 +464,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom);
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom);
setOperationAction(ISD::FP_TO_BF16, MVT::f64,
Subtarget.isSoftFPABI() ? LibCall : Custom);
}

if (Subtarget.is64Bit()) {
Expand Down Expand Up @@ -4923,6 +4929,35 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return lowerFP_TO_INT_SAT(Op, DAG, Subtarget);
case ISD::FP_TO_BF16: {
// Custom lower to ensure the libcall return is passed in an FPR on hard
// float ABIs.
assert(!Subtarget.isSoftFPABI() && "Unexpected custom legalization");
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
RTLIB::Libcall LC =
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
SDValue Res =
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
if (Subtarget.is64Bit())
return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
return DAG.getBitcast(MVT::i32, Res);
}
case ISD::BF16_TO_FP: {
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalization");
MVT VT = Op.getSimpleValueType();
SDLoc DL(Op);
Op = DAG.getNode(
ISD::SHL, DL, Op.getOperand(0).getValueType(), Op.getOperand(0),
DAG.getShiftAmountConstant(16, Op.getOperand(0).getValueType(), DL));
SDValue Res = Subtarget.is64Bit()
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Op)
: DAG.getBitcast(MVT::f32, Op);
// fp_extend if the target VT is bigger than f32.
if (VT != MVT::f32)
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
return Res;
}
case ISD::FTRUNC:
case ISD::FCEIL:
case ISD::FFLOOR:
Expand Down Expand Up @@ -16553,9 +16588,10 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
// Cast the f16 to i16, extend to i32, pad with ones to make a float nan,
// and cast to f32.
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
PartVT == MVT::f32) {
// Cast the [b]f16 to i16, extend to i32, pad with ones to make a float
// nan, and cast to f32.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
Expand Down Expand Up @@ -16606,13 +16642,14 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
PartVT == MVT::f32) {
SDValue Val = Parts[0];

// Cast the f32 to i32, truncate to i16, and cast back to f16.
// Cast the f32 to i32, truncate to i16, and cast back to [b]f16.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f16, Val);
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
return Val;
}

Expand Down
Loading

0 comments on commit 9291249

Please sign in to comment.