diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 67a8ac5b6ee76..f642124e072cc 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -116,6 +116,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtZfhOrZfhmin()) addRegisterClass(MVT::f16, &RISCV::FPR16RegClass); + if (Subtarget.hasStdExtZfbfmin()) + addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass); if (Subtarget.hasStdExtF()) addRegisterClass(MVT::f32, &RISCV::FPR32RegClass); if (Subtarget.hasStdExtD()) @@ -359,6 +361,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::i16, Custom); + + if (Subtarget.hasStdExtZfbfmin()) { + setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setOperationAction(ISD::BITCAST, MVT::bf16, Custom); + setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); + } if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { if (Subtarget.hasStdExtZfhOrZhinx()) { @@ -4768,6 +4779,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); return FPConv; } + if (VT == MVT::bf16 && Op0VT == MVT::i16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); + SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0); + return FPConv; + } if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0); @@ -4931,11 +4948,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, } return SDValue(); } - case ISD::FP_EXTEND: - case ISD::FP_ROUND: + case ISD::FP_EXTEND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) { + SDValue FloatVal = + DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal); + } + + if (!Op.getValueType().isVector()) + return Op; + return lowerVectorFPExtendOrRoundLike(Op, DAG); + } + case ISD::FP_ROUND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0); + if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() && + Subtarget.hasStdExtDOrZdinx()) { + SDValue FloatVal = + DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal); + } + if (!Op.getValueType().isVector()) return Op; return lowerVectorFPExtendOrRoundLike(Op, DAG); + } case ISD::STRICT_FP_ROUND: case ISD::STRICT_FP_EXTEND: return lowerStrictFPExtendOrRoundLike(Op, DAG); @@ -9926,6 +9974,10 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); + } else if (VT == MVT::i16 && Op0VT == MVT::bf16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue FPConv = @@ -14867,7 +14919,8 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, // similar local variables rather than directly checking against the target // ABI. - if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) { + if (UseGPRForF16_F32 && + (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) { LocVT = XLenVT; LocInfo = CCValAssign::BCvt; } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) { @@ -14960,7 +15013,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, unsigned StoreSizeBytes = XLen / 8; Align StackAlign = Align(XLen / 8); - if (ValVT == MVT::f16 && !UseGPRForF16_F32) + if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR16s); else if (ValVT == MVT::f32 && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR32s); @@ -15117,8 +15170,9 @@ static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val, Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) - Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val); + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val); else @@ -15176,7 +15230,8 @@ static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val, Val = convertToScalableVector(LocVT, Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val); @@ -16196,6 +16251,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(FCVT_WU_RV64) NODE_NAME_CASE(STRICT_FCVT_W_RV64) NODE_NAME_CASE(STRICT_FCVT_WU_RV64) + NODE_NAME_CASE(FP_ROUND_BF16) + NODE_NAME_CASE(FP_EXTEND_BF16) NODE_NAME_CASE(FROUND) NODE_NAME_CASE(FPCLASS) NODE_NAME_CASE(READ_CYCLE_WIDE) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index a6c7100ddf42b..ec90e3c0cdcdd 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -111,6 +111,9 @@ enum NodeType : unsigned { FCVT_W_RV64, FCVT_WU_RV64, + FP_ROUND_BF16, + FP_EXTEND_BF16, + // Rounds an FP value to its corresponding integer in the same FP format. // First operand is the value to round, the second operand is the largest // integer that can be represented exactly in the FP format. This will be diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td index 1f423591d3dde..35f9f03f61a13 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td @@ -13,6 +13,20 @@ // //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// RISC-V specific DAG Nodes. +//===----------------------------------------------------------------------===// + +def SDT_RISCVFP_ROUND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>; +def SDT_RISCVFP_EXTEND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>; + +def riscv_fpround_bf16 + : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>; +def riscv_fpextend_bf16 + : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>; + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -23,3 +37,27 @@ def FCVT_BF16_S : FPUnaryOp_r_frm<0b0100010, 0b01000, FPR16, FPR32, "fcvt.bf16.s def FCVT_S_BF16 : FPUnaryOp_r_frm<0b0100000, 0b00110, FPR32, FPR16, "fcvt.s.bf16">, Sched<[WriteFCvtF32ToF16, ReadFCvtF32ToF16]>; } // Predicates = [HasStdExtZfbfmin] + +//===----------------------------------------------------------------------===// +// Pseudo-instructions and codegen patterns +//===----------------------------------------------------------------------===// + +let Predicates = [HasStdExtZfbfmin] in { +/// Loads +def : LdPat; + +/// Stores +def : StPat; + +/// Float conversion operations +// f32 -> bf16, bf16 -> f32 +def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)), + (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>; +def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)), + (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>; + +// Moves (no conversion) +def : Pat<(bf16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>; +def : Pat<(riscv_fmv_x_anyexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>; +def : Pat<(riscv_fmv_x_signexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>; +} // Predicates = [HasStdExtZfbfmin] diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td index 3ea338d9ed20d..5dc02e5fa9f9e 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td @@ -16,9 +16,9 @@ //===----------------------------------------------------------------------===// def SDT_RISCVFMV_H_X - : SDTypeProfile<1, 1, [SDTCisVT<0, f16>, SDTCisVT<1, XLenVT>]>; + : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, XLenVT>]>; def SDT_RISCVFMV_X_EXTH - : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, f16>]>; + : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisFP<1>]>; def riscv_fmv_h_x : SDNode<"RISCVISD::FMV_H_X", SDT_RISCVFMV_H_X>; @@ -438,7 +438,7 @@ def : Pat<(f16 (any_fpround FPR32:$rs1)), (FCVT_H_S FPR32:$rs1, FRM_DYN)>; def : Pat<(any_fpextend (f16 FPR16:$rs1)), (FCVT_S_H FPR16:$rs1)>; // Moves (no conversion) -def : Pat<(riscv_fmv_h_x GPR:$src), (FMV_H_X GPR:$src)>; +def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>; def : Pat<(riscv_fmv_x_anyexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>; def : Pat<(riscv_fmv_x_signexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>; @@ -453,7 +453,7 @@ def : Pat<(any_fpround FPR32INX:$rs1), (FCVT_H_S_INX FPR32INX:$rs1, FRM_DYN)>; def : Pat<(any_fpextend FPR16INX:$rs1), (FCVT_S_H_INX FPR16INX:$rs1)>; // Moves (no conversion) -def : Pat<(riscv_fmv_h_x GPR:$src), (COPY_TO_REGCLASS GPR:$src, GPR)>; +def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (COPY_TO_REGCLASS GPR:$src, GPR)>; def : Pat<(riscv_fmv_x_anyexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>; def : Pat<(riscv_fmv_x_signexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>; diff --git a/llvm/test/CodeGen/RISCV/zfbfmin.ll b/llvm/test/CodeGen/RISCV/zfbfmin.ll new file mode 100644 index 0000000000000..b32e6dc0b14b5 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/zfbfmin.ll @@ -0,0 +1,92 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi ilp32d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s +; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi lp64d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s + +define bfloat @bitcast_bf16_i16(i16 %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_bf16_i16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.h.x fa0, a0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast i16 %a to bfloat + ret bfloat %1 +} + +define i16 @bitcast_i16_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_i16_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.x.h a0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast bfloat %a to i16 + ret i16 %1 +} + +define bfloat @fcvt_bf16_s(float %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_s: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc float %a to bfloat + ret bfloat %1 +} + +define float @fcvt_s_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_s_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to float + ret float %1 +} + +define bfloat @fcvt_bf16_d(double %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_d: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.d fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc double %a to bfloat + ret bfloat %1 +} + +define double @fcvt_d_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_d_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.d.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to double + ret double %1 +} + +define bfloat @bfloat_load(ptr %a) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_load: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: flh fa0, 6(a0) +; CHECKIZFBFMIN-NEXT: ret + %1 = getelementptr bfloat, ptr %a, i32 3 + %2 = load bfloat, ptr %1 + ret bfloat %2 +} + +define bfloat @bfloat_imm() nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_imm: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: lui a0, %hi(.LCPI7_0) +; CHECKIZFBFMIN-NEXT: flh fa0, %lo(.LCPI7_0)(a0) +; CHECKIZFBFMIN-NEXT: ret + ret bfloat 3.0 +} + +define dso_local void @bfloat_store(ptr %a, bfloat %b) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_store: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fsh fa0, 0(a0) +; CHECKIZFBFMIN-NEXT: fsh fa0, 16(a0) +; CHECKIZFBFMIN-NEXT: ret + store bfloat %b, ptr %a + %1 = getelementptr bfloat, ptr %a, i32 8 + store bfloat %b, ptr %1 + ret void +}