Skip to content

Commit

Permalink
Revert "[RISCV] Implement support for bf16 truncate/extend on hard FP…
Browse files Browse the repository at this point in the history
… targets"

This was committed with D153598 merged into it. Reverting to recommit as separate patches.

This reverts commit 690b1c8.
  • Loading branch information
topperc committed Jun 24, 2023
1 parent bef4007 commit 076759f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 549 deletions.
69 changes: 27 additions & 42 deletions llvm/lib/Support/RISCVISAInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,11 @@ static size_t findLastNonVersionCharacter(StringRef Ext) {
}

namespace {
struct LessExtName {
bool operator()(const RISCVSupportedExtension &LHS, StringRef RHS) {
return StringRef(LHS.Name) < RHS;
}
bool operator()(StringRef LHS, const RISCVSupportedExtension &RHS) {
return LHS < StringRef(RHS.Name);
struct FindByName {
FindByName(StringRef Ext) : Ext(Ext){};
StringRef Ext;
bool operator()(const RISCVSupportedExtension &ExtInfo) {
return ExtInfo.Name == Ext;
}
};
} // namespace
Expand All @@ -240,12 +239,12 @@ findDefaultVersion(StringRef ExtName) {
// TODO: We might set default version based on profile or ISA spec.
for (auto &ExtInfo : {ArrayRef(SupportedExtensions),
ArrayRef(SupportedExperimentalExtensions)}) {
auto I = llvm::lower_bound(ExtInfo, ExtName, LessExtName());
auto ExtensionInfoIterator = llvm::find_if(ExtInfo, FindByName(ExtName));

if (I == ExtInfo.end() || I->Name != ExtName)
if (ExtensionInfoIterator == ExtInfo.end()) {
continue;

return I->Version;
}
return ExtensionInfoIterator->Version;
}
return std::nullopt;
}
Expand Down Expand Up @@ -280,50 +279,37 @@ static StringRef getExtensionType(StringRef Ext) {

static std::optional<RISCVExtensionVersion>
isExperimentalExtension(StringRef Ext) {
auto I =
llvm::lower_bound(SupportedExperimentalExtensions, Ext, LessExtName());
if (I == std::end(SupportedExperimentalExtensions) || I->Name != Ext)
auto ExtIterator =
llvm::find_if(SupportedExperimentalExtensions, FindByName(Ext));
if (ExtIterator == std::end(SupportedExperimentalExtensions))
return std::nullopt;

return I->Version;
return ExtIterator->Version;
}

bool RISCVISAInfo::isSupportedExtensionFeature(StringRef Ext) {
bool IsExperimental = stripExperimentalPrefix(Ext);

ArrayRef<RISCVSupportedExtension> ExtInfo =
IsExperimental ? ArrayRef(SupportedExperimentalExtensions)
: ArrayRef(SupportedExtensions);

auto I = llvm::lower_bound(ExtInfo, Ext, LessExtName());
return I != ExtInfo.end() && I->Name == Ext;
if (IsExperimental)
return llvm::any_of(SupportedExperimentalExtensions, FindByName(Ext));
else
return llvm::any_of(SupportedExtensions, FindByName(Ext));
}

bool RISCVISAInfo::isSupportedExtension(StringRef Ext) {
verifyTables();

for (auto ExtInfo : {ArrayRef(SupportedExtensions),
ArrayRef(SupportedExperimentalExtensions)}) {
auto I = llvm::lower_bound(ExtInfo, Ext, LessExtName());
if (I != ExtInfo.end() && I->Name == Ext)
return true;
}

return false;
return llvm::any_of(SupportedExtensions, FindByName(Ext)) ||
llvm::any_of(SupportedExperimentalExtensions, FindByName(Ext));
}

bool RISCVISAInfo::isSupportedExtension(StringRef Ext, unsigned MajorVersion,
unsigned MinorVersion) {
for (auto ExtInfo : {ArrayRef(SupportedExtensions),
ArrayRef(SupportedExperimentalExtensions)}) {
auto Range =
std::equal_range(ExtInfo.begin(), ExtInfo.end(), Ext, LessExtName());
for (auto I = Range.first, E = Range.second; I != E; ++I)
if (I->Version.Major == MajorVersion && I->Version.Minor == MinorVersion)
return true;
}

return false;
auto FindByNameAndVersion = [=](const RISCVSupportedExtension &ExtInfo) {
return ExtInfo.Name == Ext && (MajorVersion == ExtInfo.Version.Major) &&
(MinorVersion == ExtInfo.Version.Minor);
};
return llvm::any_of(SupportedExtensions, FindByNameAndVersion) ||
llvm::any_of(SupportedExperimentalExtensions, FindByNameAndVersion);
}

bool RISCVISAInfo::hasExtension(StringRef Ext) const {
Expand Down Expand Up @@ -563,12 +549,11 @@ RISCVISAInfo::parseFeatures(unsigned XLen,
? ArrayRef(SupportedExperimentalExtensions)
: ArrayRef(SupportedExtensions);
auto ExtensionInfoIterator =
llvm::lower_bound(ExtensionInfos, ExtName, LessExtName());
llvm::find_if(ExtensionInfos, FindByName(ExtName));

// Not all features is related to ISA extension, like `relax` or
// `save-restore`, skip those feature.
if (ExtensionInfoIterator == ExtensionInfos.end() ||
ExtensionInfoIterator->Name != ExtName)
if (ExtensionInfoIterator == ExtensionInfos.end())
continue;

if (Add)
Expand Down
49 changes: 6 additions & 43 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom);
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom);
setOperationAction(ISD::FP_TO_BF16, MVT::f32,
Subtarget.isSoftFPABI() ? LibCall : Custom);

if (Subtarget.hasStdExtZfa())
setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
Expand Down Expand Up @@ -464,9 +461,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom);
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom);
setOperationAction(ISD::FP_TO_BF16, MVT::f64,
Subtarget.isSoftFPABI() ? LibCall : Custom);
}

