diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 294927aecb94b..60bb3ad953111 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -107,15 +107,15 @@ def HasStdExtZfhmin : Predicate<"Subtarget->hasStdExtZfhmin()">, def FeatureStdExtZfh : SubtargetFeature<"zfh", "HasStdExtZfh", "true", "'Zfh' (Half-Precision Floating-Point)", - [FeatureStdExtF]>; + [FeatureStdExtZfhmin]>; def HasStdExtZfh : Predicate<"Subtarget->hasStdExtZfh()">, AssemblerPredicate<(all_of FeatureStdExtZfh), "'Zfh' (Half-Precision Floating-Point)">; def NoStdExtZfh : Predicate<"!Subtarget->hasStdExtZfh()">; def HasStdExtZfhOrZfhmin - : Predicate<"Subtarget->hasStdExtZfhOrZfhmin()">, - AssemblerPredicate<(any_of FeatureStdExtZfh, FeatureStdExtZfhmin), + : Predicate<"Subtarget->hasStdExtZfhmin()">, + AssemblerPredicate<(all_of FeatureStdExtZfhmin), "'Zfh' (Half-Precision Floating-Point) or " "'Zfhmin' (Half-Precision Floating-Point Minimal)">; @@ -146,15 +146,15 @@ def HasStdExtZhinxmin : Predicate<"Subtarget->hasStdExtZhinxmin()">, def FeatureStdExtZhinx : SubtargetFeature<"zhinx", "HasStdExtZhinx", "true", "'Zhinx' (Half Float in Integer)", - [FeatureStdExtZfinx]>; + [FeatureStdExtZhinxmin]>; def HasStdExtZhinx : Predicate<"Subtarget->hasStdExtZhinx()">, AssemblerPredicate<(all_of FeatureStdExtZhinx), "'Zhinx' (Half Float in Integer)">; def NoStdExtZhinx : Predicate<"!Subtarget->hasStdExtZhinx()">; def HasStdExtZhinxOrZhinxmin - : Predicate<"Subtarget->hasStdExtZhinx() || Subtarget->hasStdExtZhinxmin()">, - AssemblerPredicate<(any_of FeatureStdExtZhinx, FeatureStdExtZhinxmin), + : Predicate<"Subtarget->hasStdExtZhinxmin()">, + AssemblerPredicate<(all_of FeatureStdExtZhinxmin), "'Zhinx' (Half Float in Integer) or " "'Zhinxmin' (Half Float in Integer Minimal)">; @@ -487,16 +487,16 @@ def HasStdExtZvfbfwma : Predicate<"Subtarget->hasStdExtZvfbfwma()">, def HasVInstructionsBF16 : Predicate<"Subtarget->hasVInstructionsBF16()">; -def FeatureStdExtZvfh - : SubtargetFeature<"zvfh", "HasStdExtZvfh", "true", - "'Zvfh' (Vector Half-Precision Floating-Point)", - [FeatureStdExtZve32f, FeatureStdExtZfhmin]>; - def FeatureStdExtZvfhmin : SubtargetFeature<"zvfhmin", "HasStdExtZvfhmin", "true", "'Zvfhmin' (Vector Half-Precision Floating-Point Minimal)", [FeatureStdExtZve32f]>; +def FeatureStdExtZvfh + : SubtargetFeature<"zvfh", "HasStdExtZvfh", "true", + "'Zvfh' (Vector Half-Precision Floating-Point)", + [FeatureStdExtZvfhmin, FeatureStdExtZfhmin]>; + def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">; def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minimal()">, diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 09b3ab96974c4..098a320c91533 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -915,8 +915,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { Opc = RISCV::FMV_H_X; break; case MVT::f16: - Opc = - Subtarget->hasStdExtZhinxOrZhinxmin() ? RISCV::COPY : RISCV::FMV_H_X; + Opc = Subtarget->hasStdExtZhinxmin() ? RISCV::COPY : RISCV::FMV_H_X; break; case MVT::f32: Opc = Subtarget->hasStdExtZfinx() ? RISCV::COPY : RISCV::FMV_W_X; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 03e994586d0c4..22c61eb20885b 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -122,7 +122,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.is64Bit() && RV64LegalI32) addRegisterClass(MVT::i32, &RISCV::GPRRegClass); - if (Subtarget.hasStdExtZfhOrZfhmin()) + if (Subtarget.hasStdExtZfhmin()) addRegisterClass(MVT::f16, &RISCV::FPR16RegClass); if (Subtarget.hasStdExtZfbfmin()) addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass); @@ -130,7 +130,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, addRegisterClass(MVT::f32, &RISCV::FPR32RegClass); if (Subtarget.hasStdExtD()) addRegisterClass(MVT::f64, &RISCV::FPR64RegClass); - if (Subtarget.hasStdExtZhinxOrZhinxmin()) + if (Subtarget.hasStdExtZhinxmin()) addRegisterClass(MVT::f16, &RISCV::GPRF16RegClass); if (Subtarget.hasStdExtZfinx()) addRegisterClass(MVT::f32, &RISCV::GPRF32RegClass); @@ -439,7 +439,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::FCEIL, ISD::FFLOOR, ISD::FTRUNC, ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN}; - if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) + if (Subtarget.hasStdExtZfhminOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::i16, Custom); static const unsigned ZfhminZfbfminPromoteOps[] = { @@ -469,7 +469,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand); } - if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { + if (Subtarget.hasStdExtZfhminOrZhinxmin()) { if (Subtarget.hasStdExtZfhOrZhinx()) { setOperationAction(FPLegalNodeTypes, MVT::f16, Legal); setOperationAction(FPRndMode, MVT::f16, @@ -1322,7 +1322,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Custom-legalize bitcasts from fixed-length vectors to scalar types. setOperationAction(ISD::BITCAST, {MVT::i8, MVT::i16, MVT::i32, MVT::i64}, Custom); - if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) + if (Subtarget.hasStdExtZfhminOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::f16, Custom); if (Subtarget.hasStdExtFOrZfinx()) setOperationAction(ISD::BITCAST, MVT::f32, Custom); @@ -1388,7 +1388,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtZbkb()) setTargetDAGCombine(ISD::BITREVERSE); - if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) + if (Subtarget.hasStdExtZfhminOrZhinxmin()) setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); if (Subtarget.hasStdExtFOrZfinx()) setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT, @@ -2099,7 +2099,7 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, bool ForCodeSize) const { bool IsLegalVT = false; if (VT == MVT::f16) - IsLegalVT = Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin(); + IsLegalVT = Subtarget.hasStdExtZfhminOrZhinxmin(); else if (VT == MVT::f32) IsLegalVT = Subtarget.hasStdExtFOrZfinx(); else if (VT == MVT::f64) @@ -2171,7 +2171,7 @@ MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled. // We might still end up using a GPR but that will be decided based on ABI. if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() && - !Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) + !Subtarget.hasStdExtZfhminOrZhinxmin()) return MVT::f32; MVT PartVT = TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT); @@ -2188,7 +2188,7 @@ unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled. // We might still end up using a GPR but that will be decided based on ABI. if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() && - !Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) + !Subtarget.hasStdExtZfhminOrZhinxmin()) return 1; return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT); @@ -5761,7 +5761,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, EVT Op0VT = Op0.getValueType(); MVT XLenVT = Subtarget.getXLenVT(); if (VT == MVT::f16 && Op0VT == MVT::i16 && - Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { + Subtarget.hasStdExtZfhminOrZhinxmin()) { SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); return FPConv; @@ -11527,11 +11527,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, EVT Op0VT = Op0.getValueType(); MVT XLenVT = Subtarget.getXLenVT(); if (VT == MVT::i16 && Op0VT == MVT::f16 && - Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { + Subtarget.hasStdExtZfhminOrZhinxmin()) { SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); } else if (VT == MVT::i16 && Op0VT == MVT::bf16 && - Subtarget.hasStdExtZfbfmin()) { + Subtarget.hasStdExtZfbfmin()) { SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() && @@ -18632,7 +18632,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, // TODO: Support fixed vectors up to XLen for P extension? if (VT.isVector()) break; - if (VT == MVT::f16 && Subtarget.hasStdExtZhinxOrZhinxmin()) + if (VT == MVT::f16 && Subtarget.hasStdExtZhinxmin()) return std::make_pair(0U, &RISCV::GPRF16RegClass); if (VT == MVT::f32 && Subtarget.hasStdExtZfinx()) return std::make_pair(0U, &RISCV::GPRF32RegClass); @@ -18640,7 +18640,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return std::make_pair(0U, &RISCV::GPRPF64RegClass); return std::make_pair(0U, &RISCV::GPRNoX0RegClass); case 'f': - if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16) + if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) return std::make_pair(0U, &RISCV::FPR16RegClass); if (Subtarget.hasStdExtF() && VT == MVT::f32) return std::make_pair(0U, &RISCV::FPR32RegClass); @@ -18753,7 +18753,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, } if (VT == MVT::f32 || VT == MVT::Other) return std::make_pair(FReg, &RISCV::FPR32RegClass); - if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16) { + if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) { unsigned RegNo = FReg - RISCV::F0_F; unsigned HReg = RISCV::F0_H + RegNo; return std::make_pair(HReg, &RISCV::FPR16RegClass); @@ -19100,7 +19100,7 @@ bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT, switch (FPVT.getSimpleVT().SimpleTy) { case MVT::f16: - return Subtarget.hasStdExtZfhOrZfhmin(); + return Subtarget.hasStdExtZfhmin(); case MVT::f32: return Subtarget.hasStdExtF(); case MVT::f64: diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index 23d56cfa6e4e5..7540218633bfc 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -143,16 +143,12 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo { bool hasStdExtZvl() const { return ZvlLen != 0; } bool hasStdExtFOrZfinx() const { return HasStdExtF || HasStdExtZfinx; } bool hasStdExtDOrZdinx() const { return HasStdExtD || HasStdExtZdinx; } - bool hasStdExtZfhOrZfhmin() const { return HasStdExtZfh || HasStdExtZfhmin; } bool hasStdExtZfhOrZhinx() const { return HasStdExtZfh || HasStdExtZhinx; } - bool hasStdExtZhinxOrZhinxmin() const { - return HasStdExtZhinx || HasStdExtZhinxmin; - } - bool hasStdExtZfhOrZfhminOrZhinxOrZhinxmin() const { - return hasStdExtZfhOrZfhmin() || hasStdExtZhinxOrZhinxmin(); + bool hasStdExtZfhminOrZhinxmin() const { + return HasStdExtZfhmin || HasStdExtZhinxmin; } bool hasHalfFPLoadStoreMove() const { - return hasStdExtZfhOrZfhmin() || HasStdExtZfbfmin; + return HasStdExtZfhmin || HasStdExtZfbfmin; } bool is64Bit() const { return IsRV64; } MVT getXLenVT() const { @@ -201,9 +197,7 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo { // Vector codegen related methods. bool hasVInstructions() const { return HasStdExtZve32x; } bool hasVInstructionsI64() const { return HasStdExtZve64x; } - bool hasVInstructionsF16Minimal() const { - return HasStdExtZvfhmin || HasStdExtZvfh; - } + bool hasVInstructionsF16Minimal() const { return HasStdExtZvfhmin; } bool hasVInstructionsF16() const { return HasStdExtZvfh; } bool hasVInstructionsBF16() const { return HasStdExtZvfbfmin; } bool hasVInstructionsF32() const { return HasStdExtZve32f; } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index efc8350064a6e..96ecc771863e5 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -334,7 +334,7 @@ class RISCVTTIImpl : public BasicTTIImplBase { return RISCVRegisterClass::GPRRC; Type *ScalarTy = Ty->getScalarType(); - if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhOrZfhmin()) || + if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) || (ScalarTy->isFloatTy() && ST->hasStdExtF()) || (ScalarTy->isDoubleTy() && ST->hasStdExtD())) { return RISCVRegisterClass::FPRRC;