if (Subtarget.is64Bit()) {
Expand Down Expand Up @@ -4929,35 +4923,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return lowerFP_TO_INT_SAT(Op, DAG, Subtarget);
case ISD::FP_TO_BF16: {
// Custom lower to ensure the libcall return is passed in an FPR on hard
// float ABIs.
assert(!Subtarget.isSoftFPABI() && "Unexpected custom legalization");
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
RTLIB::Libcall LC =
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
SDValue Res =
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
if (Subtarget.is64Bit())
return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res);
return DAG.getBitcast(MVT::i32, Res);
}
case ISD::BF16_TO_FP: {
assert(Subtarget.hasStdExtFOrZfinx() && "Unexpected custom legalization");
MVT VT = Op.getSimpleValueType();
SDLoc DL(Op);
Op = DAG.getNode(
ISD::SHL, DL, Op.getOperand(0).getValueType(), Op.getOperand(0),
DAG.getShiftAmountConstant(16, Op.getOperand(0).getValueType(), DL));
SDValue Res = Subtarget.is64Bit()
? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Op)
: DAG.getBitcast(MVT::f32, Op);
// fp_extend if the target VT is bigger than f32.
if (VT != MVT::f32)
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
return Res;
}
case ISD::FTRUNC:
case ISD::FCEIL:
case ISD::FFLOOR:
Expand Down Expand Up @@ -16588,10 +16553,9 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
PartVT == MVT::f32) {
// Cast the [b]f16 to i16, extend to i32, pad with ones to make a float
// nan, and cast to f32.
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
// Cast the f16 to i16, extend to i32, pad with ones to make a float nan,
// and cast to f32.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
Expand Down Expand Up @@ -16642,14 +16606,13 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
PartVT == MVT::f32) {
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
SDValue Val = Parts[0];

// Cast the f32 to i32, truncate to i16, and cast back to [b]f16.
// Cast the f32 to i32, truncate to i16, and cast back to f16.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f16, Val);
return Val;
}

Expand Down

0 comments on commit 076759f

Please sign in to comment.